Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Too much RAM usage by ImageClassificationData #1450

Open
@ethanwharris

Description

@ethanwharris

Discussed in #1442

Originally posted by Hravan September 1, 2022
I'm setting up a training for this kaggle competition dataset: https://www.kaggle.com/competitions/plant-pathology-2021-fgvc8
(I'm using here only samples with single labels to make the problem simpler)

The problem is that the ImageClassificationData takes too much RAM and GPU is underutilized. I wrote the code in plain PyTorch for comparison to confirm that the problem is somewhere within ImageClassificationData.

Code shared by both training versions:

import pandas as pd
from skimage import io
from sklearn.preprocessing import OneHotEncoder
import torch
from torch.utils.data import Dataset
from torchvision import transforms as T


class PlantDataset(Dataset):
    def __init__(self, df, transform=None) -> None:
        super().__init__()
        self.img_paths = df["image"].tolist()
        self.transform = transform
        self.encoder = OneHotEncoder()
        self.labels = (
            self.encoder.fit_transform(df["label"].values.reshape(-1, 1))
            .todense()
            .A
        )

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, idx):
        img = io.imread(self.img_paths[idx])
        if self.transform is not None:
            img = self.transform(img)
        label = self.labels[idx]
        # return {
        #    "input": img,
        #    "target": torch.tensor(label, dtype=torch.uint8),
        # }
        return img, torch.tensor(label, dtype=torch.float32)


def preprocess_df(csv_path, images_root):
    df = pd.read_csv(csv_path)
    df = df[~df["labels"].str.contains(" ")]
    df["image"] = images_root + df["image"]
    df = df.rename(columns={"labels": "label"})
    return df


def split_df(df, train_pct):
    df = df.sample(frac=1)
    n_train = int(train_pct * len(df))
    train_df = df.iloc[:n_train].reset_index()
    val_df = df.iloc[n_train:].reset_index()
    return train_df, val_df


def create_dataloader(df):
    train_compose = T.Compose(
        [
            T.ToPILImage(),
            T.Resize((224, 224)),
            T.RandomHorizontalFlip(),
            T.ToTensor(),
            T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
        ]
    )
    dataloader = torch.utils.data.DataLoader(
        PlantDataset(df, transform=train_compose),
        batch_size=32,
        num_workers=8,
        prefetch_factor=8,
    )
    return dataloader

Training in plain PyTorch:

def train(model, data_loader, n_epochs):
    model = model.cuda()
    optimizer = torch.optim.Adam(model.parameters())
    loss_fn = torch.nn.CrossEntropyLoss()

    for i in range(n_epochs):
        for images, labels in tqdm.tqdm(data_loader):
            images = images.cuda()
            preds = model(images)
            loss = loss_fn(preds, labels.cuda())
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
        print(f"End of epoch {i}")


def main():
    arg_parser = argparse.ArgumentParser()
    arg_parser.add_argument("csv_path")
    arg_parser.add_argument("images_root")
    args = arg_parser.parse_args()

    model = torchvision.models.resnet18()
    model.fc = torch.nn.Linear(512, 6)

    df = preprocess_df(args.csv_path, args.images_root)
    train_df, val_df = split_df(df, 0.1)
    train_loader = create_dataloader(train_df)
    time0 = perf_counter()
    train(model, train_loader, 2)
    print(f"Time elapsed: {perf_counter() - time0}")


if __name__ == "__main__":
    main()

Training in Lightning Flash:

class Resnet18(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = torchvision.models.resnet18()
        self.model.fc = torch.nn.Linear(512, 6)
        self.loss_fn = torch.nn.CrossEntropyLoss()

    def training_step(self, batch, batch_idx):
        x, y = batch["input"], batch["target"]
        y_hat = self.model(x)
        loss = self.loss_fn(y_hat, y)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.model.parameters())


def main():
    arg_parser = argparse.ArgumentParser()
    arg_parser.add_argument("csv_path")
    arg_parser.add_argument("images_root")
    args = arg_parser.parse_args()

    model = Resnet18()
    df = preprocess_df(args.csv_path, args.images_root)
    train_df, val_df = split_df(df, 0.1)
    datamodule = ImageClassificationData.from_data_frame(
        "image",
        "label",
        train_data_frame=train_df,
        batch_size=32,
        transform_kwargs=dict(image_size=(224, 224)),
        num_workers=8,
        persistent_workers=True,
        pin_memory=False,
    )

    time0 = perf_counter()
    trainer = flash.Trainer(max_epochs=2, gpus=torch.cuda.device_count())
    trainer.fit(model, datamodule=datamodule)
    print(f"Time elapsed: {perf_counter() - time0}")


if __name__ == "__main__":
    main()

When I increase bach_size to 64 or num_workers to 16 in ImageClassificationData, I start having problems with RAM, which does not happen for the plain PyTorch version. Any ideas what might be the problem? I tried profiling, but didn't get to any sensible conclusion, except that I bet the problem is in BaseDataFetcher in DataModule.

Activity

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Metadata

Metadata

Assignees

No one assigned

    Labels

    bug / fixSomething isn't workinghelp wantedExtra attention is needed

    Type

    No type

    Projects

    No projects

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions