Training Lennonnet in pytorch
The code below is what I used to load the image classification model, modify it to my needs and train it. Note that this is quite case and machine-specific, so I would not recommend to copy-paste and run it just like that.
import timm
import torch
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
from PIL import Image
import numpy as np
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
import torch.nn.functional as F
import pytorch_lightning as pl
from torchvision.datasets import ImageFolder
# Some settings for the sheet:
training_dir = 'images/training_john'
class PersonClassifier(pl.LightningModule):
def __init__(self,lr=1e-3,n_classes=2):
"""Define our Model as a subclass of a LightningModule from
Pytorch-Lightning.
Load the model from the timm library. RESNET50D seemed pretty good.
"""
super().__init__()
self.lr = lr
self.save_hyperparameters()
self.model = timm.create_model('resnet50d',
pretrained=False,
num_classes=n_classes)
def configure_optimizers(self):
opt= torch.optim.RMSprop(self.model.parameters(), self.lr)
return opt
def training_step(self, dl, idx):
x,y = dl
z = self.model(x)
loss = F.cross_entropy(z,y)
self.log('train_loss',loss)
return loss
def validation_step(self, dl, idx):
x,y = dl
z = self.model(x)
loss = F.cross_entropy(z,y)
self.log('val_loss',loss)
return loss
def forward(self, xs):
return self.model(xs) # we like to just call the model's forward method
model = PersonClassifier()
# optional: use Tensor Board as a logger to check the progress of the training:
tb_logger = pl.loggers.TensorBoardLogger(save_dir="runs", log_graph=True)
# set up a checkpoint callback so we can keep the best-performing version:
ckpt_callback = pl.callbacks.ModelCheckpoint(dirpath="lennonnet_checkpoints",
save_top_k=3,
monitor="val_loss")
# define the trainer (this can be very machine-specific!):
trainer = pl.Trainer(max_epochs=32,
logger=tb_logger,
accelerator="mps",
devices=1,
log_every_n_steps=5,
callbacks=[ckpt_callback])
# obtain the transforms of the data that are used in the model:
config = resolve_data_config({}, model=model)
transform = create_transform(**config)
# Load the training data into DataLoader objects so they can be fed to the trainer:
imgs= ImageFolder(training_dir, transform=transform)
imgs_dl= DataLoader(imgs, batch_size=32)
dataset_size = len(imgs)
dataset_indices = list(range(dataset_size))
np.random.shuffle(dataset_indices)
# use 10% of the data for validation during training:
val_split_index = int(np.floor(0.1 * dataset_size))
train_idx, val_idx = (dataset_indices[val_split_index:],
dataset_indices[:val_split_index])
train_sampler = SubsetRandomSampler(train_idx)
val_sampler = SubsetRandomSampler(val_idx)
train_imgs = ImageFolder(training_dir, transform=transform)
train_dl = DataLoader(train_imgs,
batch_size=32,
sampler=train_sampler,num_workers=4)
val_imgs = ImageFolder(training_dir, transform=transform)
val_dl = DataLoader(val_imgs,
batch_size=32,
sampler=val_sampler,
num_workers=4)
mapping = val_imgs.class_to_idx
model.mapping = mapping
model.transform = transform
print("mapping:", mapping)
print("Dataset size:", dataset_size)
print("Training set size:", len(train_idx))
print("Validation set size:", len(val_idx))
print("Training number of batches:", len(train_dl))
print("Validation number of batches:", len(val_dl))
# Run the actual training:
trainer.fit(model,train_dl,val_dl)