Convert Figma logo to code with AI

krasserm logosuper-resolution

Tensorflow 2.x based implementation of EDSR, WDSR and SRGAN for single image super-resolution

1,507
353
1,507
57

Top Related Projects

🔎 Super-scale your images and run experiments with Residual Dense and Adversarial Networks.

Quick Overview

The krasserm/super-resolution repository is a collection of super-resolution models implemented in TensorFlow 2.x and PyTorch. It focuses on single image super-resolution (SISR) techniques, providing implementations of various state-of-the-art algorithms for upscaling low-resolution images to higher resolutions.

Pros

  • Implements multiple super-resolution models in both TensorFlow and PyTorch
  • Provides pre-trained models for easy use and comparison
  • Includes detailed documentation and examples for each implemented model
  • Supports custom training and evaluation on user-provided datasets

Cons

  • Limited to single image super-resolution, not covering other super-resolution tasks
  • May require significant computational resources for training and inference
  • Dependency on specific versions of TensorFlow and PyTorch may cause compatibility issues
  • Some implemented models may be outdated compared to the latest research

Code Examples

  1. Loading a pre-trained EDSR model and performing super-resolution:
from model.edsr import edsr
from utils import load_image, plot_sample

model = edsr(scale=4, num_res_blocks=16)
model.load_weights('weights/edsr-16-x4/weights.h5')

lr = load_image('data/urban100/urban100_001_LR.png')
sr = model.predict(lr)

plot_sample(lr, sr)
  1. Training an ESPCN model on a custom dataset:
from model.espcn import espcn
from data import DIV2K

model = espcn(scale=4)
train_ds, valid_ds = DIV2K(scale=4).load(batch_size=16)

model.fit(train_ds, epochs=100, validation_data=valid_ds, 
          callbacks=[
              tf.keras.callbacks.ModelCheckpoint('weights/espcn-x4/weights.h5', 
                                                 save_best_only=True, 
                                                 save_weights_only=True),
              tf.keras.callbacks.ReduceLROnPlateau(factor=0.5, patience=5)
          ])
  1. Evaluating a WDSR model on a test set:
from model.wdsr import wdsr_b
from data import Set5
from utils import evaluate

model = wdsr_b(scale=2, num_res_blocks=32)
model.load_weights('weights/wdsr-b-32-x2/weights.h5')

test_ds = Set5(scale=2, subset='test').dataset(batch_size=1)
psnr = evaluate(model, test_ds)
print(f'PSNR: {psnr:.2f}')

Getting Started

To get started with the super-resolution models:

  1. Clone the repository:

    git clone https://github.com/krasserm/super-resolution.git
    cd super-resolution
    
  2. Install dependencies:

    pip install -r requirements.txt
    
  3. Download pre-trained weights:

    ./download_weights.sh
    
  4. Run a simple example:

    from model.edsr import edsr
    from utils import load_image, plot_sample
    
    model = edsr(scale=4, num_res_blocks=16)
    model.load_weights('weights/edsr-16-x4/weights.h5')
    
    lr = load_image('data/div2k/DIV2K_valid_LR_bicubic/X4/0801x4.png')
    sr = model.predict(lr)
    plot_sample(lr, sr)
    

Competitor Comparisons

🔎 Super-scale your images and run experiments with Residual Dense and Adversarial Networks.

Pros of image-super-resolution

  • More comprehensive documentation and usage examples
  • Includes pre-trained models for easier implementation
  • Supports multiple super-resolution methods (ESRGAN, SRGAN, SRResNet)

Cons of image-super-resolution

  • Less frequent updates and maintenance
  • Limited to 4x upscaling factor
  • Fewer advanced features compared to super-resolution

Code Comparison

image-super-resolution:

from ISR.models import RDN
import numpy as np
from PIL import Image

img = Image.open('input.png')
lr_img = np.array(img)
model = RDN(weights='psnr-small')
sr_img = model.predict(lr_img)

super-resolution:

from model import resolve_single
from utils import load_image, plot_sample

model = resolve_single.edsr(scale=4, num_res_blocks=16)
model.load_weights('weights/edsr-16-x4/weights.h5')

lr = load_image('input.png')
sr = resolve_single(model, lr)
plot_sample(lr, sr)

Both repositories offer image super-resolution capabilities, but they differ in implementation and features. image-super-resolution provides a more user-friendly approach with pre-trained models and extensive documentation, while super-resolution offers more advanced features and flexibility for experienced users.

Convert Figma logo designs to code with AI

Visual Copilot

Introducing Visual Copilot: A new AI model to turn Figma designs to high quality code using your components.

Try Visual Copilot

