This repository contains the official implementation of AR-RAG: Autoregressive Retrieval Augmentation for Image Generation.
AR-RAG introduces a novel retrieval augmentation paradigm that enhances modern photorealistic image generation by augmenting image predictions with k-nearest neighbor (k-NN) retrievals at the patch level. Unlike existing approaches that rely on full-image retrieval conditioned on textual captions, AR-RAG retrieves locally similar patches based on their surrounding visual context, enabling caption-free retrieval while enforcing spatial coherence and semantic consistency for higher-quality image generation.
We propose two parallel frameworks:
-
Distribution-Augmentation in Decoding (DAiD): A training-free decoding strategy that directly merges the distribution of model-predicted patches with the distribution of retrieved patches.
-
Feature-Augmentation in Decoding (FAiD): A parameter-efficient fine-tuning method that smoothly integrates retrieved patches into the generation process via convolution operations.
Our methods significantly improve image generation quality across multiple benchmarks:
| Method | Single Obj. | Two Obj. | Counting | Colors | Position | Color Attri. | Overall ↑ |
|---|---|---|---|---|---|---|---|
| Janus-Pro | 0.98 | 0.77 | 0.52 | 0.84 | 0.61 | 0.55 | 0.71 |
| DAiD (ours) | 0.98 | 0.82 | 0.54 | 0.87 | 0.63 | 0.49 | 0.72 |
| FAiD (ours) | 1.00 | 0.92 | 0.41 | 0.87 | 0.71 | 0.60 | 0.75 |
| Method | Global | Entity | Attribute | Relation | Other | Overall ↑ |
|---|---|---|---|---|---|---|
| Janus-Pro | 81.76 | 84.53 | 84.34 | 92.22 | 75.20 | 77.26 |
| DAiD (ours) | 83.58 | 84.46 | 84.76 | 91.49 | 76.40 | 77.88 |
| FAiD (ours) | 82.67 | 85.80 | 85.38 | 92.3 | 76.80 | 79.36 |
| Model | MSCOCO FID | Midjourney FID |
|---|---|---|
| Janus-Pro | 19.59 | 12.81 |
| DAiD (ours) | 18.02 | 11.93 |
| FAiD (ours) | 17.60 | 9.31 |
| Model | Description | Size | HF Link |
|---|---|---|---|
| AR-RAG-FAiD | Fine-tuned model with Smoothly Feature Blending | 1.2B | 🤗 Model |
| Data Source | Image Num | Suggest GPU Memory | HF Link |
|---|---|---|---|
| JourneyDB | 1M | 12 GB | ZIP |
| CC12M | 12M | 96 GB | ZIP |
| DataCamp | 70M | - | 🤗 Coming soon |
git clone https://github.com/PLUM-Lab/AR-RAG.git
cd AR-RAG
# Create and activate conda environment
conda env create -f arrag.ymlDownload the checkpoint of VQ-VAE model from LlamaGen
wget -P arrag/Janus/janus https://huggingface.co/peizesun/llamagen_t2i/resolve/main/vq_ds16_t2i.ptbash arrag/build_retriever/build_retriever.shThe output faiss index will be: data/retriever/index_L
# Download pre-built retrieval database
wget http://nlplab1.cs.vt.edu/~jingyuan/AR-RAG/retrieval_db.zipbash ./arrag/train/train_FAiD.shThe default output checkpoint path: result/ckpts/ckpts_FAiD_bx_hx.
python arrag/t2i_example/t2i_daid_L.shThe default output image path: result/generated_imgs/example_t2i_daid.jpg.
python arrag/t2i_example/t2i_faid_L.shThe default output image path: result/generated_imgs/example_t2i_faid.jpg.
This project is licensed under the MIT License - see the LICENSE file for details.

