Skip to content

PyTorch implementation of the paper "Learning to Resize Images for Computer Vision Tasks" on Imagenette and Imagewoof datasets

License

Notifications You must be signed in to change notification settings

KushajveerSingh/resize_network_cv

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

17 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Data augmentation with Resizer network for Image Classification

This repository contains the PyTorch implementation of the Resizer model proposed in the paper Learning to Resize Images for Computer Vision Tasks. The model is tested on two datasets: Imagenette and Imagewoof using ResNet-50 as the baseline model. Check the accompanying blog post for details on the model.

In summary, a CNN model is used to learn data augmentation. The augmented image is passed to a standard image classification model (e.g. ResNet-50) for the downstream task (e.g. image classification).

Table of Contents

Results

Dataset Model Acc
Imagenette ResNet-50 81.07
Resizer + ResNet-50 82.16
Imagewoof ResNet-50 58.13
Resizer + ResNet-50 65.20

Note:- Due to compute limitation I stopped the training of models early. If you want to get better results increase the number of epochs to 300 and change the learning rate scheduler to reduce learning rate every 50 epochs with a factor of 0.8.

Details of config file

If you are unfamiliar with hydra, check my blog Complete tutorial on how to use Hydra in Machine Learning projects for a quick guide on how to use hydra.

cfg.data contains the arguments to load the desired dataset. List of all arguments is shown below

data:
  root: ../data     # directory where data is downloaded (not including the folder name)
  name: imagenette2  # "imagenette2" or "imagewoof2" (folder name of dataset inside `root`)
  resizer_image_size: 448  # size of images passed to resizer model
  image_size: 224          # size of images passed to CNN model
  num_classes: 10          # number of labels in training dataset

  # Passed to torch.utils.data.DataLoader
  batch_size: 64
  num_workers: 8

The main arguments that you need to adjust are root where the dataset is downloaded and name (imagenette2, imagewoof2) which dataset to use for training.

To apply the resizer model use apply_resizer_model: true and it will apply the resizer before the base model. The arguments of the resizer are specified in cfg.resizer

resizer:
  in_channels: 3       # Number of input channels of resizer (for RGB images it is 3)
  out_channels: 3      # Number of output channels of resizer (for RGB images it is 3)
  num_kernels: 16      # Same as `n` in paper
  num_resblocks: 2     # Same as 'r' in paper
  negative_slope: 0.2  # Used by leaky relu
  interpolate_mode: bilinear  # Passed to torch.nn.functional.interpolate

in_channels and out_channels specify the number of input channels to the resizer and the number of channels outputted by resizer respectively. In most scenarios, both these values should be same (and equal to 3 for RGB images).

Download datasets

Imagenette and Imagewoof datasets are used. You can learn more about the datasets at fastai/imagenette. The instructions to download and setup the data are provided below or you can use download_data.sh script to do all of this for you (./download_data.sh).

Imagenette

Download link or run the following commands from the root directory of this repo

wget https://s3.amazonaws.com/fast-ai-imageclas/imagenette2.tgz
tar -xzf imagenette2.tgz -C data/
rm imagenette2.tgz

Imagewoof

Download link or run the following commands from the root directory of this repo

wget https://s3.amazonaws.com/fast-ai-imageclas/imagewoof2.tgz
tar -xzf imagewoof2.tgz -C data/
rm imagewoof2.tgz

Reproducing experiments

The config files to reproduce the experiments are provided in config_files folder. Simply copy the config file to src/config.yaml and run python trainer.py

Note:- I trained the models on RTX 2080Ti, so you may have to adjust the batch size depending on your GPU.

An example of how to use the config files is shown below (from the root of this repo)

cd src

# For ResNet50 on Imagenette
mv ../config_files/imagenette_resnet50.yaml config.yaml
python trainer.py

# For ResNet50 + Resizer on Imagenette
mv ../config_files/imagenette_resnet50_resizer.yaml config.yaml
python trainer.py

# For ResNet50 on Imagewoof
mv ../config_files/imagewoof_resnet50.yaml config.yaml
python trainer.py

# For ResNet50 + Resizer on Imagewoof
mv ../config_files/imagewoof_resnet50_resizer.yaml config.yaml
python trainer.py

Repository structure

  • download_data.sh - Script to download Imagenette and Imagewoof datasets. Usage ./download_data.sh
  • src
    • config.yaml - The hydra config file to handle anything in the repository. All the options in the config file are well documented. Check the Details of Config file section for all the details about the config file.
    • data.py - Contains the code to create pytorch_lightning.LightningDataModule for the specified dataset.
    • models
      • resizer.py - Contains the implementation of the Resizer model proposed in the paper
      • base_model.py - It loads torchvision ResNet50 model which is used as the base model in this repo. You can specify your own base model here.
      • __init__.py - Provides a utility function get_model to load the above two models by providing the corresponding name (resizer, base_model)
    • model.py - Contains the code to create pytorch_lightning.LightningModule. This loads the above models and specifies the training/validation steps, optimizers, learning rate scheduler
    • trainer.py - The main python script that you should call to train your models. It reads the arguments from config.yaml and does the specified training, while saving all the outputs to outputs/{date}/{time} directory.

Requirements

  • Python = 3.8.8
  • hydra-core = 1.0.6
  • matplotlib = 3.4.2
  • pytorch = 1.8.1
  • torchvision = 0.9.1
  • pytorch-lightning = 1.3.3
  • torchmetrics = 0.3.2

License

Apache License 2.0

About

PyTorch implementation of the paper "Learning to Resize Images for Computer Vision Tasks" on Imagenette and Imagewoof datasets

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published