Skip to content

Commit

Permalink
feat: Added support of AMP to all PyTorch training scripts (#604)
Browse files Browse the repository at this point in the history
* refactor: Renamed the model argument to arch

* refactor: Removed epoch index logging

* feat: Added dataset hash logging

* feat: Added elements to the payload logged on W&B at each epoch

* refactor: Removed unused arguments

* feat: Cleaned sample plotting

* feat: Added AMP support to text detection training

* feat: Added AMP for pytorch trainings

* fix: Fixed compatibility with FP16
  • Loading branch information
fg-mindee authored Nov 10, 2021
1 parent 2335b6f commit 250a3cb
Show file tree
Hide file tree
Showing 11 changed files with 190 additions and 77 deletions.
2 changes: 1 addition & 1 deletion doctr/models/detection/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def __call__(

for p_, bitmap_ in zip(proba_map, bitmap):
# Perform opening (erosion + dilatation)
bitmap_ = cv2.morphologyEx(bitmap_, cv2.MORPH_OPEN, kernel)
bitmap_ = cv2.morphologyEx(bitmap_.astype(np.float32), cv2.MORPH_OPEN, kernel)
# Rotate bitmap and proba_map
angle = get_bitmap_angle(bitmap_)
angles_batch.append(angle)
Expand Down
2 changes: 1 addition & 1 deletion doctr/utils/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def rotate_image(

height, width = exp_img.shape[:2]
rot_mat = cv2.getRotationMatrix2D((width / 2, height / 2), angle, 1.0)
rot_img = cv2.warpAffine(exp_img, rot_mat, (width, height))
rot_img = cv2.warpAffine(exp_img.astype(np.float32), rot_mat, (width, height))
if expand:
# Pad to get the same aspect ratio
if (image.shape[0] / image.shape[1]) != (rot_img.shape[0] / rot_img.shape[1]):
Expand Down
50 changes: 35 additions & 15 deletions references/classification/train_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,11 @@
from doctr.datasets import VOCABS, CharacterGenerator


def fit_one_epoch(model, train_loader, batch_transforms, optimizer, scheduler, mb):
def fit_one_epoch(model, train_loader, batch_transforms, optimizer, scheduler, mb, amp=False):

if amp:
scaler = torch.cuda.amp.GradScaler()

model.train()
train_iter = iter(train_loader)
# Iterate over the batches of the dataset
Expand All @@ -36,28 +40,41 @@ def fit_one_epoch(model, train_loader, batch_transforms, optimizer, scheduler, m

images = batch_transforms(images)

out = model(images)
train_loss = cross_entropy(out, targets)

optimizer.zero_grad()
train_loss.backward()
optimizer.step()
if amp:
with torch.cuda.amp.autocast():
out = model(images)
train_loss = cross_entropy(out, targets)
scaler.scale(train_loss).backward()
# Update the params
scaler.step(optimizer)
scaler.update()
else:
out = model(images)
train_loss = cross_entropy(out, targets)
train_loss.backward()
optimizer.step()
scheduler.step()

mb.child.comment = f'Training loss: {train_loss.item():.6}'


@torch.no_grad()
def evaluate(model, val_loader, batch_transforms):
def evaluate(model, val_loader, batch_transforms, amp=False):
# Model in eval mode
model.eval()
# Validation loop
val_loss, correct, samples, batch_cnt = 0, 0, 0, 0
val_iter = iter(val_loader)
for images, targets in val_iter:
images = batch_transforms(images)
out = model(images)
loss = cross_entropy(out, targets)
if amp:
with torch.cuda.amp.autocast():
out = model(images)
loss = cross_entropy(out, targets)
else:
out = model(images)
loss = cross_entropy(out, targets)
# Compute metric
correct += (out.argmax(dim=1) == targets).sum().item()

Expand Down Expand Up @@ -104,7 +121,7 @@ def main(args):
batch_transforms = Normalize(mean=(0.694, 0.695, 0.693), std=(0.299, 0.296, 0.301))

# Load doctr model
model = models.__dict__[args.model](pretrained=args.pretrained, num_classes=len(vocab))
model = models.__dict__[args.arch](pretrained=args.pretrained, num_classes=len(vocab))

# Resume weights
if isinstance(args.resume, str):
Expand Down Expand Up @@ -163,7 +180,7 @@ def main(args):

# Training monitoring
current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
exp_name = f"{args.model}_{current_time}" if args.name is None else args.name
exp_name = f"{args.arch}_{current_time}" if args.name is None else args.name

# W&B
if args.wb:
Expand All @@ -174,12 +191,15 @@ def main(args):
config={
"learning_rate": args.lr,
"epochs": args.epochs,
"weight_decay": args.weight_decay,
"batch_size": args.batch_size,
"architecture": args.model,
"architecture": args.arch,
"input_size": args.input_size,
"optimizer": "adam",
"exp_type": "character-classification",
"framework": "pytorch",
"vocab": args.vocab,
"scheduler": args.sched,
"pretrained": args.pretrained,
}
)

Expand All @@ -200,7 +220,6 @@ def main(args):
# W&B
if args.wb:
wandb.log({
'epochs': epoch + 1,
'val_loss': val_loss,
'acc': acc,
})
Expand All @@ -214,7 +233,7 @@ def parse_args():
parser = argparse.ArgumentParser(description='DocTR training script for character classification (PyTorch)',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)

parser.add_argument('model', type=str, help='text-recognition model to train')
parser.add_argument('arch', type=str, help='text-recognition model to train')
parser.add_argument('--name', type=str, default=None, help='Name of your training experiment')
parser.add_argument('--epochs', type=int, default=10, help='number of epochs to train the model on')
parser.add_argument('-b', '--batch_size', type=int, default=64, help='batch size for training')
Expand Down Expand Up @@ -248,6 +267,7 @@ def parse_args():
parser.add_argument('--pretrained', dest='pretrained', action='store_true',
help='Load pretrained parameters before starting the training')
parser.add_argument('--sched', type=str, default='cosine', help='scheduler to use')
parser.add_argument("--amp", dest="amp", help="Use Automatic Mixed Precision", action="store_true")
args = parser.parse_args()

return args
Expand Down
14 changes: 8 additions & 6 deletions references/classification/train_tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def main(args):
f"{val_loader.num_batches} batches)")

# Load doctr model
model = backbones.__dict__[args.model](
model = backbones.__dict__[args.arch](
pretrained=args.pretrained,
input_shape=(args.input_size, args.input_size, 3),
num_classes=len(vocab),
Expand Down Expand Up @@ -170,7 +170,7 @@ def main(args):

# Tensorboard to monitor training
current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
exp_name = f"{args.model}_{current_time}" if args.name is None else args.name
exp_name = f"{args.arch}_{current_time}" if args.name is None else args.name

# W&B
if args.wb:
Expand All @@ -181,12 +181,15 @@ def main(args):
config={
"learning_rate": args.lr,
"epochs": args.epochs,
"weight_decay": args.weight_decay,
"batch_size": args.batch_size,
"architecture": args.model,
"architecture": args.arch,
"input_size": args.input_size,
"optimizer": "adam",
"exp_type": "character-classification",
"framework": "tensorflow",
"vocab": args.vocab,
"scheduler": args.sched,
"pretrained": args.pretrained,
}
)

Expand All @@ -208,7 +211,6 @@ def main(args):
# W&B
if args.wb:
wandb.log({
'epochs': epoch + 1,
'val_loss': val_loss,
'acc': acc,
})
Expand All @@ -222,7 +224,7 @@ def parse_args():
parser = argparse.ArgumentParser(description='DocTR training script for character classification (TensorFlow)',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)

parser.add_argument('model', type=str, help='text-recognition model to train')
parser.add_argument('arch', type=str, help='text-recognition model to train')
parser.add_argument('--name', type=str, default=None, help='Name of your training experiment')
parser.add_argument('--epochs', type=int, default=10, help='number of epochs to train the model on')
parser.add_argument('-b', '--batch_size', type=int, default=64, help='batch size for training')
Expand Down
15 changes: 10 additions & 5 deletions references/classification/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
def plot_samples(images, targets):
# Unnormalize image
num_samples = min(len(images), 12)
num_rows = min(len(images), 3)
num_cols = int(math.ceil(num_samples / num_rows))
num_cols = min(len(images), 8)
num_rows = int(math.ceil(num_samples / num_cols))
_, axes = plt.subplots(num_rows, num_cols, figsize=(20, 5))
for idx in range(num_samples):
img = (255 * images[idx].numpy()).round().clip(0, 255).astype(np.uint8)
Expand All @@ -23,7 +23,12 @@ def plot_samples(images, targets):
row_idx = idx // num_cols
col_idx = idx % num_cols

axes[row_idx][col_idx].imshow(img)
axes[row_idx][col_idx].axis('off')
axes[row_idx][col_idx].set_title(targets[idx])
ax = axes[row_idx] if num_rows > 1 else axes
ax = ax[col_idx] if num_cols > 1 else ax

ax.imshow(img)
ax.set_title(targets[idx])
# Disable axis
for ax in axes.ravel():
ax.axis('off')
plt.show()
64 changes: 47 additions & 17 deletions references/detection/train_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
os.environ['USE_TORCH'] = '1'

import datetime
import hashlib
import logging
import multiprocessing as mp
import time
Expand All @@ -28,7 +29,11 @@
from doctr.utils.metrics import LocalizationConfusion


def fit_one_epoch(model, train_loader, batch_transforms, optimizer, scheduler, mb):
def fit_one_epoch(model, train_loader, batch_transforms, optimizer, scheduler, mb, amp=False):

if amp:
scaler = torch.cuda.amp.GradScaler()

model.train()
train_iter = iter(train_loader)
# Iterate over the batches of the dataset
Expand All @@ -39,19 +44,30 @@ def fit_one_epoch(model, train_loader, batch_transforms, optimizer, scheduler, m
images = images.cuda()
images = batch_transforms(images)

train_loss = model(images, targets)['loss']

optimizer.zero_grad()
train_loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 5)
optimizer.step()
if amp:
with torch.cuda.amp.autocast():
train_loss = model(images, targets)['loss']
scaler.scale(train_loss).backward()
# Gradient clipping
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), 5)
# Update the params
scaler.step(optimizer)
scaler.update()
else:
train_loss = model(images, targets)['loss']
train_loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 5)
optimizer.step()

