Skip to content

Commit

Permalink
[Fix] PARSeq tensorflow fixes (#1228)
Browse files Browse the repository at this point in the history
  • Loading branch information
felixdittrich92 authored Jun 23, 2023
1 parent b4b613a commit 61a32a1
Showing 1 changed file with 79 additions and 82 deletions.
161 changes: 79 additions & 82 deletions doctr/models/recognition/parseq/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ def __init__(self, vocab_size: int, d_model: int):
self.embedding = tf.keras.layers.Embedding(vocab_size, d_model)
self.d_model = d_model

def call(self, x: tf.Tensor) -> tf.Tensor:
return math.sqrt(self.d_model) * self.embedding(x)
def call(self, x: tf.Tensor, **kwargs: Any) -> tf.Tensor:
return math.sqrt(self.d_model) * self.embedding(x, **kwargs)


class PARSeqDecoder(layers.Layer):
Expand Down Expand Up @@ -136,7 +136,7 @@ def __init__(
max_length: int = 32, # different from paper
dropout_prob: float = 0.1,
dec_num_heads: int = 12,
dec_ff_dim: int = 2048,
dec_ff_dim: int = 384, # we use it from the original implementation instead of 2048
dec_ffd_ratio: int = 4,
input_shape: Tuple[int, int, int] = (32, 128, 3),
exportable: bool = False,
Expand Down Expand Up @@ -209,10 +209,7 @@ def generate_permutations(self, seqlen: tf.Tensor) -> tf.Tensor:
combined = tf.tensor_scatter_nd_update(
combined, [[1, i] for i in range(1, max_num_chars + 2)], max_num_chars + 1 - tf.range(max_num_chars + 1)
)
# we pad to max length with eos idx to fit the mask generation
return tf.pad(
combined, [[0, 0], [0, self.max_length + 1 - tf.shape(combined)[1]]], constant_values=max_num_chars + 2
) # (num_perms, self.max_length + 1)
return combined

@tf.function
def generate_permutations_attention_masks(self, permutation: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
Expand All @@ -232,7 +229,6 @@ def generate_permutations_attention_masks(self, permutation: tf.Tensor) -> Tuple
mask, tf.where(eye_indices), tf.zeros_like(tf.boolean_mask(mask, eye_indices))
)
target_mask = mask[1:, :-1]

return tf.cast(source_mask, dtype=tf.bool), tf.cast(target_mask, dtype=tf.bool)

@tf.function
Expand All @@ -246,110 +242,78 @@ def decode(
) -> tf.Tensor:
batch_size, sequence_length = target.shape
# apply positional information to the target sequence excluding the SOS token
null_ctx = self.embed(target[:, :1])
content = self.pos_queries[:, : sequence_length - 1] + self.embed(target[:, 1:])
null_ctx = self.embed(target[:, :1], **kwargs)
content = self.pos_queries[:, : sequence_length - 1] + self.embed(target[:, 1:], **kwargs)
content = self.dropout(tf.concat([null_ctx, content], axis=1), **kwargs)
if target_query is None:
target_query = tf.tile(self.pos_queries[:, :sequence_length], [batch_size, 1, 1])
target_query = self.dropout(target_query, **kwargs)
return self.decoder(target_query, content, memory, target_mask, **kwargs)

@tf.function
def decode_autoregressive(self, features: tf.Tensor) -> tf.Tensor:
def decode_autoregressive(self, features: tf.Tensor, max_len: Optional[int] = None, **kwargs) -> tf.Tensor:
"""Generate predictions for the given features."""
# Padding symbol + SOS at the beginning
max_length = max_len if max_len is not None else self.max_length
max_length = min(max_length, self.max_length) + 1
b = tf.shape(features)[0]
ys = tf.fill(dims=(b, self.max_length), value=self.vocab_size + 2)
# Padding symbol + SOS at the beginning
ys = tf.fill(dims=(b, max_length), value=self.vocab_size + 2)
start_vector = tf.fill(dims=(b, 1), value=self.vocab_size + 1)
ys = tf.concat([start_vector, ys], axis=-1)
pos_queries = tf.tile(self.pos_queries[:, : self.max_length + 1], [b, 1, 1])
query_mask = tf.cast(
tf.linalg.band_part(tf.ones((self.max_length + 1, self.max_length + 1)), -1, 0), dtype=tf.bool
)
pos_queries = tf.tile(self.pos_queries[:, :max_length], [b, 1, 1])
query_mask = tf.cast(tf.linalg.band_part(tf.ones((max_length, max_length)), -1, 0), dtype=tf.bool)

pos_logits = []
for i in range(self.max_length):
for i in range(max_length):
# Decode one token at a time without providing information about the future tokens
tgt_out = self.decode(
ys[:, : i + 1],
features,
query_mask[i : i + 1, : i + 1],
target_query=pos_queries[:, i : i + 1],
**kwargs,
)
pos_prob = self.head(tgt_out)
pos_logits.append(pos_prob)

if i + 1 < self.max_length:
if i + 1 < max_length:
# update ys with the next token
i_mesh, j_mesh = tf.meshgrid(tf.range(b), tf.range(self.max_length), indexing="ij")
i_mesh, j_mesh = tf.meshgrid(tf.range(b), tf.range(max_length), indexing="ij")
indices = tf.stack([i_mesh[:, i + 1], j_mesh[:, i + 1]], axis=1)
ys = tf.tensor_scatter_nd_update(
ys, indices, tf.cast(tf.argmax(pos_prob[:, -1, :], axis=-1), dtype=tf.int32)
)

# Stop decoding if all sequences have reached the EOS token
# We need to check it on True to be compatible with ONNX
if tf.reduce_any(tf.reduce_all(tf.equal(ys, tf.constant(self.vocab_size)), axis=-1)) is True:
if (
max_len is None
and tf.reduce_any(tf.reduce_all(tf.equal(ys, tf.constant(self.vocab_size)), axis=-1)) is True
):
break

logits = tf.concat(pos_logits, axis=1) # (N, max_length, vocab_size + 1)

# One refine iteration
# Update query mask
query_mask = tf.cast(1 - tf.linalg.diag(tf.ones(self.max_length, dtype=tf.int32), k=-1), dtype=tf.bool)
diag_matrix = tf.eye(max_length)
diag_matrix = tf.cast(tf.logical_not(tf.cast(diag_matrix, dtype=tf.bool)), dtype=tf.float32)
query_mask = tf.cast(tf.concat([diag_matrix[1:], tf.ones((1, max_length))], axis=0), dtype=tf.bool)

sos = tf.fill((tf.shape(features)[0], 1), self.vocab_size + 1)
ys = tf.concat([sos, tf.cast(tf.argmax(logits[:, :-1], axis=-1), dtype=tf.int32)], axis=1)
# Create padding mask for refined target input maskes all behind EOS token as False
# (N, 1, 1, max_length)
target_pad_mask = tf.cumsum(tf.cast(tf.equal(ys, self.vocab_size), dtype=tf.int32), axis=1, reverse=False)
target_pad_mask = tf.logical_not(tf.cast(target_pad_mask[:, tf.newaxis, tf.newaxis, :], dtype=tf.bool))
mask = tf.cast(tf.equal(ys, self.vocab_size), tf.float32)
first_eos_indices = tf.argmax(mask, axis=1, output_type=tf.int32)
mask = tf.sequence_mask(first_eos_indices + 1, maxlen=ys.shape[-1], dtype=tf.float32)
target_pad_mask = tf.cast(mask[:, tf.newaxis, tf.newaxis, :], dtype=tf.bool)

mask = tf.math.logical_and(target_pad_mask, query_mask[:, : ys.shape[1]])
logits = self.head(self.decode(ys, features, mask, target_query=pos_queries))
logits = self.head(self.decode(ys, features, mask, target_query=pos_queries, **kwargs), **kwargs)

return logits # (N, max_length, vocab_size + 1)

@tf.function
def decode_non_autoregressive(self, features: tf.Tensor) -> tf.Tensor:
"""Decode the given features at once"""
pos_queries = tf.tile(self.pos_queries[:, : self.max_length + 1], [tf.shape(features)[0], 1, 1])
ys = tf.fill((tf.shape(features)[0], 1), self.vocab_size + 1)
return self.head(self.decode(ys, features, target_query=pos_queries))[:, : self.max_length]

@staticmethod
def compute_loss(
model_output: tf.Tensor,
gt: tf.Tensor,
seq_len: List[int],
) -> tf.Tensor:
"""Compute categorical cross-entropy loss for the model.
Sequences are masked after the EOS character.
Args:
model_output: predicted logits of the model
gt: the encoded tensor with gt labels
seq_len: lengths of each gt word inside the batch
Returns:
The loss of the model on the batch
"""
# Input length : number of steps
input_len = tf.shape(model_output)[1]
# Add one for additional <eos> token (sos disappear in shift!)
seq_len = tf.cast(seq_len, tf.int32) + 1
# One-hot gt labels
oh_gt = tf.one_hot(gt, depth=model_output.shape[2])
# Compute loss: don't forget to shift gt! Otherwise the model learns to output the gt[t-1]!
# The "masked" first gt char is <sos>. Delete last logit of the model output.
cce = tf.nn.softmax_cross_entropy_with_logits(oh_gt[:, 1:, :], model_output[:, :-1, :])
# Compute mask
mask_values = tf.zeros_like(cce)
mask_2d = tf.sequence_mask(seq_len, input_len - 1) # delete the last mask timestep as well
masked_loss = tf.where(mask_2d, cce, mask_values)
ce_loss = tf.math.divide(tf.reduce_sum(masked_loss, axis=1), tf.cast(seq_len, model_output.dtype))

return tf.expand_dims(ce_loss, axis=1)

def call(
self,
x: tf.Tensor,
Expand All @@ -362,36 +326,69 @@ def call(
# remove cls token
features = features[:, 1:, :]

if target is not None:
gt, seq_len = self.build_target(target)
seq_len = tf.cast(seq_len, tf.int32)

if kwargs.get("training", False) and target is None:
raise ValueError("Need to provide labels during training")

if target is not None:
gt, seq_len = self.build_target(target)
seq_len = tf.cast(seq_len, tf.int32)
gt = gt[:, : int(tf.reduce_max(seq_len)) + 2] # slice up to the max length of the batch + 2 (SOS + EOS)

if kwargs.get("training", False):
# Generate permutations of the target sequences
tgt_perms = self.generate_permutations(seq_len)

gt_in = gt[:, :-1] # remove EOS token from longest target sequence
gt_out = gt[:, 1:] # remove SOS token

# Create padding mask for target input
# [True, True, True, ..., False, False, False] -> False is masked
padding_mask = ((gt != self.vocab_size + 2) | (gt != self.vocab_size))[:, tf.newaxis, tf.newaxis, :]

for perm in tgt_perms:
# Generate attention masks for the permutations
_, target_mask = self.generate_permutations_attention_masks(perm)
# combine target padding mask and query mask
mask = tf.math.logical_and(target_mask, padding_mask)
logits = self.head(self.decode(gt, features, mask))
padding_mask = tf.math.logical_and(
tf.math.not_equal(gt_in, self.vocab_size + 2), tf.math.not_equal(gt_in, self.vocab_size)
)
padding_mask = padding_mask[:, tf.newaxis, tf.newaxis, :] # (N, 1, 1, seq_len)

loss = tf.constant(0.0)
loss_numel = tf.constant(0.0)
n = tf.reduce_sum(tf.cast(tf.math.not_equal(gt_out, self.vocab_size + 2), dtype=tf.float32))
for i, perm in enumerate(tgt_perms):
_, target_mask = self.generate_permutations_attention_masks(perm) # (seq_len, seq_len)
# combine both masks to (N, 1, seq_len, seq_len)
mask = tf.logical_and(padding_mask, tf.expand_dims(tf.expand_dims(target_mask, axis=0), axis=0))

logits = self.head(self.decode(gt_in, features, mask, **kwargs), **kwargs)
logits_flat = tf.reshape(logits, (-1, logits.shape[-1]))
targets_flat = tf.reshape(gt_out, (-1,))
mask = tf.not_equal(targets_flat, self.vocab_size + 2)
loss += n * tf.reduce_mean(
tf.nn.sparse_softmax_cross_entropy_with_logits(
labels=tf.boolean_mask(targets_flat, mask), logits=tf.boolean_mask(logits_flat, mask)
)
)
loss_numel += n

# After the second iteration (i.e. done with canonical and reverse orderings),
# remove the [EOS] tokens for the succeeding perms
if i == 1:
gt_out = tf.where(tf.equal(gt_out, self.vocab_size), self.vocab_size + 2, gt_out)
n = tf.reduce_sum(tf.cast(tf.math.not_equal(gt_out, self.vocab_size + 2), dtype=tf.float32))

loss /= loss_numel

else:
# eval step - use non-autoregressive decoding while training evaluation
logits = self.decode_non_autoregressive(features)
gt = gt[:, 1:] # remove SOS token
max_len = gt.shape[1] - 1 # exclude EOS token
logits = self.decode_autoregressive(features, max_len, **kwargs)
logits_flat = tf.reshape(logits, (-1, logits.shape[-1]))
targets_flat = tf.reshape(gt, (-1,))
mask = tf.not_equal(targets_flat, self.vocab_size + 2)
loss = tf.reduce_mean(
tf.nn.sparse_softmax_cross_entropy_with_logits(
labels=tf.boolean_mask(targets_flat, mask), logits=tf.boolean_mask(logits_flat, mask)
)
)
else:
logits = self.decode_autoregressive(features)
logits = self.decode_autoregressive(features, **kwargs)

out: Dict[str, tf.Tensor] = {}
if self.exportable:
Expand All @@ -406,7 +403,7 @@ def call(
out["preds"] = self.postprocessor(logits)

if target is not None:
out["loss"] = self.compute_loss(logits, gt, seq_len)
out["loss"] = loss

return out

Expand Down

0 comments on commit 61a32a1

Please sign in to comment.