stylegan2-pytorch
Implementation of Analyzing and Improving the Image Quality of StyleGAN (StyleGAN 2) in PyTorch
Top Related Projects
StyleGAN2 - Official TensorFlow Implementation
Official PyTorch implementation of StyleGAN3
Simplest working implementation of Stylegan2, state of the art generative adversarial network, in Pytorch. Enabling everyone to experience disentanglement
🤗 Diffusers: State-of-the-art diffusion models for image and audio generation in PyTorch and FLAX.
A latent text-to-image diffusion model
CLIP (Contrastive Language-Image Pretraining), Predict the most relevant text snippet given an image
Quick Overview
StyleGAN2-PyTorch is an unofficial PyTorch implementation of StyleGAN2, a state-of-the-art generative adversarial network (GAN) for high-quality image synthesis. This repository provides a PyTorch-based implementation of the original TensorFlow code, making it more accessible to researchers and developers who prefer PyTorch.
Pros
- Implements StyleGAN2 in PyTorch, which is more widely used in the research community
- Includes pre-trained models for easy experimentation and transfer learning
- Supports distributed training for faster model training on multiple GPUs
- Provides tools for image generation, style mixing, and model conversion
Cons
- May have slight differences in performance compared to the original TensorFlow implementation
- Requires significant computational resources for training large models
- Limited documentation compared to the original StyleGAN2 repository
- Might not include all the latest improvements and features of StyleGAN2-ADA
Code Examples
- Generate images using a pre-trained model:
import torch
from model import Generator
generator = Generator(size=1024, style_dim=512, n_mlp=8)
generator.load_state_dict(torch.load("stylegan2-ffhq-config-f.pt")["g_ema"])
z = torch.randn(1, 512)
with torch.no_grad():
img = generator(z)[0]
- Perform style mixing:
import torch
from model import Generator
generator = Generator(size=1024, style_dim=512, n_mlp=8)
generator.load_state_dict(torch.load("stylegan2-ffhq-config-f.pt")["g_ema"])
z1 = torch.randn(1, 512)
z2 = torch.randn(1, 512)
with torch.no_grad():
img = generator([z1, z2], truncation=0.7, truncation_latent=generator.mean_latent(4096))
- Convert TensorFlow checkpoint to PyTorch:
from convert_weight import convert_tf_weight
convert_tf_weight("stylegan2-ffhq-config-f.pkl", "stylegan2-ffhq-config-f.pt")
Getting Started
-
Clone the repository:
git clone https://github.com/rosinality/stylegan2-pytorch.git cd stylegan2-pytorch
-
Install dependencies:
pip install torch torchvision tqdm
-
Download pre-trained models:
wget https://github.com/rosinality/stylegan2-pytorch/releases/download/v0.1/stylegan2-ffhq-config-f.pt
-
Generate images:
import torch from model import Generator generator = Generator(size=1024, style_dim=512, n_mlp=8) generator.load_state_dict(torch.load("stylegan2-ffhq-config-f.pt")["g_ema"]) z = torch.randn(1, 512) with torch.no_grad(): img = generator(z)[0]
Competitor Comparisons
StyleGAN2 - Official TensorFlow Implementation
Pros of stylegan2
- Official NVIDIA implementation, ensuring high fidelity to the original paper
- Optimized for NVIDIA GPUs, potentially offering better performance
- Comprehensive documentation and examples provided by NVIDIA
Cons of stylegan2
- Requires TensorFlow 1.x, which is outdated and less flexible
- Limited compatibility with non-NVIDIA hardware
- Steeper learning curve for those not familiar with TensorFlow
Code Comparison
stylegan2 (TensorFlow):
import tensorflow as tf
import dnnlib
import dnnlib.tflib as tflib
def generate_images(network_pkl, seeds, truncation_psi):
tflib.init_tf()
with dnnlib.util.open_url(network_pkl) as f:
_G, _D, Gs = pickle.load(f)
stylegan2-pytorch (PyTorch):
import torch
from model import Generator
def generate_images(ckpt, seeds, truncation_psi):
generator = Generator(size, style_dim, n_mlp).to(device)
ckpt = torch.load(ckpt)
generator.load_state_dict(ckpt['g_ema'])
The stylegan2-pytorch implementation offers a more modern PyTorch-based approach, making it easier to integrate with other PyTorch projects and potentially more accessible to a wider range of developers. However, the NVIDIA implementation may provide better performance on NVIDIA hardware and closer adherence to the original paper's specifications.
Official PyTorch implementation of StyleGAN3
Pros of StyleGAN3
- Improved image quality and reduced artifacts compared to StyleGAN2
- Better performance and faster training times
- More flexible architecture with alias-free generator
Cons of StyleGAN3
- Higher computational requirements for training and inference
- Less community-contributed extensions and modifications
- Steeper learning curve for implementation and fine-tuning
Code Comparison
StyleGAN2-PyTorch:
import torch
from model import Generator
generator = Generator(size, style_dim, n_mlp).to(device)
noise = torch.randn(batch, style_dim, device=device)
fake_img = generator(noise)
StyleGAN3:
import torch
import dnnlib
import legacy
network_pkl = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-t-ffhq-1024x1024.pkl'
device = torch.device('cuda')
with dnnlib.util.open_url(network_pkl) as f:
G = legacy.load_network_pkl(f)['G_ema'].to(device)
z = torch.randn([1, G.z_dim]).to(device)
img = G(z, None)
Simplest working implementation of Stylegan2, state of the art generative adversarial network, in Pytorch. Enabling everyone to experience disentanglement
Pros of stylegan2-pytorch (lucidrains)
- More actively maintained with recent updates
- Includes additional features like tiling and truncation tricks
- Better documentation and code organization
Cons of stylegan2-pytorch (lucidrains)
- May have slightly higher memory usage
- Less established/tested in production environments
- Some users report occasional stability issues
Code Comparison
stylegan2-pytorch (rosinality):
def make_noise(batch, latent_dim, n_noise, device):
if n_noise == 1:
return torch.randn(batch, latent_dim, device=device)
noises = torch.randn(n_noise, batch, latent_dim, device=device).unbind(0)
return noises
stylegan2-pytorch (lucidrains):
def noise(n, latent_dim, device):
return torch.randn(n, latent_dim).to(device)
def noise_list(n, layers, latent_dim, device):
return [(noise(n, latent_dim, device), layers)]
The lucidrains implementation offers a more concise and flexible approach to noise generation, allowing for easy creation of noise lists for different layers.
🤗 Diffusers: State-of-the-art diffusion models for image and audio generation in PyTorch and FLAX.
Pros of diffusers
- Broader scope, supporting multiple diffusion models and techniques
- Extensive documentation and integration with the Hugging Face ecosystem
- Active development and frequent updates
Cons of diffusers
- Higher complexity due to supporting multiple models
- Potentially slower inference for specific models compared to specialized implementations
Code comparison
stylegan2-pytorch:
import torch
from model import Generator
generator = Generator(size, style_dim, n_mlp).to(device)
z = torch.randn(batch, style_dim, device=device)
generated_images = generator(z)
diffusers:
from diffusers import StableDiffusionPipeline
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
prompt = "a photo of an astronaut riding a horse on mars"
image = pipe(prompt).images[0]
Summary
diffusers offers a more versatile and well-supported framework for working with various diffusion models, while stylegan2-pytorch provides a specialized implementation of StyleGAN2. The choice between them depends on the specific requirements of your project, such as the need for multiple models or a focus on StyleGAN2 specifically.
A latent text-to-image diffusion model
Pros of Stable-Diffusion
- More versatile, capable of generating diverse images from text prompts
- Supports inpainting and image-to-image translation tasks
- Actively maintained with frequent updates and improvements
Cons of Stable-Diffusion
- Requires more computational resources and longer training time
- More complex architecture, potentially harder to understand and modify
- May produce less consistent results compared to StyleGAN2
Code Comparison
StyleGAN2-PyTorch:
generator = Generator(z_dim, w_dim, img_resolution, img_channels).to(device)
discriminator = Discriminator(img_resolution, img_channels).to(device)
Stable-Diffusion:
model = create_model('./v1-5-pruned.ckpt').to(device)
sampler = DDIMSampler(model)
Key Differences
- StyleGAN2-PyTorch focuses on generating high-quality images from random noise
- Stable-Diffusion allows for text-guided image generation and manipulation
- StyleGAN2 uses a GAN architecture, while Stable-Diffusion is based on diffusion models
- Stable-Diffusion offers more flexibility in terms of input and output modalities
- StyleGAN2 may be better suited for specific tasks like face generation
CLIP (Contrastive Language-Image Pretraining), Predict the most relevant text snippet given an image
Pros of CLIP
- Versatile multimodal learning: CLIP can understand both images and text, enabling various applications like image search and classification
- Zero-shot capabilities: CLIP can perform tasks without fine-tuning on specific datasets
- Robust performance across diverse domains due to its large-scale pre-training
Cons of CLIP
- Higher computational requirements for inference compared to StyleGAN2
- Less focused on image generation, primarily designed for image-text understanding
- May require more complex integration for certain image manipulation tasks
Code Comparison
CLIP example:
import torch
from PIL import Image
import clip
model, preprocess = clip.load("ViT-B/32", device="cuda")
image = preprocess(Image.open("image.jpg")).unsqueeze(0).to("cuda")
text = clip.tokenize(["a dog", "a cat"]).to("cuda")
with torch.no_grad():
image_features = model.encode_image(image)
text_features = model.encode_text(text)
StyleGAN2 example:
import torch
from model import Generator
generator = Generator(size=1024, style_dim=512, n_mlp=8).to("cuda")
generator.load_state_dict(torch.load("stylegan2-ffhq-config-f.pt")["g_ema"])
z = torch.randn(1, 512).to("cuda")
with torch.no_grad():
img = generator(z)
Convert designs to code with AI
Introducing Visual Copilot: A new AI model to turn Figma designs to high quality code using your components.
Try Visual CopilotREADME
StyleGAN 2 in PyTorch
Implementation of Analyzing and Improving the Image Quality of StyleGAN (https://arxiv.org/abs/1912.04958) in PyTorch
Notice
I have tried to match official implementation as close as possible, but maybe there are some details I missed. So please use this implementation with care.
Requirements
I have tested on:
- PyTorch 1.3.1
- CUDA 10.1/10.2
Usage
First create lmdb datasets:
python prepare_data.py --out LMDB_PATH --n_worker N_WORKER --size SIZE1,SIZE2,SIZE3,... DATASET_PATH
This will convert images to jpeg and pre-resizes it. This implementation does not use progressive growing, but you can create multiple resolution datasets using size arguments with comma separated lists, for the cases that you want to try another resolutions later.
Then you can train model in distributed settings
python -m torch.distributed.launch --nproc_per_node=N_GPU --master_port=PORT train.py --batch BATCH_SIZE LMDB_PATH
train.py supports Weights & Biases logging. If you want to use it, add --wandb arguments to the script.
SWAGAN
This implementation experimentally supports SWAGAN: A Style-based Wavelet-driven Generative Model (https://arxiv.org/abs/2102.06108). You can train SWAGAN by using
python -m torch.distributed.launch --nproc_per_node=N_GPU --master_port=PORT train.py --arch swagan --batch BATCH_SIZE LMDB_PATH
As noted in the paper, SWAGAN trains much faster. (About ~2x at 256px.)
Convert weight from official checkpoints
You need to clone official repositories, (https://github.com/NVlabs/stylegan2) as it is requires for load official checkpoints.
For example, if you cloned repositories in ~/stylegan2 and downloaded stylegan2-ffhq-config-f.pkl, You can convert it like this:
python convert_weight.py --repo ~/stylegan2 stylegan2-ffhq-config-f.pkl
This will create converted stylegan2-ffhq-config-f.pt file.
Generate samples
python generate.py --sample N_FACES --pics N_PICS --ckpt PATH_CHECKPOINT
You should change your size (--size 256 for example) if you train with another dimension.
Project images to latent spaces
python projector.py --ckpt [CHECKPOINT] --size [GENERATOR_OUTPUT_SIZE] FILE1 FILE2 ...
Closed-Form Factorization (https://arxiv.org/abs/2007.06600)
You can use closed_form_factorization.py
and apply_factor.py
to discover meaningful latent semantic factor or directions in unsupervised manner.
First, you need to extract eigenvectors of weight matrices using closed_form_factorization.py
python closed_form_factorization.py [CHECKPOINT]
This will create factor file that contains eigenvectors. (Default: factor.pt) And you can use apply_factor.py
to test the meaning of extracted directions
python apply_factor.py -i [INDEX_OF_EIGENVECTOR] -d [DEGREE_OF_MOVE] -n [NUMBER_OF_SAMPLES] --ckpt [CHECKPOINT] [FACTOR_FILE]
For example,
python apply_factor.py -i 19 -d 5 -n 10 --ckpt [CHECKPOINT] factor.pt
Will generate 10 random samples, and samples generated from latents that moved along 19th eigenvector with size/degree +-5.
Pretrained Checkpoints
I have trained the 256px model on FFHQ 550k iterations. I got FID about 4.5. Maybe data preprocessing, resolution, training loop could made this difference, but currently I don't know the exact reason of FID differences.
Samples
Sample from FFHQ. At 110,000 iterations. (trained on 3.52M images)
Sample from MetFaces with Non-leaking augmentations. At 150,000 iterations. (trained on 4.8M images)
Samples from converted weights
Sample from FFHQ (1024px)
Sample from LSUN Church (256px)
License
Model details and custom CUDA kernel codes are from official repostiories: https://github.com/NVlabs/stylegan2
Codes for Learned Perceptual Image Patch Similarity, LPIPS came from https://github.com/richzhang/PerceptualSimilarity
To match FID scores more closely to tensorflow official implementations, I have used FID Inception V3 implementations in https://github.com/mseitzer/pytorch-fid
Top Related Projects
StyleGAN2 - Official TensorFlow Implementation
Official PyTorch implementation of StyleGAN3
Simplest working implementation of Stylegan2, state of the art generative adversarial network, in Pytorch. Enabling everyone to experience disentanglement
🤗 Diffusers: State-of-the-art diffusion models for image and audio generation in PyTorch and FLAX.
A latent text-to-image diffusion model
CLIP (Contrastive Language-Image Pretraining), Predict the most relevant text snippet given an image
Convert designs to code with AI
Introducing Visual Copilot: A new AI model to turn Figma designs to high quality code using your components.
Try Visual Copilot