Description
Discussed in #1442
Originally posted by Hravan September 1, 2022
I'm setting up a training for this kaggle competition dataset: https://www.kaggle.com/competitions/plant-pathology-2021-fgvc8
(I'm using here only samples with single labels to make the problem simpler)
The problem is that the ImageClassificationData takes too much RAM and GPU is underutilized. I wrote the code in plain PyTorch for comparison to confirm that the problem is somewhere within ImageClassificationData.
Code shared by both training versions:
import pandas as pd
from skimage import io
from sklearn.preprocessing import OneHotEncoder
import torch
from torch.utils.data import Dataset
from torchvision import transforms as T
class PlantDataset(Dataset):
def __init__(self, df, transform=None) -> None:
super().__init__()
self.img_paths = df["image"].tolist()
self.transform = transform
self.encoder = OneHotEncoder()
self.labels = (
self.encoder.fit_transform(df["label"].values.reshape(-1, 1))
.todense()
.A
)
def __len__(self):
return len(self.img_paths)
def __getitem__(self, idx):
img = io.imread(self.img_paths[idx])
if self.transform is not None:
img = self.transform(img)
label = self.labels[idx]
# return {
# "input": img,
# "target": torch.tensor(label, dtype=torch.uint8),
# }
return img, torch.tensor(label, dtype=torch.float32)
def preprocess_df(csv_path, images_root):
df = pd.read_csv(csv_path)
df = df[~df["labels"].str.contains(" ")]
df["image"] = images_root + df["image"]
df = df.rename(columns={"labels": "label"})
return df
def split_df(df, train_pct):
df = df.sample(frac=1)
n_train = int(train_pct * len(df))
train_df = df.iloc[:n_train].reset_index()
val_df = df.iloc[n_train:].reset_index()
return train_df, val_df
def create_dataloader(df):
train_compose = T.Compose(
[
T.ToPILImage(),
T.Resize((224, 224)),
T.RandomHorizontalFlip(),
T.ToTensor(),
T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
]
)
dataloader = torch.utils.data.DataLoader(
PlantDataset(df, transform=train_compose),
batch_size=32,
num_workers=8,
prefetch_factor=8,
)
return dataloader
Training in plain PyTorch:
def train(model, data_loader, n_epochs):
model = model.cuda()
optimizer = torch.optim.Adam(model.parameters())
loss_fn = torch.nn.CrossEntropyLoss()
for i in range(n_epochs):
for images, labels in tqdm.tqdm(data_loader):
images = images.cuda()
preds = model(images)
loss = loss_fn(preds, labels.cuda())
loss.backward()
optimizer.step()
optimizer.zero_grad()
print(f"End of epoch {i}")
def main():
arg_parser = argparse.ArgumentParser()
arg_parser.add_argument("csv_path")
arg_parser.add_argument("images_root")
args = arg_parser.parse_args()
model = torchvision.models.resnet18()
model.fc = torch.nn.Linear(512, 6)
df = preprocess_df(args.csv_path, args.images_root)
train_df, val_df = split_df(df, 0.1)
train_loader = create_dataloader(train_df)
time0 = perf_counter()
train(model, train_loader, 2)
print(f"Time elapsed: {perf_counter() - time0}")
if __name__ == "__main__":
main()
Training in Lightning Flash:
class Resnet18(pl.LightningModule):
def __init__(self):
super().__init__()
self.model = torchvision.models.resnet18()
self.model.fc = torch.nn.Linear(512, 6)
self.loss_fn = torch.nn.CrossEntropyLoss()
def training_step(self, batch, batch_idx):
x, y = batch["input"], batch["target"]
y_hat = self.model(x)
loss = self.loss_fn(y_hat, y)
return loss
def configure_optimizers(self):
return torch.optim.Adam(self.model.parameters())
def main():
arg_parser = argparse.ArgumentParser()
arg_parser.add_argument("csv_path")
arg_parser.add_argument("images_root")
args = arg_parser.parse_args()
model = Resnet18()
df = preprocess_df(args.csv_path, args.images_root)
train_df, val_df = split_df(df, 0.1)
datamodule = ImageClassificationData.from_data_frame(
"image",
"label",
train_data_frame=train_df,
batch_size=32,
transform_kwargs=dict(image_size=(224, 224)),
num_workers=8,
persistent_workers=True,
pin_memory=False,
)
time0 = perf_counter()
trainer = flash.Trainer(max_epochs=2, gpus=torch.cuda.device_count())
trainer.fit(model, datamodule=datamodule)
print(f"Time elapsed: {perf_counter() - time0}")
if __name__ == "__main__":
main()
When I increase bach_size to 64 or num_workers to 16 in ImageClassificationData, I start having problems with RAM, which does not happen for the plain PyTorch version. Any ideas what might be the problem? I tried profiling, but didn't get to any sensible conclusion, except that I bet the problem is in BaseDataFetcher in DataModule.