README

Travis CI

Single Image Super-Resolution with EDSR, WDSR and SRGAN

A Tensorflow 2.x based implementation of

This is a complete re-write of the old Keras/Tensorflow 1.x based implementation available here. Some parts are still work in progress but you can already train models as described in the papers via a high-level training API. Furthermore, you can also fine-tune EDSR and WDSR models in an SRGAN context. Training and usage examples are given in the notebooks

A DIV2K data provider automatically downloads DIV2K training and validation images of given scale (2, 3, 4 or 8) and downgrade operator ("bicubic", "unknown", "mild" or "difficult").

Important: if you want to evaluate the pre-trained models with a dataset other than DIV2K please read this comment (and replies) first.

Environment setup

Create a new conda environment with

conda env create -f environment.yml

and activate it with

conda activate sisr

Introduction

You can find an introduction to single-image super-resolution in this article. It also demonstrates how EDSR and WDSR models can be fine-tuned with SRGAN (see also this section).

Getting started

Examples in this section require following pre-trained weights for running (see also example notebooks):

Pre-trained weights

  • weights-edsr-16-x4.tar.gz
    • EDSR x4 baseline as described in the EDSR paper: 16 residual blocks, 64 filters, 1.52M parameters.
    • PSNR on DIV2K validation set = 28.89 dB (images 801 - 900, 6 + 4 pixel border included).
  • weights-wdsr-b-32-x4.tar.gz
    • WDSR B x4 custom model: 32 residual blocks, 32 filters, expansion factor 6, 0.62M parameters.
    • PSNR on DIV2K validation set = 28.91 dB (images 801 - 900, 6 + 4 pixel border included).
  • weights-srgan.tar.gz
    • SRGAN as described in the SRGAN paper: 1.55M parameters, trained with VGG54 content loss.

After download, extract them in the root folder of the project with

tar xvfz weights-<...>.tar.gz

EDSR

from model import resolve_single
from model.edsr import edsr

from utils import load_image, plot_sample

model = edsr(scale=4, num_res_blocks=16)
model.load_weights('weights/edsr-16-x4/weights.h5')

lr = load_image('demo/0851x4-crop.png')
sr = resolve_single(model, lr)

plot_sample(lr, sr)

result-edsr

WDSR

from model.wdsr import wdsr_b

model = wdsr_b(scale=4, num_res_blocks=32)
model.load_weights('weights/wdsr-b-32-x4/weights.h5')

lr = load_image('demo/0829x4-crop.png')
sr = resolve_single(model, lr)

plot_sample(lr, sr)

result-wdsr

Weight normalization in WDSR models is implemented with the new WeightNormalization layer wrapper of Tensorflow Addons. In its latest version, this wrapper seems to corrupt weights when running model.predict(...). A workaround is to set model.run_eagerly = True or compile the model with model.compile(loss='mae') in advance. This issue doesn't arise when calling the model directly with model(...) though. To be further investigated ...

SRGAN

from model.srgan import generator

model = generator()
model.load_weights('weights/srgan/gan_generator.h5')

lr = load_image('demo/0869x4-crop.png')
sr = resolve_single(model, lr)

plot_sample(lr, sr)

result-srgan

DIV2K dataset

For training and validation on DIV2K images, applications should use the provided DIV2K data loader. It automatically downloads DIV2K images to .div2k directory and converts them to a different format for faster loading.

Training dataset

from data import DIV2K

train_loader = DIV2K(scale=4,             # 2, 3, 4 or 8
                     downgrade='bicubic', # 'bicubic', 'unknown', 'mild' or 'difficult' 
                     subset='train')      # Training dataset are images 001 - 800
                     
# Create a tf.data.Dataset          
train_ds = train_loader.dataset(batch_size=16,         # batch size as described in the EDSR and WDSR papers
                                random_transform=True, # random crop, flip, rotate as described in the EDSR paper
                                repeat_count=None)     # repeat iterating over training images indefinitely

# Iterate over LR/HR image pairs                                
for lr, hr in train_ds:
    # .... 

Crop size in HR images is 96x96.

Validation dataset

from data import DIV2K

valid_loader = DIV2K(scale=4,             # 2, 3, 4 or 8
                     downgrade='bicubic', # 'bicubic', 'unknown', 'mild' or 'difficult' 
                     subset='valid')      # Validation dataset are images 801 - 900
                     
# Create a tf.data.Dataset          
valid_ds = valid_loader.dataset(batch_size=1,           # use batch size of 1 as DIV2K images have different size
                                random_transform=False, # use DIV2K images in original size 
                                repeat_count=1)         # 1 epoch
                                