scheduler.step()

mb.child.comment = f'Training loss: {train_loss.item():.6}'


@torch.no_grad()
def evaluate(model, val_loader, batch_transforms, val_metric):
def evaluate(model, val_loader, batch_transforms, val_metric, amp=False):
# Model in eval mode
model.eval()
# Reset val metric
Expand All @@ -63,7 +79,11 @@ def evaluate(model, val_loader, batch_transforms, val_metric):
if torch.cuda.is_available():
images = images.cuda()
images = batch_transforms(images)
out = model(images, targets, return_boxes=True)
if amp:
with torch.cuda.amp.autocast():
out = model(images, targets, return_boxes=True)
else:
out = model(images, targets, return_boxes=True)
# Compute metric
loc_preds, _ = out['preds']
for boxes_gt, boxes_pred in zip(targets, loc_preds):
Expand Down Expand Up @@ -105,11 +125,13 @@ def main(args):
)
print(f"Validation set loaded in {time.time() - st:.4}s ({len(val_set)} samples in "
f"{len(val_loader)} batches)")
with open(os.path.join(args.val_path, 'labels.json'), 'rb') as f:
val_hash = hashlib.sha256(f.read()).hexdigest()

batch_transforms = Normalize(mean=(0.798, 0.785, 0.772), std=(0.264, 0.2749, 0.287))

