contrastive-unpaired-translation
Contrastive unpaired image-to-image translation, faster and lighter training than cyclegan (ECCV 2020, in PyTorch)
Top Related Projects
Software that can generate photos from paintings, turn horses into zebras, perform style transfer, and more.
Multimodal Unsupervised Image-to-Image Translation
Synthesizing and manipulating 2048x1024 images with conditional GANs
PyTorch implementation of SwAV https//arxiv.org/abs/2006.09882
A latent text-to-image diffusion model
CLIP (Contrastive Language-Image Pretraining), Predict the most relevant text snippet given an image
Quick Overview
The contrastive-unpaired-translation repository is an implementation of CUT (Contrastive Unpaired Translation) and FastCUT, which are novel frameworks for unsupervised image-to-image translation. These methods leverage contrastive learning to capture correspondences between input and output images without paired data, offering faster training and improved results compared to traditional approaches like CycleGAN.
Pros
- Achieves high-quality image translation without paired data
- Faster training and inference compared to CycleGAN
- Versatile application across various image translation tasks
- Provides both CUT and FastCUT implementations for different use cases
Cons
- Requires significant computational resources for training
- May struggle with complex, multi-modal translations
- Limited documentation for advanced customization
- Potential difficulty in fine-tuning for specific domains
Code Examples
- Loading a pre-trained model and performing inference:
from models import create_model
import torch
model = create_model('cut')
model.setup(opt)
model.eval()
# Load input image
input_image = torch.randn(1, 3, 256, 256) # Example input tensor
# Perform inference
with torch.no_grad():
output = model.netG(input_image)
- Defining a custom dataset for training:
from data import CustomDataset
class MyDataset(CustomDataset):
def __init__(self, opt):
CustomDataset.__init__(self, opt)
# Custom dataset initialization
def __getitem__(self, index):
# Custom data loading logic
return {'A': A_img, 'B': B_img, 'A_paths': A_path, 'B_paths': B_path}
- Training a CUT model:
from options.train_options import TrainOptions
from models import create_model
from data import create_dataset
opt = TrainOptions().parse()
dataset = create_dataset(opt)
model = create_model(opt)
for epoch in range(opt.n_epochs):
for i, data in enumerate(dataset):
model.set_input(data)
model.optimize_parameters()
if epoch % opt.save_epoch_freq == 0:
model.save_networks('latest')
model.save_networks(epoch)
Getting Started
-
Clone the repository:
git clone https://github.com/taesungp/contrastive-unpaired-translation.git cd contrastive-unpaired-translation
-
Install dependencies:
pip install -r requirements.txt
-
Prepare your dataset in the required format (see dataset preparation guide in the repository).
-
Train the model:
python train.py --dataroot ./datasets/your_dataset --name your_experiment_name --model cut
-
Test the model:
python test.py --dataroot ./datasets/your_dataset --name your_experiment_name --model cut
Competitor Comparisons
Software that can generate photos from paintings, turn horses into zebras, perform style transfer, and more.
Pros of CycleGAN
- Well-established and widely recognized in the field of image-to-image translation
- Extensive documentation and a large community for support
- Includes pre-trained models for various applications
Cons of CycleGAN
- May struggle with complex transformations or preserving fine details
- Can sometimes produce artifacts or unrealistic results
- Limited flexibility in terms of loss functions and network architectures
Code Comparison
CycleGAN:
class CycleGANModel(BaseModel):
def __init__(self, opt):
BaseModel.__init__(self, opt)
self.loss_names = ['D_A', 'G_A', 'cycle_A', 'idt_A', 'D_B', 'G_B', 'cycle_B', 'idt_B']
self.visual_names = ['real_A', 'fake_B', 'rec_A', 'real_B', 'fake_A', 'rec_B']
Contrastive Unpaired Translation:
class CUTModel(BaseModel):
def __init__(self, opt):
BaseModel.__init__(self, opt)
self.loss_names = ['G_GAN', 'D_real', 'D_fake', 'G', 'NCE', 'NCE_Y']
self.visual_names = ['real_A', 'fake_B', 'real_B']
The code snippets show differences in loss functions and visual outputs between the two models, reflecting their distinct approaches to unpaired image translation.
Multimodal Unsupervised Image-to-Image Translation
Pros of MUNIT
- Offers more flexible and diverse image-to-image translations
- Supports multi-modal outputs for a single input image
- Provides better disentanglement of content and style
Cons of MUNIT
- May require more computational resources due to its complexity
- Can be more challenging to train and fine-tune
- Potentially less stable in some scenarios compared to CUT
Code Comparison
MUNIT:
def forward(self, x_a, x_b):
c_a = self.gen_a.encode(x_a)
c_b = self.gen_b.encode(x_b)
s_a = self.gen_a.encode_style(x_a)
s_b = self.gen_b.encode_style(x_b)
x_ba = self.gen_a.decode(c_b, s_a)
x_ab = self.gen_b.decode(c_a, s_b)
CUT:
def forward(self, real):
fake = self.netG(real)
feat_q = self.netF(fake)
feat_k = self.netF(real)
loss = self.contrastive_loss(feat_q, feat_k)
return fake, loss
The code snippets highlight MUNIT's focus on content and style encoding/decoding, while CUT emphasizes contrastive learning between real and generated features.
Synthesizing and manipulating 2048x1024 images with conditional GANs
Pros of pix2pixHD
- High-resolution image synthesis capability (up to 2048x1024 pixels)
- Multi-scale generator and discriminator architecture for improved results
- Includes feature matching loss for enhanced stability and visual quality
Cons of pix2pixHD
- Requires paired training data, limiting its applicability to certain domains
- May struggle with preserving fine details in complex scenes
- Potential mode collapse issues in some scenarios
Code Comparison
pix2pixHD:
class GlobalGenerator(nn.Module):
def __init__(self, input_nc, output_nc, ngf=64, n_downsampling=3, n_blocks=9, norm_layer=nn.BatchNorm2d,
padding_type='reflect'):
super(GlobalGenerator, self).__init__()
activation = nn.ReLU(True)
contrastive-unpaired-translation:
class ResnetGenerator(nn.Module):
def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect'):
assert(n_blocks >= 0)
super(ResnetGenerator, self).__init__()
self.input_nc = input_nc
self.output_nc = output_nc
self.ngf = ngf
The code snippets show the generator class definitions for both projects. pix2pixHD uses a GlobalGenerator with downsampling and multiple blocks, while contrastive-unpaired-translation employs a ResnetGenerator with a variable number of blocks and optional dropout.
PyTorch implementation of SwAV https//arxiv.org/abs/2006.09882
Pros of swav
- Focuses on self-supervised learning for computer vision tasks
- Implements a novel contrastive learning approach (SwAV)
- Provides pre-trained models and evaluation scripts
Cons of swav
- Limited to image-based tasks, unlike CUT's focus on unpaired image translation
- Requires more computational resources for training
- Less flexibility in terms of input data types and domains
Code Comparison
swav:
loss = swav_loss(scores, temperature, sinkhorn_iterations)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
contrastive-unpaired-translation:
loss_G = self.compute_G_loss()
self.optimizer_G.zero_grad()
loss_G.backward()
self.optimizer_G.step()
Both repositories use PyTorch and implement custom loss functions. swav focuses on contrastive learning for self-supervised tasks, while contrastive-unpaired-translation (CUT) is designed for unpaired image-to-image translation. swav's code is more specific to its SwAV algorithm, while CUT's code is more general for GAN-based image translation tasks.
A latent text-to-image diffusion model
Pros of Stable-diffusion
- Generates high-quality images from text descriptions
- Supports various image manipulation tasks (inpainting, outpainting, etc.)
- Large and active community with frequent updates and improvements
Cons of Stable-diffusion
- Requires significant computational resources for training and inference
- May produce biased or inappropriate content without proper safeguards
- Limited to image generation and manipulation tasks
Code comparison
Stable-diffusion (image generation):
from diffusers import StableDiffusionPipeline
pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")
prompt = "a photo of an astronaut riding a horse on mars"
image = pipe(prompt).images[0]
Contrastive-unpaired-translation (image translation):
from models import create_model
import util.util as util
model = create_model(opt)
model.setup(opt)
real_A = util.tensor2im(data['A'])
fake_B = model.netG_A(real_A)
Both repositories focus on image processing tasks, but Stable-diffusion excels in text-to-image generation, while Contrastive-unpaired-translation specializes in unpaired image-to-image translation. Stable-diffusion offers more versatility and a larger community, but requires more computational resources. Contrastive-unpaired-translation is more focused on a specific task and may be lighter in terms of resource requirements.
CLIP (Contrastive Language-Image Pretraining), Predict the most relevant text snippet given an image
Pros of CLIP
- Broader applicability across various vision-language tasks
- Pre-trained on a massive dataset of 400 million image-text pairs
- Demonstrates zero-shot capabilities for image classification
Cons of CLIP
- Requires significant computational resources for training and inference
- May struggle with fine-grained visual distinctions or domain-specific tasks
- Limited ability to generate or manipulate images
Code Comparison
CLIP:
import torch
from PIL import Image
import clip
model, preprocess = clip.load("ViT-B/32", device="cuda")
image = preprocess(Image.open("image.jpg")).unsqueeze(0).to("cuda")
text = clip.tokenize(["a dog", "a cat"]).to("cuda")
with torch.no_grad():
image_features = model.encode_image(image)
text_features = model.encode_text(text)
Contrastive-unpaired-translation:
from models import create_model
import util.util as util
opt = util.parse_args()
model = create_model(opt)
model.setup(opt)
for i, data in enumerate(dataset):
model.set_input(data)
model.optimize_parameters()
The code snippets highlight the different focus areas of the two projects. CLIP emphasizes encoding images and text for comparison, while Contrastive-unpaired-translation focuses on image-to-image translation using unpaired data.
Convert designs to code with AI
Introducing Visual Copilot: A new AI model to turn Figma designs to high quality code using your components.
Try Visual CopilotREADME
Contrastive Unpaired Translation (CUT)
video (1m) | video (10m) | website | paper
We provide our PyTorch implementation of unpaired image-to-image translation based on patchwise contrastive learning and adversarial learning. No hand-crafted loss and inverse network is used. Compared to CycleGAN, our model training is faster and less memory-intensive. In addition, our method can be extended to single image training, where each âdomainâ is only a single image.
Contrastive Learning for Unpaired Image-to-Image Translation
Taesung Park, Alexei A. Efros, Richard Zhang, Jun-Yan Zhu
UC Berkeley and Adobe Research
In ECCV 2020
Pseudo code
import torch
cross_entropy_loss = torch.nn.CrossEntropyLoss()
# Input: f_q (BxCxS) and sampled features from H(G_enc(x))
# Input: f_k (BxCxS) are sampled features from H(G_enc(G(x))
# Input: tau is the temperature used in PatchNCE loss.
# Output: PatchNCE loss
def PatchNCELoss(f_q, f_k, tau=0.07):
# batch size, channel size, and number of sample locations
B, C, S = f_q.shape
# calculate v * v+: BxSx1
l_pos = (f_k * f_q).sum(dim=1)[:, :, None]
# calculate v * v-: BxSxS
l_neg = torch.bmm(f_q.transpose(1, 2), f_k)
# The diagonal entries are not negatives. Remove them.
identity_matrix = torch.eye(S)[None, :, :]
l_neg.masked_fill_(identity_matrix, -float('inf'))
# calculate logits: (B)x(S)x(S+1)
logits = torch.cat((l_pos, l_neg), dim=2) / tau
# return PatchNCE loss
predictions = logits.flatten(0, 1)
targets = torch.zeros(B * S, dtype=torch.long)
return cross_entropy_loss(predictions, targets)
Example Results
Unpaired Image-to-Image Translation
Single Image Unpaired Translation
Russian Blue Cat to Grumpy Cat
Parisian Street to Burano's painted houses
Prerequisites
- Linux or macOS
- Python 3
- CPU or NVIDIA GPU + CUDA CuDNN
Update log
9/12/2020: Added single-image translation.
Getting started
- Clone this repo:
git clone https://github.com/taesungp/contrastive-unpaired-translation CUT
cd CUT
-
Install PyTorch 1.1 and other dependencies (e.g., torchvision, visdom, dominate, gputil).
For pip users, please type the command
pip install -r requirements.txt
.For Conda users, you can create a new Conda environment using
conda env create -f environment.yml
.
CUT and FastCUT Training and Test
- Download the
grumpifycat
dataset (Fig 8 of the paper. Russian Blue -> Grumpy Cats)
bash ./datasets/download_cut_dataset.sh grumpifycat
The dataset is downloaded and unzipped at ./datasets/grumpifycat/
.
-
To view training results and loss plots, run
python -m visdom.server
and click the URL http://localhost:8097. -
Train the CUT model:
python train.py --dataroot ./datasets/grumpifycat --name grumpycat_CUT --CUT_mode CUT
Or train the FastCUT model
python train.py --dataroot ./datasets/grumpifycat --name grumpycat_FastCUT --CUT_mode FastCUT
The checkpoints will be stored at ./checkpoints/grumpycat_*/web
.
- Test the CUT model:
python test.py --dataroot ./datasets/grumpifycat --name grumpycat_CUT --CUT_mode CUT --phase train
The test results will be saved to a html file here: ./results/grumpifycat/latest_train/index.html
.
CUT, FastCUT, and CycleGAN
CUT is trained with the identity preservation loss and with lambda_NCE=1
, while FastCUT is trained without the identity loss but with higher lambda_NCE=10.0
. Compared to CycleGAN, CUT learns to perform more powerful distribution matching, while FastCUT is designed as a lighter (half the GPU memory, can fit a larger image), and faster (twice faster to train) alternative to CycleGAN. Please refer to the paper for more details.
In the above figure, we measure the percentage of pixels belonging to the horse/zebra bodies, using a pre-trained semantic segmentation model. We find a distribution mismatch between sizes of horses and zebras images -- zebras usually appear larger (36.8% vs. 17.9%). Our full method CUT has the flexibility to enlarge the horses, as a means of better matching of the training statistics than CycleGAN. FastCUT behaves more conservatively like CycleGAN.
Training using our launcher scripts
Please see experiments/grumpifycat_launcher.py
that generates the above command line arguments. The launcher scripts are useful for configuring rather complicated command-line arguments of training and testing.
Using the launcher, the command below generates the training command of CUT and FastCUT.
python -m experiments grumpifycat train 0 # CUT
python -m experiments grumpifycat train 1 # FastCUT
To test using the launcher,
python -m experiments grumpifycat test 0 # CUT
python -m experiments grumpifycat test 1 # FastCUT
Possible commands are run, run_test, launch, close, and so on. Please see experiments/__main__.py
for all commands. Launcher is easy and quick to define and use. For example, the grumpifycat launcher is defined in a few lines:
from .tmux_launcher import Options, TmuxLauncher
class Launcher(TmuxLauncher):
def common_options(self):
return [
Options( # Command 0
dataroot="./datasets/grumpifycat",
name="grumpifycat_CUT",
CUT_mode="CUT"
),
Options( # Command 1
dataroot="./datasets/grumpifycat",
name="grumpifycat_FastCUT",
CUT_mode="FastCUT",
)
]
def commands(self):
return ["python train.py " + str(opt) for opt in self.common_options()]
def test_commands(self):
# Russian Blue -> Grumpy Cats dataset does not have test split.
# Therefore, let's set the test split to be the "train" set.
return ["python test.py " + str(opt.set(phase='train')) for opt in self.common_options()]
Apply a pre-trained CUT model and evaluate FID
To run the pretrained models, run the following.
# Download and unzip the pretrained models. The weights should be located at
# checkpoints/horse2zebra_cut_pretrained/latest_net_G.pth, for example.
wget http://efrosgans.eecs.berkeley.edu/CUT/pretrained_models.tar
tar -xf pretrained_models.tar
# Generate outputs. The dataset paths might need to be adjusted.
# To do this, modify the lines of experiments/pretrained_launcher.py
# [id] corresponds to the respective commands defined in pretrained_launcher.py
# 0 - CUT on Cityscapes
# 1 - FastCUT on Cityscapes
# 2 - CUT on Horse2Zebra
# 3 - FastCUT on Horse2Zebra
# 4 - CUT on Cat2Dog
# 5 - FastCUT on Cat2Dog
python -m experiments pretrained run_test [id]
# Evaluate FID. To do this, first install pytorch-fid of https://github.com/mseitzer/pytorch-fid
# pip install pytorch-fid
# For example, to evaluate horse2zebra FID of CUT,
# python -m pytorch_fid ./datasets/horse2zebra/testB/ results/horse2zebra_cut_pretrained/test_latest/images/fake_B/
# To evaluate Cityscapes FID of FastCUT,
# python -m pytorch_fid ./datasets/cityscapes/valA/ ~/projects/contrastive-unpaired-translation/results/cityscapes_fastcut_pretrained/test_latest/images/fake_B/
# Note that a special dataset needs to be used for the Cityscapes model. Please read below.
python -m pytorch_fid [path to real test images] [path to generated images]
Note: the Cityscapes pretrained model was trained and evaluated on a resized and JPEG-compressed version of the original Cityscapes dataset. To perform evaluation, please download this validation set and perform evaluation.
SinCUT Single Image Unpaired Training
To train SinCUT (single-image translation, shown in Fig 9, 13 and 14 of the paper), you need to
- set the
--model
option as--model sincut
, which invokes the configuration and codes at./models/sincut_model.py
, and - specify the dataset directory of one image in each domain, such as the example dataset included in this repo at
./datasets/single_image_monet_etretat/
.
For example, to train a model for the Etretat cliff (first image of Figure 13), please use the following command.
python train.py --model sincut --name singleimage_monet_etretat --dataroot ./datasets/single_image_monet_etretat
or by using the experiment launcher script,
python -m experiments singleimage run 0
For single-image translation, we adopt network architectural components of StyleGAN2, as well as the pixel identity preservation loss used in DTN and CycleGAN. In particular, we adopted the code of rosinality, which exists at models/stylegan_networks.py
.
The training takes several hours. To generate the final image using the checkpoint,
python test.py --model sincut --name singleimage_monet_etretat --dataroot ./datasets/single_image_monet_etretat
or simply
python -m experiments singleimage run_test 0
Datasets
Download CUT/CycleGAN/pix2pix datasets. For example,
bash datasets/download_cut_datasets.sh horse2zebra
The Cat2Dog dataset is prepared from the AFHQ dataset. Please visit https://github.com/clovaai/stargan-v2 and download the AFHQ dataset by bash download.sh afhq-dataset
of the github repo. Then reorganize directories as follows.
mkdir datasets/cat2dog
ln -s datasets/cat2dog/trainA [path_to_afhq]/train/cat
ln -s datasets/cat2dog/trainB [path_to_afhq]/train/dog
ln -s datasets/cat2dog/testA [path_to_afhq]/test/cat
ln -s datasets/cat2dog/testB [path_to_afhq]/test/dog
The Cityscapes dataset can be downloaded from https://cityscapes-dataset.com.
After that, use the script ./datasets/prepare_cityscapes_dataset.py
to prepare the dataset.
Preprocessing of input images
The preprocessing of the input images, such as resizing or random cropping, is controlled by the option --preprocess
, --load_size
, and --crop_size
. The usage follows the CycleGAN/pix2pix repo.
For example, the default setting --preprocess resize_and_crop --load_size 286 --crop_size 256
resizes the input image to 286x286
, and then makes a random crop of size 256x256
as a way to perform data augmentation. There are other preprocessing options that can be specified, and they are specified in base_dataset.py. Below are some example options.
--preprocess none
: does not perform any preprocessing. Note that the image size is still scaled to be a closest multiple of 4, because the convolutional generator cannot maintain the same image size otherwise.--preprocess scale_width --load_size 768
: scales the width of the image to be of size 768.--preprocess scale_shortside_and_crop
: scales the image preserving aspect ratio so that the short side isload_size
, and then performs random cropping of window sizecrop_size
.
More preprocessing options can be added by modifying get_transform()
of base_dataset.py
.
Citation
If you use this code for your research, please cite our paper.
@inproceedings{park2020cut,
title={Contrastive Learning for Unpaired Image-to-Image Translation},
author={Taesung Park and Alexei A. Efros and Richard Zhang and Jun-Yan Zhu},
booktitle={European Conference on Computer Vision},
year={2020}
}
If you use the original pix2pix and CycleGAN model included in this repo, please cite the following papers
@inproceedings{CycleGAN2017,
title={Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks},
author={Zhu, Jun-Yan and Park, Taesung and Isola, Phillip and Efros, Alexei A},
booktitle={IEEE International Conference on Computer Vision (ICCV)},
year={2017}
}
@inproceedings{isola2017image,
title={Image-to-Image Translation with Conditional Adversarial Networks},
author={Isola, Phillip and Zhu, Jun-Yan and Zhou, Tinghui and Efros, Alexei A},
booktitle={IEEE Conference on Computer Vision and Pattern Recognition (CVPR)},
year={2017}
}
Acknowledgments
We thank Allan Jabri and Phillip Isola for helpful discussion and feedback. Our code is developed based on pytorch-CycleGAN-and-pix2pix. We also thank pytorch-fid for FID computation, drn for mIoU computation, and stylegan2-pytorch for the PyTorch implementation of StyleGAN2 used in our single-image translation setting.
Top Related Projects
Software that can generate photos from paintings, turn horses into zebras, perform style transfer, and more.
Multimodal Unsupervised Image-to-Image Translation
Synthesizing and manipulating 2048x1024 images with conditional GANs
PyTorch implementation of SwAV https//arxiv.org/abs/2006.09882
A latent text-to-image diffusion model
CLIP (Contrastive Language-Image Pretraining), Predict the most relevant text snippet given an image
Convert designs to code with AI
Introducing Visual Copilot: A new AI model to turn Figma designs to high quality code using your components.
Try Visual Copilot