Diffusion-TTA: Test-time Adaptation of Discriminative Models via Generative Feedback

Carnegie Mellon University
NeurIPS 2023

Abstract

The advancements in generative modeling, particularly the advent of diffusion models, have sparked a fundamental question: how can these models be effectively used for discriminative tasks? In this work, we find that generative models can be great test-time adapters for discriminative models. Our method, Diffusion-TTA, adapts pre-trained discriminative models such as image classifiers, segmenters and depth predictors, to each unlabelled example in the test set using generative feedback from a diffusion model.

We achieve this by modulating the conditioning of the diffusion model using the output of the discriminative model. We then maximize the image likelihood objective by backpropagating the gradients to discriminative model’s parameters. We show Diffusion-TTA significantly enhances the accuracy of various large-scale pre-trained discriminative models, such as, ImageNet classifiers, CLIP models, image pixel labellers and image depth predictors. Diffusion-TTA outperforms existing test-time adaptation methods, including TTT-MAE and TENT, and particularly shines in online adaptation setups, where the discriminative model is continually adapted to each example in the test set.

Diffusion-TTA

Generative diffusion models are great test-time adapters for discriminative models. Our method consists of discriminative and generative modules. Given an image \(\mathbf x \), the discriminative model \(\mathbf f_{\theta} \) predicts task output \(\mathbf y \). The task output \(\mathbf y \) is transformed into condition \(\mathbf c \). Finally, we use the generative diffusion model \(\mathbf \epsilon_{\phi} \) to measure the likelihood of the input image, conditioned on \(\mathbf c \). This consists of using the diffusion model \(\mathbf \epsilon_{\phi} \) to predict the added noise \(\mathbf \epsilon \) from the noisy image \(\mathbf x_{t} \) and condition \(\mathbf c \). We maximize the image likelihood using the diffusion loss by updating the discriminative and generative model weights via backpropagation.

Test-time adaptation of ImageNet-trained classifiers

We use Diffusion-TTA to adapt multiple ImageNet classifiers with varying backbone architectures and sizes: ResNet-18, ResNet-50, ViT-B/32, and ConvNext-Tiny/Large. For fair comparisions, we use Diffusion Transformer (DiT-XL/2) as our class-conditional generative model, which is trained on ImageNet from scratch. All models are trained on ImageNet without using additional data. We consider the following datasets for TTA of ImageNet classifiers: ImageNet and its out-of-distribution counterparts: ImageNet-C (level-5 gaussian noise), ImageNet-A, ImageNet-R, ImageNetV2, and Stylized ImageNet.

Classification Accuracy for Single-sample test-time adaptation of ImageNet-trained classifiers.
We observe consistent and significant performance gain across all types of classifiers and distribution drifts.
ID OOD
Model TTA ImageNet ImageNet-A ImageNet-R ImageNet-C ImageNet-V2 ImageNet-S
Customized ViT-L/16 82.1 14.4 33.0 17.5 72.5 11.9
TTT-MAE 82.0 21.3 39.2 27.5 72.3 20.2
ResNet18 69.5 1.4 34.6 2.6 57.1 7.7
TENT 63.0 0.6 34.7 12.1 52.0 9.8
COTTA 63.0 0.7 34.7 11.7 52.1 9.7
Diffusion-TTA 77.2 (+7.7) 6.1 (+4.7) 39.7 (+5.1) 4.5 (+1.9) 63.8 (+6.7) 12.3 (+4.6)
ViT-B/32 75.7 9.0 45.2 39.5 61.0 15.8
TENT 75.7 9.0 45.3 38.9 61.1 10.4
COTTA 75.8 8.6 45.0 40.0 60.9 1.1
Diffusion-TTA 77.6 (+1.9) 11.2 (+2.2) 46.5 (+1.3) 41.4 (+1.9) 64.4 (+3.4) 21.3 (+5.5)
ConvNext-Tiny 81.9 22.7 47.8 16.4 70.9 20.2
TENT 79.3 10.6 42.7 2.7 69.0 19.9
COTTA 80.5 13.2 47.2 13.7 68.9 19.3
Diffusion-TTA 83.1 (+1.2) 25.8 (+3.1) 49.7 (+1.9) 21.0 (+4.6) 71.5 (+0.6) 22.6 (+2.4)

