Skip to content

A plug 'n' play CTAugment wrapper that can be used with classification tasks using pytorch framework.

License

Notifications You must be signed in to change notification settings

shreejalt/pytorch-models-ctaugment

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

4 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

CTAugment Wrapper for PyTorch

This repository contains the unofficial implementation of the CTAugment Wrapper that can be used with any classification task in PyTorch.

  • The implementation of the CTAugment readily with the present training pipelines is bit difficult, as we may have to change and add the piece of code to support it. On the other hand, this repository will come handy if you directly want to use the CTAugment with any existing image classification training repositories and codebase with just minimal changes.

  • CTAugment stands for Control Theory Augment which was introduced in the ReMixMatch Paper.

  • It basically weighs the power of an augmentation to be applied on an image by learning through the error(output of the model and real label).

More details can be obtained from the Section 3.2.2 of the paper.

Structure of ctaugment.py file

The ctaugment.py contains two main classes i.e CTAugment and CTUpdater.

  • CTAugment contains two necessary functions. __call__ is used when CTAugment object is used to apply transform on an image.

    __probe__ is used to change the weights of the bins of CTAugment. It expects the threshold, decay, and depth parameters. By default the parameters are 0.8, 0.99, and 2 respectively.

  • CTUpdater contains the probe dataset which will be used to update the weights of the bins throughout the training. It expects the path to the training data datapath and train_transforms which we will be using for training the model.

How to use the CTAugment and CTUpdater

# Step 1: Add the CTAugment module to the train_transforms list.

from ctaugment import CTAugment, CTUpdater

train_transforms = [
     transforms.RandomResizedCrop((size, size)),
     transforms.RandomHorizontalFlip(),
     CTAugment(depth=depth, thresh=thresh, decay=decay).
]

train_transforms = transforms.Compose(train_transforms)

# Step 2: Initialize the CTUpdater module

updater = CTUpdater(datapath=datapath, 
          train_transforms=train_transforms, 
          batch_size=probe_batch_size, 
          num_workers=num_workers, 
          gpu=$(gpu ID))

NOTE: Keep the probe_batch_size >> train_batch_size. 
In this example it is 3 times the train_batch_size. 
datapath consist of the path to the training data(ImageFolder like structure).

root/dog/xxx.png
root/dog/xxy.png
root/dog/[...]/xxz.png

root/cat/123.png
root/cat/nsdf3.png
root/cat/[...]/asd932_.png

# Step 3: Update the weights of the CTAugment.

output = model(input)

loss = loss_fn(output, target)
loss.backward()
optimizer.step()
.
.
.
updater.update(model)

NOTE: Update the weights by calling the update() function of the 

CTUpdater object after the end of every iteration(For simplicity after optimizer.step()). 

You need to pass the model to calculate the error rate on the probe dataset. 

Running the example script

I have attached main.py to test and demonstrate the CTAugment using pytorch models. The script is taken from PyTorch Examples Repository. You can visit the given link for more information on the script usage


Installing the dependencies.

The script requires python >= 3.6

Other requirements can be installed by running the command pip3 install -r requirements.txt


Structure of the script

PyTorch Training with support of CTAugment Dataset

positional arguments:
  DIR                   path to dataset

optional arguments:
  -h, --help            show this help message and exit
  --use-weighted        Use WeightedRandomSampler
  -a ARCH, --arch ARCH  model architecture: alexnet | densenet121 | densenet161 | densenet169 | densenet201 | googlenet | inception_v3 | mnasnet0_5 | mnasnet0_75 | mnasnet1_0 |
                        mnasnet1_3 | mobilenet_v2 | mobilenet_v3_large | mobilenet_v3_small | resnet101 | resnet152 | resnet18 | resnet34 | resnet50 | resnext101_32x8d |
                        resnext50_32x4d | shufflenet_v2_x0_5 | shufflenet_v2_x1_0 | shufflenet_v2_x1_5 | shufflenet_v2_x2_0 | squeezenet1_0 | squeezenet1_1 | vgg11 | vgg11_bn | vgg13
                        | vgg13_bn | vgg16 | vgg16_bn | vgg19 | vgg19_bn | wide_resnet101_2 | wide_resnet50_2 (default: resnet18)
  -j N, --workers N     number of data loading workers (default: 4)
  -nc N, --num-classes N
                        number of classes
  --data-percent N      percentage of training data to take for training
  --epochs N            number of total epochs to run
  --output-dir V        directory to store output weights and logs
  --size N              image size
  --mu-ratio N          multiplicative ratio for ct augment probe loader
  --min-step-lr N       minimum for step lr
  --max-step-lr N       maxiumum for step lr
  --save-every S        save checkpoints every N epochs
  --rand-depth RAND_DEPTH
                        depth of RandAugment
  --rand-magnitude RAND_MAGNITUDE
                        magnitude of RandAugment
  --ct-depth CT_DEPTH   depth of CT Augment
  --ct-decay CT_DECAY   decay of CT Augment
  --ct-thresh CT_THRESH
                        thresh of CT Augment
  --start-epoch N       manual epoch number (useful on restarts)
  -b N, --batch-size N  batch size for training/testing
  --lr LR, --learning-rate LR
                        initial learning rate
  --momentum M          momentum
  --wd W, --weight-decay W
                        weight decay (default: 1e-4)
  -p N, --print-freq N  print frequency (default: 10)
  --resume PATH         path to latest checkpoint (default: none)
  -e, --evaluate        evaluate model on validation set
  --pretrained          use pre-trained model
  --use-scheduler       Flag to use the scheduler during the training
  --use-cosine          use Cosine Scheduler
  --use-ct              use CTAugment strategy
  --no-update-ct        Flag that will disable to update the CTAug.
  --seed SEED           seed for initializing training.
  --gpu GPU             GPU id to use.

  • Enable --use-ct flag to train the model with CTAugment. By default the script takes RandAugment strategy during the training.

  • Parameters of the CTAugment class can be changed using the flags ct-depth, --ct-decay, and --ct-depth.

  • By default, the CTAugments weights are saved as a state_dict() in the model checkpoints file. If you want to resumme the training, just use --resume and path to the checkpoint. Note that the weights will only be loaded if the resume path contains the CT checkpoints.

  • You can also use --no-update-ct to stop the update of the weights during the training.

  • Demo Command: python3 main.py ${path/to/dataset} --use-cosine --lr 0.01 --gpu 0 -b 64 --size 32 --use-ct --use-scheduler --arch resnet18.


References and Citations

@article{berthelot2019remixmatch,
    title={ReMixMatch: Semi-Supervised Learning with Distribution Alignment and Augmentation Anchoring},
    author={David Berthelot and Nicholas Carlini and Ekin D. Cubuk and Alex Kurakin and Kihyuk Sohn and Han Zhang and Colin Raffel},
    journal={arXiv preprint arXiv:1911.09785},
    year={2019},
}

Feel free to submit PR if you find any changes or a better approach to implement CTAugment.

About

A plug 'n' play CTAugment wrapper that can be used with classification tasks using pytorch framework.

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages