Cross-Attention Makes Inference Cumbersome in Text-to-Image Diffusion Models
Authors: Wentian Zhang, Haozhe Liu, Jinheng Xie, Francesco Faccio, Mike Zheng Shou, Jürgen Schmidhuber
What
This paper investigates the role of cross-attention in text-to-image diffusion models during inference and finds that cross-attention maps converge quickly, becoming redundant in later inference steps.
Why
This paper is important because it challenges the assumption that cross-attention is crucial for every inference step in text-to-image diffusion models, offering a potential path to significantly reduce computational cost without sacrificing image quality.
How
The authors analyze the role of cross-attention by replacing text embeddings with null embeddings at various stages of the inference process. They then quantitatively evaluate the impact of this replacement on image generation quality using FID scores on the MS-COCO dataset. They also visualize the generated images at different inference steps to understand the dynamic of cross-attention.
Result
The key findings are that cross-attention outputs converge to a fixed point early in the inference process. The authors leverage this finding to develop \textsc{Tgate}, a training-free method that caches and reuses cross-attention outputs from early inference steps, leading to reduced computational cost (up to 50% reduction in latency) and even slight improvements in FID scores compared to baseline models. Notably, \textsc{Tgate} is effective across various model architectures, including both convolutional and transformer-based diffusion models.
LF
The authors acknowledge that while \textsc{Tgate} brings quantitative improvements in FID scores, the visual differences in generated images might be subtle for users. As for future work, the authors suggest exploring the impact of scaling token length and image resolution on the efficiency gains provided by \textsc{Tgate}, hinting at its potential benefits for the emerging trend of larger input sizes in diffusion models.
Abstract
This study explores the role of cross-attention during inference in text-conditional diffusion models. We find that cross-attention outputs converge to a fixed point after few inference steps. Accordingly, the time point of convergence naturally divides the entire inference process into two stages: an initial semantics-planning stage, during which, the model relies on cross-attention to plan text-oriented visual semantics, and a subsequent fidelity-improving stage, during which the model tries to generate images from previously planned semantics. Surprisingly, ignoring text conditions in the fidelity-improving stage not only reduces computation complexity, but also maintains model performance. This yields a simple and training-free method called TGATE for efficient generation, which caches the cross-attention output once it converges and keeps it fixed during the remaining inference steps. Our empirical study on the MS-COCO validation set confirms its effectiveness. The source code of TGATE is available at https://github.com/HaozheLiu-ST/T-GATE.