Classification Accuracy for Online test-time adaptation of ImageNet-trained classifiers on ImageNet-C.
Model TTA Gaussian Noise Fog Pixelate Snow Contrast
Customized ViT-L/16 17.1 38.7 47.1 35.6 6.9
TTT-MAE 37.9 51.1 65.7 56.5 10.0
ResNet50 6.3 25.2 26.5 16.7 3.6
TENT 12.3 43.2 41.8 28.4 12.0
COTTA 12.2 42.4 41.7 28.6 11.9
Diffusion-TTA 19.0 (+12.7) 43.2 (+18.0) 50.2 (+23.7) 33.6 (+16.9) 2.7 (-0.9)
ViT-B/32 39.5 35.9 55.0 30.0 31.5
TENT 38.9 35.8 55.5 30.7 32.1
COTTA 40.0 34.6 54.5 29.7 32.0
Diffusion-TTA 46.5 (+7.0) 56.2 (+20.3) 64.7 (+9.7) 50.4 (+20.4) 33.6 (+2.1)
ConvNext-Tiny 16.4 32.3 37.2 38.3 32.0
TENT 2.7 5.0 43.9 15.2 40.7
COTTA 13.7 29.8 37.3 26.6 32.6
Diffusion-TTA 47.4 (+31.0) 65.9 (+33.6) 69.0 (+31.8) 62.6 (+24.3) 46.2 (+14.2)
ConvNext-Large 33.0 34.4 49.3 44.5 39.8
TENT 30.8 53.5 51.1 44.6 52.4
COTTA 33.3 15.1 34.6 7.7 10.7
Diffusion-TTA 54.9 (+21.9) 67.7 (+33.3) 71.7 (+22.4) 64.8 (+20.3) 55.7 (+15.9)

Test-time adaptation of CLIP classifiers

Under single-sample test-time adapation setup, our method improves CLILP classifiers of different sizes consistently over all following datasets: ImageNet, CIFAR 100, Food 101, Flowers 102, FGVC Aircraft, and Oxford-IIIT Pets datasets. CLIP classifiers have not seen test images from these dataset during training.

Classification Accuracy for Singe-sample test-time adaptation of ImageNet-trained classifiers.
Model TTA Food101 CIFAR100 FGVC Pets Flowers ImageNet
CLIP ViT-B/32 82.6 61.2 17.8 82.2 66.7 57.5
86.2 (+3.6) 62.8 (+1.6) 21.0 (+3.2) 84.9 (+2.7) 67.7 (+1.0) 60.8 (+3.3)
CLIP ViT-B/16 88.1 69.4 22.8 85.5 69.3 62.3
88.8 (+0.7) 69.0 (-0.4) 24.6 (+1.8) 86.1 (+0.6) 71.5 (+2.2) 63.8 (+1.5)
CLIP ViT-L/14 93.1 79.6 32.0 91.9 78.8 70.0
93.1 (0.0) 80.6 (+1.0) 33.4 (+1.4) 92.3 (+0.4) 79.2 (+0.4) 71.2 (+1.2)

Improved Semantic Segmentation and Depth Estimation

We use Diffusion-TTA to adapt semantic segmentors of SegFormer and depth predictors of DenseDepth. We use a latent diffusion model which is pre-trained on ADE20K and NYU Depth v2 dataset for the respective task. The diffusion model concatenates the segmentation/depth map with the noisy image during conditioning. We test adaptation performance on different distribution drifts.

Input

Before TTA

After TTA

Ground-truth

Single-sample test-time adaptation on semantic segmentation and depth estimation
Task: Model TTA Clean Gaussian Noise Fog Frost Snow Contrast Shot
Segmentation: Segformer 66.1 65.3 63.0 58.0 55.2 65.3 63.3
66.1 (0.0) 66.4 (+1.1) 65.1 (+2.1) 58.9 (+0.9) 56.6 (+1.4) 66.4 (+1.1) 63.7 (+0.4)
Depth: DenseDepth 92.4 79.1 81.6 72.2 72.7 77.4 81.3
92.6 (+0.3) 82.1 (+3.0) 84.4 (+2.8) 73.0 (+0.8) 74.1 (+1.4) 77.4 (0.0) 82.1 (+0.8)

BibTeX

@inproceedings{prabhudesai2023difftta,
      title={Test-time Adaptation of Discriminative Models via Diffusion Generative Feedback},
      author={Prabhudesai, Mihir and Ke, Tsung-Wei and Li, Alexander C. and Pathak, Deepak and Fragkiadaki, Katerina},
      year={2023},
      booktitle={Conference on Neural Information Processing Systems},
}