The advancements in generative modeling, particularly the advent of diffusion models, have sparked a fundamental question: how can these models be effectively used for discriminative tasks? In this work, we find that generative models can be great test-time adapters for discriminative models. Our method, Diffusion-TTA, adapts pre-trained discriminative models such as image classifiers, segmenters and depth predictors, to each unlabelled example in the test set using generative feedback from a diffusion model.
We achieve this by modulating the conditioning of the diffusion model using the output of the discriminative model. We then maximize the image likelihood objective by backpropagating the gradients to discriminative model’s parameters. We show Diffusion-TTA significantly enhances the accuracy of various large-scale pre-trained discriminative models, such as, ImageNet classifiers, CLIP models, image pixel labellers and image depth predictors. Diffusion-TTA outperforms existing test-time adaptation methods, including TTT-MAE and TENT, and particularly shines in online adaptation setups, where the discriminative model is continually adapted to each example in the test set.
Generative diffusion models are great test-time adapters for discriminative models. Our method consists of discriminative and generative modules. Given an image \(\mathbf x \), the discriminative model \(\mathbf f_{\theta} \) predicts task output \(\mathbf y \). The task output \(\mathbf y \) is transformed into condition \(\mathbf c \). Finally, we use the generative diffusion model \(\mathbf \epsilon_{\phi} \) to measure the likelihood of the input image, conditioned on \(\mathbf c \). This consists of using the diffusion model \(\mathbf \epsilon_{\phi} \) to predict the added noise \(\mathbf \epsilon \) from the noisy image \(\mathbf x_{t} \) and condition \(\mathbf c \). We maximize the image likelihood using the diffusion loss by updating the discriminative and generative model weights via backpropagation.
We use Diffusion-TTA to adapt multiple ImageNet classifiers with varying backbone architectures and sizes: ResNet-18, ResNet-50, ViT-B/32, and ConvNext-Tiny/Large. For fair comparisions, we use Diffusion Transformer (DiT-XL/2) as our class-conditional generative model, which is trained on ImageNet from scratch. All models are trained on ImageNet without using additional data. We consider the following datasets for TTA of ImageNet classifiers: ImageNet and its out-of-distribution counterparts: ImageNet-C (level-5 gaussian noise), ImageNet-A, ImageNet-R, ImageNetV2, and Stylized ImageNet.
ID | OOD | |||||||||
---|---|---|---|---|---|---|---|---|---|---|
Model | TTA | ImageNet | ImageNet-A | ImageNet-R | ImageNet-C | ImageNet-V2 | ImageNet-S | |||
Customized ViT-L/16 | ✗ | 82.1 | 14.4 | 33.0 | 17.5 | 72.5 | 11.9 | |||
TTT-MAE | 82.0 | 21.3 | 39.2 | 27.5 | 72.3 | 20.2 | ||||
ResNet18 | ✗ | 69.5 | 1.4 | 34.6 | 2.6 | 57.1 | 7.7 | |||
TENT | 63.0 | 0.6 | 34.7 | 12.1 | 52.0 | 9.8 | ||||
COTTA | 63.0 | 0.7 | 34.7 | 11.7 | 52.1 | 9.7 | ||||
Diffusion-TTA | 77.2 (+7.7) | 6.1 (+4.7) | 39.7 (+5.1) | 4.5 (+1.9) | 63.8 (+6.7) | 12.3 (+4.6) | ||||
ViT-B/32 | ✗ | 75.7 | 9.0 | 45.2 | 39.5 | 61.0 | 15.8 | |||
TENT | 75.7 | 9.0 | 45.3 | 38.9 | 61.1 | 10.4 | ||||
COTTA | 75.8 | 8.6 | 45.0 | 40.0 | 60.9 | 1.1 | ||||
Diffusion-TTA | 77.6 (+1.9) | 11.2 (+2.2) | 46.5 (+1.3) | 41.4 (+1.9) | 64.4 (+3.4) | 21.3 (+5.5) | ||||
ConvNext-Tiny | ✗ | 81.9 | 22.7 | 47.8 | 16.4 | 70.9 | 20.2 | |||
TENT | 79.3 | 10.6 | 42.7 | 2.7 | 69.0 | 19.9 | ||||
COTTA | 80.5 | 13.2 | 47.2 | 13.7 | 68.9 | 19.3 | ||||
Diffusion-TTA | 83.1 (+1.2) | 25.8 (+3.1) | 49.7 (+1.9) | 21.0 (+4.6) | 71.5 (+0.6) | 22.6 (+2.4) | ||||
Model | TTA | Gaussian Noise | Fog | Pixelate | Snow | Contrast | ||||
---|---|---|---|---|---|---|---|---|---|---|
Customized ViT-L/16 | ✗ | 17.1 | 38.7 | 47.1 | 35.6 | 6.9 | ||||
TTT-MAE | 37.9 | 51.1 | 65.7 | 56.5 | 10.0 | |||||
ResNet50 | ✗ | 6.3 | 25.2 | 26.5 | 16.7 | 3.6 | ||||
TENT | 12.3 | 43.2 | 41.8 | 28.4 | 12.0 | |||||
COTTA | 12.2 | 42.4 | 41.7 | 28.6 | 11.9 | |||||
Diffusion-TTA | 19.0 (+12.7) | 43.2 (+18.0) | 50.2 (+23.7) | 33.6 (+16.9) | 2.7 (-0.9) | |||||
ViT-B/32 | ✗ | 39.5 | 35.9 | 55.0 | 30.0 | 31.5 | ||||
TENT | 38.9 | 35.8 | 55.5 | 30.7 | 32.1 | |||||
COTTA | 40.0 | 34.6 | 54.5 | 29.7 | 32.0 | |||||
Diffusion-TTA | 46.5 (+7.0) | 56.2 (+20.3) | 64.7 (+9.7) | 50.4 (+20.4) | 33.6 (+2.1) | |||||
ConvNext-Tiny | ✗ | 16.4 | 32.3 | 37.2 | 38.3 | 32.0 | ||||
TENT | 2.7 | 5.0 | 43.9 | 15.2 | 40.7 | |||||
COTTA | 13.7 | 29.8 | 37.3 | 26.6 | 32.6 | |||||
Diffusion-TTA | 47.4 (+31.0) | 65.9 (+33.6) | 69.0 (+31.8) | 62.6 (+24.3) | 46.2 (+14.2) | |||||
ConvNext-Large | ✗ | 33.0 | 34.4 | 49.3 | 44.5 | 39.8 | ||||
TENT | 30.8 | 53.5 | 51.1 | 44.6 | 52.4 | |||||
COTTA | 33.3 | 15.1 | 34.6 | 7.7 | 10.7 | |||||
Diffusion-TTA | 54.9 (+21.9) | 67.7 (+33.3) | 71.7 (+22.4) | 64.8 (+20.3) | 55.7 (+15.9) | |||||
Under single-sample test-time adapation setup, our method improves CLILP classifiers of different sizes consistently over all following datasets: ImageNet, CIFAR 100, Food 101, Flowers 102, FGVC Aircraft, and Oxford-IIIT Pets datasets. CLIP classifiers have not seen test images from these dataset during training.
Model | TTA | Food101 | CIFAR100 | FGVC | Pets | Flowers | ImageNet | |||
---|---|---|---|---|---|---|---|---|---|---|
CLIP ViT-B/32 | ✗ | 82.6 | 61.2 | 17.8 | 82.2 | 66.7 | 57.5 | |||
✓ | 86.2 (+3.6) | 62.8 (+1.6) | 21.0 (+3.2) | 84.9 (+2.7) | 67.7 (+1.0) | 60.8 (+3.3) | ||||
CLIP ViT-B/16 | ✗ | 88.1 | 69.4 | 22.8 | 85.5 | 69.3 | 62.3 | |||
✓ | 88.8 (+0.7) | 69.0 (-0.4) | 24.6 (+1.8) | 86.1 (+0.6) | 71.5 (+2.2) | 63.8 (+1.5) | ||||
CLIP ViT-L/14 | ✗ | 93.1 | 79.6 | 32.0 | 91.9 | 78.8 | 70.0 | |||
✓ | 93.1 (0.0) | 80.6 (+1.0) | 33.4 (+1.4) | 92.3 (+0.4) | 79.2 (+0.4) | 71.2 (+1.2) | ||||
We use Diffusion-TTA to adapt semantic segmentors of SegFormer and depth predictors of DenseDepth. We use a latent diffusion model which is pre-trained on ADE20K and NYU Depth v2 dataset for the respective task. The diffusion model concatenates the segmentation/depth map with the noisy image during conditioning. We test adaptation performance on different distribution drifts.
Input
Before TTA
After TTA
Ground-truth
Task: Model | TTA | Clean | Gaussian Noise | Fog | Frost | Snow | Contrast | Shot | ||
---|---|---|---|---|---|---|---|---|---|---|
Segmentation: Segformer | ✗ | 66.1 | 65.3 | 63.0 | 58.0 | 55.2 | 65.3 | 63.3 | ||
✓ | 66.1 (0.0) | 66.4 (+1.1) | 65.1 (+2.1) | 58.9 (+0.9) | 56.6 (+1.4) | 66.4 (+1.1) | 63.7 (+0.4) | |||
Depth: DenseDepth | ✗ | 92.4 | 79.1 | 81.6 | 72.2 | 72.7 | 77.4 | 81.3 | ||
✓ | 92.6 (+0.3) | 82.1 (+3.0) | 84.4 (+2.8) | 73.0 (+0.8) | 74.1 (+1.4) | 77.4 (0.0) | 82.1 (+0.8) | |||
@inproceedings{prabhudesai2023difftta,
title={Test-time Adaptation of Discriminative Models via Diffusion Generative Feedback},
author={Prabhudesai, Mihir and Ke, Tsung-Wei and Li, Alexander C. and Pathak, Deepak and Fragkiadaki, Katerina},
year={2023},
booktitle={Conference on Neural Information Processing Systems},
}