# Load doctr model
model = detection.__dict__[args.model](pretrained=args.pretrained)
model = detection.__dict__[args.arch](pretrained=args.pretrained)

# Resume weights
if isinstance(args.resume, str):
Expand Down Expand Up @@ -167,10 +189,12 @@ def main(args):
)
print(f"Train set loaded in {time.time() - st:.4}s ({len(train_set)} samples in "
f"{len(train_loader)} batches)")
with open(os.path.join(args.train_path, 'labels.json'), 'rb') as f:
train_hash = hashlib.sha256(f.read()).hexdigest()

if args.show_samples:
x, target = next(iter(train_loader))
plot_samples(x, target, rotation=args.rotation)
plot_samples(x, target)
return

# Backbone freezing
Expand All @@ -190,7 +214,7 @@ def main(args):

# Training monitoring
current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
exp_name = f"{args.model}_{current_time}" if args.name is None else args.name
exp_name = f"{args.arch}_{current_time}" if args.name is None else args.name

# W&B
if args.wb:
Expand All @@ -201,12 +225,18 @@ def main(args):
config={
"learning_rate": args.lr,
"epochs": args.epochs,
"weight_decay": args.weight_decay,
"batch_size": args.batch_size,
"architecture": args.model,
"architecture": args.arch,
"input_size": args.input_size,
"optimizer": "adam",
"exp_type": "text-detection",
"framework": "pytorch",
"scheduler": args.sched,
"train_hash": train_hash,
"val_hash": val_hash,
"pretrained": args.pretrained,
"rotation": args.rotation,
"amp": args.amp,
}
)

