Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] PARSeq pytorch fixes #1227

Merged
merged 6 commits into from
Jun 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 45 additions & 73 deletions doctr/models/recognition/parseq/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import math
from copy import deepcopy
from itertools import permutations
from typing import Any, Callable, Dict, List, Optional, Tuple
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import numpy as np
import torch
Expand Down Expand Up @@ -131,7 +131,7 @@ def __init__(
max_length: int = 32, # different from the paper
dropout_prob: float = 0.1,
dec_num_heads: int = 12,
dec_ff_dim: int = 2048,
dec_ff_dim: int = 384, # we use it from the original implementation instead of 2048
dec_ffd_ratio: int = 4,
input_shape: Tuple[int, int, int] = (3, 32, 128),
exportable: bool = False,
Expand Down Expand Up @@ -212,10 +212,7 @@ def generate_permutations(self, seqlen: torch.Tensor) -> torch.Tensor:
combined = torch.cat([sos_idx, final_perms + 1, eos_idx], dim=1).int()
if len(combined) > 1:
combined[1, 1:] = max_num_chars + 1 - torch.arange(max_num_chars + 1, device=seqlen.device)
# we pad to max length with eos idx to fit the mask generation
return F.pad(
combined, (0, self.max_length + 1 - combined.shape[-1]), value=max_num_chars + 1
) # (num_perms, self.max_length + 1)
return combined

def generate_permutations_attention_masks(self, permutation: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
# Generate source and target mask for the decoder attention.
Expand Down Expand Up @@ -251,23 +248,23 @@ def decode(
target_query = self.dropout(target_query)
return self.decoder(target_query, content, memory, target_mask)

def decode_autoregressive(self, features: torch.Tensor) -> torch.Tensor:
def decode_autoregressive(self, features: torch.Tensor, max_len: Optional[int] = None) -> torch.Tensor:
"""Generate predictions for the given features."""
max_length = max_len if max_len is not None else self.max_length
max_length = min(max_length, self.max_length) + 1
# Padding symbol + SOS at the beginning
ys = torch.full(
(features.size(0), self.max_length), self.vocab_size + 2, dtype=torch.long, device=features.device
(features.size(0), max_length), self.vocab_size + 2, dtype=torch.long, device=features.device
) # pad
ys[:, 0] = self.vocab_size + 1 # SOS token
pos_queries = self.pos_queries[:, : self.max_length + 1].expand(features.size(0), -1, -1)
pos_queries = self.pos_queries[:, :max_length].expand(features.size(0), -1, -1)
# Create query mask for the decoder attention
query_mask = (
torch.tril(torch.ones((self.max_length + 1, self.max_length + 1), device=features.device), diagonal=0).to(
dtype=torch.bool
)
torch.tril(torch.ones((max_length, max_length), device=features.device), diagonal=0).to(dtype=torch.bool)
).int()

pos_logits = []
for i in range(self.max_length):
for i in range(max_length):
# Decode one token at a time without providing information about the future tokens
tgt_out = self.decode(
ys[:, : i + 1],
Expand All @@ -278,23 +275,19 @@ def decode_autoregressive(self, features: torch.Tensor) -> torch.Tensor:
pos_prob = self.head(tgt_out)
pos_logits.append(pos_prob)

if i + 1 < self.max_length:
if i + 1 < max_length:
# Update with the next token
ys[:, i + 1] = pos_prob.squeeze().argmax(-1)

# Stop decoding if all sequences have reached the EOS token
if (ys == self.vocab_size).any(dim=-1).all():
if max_len is None and (ys == self.vocab_size).any(dim=-1).all():
break

logits = torch.cat(pos_logits, dim=1) # (N, max_length, vocab_size + 1)

# One refine iteration
# Update query mask
query_mask[
torch.triu(
torch.ones(self.max_length + 1, self.max_length + 1, dtype=torch.bool, device=features.device), 2
)
] = 1
query_mask[torch.triu(torch.ones(max_length, max_length, dtype=torch.bool, device=features.device), 2)] = 1

# Prepare target input for 1 refine iteration
sos = torch.full((features.size(0), 1), self.vocab_size + 1, dtype=torch.long, device=features.device)
Expand All @@ -308,12 +301,6 @@ def decode_autoregressive(self, features: torch.Tensor) -> torch.Tensor:

return logits # (N, max_length, vocab_size + 1)

def decode_non_autoregressive(self, features: torch.Tensor) -> torch.Tensor:
"""Decode the given features at once"""
pos_queries = self.pos_queries[:, : self.max_length + 1].expand(features.size(0), -1, -1)
ys = torch.full((features.shape[0], 1), self.vocab_size + 1, dtype=torch.long, device=features.device)
return self.head(self.decode(ys, features, target_query=pos_queries))[:, : self.max_length]

def forward(
self,
x: torch.Tensor,
Expand All @@ -332,26 +319,46 @@ def forward(
# Build target tensor
_gt, _seq_len = self.build_target(target)
gt, seq_len = torch.from_numpy(_gt).to(dtype=torch.long).to(x.device), torch.tensor(_seq_len).to(x.device)
gt = gt[:, : int(seq_len.max().item()) + 2] # slice up to the max length of the batch + 2 (SOS + EOS)

if self.training:
# Generate permutations for the target sequences
tgt_perms = self.generate_permutations(seq_len)

gt_in = gt[:, :-1] # remove EOS token from longest target sequence
gt_out = gt[:, 1:] # remove SOS token
# Create padding mask for target input
# [True, True, True, ..., False, False, False] -> False is masked
padding_mask = (
((gt != self.vocab_size + 2) | (gt != self.vocab_size)).unsqueeze(1).unsqueeze(1)
) # (N, 1, 1, max_length)

for perm in tgt_perms:
# Generate attention masks for the permutations
_, target_mask = self.generate_permutations_attention_masks(perm)
# combine target padding mask and query mask
mask = (target_mask & padding_mask).int()
logits = self.head(self.decode(gt, features, mask)) # (N, max_length, vocab_size + 1)
~(((gt_in == self.vocab_size + 2) | (gt_in == self.vocab_size)).int().cumsum(-1) > 0)
.unsqueeze(1)
.unsqueeze(1)
) # (N, 1, 1, seq_len)

loss = torch.tensor(0.0, device=features.device)
loss_numel: Union[int, float] = 0
n = (gt_out != self.vocab_size + 2).sum().item()
for i, perm in enumerate(tgt_perms):
_, target_mask = self.generate_permutations_attention_masks(perm) # (seq_len, seq_len)
# combine both masks
mask = (target_mask.bool() & padding_mask.bool()).int() # (N, 1, seq_len, seq_len)

logits = self.head(self.decode(gt_in, features, mask)).flatten(end_dim=1)
loss += n * F.cross_entropy(logits, gt_out.flatten(), ignore_index=self.vocab_size + 2)
loss_numel += n
# After the second iteration (i.e. done with canonical and reverse orderings),
# remove the [EOS] tokens for the succeeding perms
if i == 1:
gt_out = torch.where(gt_out == self.vocab_size, self.vocab_size + 2, gt_out)
n = (gt_out != self.vocab_size + 2).sum().item()

loss /= loss_numel

else:
# eval step - use non-autoregressive decoding while training evaluation
logits = self.decode_non_autoregressive(features)
gt = gt[:, 1:] # remove SOS token
max_len = gt.shape[1] - 1 # exclude EOS token
logits = self.decode_autoregressive(features, max_len)
loss = F.cross_entropy(logits.flatten(end_dim=1), gt.flatten(), ignore_index=self.vocab_size + 2)
else:
logits = self.decode_autoregressive(features)

Expand All @@ -368,45 +375,10 @@ def forward(
out["preds"] = self.postprocessor(logits)

if target is not None:
out["loss"] = self.compute_loss(logits, gt, seq_len, ignore_index=self.vocab_size + 2)
out["loss"] = loss

return out

@staticmethod
def compute_loss(
model_output: torch.Tensor,
gt: torch.Tensor,
seq_len: torch.Tensor,
ignore_index: int = -100,
) -> torch.Tensor:
"""Compute categorical cross-entropy loss for the model.
Sequences are masked after the EOS character.

Args:
model_output: predicted logits of the model
gt: the encoded tensor with gt labels
seq_len: lengths of each gt word inside the batch
ignore_index: index to ignore in the loss

Returns:
The loss of the model on the batch
"""
# Input length : number of steps
input_len = model_output.shape[1]
# Add one for additional <eos> token (sos disappear in shift!)
seq_len = seq_len + 1
# Compute loss: don't forget to shift gt! Otherwise the model learns to output the gt[t-1]!
# The "masked" first gt char is <sos>. Delete last logit of the model output.
cce = F.cross_entropy(
model_output[:, :-1, :].permute(0, 2, 1), gt[:, 1:], reduction="none", ignore_index=ignore_index
)
# Compute mask, remove 1 timestep here as well
mask_2d = torch.arange(input_len - 1, device=model_output.device)[None, :] >= seq_len[:, None]
cce[mask_2d] = 0

ce_loss = cce.sum(1) / seq_len.to(dtype=model_output.dtype)
return ce_loss.mean()


class PARSeqPostProcessor(_PARSeqPostProcessor):
"""Post processor for PARSeq architecture
Expand Down
2 changes: 1 addition & 1 deletion doctr/models/recognition/vitstr/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def compute_loss(
# Compute loss: don't forget to shift gt! Otherwise the model learns to output the gt[t-1]!
# The "masked" first gt char is <sos>.
cce = F.cross_entropy(model_output.permute(0, 2, 1), gt[:, 1:], reduction="none")
# Compute mask, remove 1 timestep here as well
# Compute mask
mask_2d = torch.arange(input_len, device=model_output.device)[None, :] >= seq_len[:, None]
cce[mask_2d] = 0

Expand Down