This repository contains code for our paper
Closed-Loop Training for Projected GAN
Jiangwei Zhao, Liang Zhang, Lili Pan, Hongliang Li
IEEE Signal Processing Letters
Abstract:Projected GAN, a pre-trained GAN, has been found to perform well in generating images with only a few training samples. However, it struggles with extended training, which may lead to decreased performance over time. This is because the pre-trained discriminator consistently surpasses the generator, creating an unstable training environment. In this work, we propose a solution to this issue by introducing closed-loop control (CLC) into the dynamics of Projected GAN, stabilizing training, and improving generation performance. Our proposed method consistently reduces the Fréchet Inception Distance (FID) of the previous methods; for example, it reduces the FID of Projected GAN by 4.31 on the Obama dataset. Our finding is fundamental and can be used in other pre-trained GANs.
Dependencies
- 64-bit Python 3.8
- PyTorch 1.9.0 (or later). See https://pytorch.org for PyTorch install instructions.
Installation
First, you can clone this repo using the command:
git clone https://github.com/learninginvision/ProjectedGAN-CLCThen, you can create a virtual environment using conda, as follows:
conda env create -f environment.yaml
conda activate pg-clcFor a quick start, you can download the few-shot datasets provided by the authors of FastGAN. You can download them here. To prepare the dataset at the respective resolution, run for example
python dataset_tool.py --source=./data/pokemon --dest=./data/pokemon256.zip \
--resolution=256x256 --transform=center-crop
You can get the datasets we used in our paper at their respective websites: AFHQ, Landscape.
Training your own PG-CLC on Pokemon using 2 GPUs:
python train.py --outdir=./training-runs/ --cfg=fastgan --data=./data/pokemon256.zip \
--gpus=2 --batch=64 --mirror=1 --snap=50 --batch-gpu=16 --kimg=10000
--batch specifies the overall batch size, --batch-gpu specifies the batch size per GPU.
We use a lightweight version of FastGAN (--cfg=fastgan_lite). This backbone trains fast regarding wallclock
time and yields better results on small datasets like Pokemon.
Samples and metrics are saved in outdir. To monitor the training progress, you can inspect fid50k_full.json or run tensorboard in training-runs.
You can change the config of clc on train.py#L240-L243
We provide the following pretrained models (pass the url as PATH_TO_NETWORK_PKL):
| Dataset | Loss Weight | Queue Factor | FID | PATH |
|---|---|---|---|---|
| Pokemon | 0.1 | 100 | 25.04 | https://drive.google.com/file/d/18-678PSsr4sYX28qtIkdkOd3TtdpKCWf |
| Art-Paint | 0.05 | 200 | 26.91 | https://drive.google.com/file/d/1if_qohz0PYtSzuSlL72nE71oATxuSmVT |
| Flowers | 0.05 | 200 | 12.82 | https://drive.google.com/file/d/1B844ooziyOhk3dGbS389XWujIPjTpYbN |
| landscapes | 0.05 | 100 | 6.55 | https://drive.google.com/file/d/1RpDg4vRPgD6UXajzmWDNSuyxkS2F_pwK |
| Obama | 0.05 | 100 | 20.12 | https://drive.google.com/file/d/1A0SbqW3xvHMfWVs_Pp7nUs8Ih5Uj9aYL |
Per default, train.py tracks FID50k during training. To calculate metrics for a specific network snapshot, run
python calc_metrics.py --metrics=fid50k_full --network=PATH_TO_NETWORK_PKL
To see the available metrics, run
python calc_metrics.py --help
@ARTICLE{10334000,
author={Zhao, Jiangwei and Zhang, Liang and Pan, Lili and Li, Hongliang},
journal={IEEE Signal Processing Letters},
title={Closed-Loop Training for Projected GAN},
year={2024},
volume={31},
number={},
pages={106-110},
doi={10.1109/LSP.2023.3337711}}
Our codebase build and extends the awesome StyleGAN2-ADA repo, ProjectedGAN repo and StyleGAN3 repo