Expand All @@ -216,9 +246,9 @@ def main(args):
# Training loop
mb = master_bar(range(args.epochs))
for epoch in mb:
fit_one_epoch(model, train_loader, batch_transforms, optimizer, scheduler, mb)
fit_one_epoch(model, train_loader, batch_transforms, optimizer, scheduler, mb, amp=args.amp)
# Validation loop at the end of each epoch
val_loss, recall, precision, mean_iou = evaluate(model, val_loader, batch_transforms, val_metric)
val_loss, recall, precision, mean_iou = evaluate(model, val_loader, batch_transforms, val_metric, amp=args.amp)
if val_loss < min_loss:
print(f"Validation loss decreased {min_loss:.6} --> {val_loss:.6}: saving state...")
torch.save(model.state_dict(), f"./{exp_name}.pt")
Expand All @@ -231,7 +261,6 @@ def main(args):
# W&B
if args.wb:
wandb.log({
'epochs': epoch + 1,
'val_loss': val_loss,
'recall': recall,
'precision': precision,
Expand All @@ -251,7 +280,7 @@ def parse_args():

parser.add_argument('train_path', type=str, help='path to training data folder')
parser.add_argument('val_path', type=str, help='path to validation data folder')
parser.add_argument('model', type=str, help='text-detection model to train')
parser.add_argument('arch', type=str, help='text-detection model to train')
parser.add_argument('--name', type=str, default=None, help='Name of your training experiment')
parser.add_argument('--epochs', type=int, default=10, help='number of epochs to train the model on')
parser.add_argument('-b', '--batch_size', type=int, default=2, help='batch size for training')
Expand All @@ -273,6 +302,7 @@ def parse_args():
parser.add_argument('--rotation', dest='rotation', action='store_true',
help='train with rotated bbox')
parser.add_argument('--sched', type=str, default='cosine', help='scheduler to use')
parser.add_argument("--amp", dest="amp", help="Use Automatic Mixed Precision", action="store_true")
args = parser.parse_args()

return args
Expand Down
Loading

0 comments on commit 250a3cb

Please sign in to comment.