Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Too much RAM usage by ImageClassificationData #1450

Open
@ethanwharris

Description

Discussed in #1442

Originally posted by Hravan September 1, 2022
I'm setting up a training for this kaggle competition dataset: https://www.kaggle.com/competitions/plant-pathology-2021-fgvc8
(I'm using here only samples with single labels to make the problem simpler)

The problem is that the ImageClassificationData takes too much RAM and GPU is underutilized. I wrote the code in plain PyTorch for comparison to confirm that the problem is somewhere within ImageClassificationData.

Code shared by both training versions:

import pandas as pd
from skimage import io
from sklearn.preprocessing import OneHotEncoder
import torch
from torch.utils.data import Dataset
from torchvision import transforms as T


class PlantDataset(Dataset):
    def __init__(self, df, transform=None) -> None:
        super().__init__()
        self.img_paths = df["image"].tolist()
        self.transform = transform
        self.encoder = OneHotEncoder()
        self.labels = (
            self.encoder.fit_transform(df["label"].values.reshape(-1, 1))
            .todense()
            .A
        )

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, idx):
        img = io.imread(self.img_paths[idx])
        if self.transform is not None:
            img = self.transform(img)
        label = self.labels[idx]
        # return {
        #    "input": img,
        #    "target": torch.tensor(label, dtype=torch.uint8),
        # }
        return img, torch.tensor(label, dtype=torch.float32)


def preprocess_df(csv_path, images_root):
    df = pd.read_csv(csv_path)
    df = df[~df["labels"].str.contains(" ")]
    df["image"] = images_root + df["image"]
    df = df.rename(columns={"labels": "label"})
    return df


def split_df(df, train_pct):
    df = df.sample(frac=1)
    n_train = int(train_pct * len(df))
    train_df = df.iloc[:n_train].reset_index()
    val_df = df.iloc[n_train:].reset_index()
    return train_df, val_df


def create_dataloader(df):
    train_compose = T.Compose(
        [
            T.ToPILImage(),
            T.Resize((224, 224)),
            T.RandomHorizontalFlip(),
            T.ToTensor(),
            T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
        ]
    )
    dataloader = torch.utils.data.DataLoader(
        PlantDataset(df, transform=train_compose),
        batch_size=32,
        num_workers=8,
        prefetch_factor=8,
    )
    return dataloader

Training in plain PyTorch:

def train(model, data_loader, n_epochs):
    model = model.cuda()
    optimizer = torch.optim.Adam(model.parameters())
    loss_fn = torch.nn.CrossEntropyLoss()

    for i in range(n_epochs):
        for images, labels in tqdm.tqdm(data_loader):
            images = images.cuda()
            preds = model(images)
            loss = loss_fn(preds, labels.cuda())
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
        print(f"End of epoch {i}")


def main():
    arg_parser = argparse.ArgumentParser()
    arg_parser.add_argument("csv_path")
    arg_parser.add_argument("images_root")
    args = arg_parser.parse_args()

    model = torchvision.models.resnet18()
    model.fc = torch.nn.Linear(512, 6)

    df = preprocess_df(args.csv_path, args.images_root)
    train_df, val_df = split_df(df, 0.1)
    train_loader = create_dataloader(train_df)
    time0 = perf_counter()
    train(model, train_loader, 2)
    print(f"Time elapsed: {perf_counter() - time0}")


if __name__ == "__main__":
    main()

Training in Lightning Flash:

class Resnet18(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = torchvision.models.resnet18()
        self.model.fc = torch.nn.Linear(512, 6)
        self.loss_fn = torch.nn.CrossEntropyLoss()

    def training_step(self, batch, batch_idx):
        x, y = batch["input"], batch["target"]
        y_hat = self.model(x)
        loss = self.loss_fn(y_hat, y)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.model.parameters())


def main():
    arg_parser = argparse.ArgumentParser()
    arg_parser.add_argument("csv_path")
    arg_parser.add_argument("images_root")
    args = arg_parser.parse_args()

    model = Resnet18()
    df = preprocess_df(args.csv_path, args.images_root)
    train_df, val_df = split_df(df, 0.1)
    datamodule = ImageClassificationData.from_data_frame(
        "image",
        "label",
        train_data_frame=train_df,
        batch_size=32,
        transform_kwargs=dict(image_size=(224, 224)),
        num_workers=8,
        persistent_workers=True,
        pin_memory=False,
    )

    time0 = perf_counter()
    trainer = flash.Trainer(max_epochs=2, gpus=torch.cuda.device_count())
    trainer.fit(model, datamodule=datamodule)
    print(f"Time elapsed: {perf_counter() - time0}")


if __name__ == "__main__":
    main()

When I increase bach_size to 64 or num_workers to 16 in ImageClassificationData, I start having problems with RAM, which does not happen for the plain PyTorch version. Any ideas what might be the problem? I tried profiling, but didn't get to any sensible conclusion, except that I bet the problem is in BaseDataFetcher in DataModule.

Metadata

Assignees

No one assigned

    Labels

    bug / fixSomething isn't workinghelp wantedExtra attention is needed

    Type

    No type

    Projects

    No projects

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions