Skip to content

Commit

Permalink
update comments and mypy
Browse files Browse the repository at this point in the history
  • Loading branch information
felixdittrich92 committed Jun 21, 2023
1 parent f0c1267 commit 6755842
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions doctr/models/recognition/parseq/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ def forward(
.unsqueeze(1)
) # (N, 1, 1, seq_len)

loss = 0.0
loss = torch.tensor(0.0, device=features.device)
loss_numel: Union[int, float] = 0
n = (gt_out != self.vocab_size + 2).sum().item()
for i, perm in enumerate(tgt_perms):
Expand All @@ -344,7 +344,7 @@ def forward(
mask = (target_mask.bool() & padding_mask.bool()).int() # (N, 1, seq_len, seq_len)

logits = self.head(self.decode(gt_in, features, mask)).flatten(end_dim=1)
loss += n * F.cross_entropy(logits, gt_out.flatten(), ignore_index=self.vocab_size + 2).item()
loss += n * F.cross_entropy(logits, gt_out.flatten(), ignore_index=self.vocab_size + 2)
loss_numel += n
# After the second iteration (i.e. done with canonical and reverse orderings),
# remove the [EOS] tokens for the succeeding perms
Expand All @@ -358,7 +358,7 @@ def forward(
gt = gt[:, 1:] # remove SOS token
max_len = gt.shape[1] - 1 # exclude EOS token
logits = self.decode_autoregressive(features, max_len)
loss = F.cross_entropy(logits.flatten(end_dim=1), gt.flatten(), ignore_index=self.vocab_size + 2).item()
loss = F.cross_entropy(logits.flatten(end_dim=1), gt.flatten(), ignore_index=self.vocab_size + 2)
else:
logits = self.decode_autoregressive(features)

Expand Down

0 comments on commit 6755842

Please sign in to comment.