# Iterate over LR/HR image pairs                                
for lr, hr in valid_ds:
    # ....                                 

Training

The following training examples use the training and validation datasets described earlier. The high-level training API is designed around steps (= minibatch updates) rather than epochs to better match the descriptions in the papers.

EDSR

from model.edsr import edsr
from train import EdsrTrainer

# Create a training context for an EDSR x4 model with 16 
# residual blocks.
trainer = EdsrTrainer(model=edsr(scale=4, num_res_blocks=16), 
                      checkpoint_dir=f'.ckpt/edsr-16-x4')
                      
# Train EDSR model for 300,000 steps and evaluate model
# every 1000 steps on the first 10 images of the DIV2K
# validation set. Save a checkpoint only if evaluation
# PSNR has improved.
trainer.train(train_ds,
              valid_ds.take(10),
              steps=300000, 
              evaluate_every=1000, 
              save_best_only=True)
              
# Restore from checkpoint with highest PSNR.
trainer.restore()

# Evaluate model on full validation set.
psnr = trainer.evaluate(valid_ds)
print(f'PSNR = {psnr.numpy():3f}')

# Save weights to separate location.
trainer.model.save_weights('weights/edsr-16-x4/weights.h5')                                    

Interrupting training and restarting it again resumes from the latest saved checkpoint. The trained Keras model can be accessed with trainer.model.

WDSR

from model.wdsr import wdsr_b
from train import WdsrTrainer

# Create a training context for a WDSR B x4 model with 32 
# residual blocks.
trainer = WdsrTrainer(model=wdsr_b(scale=4, num_res_blocks=32), 
                      checkpoint_dir=f'.ckpt/wdsr-b-8-x4')

# Train WDSR B model for 300,000 steps and evaluate model
# every 1000 steps on the first 10 images of the DIV2K
# validation set. Save a checkpoint only if evaluation
# PSNR has improved.
trainer.train(train_ds,
              valid_ds.take(10),
              steps=300000, 
              evaluate_every=1000, 
              save_best_only=True)

# Restore from checkpoint with highest PSNR.
trainer.restore()

# Evaluate model on full validation set.
psnr = trainer.evaluate(valid_ds)
print(f'PSNR = {psnr.numpy():3f}')

# Save weights to separate location.
trainer.model.save_weights('weights/wdsr-b-32-x4/weights.h5')

SRGAN

Generator pre-training

from model.srgan import generator
from train import SrganGeneratorTrainer

# Create a training context for the generator (SRResNet) alone.
pre_trainer = SrganGeneratorTrainer(model=generator(), checkpoint_dir=f'.ckpt/pre_generator')

# Pre-train the generator with 1,000,000 steps (100,000 works fine too). 
pre_trainer.train(train_ds, valid_ds.take(10), steps=1000000, evaluate_every=1000)

# Save weights of pre-trained generator (needed for fine-tuning with GAN).
pre_trainer.model.save_weights('weights/srgan/pre_generator.h5')

Generator fine-tuning (GAN)

from model.srgan import generator, discriminator
from train import SrganTrainer

# Create a new generator and init it with pre-trained weights.
gan_generator = generator()
gan_generator.load_weights('weights/srgan/pre_generator.h5')

# Create a training context for the GAN (generator + discriminator).
gan_trainer = SrganTrainer(generator=gan_generator, discriminator=discriminator())

# Train the GAN with 200,000 steps.
gan_trainer.train(train_ds, steps=200000)

# Save weights of generator and discriminator.
gan_trainer.generator.save_weights('weights/srgan/gan_generator.h5')
gan_trainer.discriminator.save_weights('weights/srgan/gan_discriminator.h5')

SRGAN for fine-tuning EDSR and WDSR models

It is also possible to fine-tune EDSR and WDSR x4 models with SRGAN. They can be used as drop-in replacement for the original SRGAN generator. More details in this article.

# Create EDSR generator and init with pre-trained weights
generator = edsr(scale=4, num_res_blocks=16)
generator.load_weights('weights/edsr-16-x4/weights.h5')

# Fine-tune EDSR model via SRGAN training.
gan_trainer = SrganTrainer(generator=generator, discriminator=discriminator())
gan_trainer.train(train_ds, steps=200000)
# Create WDSR B generator and init with pre-trained weights
generator = wdsr_b(scale=4, num_res_blocks=32)
generator.load_weights('weights/wdsr-b-16-32/weights.h5')

# Fine-tune WDSR B  model via SRGAN training.
gan_trainer = SrganTrainer(generator=generator, discriminator=discriminator())
gan_trainer.train(train_ds, steps=200000)