Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] PARSeq tensorflow fixes #1228

Merged
merged 7 commits into from
Jun 23, 2023
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix run
  • Loading branch information
felixdittrich92 committed Jun 23, 2023
commit 0b47c81a067aa17613966f67aa97f82afb1380d8
40 changes: 20 additions & 20 deletions doctr/models/recognition/parseq/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ def __init__(self, vocab_size: int, d_model: int):
self.embedding = tf.keras.layers.Embedding(vocab_size, d_model)
self.d_model = d_model

def call(self, x: tf.Tensor) -> tf.Tensor:
return math.sqrt(self.d_model) * self.embedding(x)
def call(self, x: tf.Tensor, **kwargs: Any) -> tf.Tensor:
return math.sqrt(self.d_model) * self.embedding(x, **kwargs)


class PARSeqDecoder(layers.Layer):
Expand Down Expand Up @@ -242,16 +242,16 @@ def decode(
) -> tf.Tensor:
batch_size, sequence_length = target.shape
# apply positional information to the target sequence excluding the SOS token
null_ctx = self.embed(target[:, :1])
content = self.pos_queries[:, : sequence_length - 1] + self.embed(target[:, 1:])
null_ctx = self.embed(target[:, :1], **kwargs)
content = self.pos_queries[:, : sequence_length - 1] + self.embed(target[:, 1:], **kwargs)
content = self.dropout(tf.concat([null_ctx, content], axis=1), **kwargs)
if target_query is None:
target_query = tf.tile(self.pos_queries[:, :sequence_length], [batch_size, 1, 1])
target_query = self.dropout(target_query, **kwargs)
return self.decoder(target_query, content, memory, target_mask, **kwargs)

@tf.function
def decode_autoregressive(self, features: tf.Tensor, max_len: Optional[int] = None) -> tf.Tensor:
def decode_autoregressive(self, features: tf.Tensor, max_len: Optional[int] = None, **kwargs) -> tf.Tensor:
"""Generate predictions for the given features."""
max_length = max_len if max_len is not None else self.max_length
max_length = min(max_length, self.max_length) + 1
Expand All @@ -271,6 +271,7 @@ def decode_autoregressive(self, features: tf.Tensor, max_len: Optional[int] = No
features,
query_mask[i : i + 1, : i + 1],
target_query=pos_queries[:, i : i + 1],
**kwargs,
)
pos_prob = self.head(tgt_out)
pos_logits.append(pos_prob)
Expand Down Expand Up @@ -305,7 +306,7 @@ def decode_autoregressive(self, features: tf.Tensor, max_len: Optional[int] = No
target_pad_mask = tf.cumsum(tf.cast(tf.equal(ys, self.vocab_size), dtype=tf.int32), axis=-1, reverse=True) > 0
target_pad_mask = target_pad_mask[:, tf.newaxis, tf.newaxis, :]
mask = tf.math.logical_and(target_pad_mask, query_mask[:, : ys.shape[1]])
logits = self.head(self.decode(ys, features, mask, target_query=pos_queries))
logits = self.head(self.decode(ys, features, mask, target_query=pos_queries, **kwargs), **kwargs)

return logits # (N, max_length, vocab_size + 1)

Expand Down Expand Up @@ -344,24 +345,21 @@ def call(
padding_mask = padding_mask[:, tf.newaxis, tf.newaxis, :] # (N, 1, 1, seq_len)

loss = tf.constant(0.0, dtype=tf.float32)
loss_numel = tf.constant(0, dtype=tf.float32)
loss_numel = tf.constant(0.0, dtype=tf.float32)
n = tf.reduce_sum(tf.cast(tf.math.not_equal(gt_out, self.vocab_size + 2), dtype=tf.float32))
for i, perm in enumerate(tgt_perms):
_, target_mask = self.generate_permutations_attention_masks(perm) # (seq_len, seq_len)
# combine both masks to (N, 1, seq_len, seq_len)
mask = tf.logical_and(padding_mask, tf.expand_dims(tf.expand_dims(target_mask, axis=0), axis=0))

logits = self.head(self.decode(gt_in, features, mask))
# TODO: Fix me
cce_loss = losses.sparse_categorical_crossentropy(
flattened_gt_out, flattened_logits, from_logits=False, ignore_class=self.vocab_size + 2
)
print(cce_loss)
# Convert cce_loss to float tensor
cce_loss = tf.cast(cce_loss, dtype=tf.float32)
# Convert n to float tensor
n = tf.cast(n, dtype=tf.float32)
loss += cce_loss * n
logits = self.head(self.decode(gt_in, features, mask, **kwargs), **kwargs)

# TODO: Fix me :)
# Compute loss
cce_loss = tf.reduce_sum(losses.sparse_categorical_crossentropy(gt_out, logits, from_logits=True, ignore_class=self.vocab_size + 2))

loss += cce_loss * tf.cast(n, dtype=tf.float32)
loss_numel += tf.cast(n, dtype=tf.float32)

# After the second iteration (i.e. done with canonical and reverse orderings),
# remove the [EOS] tokens for the succeeding perms
Expand All @@ -370,16 +368,18 @@ def call(
n = tf.reduce_sum(tf.cast(tf.math.not_equal(gt_out, self.vocab_size + 2), dtype=tf.int32))

loss /= loss_numel
print(loss)
#loss = tf.constant(1.0, dtype=tf.float32)

else:
gt = gt[:, 1:] # remove SOS token
max_len = gt.shape[1] - 1 # exclude EOS token
logits = self.decode_autoregressive(features, max_len)
logits = self.decode_autoregressive(features, max_len, **kwargs)
loss = losses.sparse_categorical_crossentropy(
tf.nest.flatten(gt_out), logits, from_logits=True, ignore_class=self.vocab_size + 2
)
else:
logits = self.decode_autoregressive(features)
logits = self.decode_autoregressive(features, **kwargs)

out: Dict[str, tf.Tensor] = {}
if self.exportable:
Expand Down