Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Masked language modeling with Bert example #244

Merged
merged 13 commits into from
Sep 23, 2020

Conversation

Ankur3107
Copy link
Contributor

  • Uses pretrained bert-base-uncased from HuggingFace transformers and their tokenizers.
  • Train Masked Language Bert Model on the IMDB dataset.

Copy link
Contributor

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the PR!


return compute_mlm_loss, train_mlm_loss_metric

def masked_sparse_categorical_crossentropy(self, y_true, y_pred):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should be able to achieve the same with just SparseCategoricalCrossentropy by using the sample_weight argument. https://keras.io/api/losses/probabilistic_losses/#sparsecategoricalcrossentropy-class

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have tried using SparseCategoricalCrossentropy sample_weight argument but it's giving nan loss (because I have used -1 (negative/ unmasked token) in labels tensor). Should I use non-negative id for unmasked token and then use sample_weight?

strategy.experimental_run_v2(self.mlm_train_step, args=(data,))

@tf.function
def mlm_train_step(self, inputs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead, please use the train_step overriding workflow. This will be much simpler and won't need to rely on experimental_run_v2. https://keras.io/guides/customizing_what_happens_in_fit/

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have implemented this using train_step overriding workflow.

"""


class Dataset:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dataset is already a base class in TF. I would recommend simply refactoring this as separate utility functions, since there is no need for an OO approach here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I have implemented it. Please review functions

@Ankur3107
Copy link
Contributor Author

@fchollet Thanks for your review.
I have updated it, according to your guidance. One doubt I have, with SparseCategoricalCrossentropy sample_weight argument. I have used -1(target tensor) to identify and ignore unmasked token while calculating loss. So, if I use SparseCategoricalCrossentropy with sample_weight then the loss function output would be nan.

Should I use a non-negative number to represent unmasked token and then use SparseCategoricalCrossentropy sample_weight?

@Ankur3107
Copy link
Contributor Author

Hi @fchollet

I have tried to implement SparseCategoricalCrossentropy with sample_weight parameter and I am able to run on CPU and GPU Successfully but it's not running on TPU, below is the error

Invalid argument: {{function_node __inference_train_function_196061}} Compilation failure: Detected unsupported operations when trying to compile graph broadcast_weights_assert_broadcastable_is_valid_shape_has_valid_nonscalar_shape_true_195924_const_0[] on XLA_TPU_JIT: DenseToDenseSetOperation (No registered 'DenseToDenseSetOperation' OpKernel for XLA_TPU_JIT devices compatible with node {{node broadcast_weights/assert_broadcastable/is_valid_shape/has_valid_nonscalar_shape/has_invalid_dims/DenseToDenseSetOperation}}){{node broadcast_weights/assert_broadcastable/is_valid_shape/has_valid_nonscalar_shape/has_invalid_dims/DenseToDenseSetOperation}}
    [[broadcast_weights/assert_broadcastable/is_valid_shape/has_valid_nonscalar_shape]]
    [[broadcast_weights/assert_broadcastable/is_valid_shape]]
   TPU compilation failed
    [[tpu_compile_succeeded_assert/_18335038187613310273/_6]]
    [[tpu_compile_succeeded_assert/_18335038187613310273/_6/_265]]

TPU Failed Version Colab: https://colab.research.google.com/drive/19U0VJkd9nOIn2vwZ0dcNflmNoV-CZ5Zz?usp=sharing

Please see the code

Copy link
Contributor

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the update!



with strategy.scope():
loss_tracker = tf.keras.metrics.Mean(name="loss")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should define this in __init__ (or compile)

self.optimizer.apply_gradients(zip(gradients, trainable_vars))

# Compute our own metrics
loss_tracker.update_state(loss)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A problem here is that you never reset_states() on the loss tracker. So what you display is an average of batch losses since the start of training.

You should create the property metrics:

    @property
    def metrics(self):
        # We list our `Metric` objects here so that `reset_states()` can be
        # called automatically at the start of each epoch
        # or at the start of `evaluate()`.
        return [loss_tracker]

"""


def masked_sparse_categorical_crossentropy(y_true, y_pred):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's fine to keep it implemented like this -- thanks for investigating the sample_weights arg

"""

imdb_bert_model = TFAutoModelWithLMHead.from_pretrained("imdb_bert_uncased")
nlp = pipeline("fill-mask", model=imdb_bert_model, tokenizer=tokenizer, framework="tf")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A concern I have here: this example is basically more about "how to use HuggingFace Transformers" than about masked language modeling or Keras. The pipeline, tokenizer, and TFAutoModelWithLMHead are black boxes.

The purpose of our code examples is to be Keras-centric and as didactic as possible, i.e. they should explain the task and problem domain, similar to what we do here for object detection, where everything is implemented from scratch.

Would you consider implementing these component from scratch in Keras, e.g. with the TextVectorization layer for tokenization and with a Transformer block built with the new MultiHeadAttention layer (in tf-nightly)? No pretraining necessary.

The model could be similar to this, expect you wouldn't even have to implement MultiHeadAttention.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, @fchollet Thanks for your comment.

True, I feel the same, this example looks like 'how to use HuggingFace Transformers'.

Yes, I would like to consider implementing these components from scratch in Keras, e.g. with the TextVectorization layer for tokenization and with a Transformer block built with the new MultiHeadAttention layer.

Should I use this for the classification Problems or Pretraining Problems?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have started according to your suggestion to use MultiHead layer and TextVectorizationLayer. And it's almost done for both pretraning and fine tuning.

I have one doubt, TextVectorizationLayer is not running on TPU. Same notebook running on CPU n GPU, but not on TPU

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you using TextVectorization as part of the tf.data pipeline? Can you share your notebook?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, I am not using tf.data.

def get_vectorize_layer(texts, vocab_size, max_seq, special_tokens=['[MASK]']):

  vectorize_layer = preprocessing.TextVectorization(
    max_tokens=vocab_size,
    output_mode='int',
    standardize = None,
    output_sequence_length=max_seq)
  
  vectorize_layer.adapt(texts) # error in adapt

  vocab = vectorize_layer.get_vocabulary()
  vocab = vocab[2:vocab_size-len(special_tokens)] + ['[MASK]']
  vectorize_layer.set_vocabulary(vocab)
  return vectorize_layer


vectorize_layer = get_vectorize_layer(data.text.values.tolist(), 20000, 196, special_tokens=['[MASK]'])

Error is:

NotFoundError: 'OptimizeDatasetV2' is neither a type of a primitive operation nor a name of a function registered in binary running on n-86c78cbc-w-0. Make sure the operation or function is registered in the binary running in this process.

Tensorflow: 2.4.0-dev20200915
Reference Notebook: https://colab.research.google.com/drive/1QiQmA2M2WkwsjiHHxBhPM0Jw3_LBJnGV?usp=sharing

Note: Same code working with tensorflow:2.3.0

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bug likely to a recently introduced optimization setting for tf.data datasets (which are used internally in adapt). We're looking into it -- thank you for the report.

Since we will fix this before the next release, I recommend that you keep developing this example with TextVectorization, using GPU only for now. You could try to use an older tf-nightly version that wouldn't include the bug (though I am not sure if you'll find one that also has the MultiHeadAttention layer).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am developing with GPU compatible. I have tried almost random 5-7 different tf-nightly version. All have same issues.

My notebook is almost ready. going to commit in a few mins.
Thanks!

I have implemented Mask Language Model and Fine-tune model with bert from scratch using tf-nightly
@Ankur3107
Copy link
Contributor Author

Hi @fchollet

Thanks for your guidance, I have implemented BERT from scratch using tf-nightly. I used TextVectorization and MultiHeadAttention layer. And also implemented the end-to-end export method.

Thanks

@fchollet
Copy link
Contributor

Thanks a lot, I will review shortly.

Here's an update about the TPU failure: it is due to the fact that the TPU client is using tf-nightly while the TPU worker is using TF 2.3. In the general case, it is not safe to use TPU in Colab with tf-nightly, you should always use them with the latest stable TF version instead.

Copy link
Contributor

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the update!

texts (list): List of String i.e input texts
vocab_size (int): vocab size
max_seq (int): maximum sequence len
special_tokens (list, optional): List of special tokens. Defaults to ['[MASK]'].
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In practice, this can just be mask_token='[MASK]'.

vectorize_layer.adapt(texts)

vocab = vectorize_layer.get_vocabulary()
vocab = vocab[2 : vocab_size - len(special_tokens)] + ["[MASK]"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you put it as the start of the list you won't need to compute mask_token_id by hand (it will be 0)

Copy link
Contributor Author

@Ankur3107 Ankur3107 Sep 17, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am a little bit confused. My understanding is index 0 is reserved for padding token. (Reference)

And I feel there is no argument called mask_token in the class constructor of TextVectorization (Reference)

I need to insert special tokens ('[MASK] is this case') into the vocabulary and I would like to use zero index for padding.

Am I thinking right?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it. The mask token is normally index 0. So you could just insert it at index 0, which is what I was saying. Note that there is a mask token already, by default it is the empty string. You could simply replace [MASK] with the empty string in your code and you wouldn't need to configure it (and its index would be 0).

Note, the purpose of the default mask token is to handle padding, so your data will already contain 0s for sections that are padded (sequences shorter than MAX_LEN). Not sure if that forces you to use a separate mask token...

I agree with you that it should be configurable in the constructor of TextVectorization. We will add it (alongside a couple other arguments). Internally TextVectorization uses a StringLookup layer which allows this customization. However if we configure it in this way it would still be used for padding as well.



def build_mlm_model():
tf.keras.backend.clear_session()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should be able to remove this statement

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok I will refactor.



def build_classifer_model():
tf.keras.backend.clear_session()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should be able to remove this statement

)

# query, value, key for attention
query = tf.keras.layers.Dense(flags.EMBED_DIM)(embedding_layer)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still nit: use

from tensorflow.keras import layers

then layers.Embedding, etc.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok I will refactor in the same way.

"""

mlm_model = build_mlm_model()
mlm_model.fit(mlm_train_data, epochs=1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any call to fit() should include a way to monitor results on an evaluation set.

Why only train for 1 epoch?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: since you're training using a generator, consider using multiple workers for better performance.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I used one epoch because of time. It is taking 22min per epoch. I will use multiple workers for better performance and will do for 3 epochs.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually I am not able to install (running without gpu) tf-nightly-gpu on colab or kaggle. I tried to install but I unable to get GPU support.

mlm_model = tf.keras.Model(inputs, outputs)

loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
reduction=tf.keras.losses.Reduction.NONE, from_logits=True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why no reduction? since you're using compile/fit the results will be averaged for reduction anyway (since some form of reduction is necessary).

labels[inp_mask] = X[inp_mask]

# prepare input
X_mlm = np.copy(X)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Style nit: variable names should be lowercase. Please always use fully-spelled out variable names for better readability, instead of abbreviations or single letters.

labels (list): List of Labels i.e for classification model only
batch_size (int): batch size
vectorizer (TextVectorization): Keras TextVectorization layer
is_training (bool, optional): Is training dataset generator. Defaults to True.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Prefer using include_labels=True which is more explicit

mask_token_id = vectorize_layer(["[MASK]"]).numpy()[0][0]


class IMDBReviewsDataGenerator(tf.keras.utils.Sequence):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any way to handle this using tf.data.Dataset instead?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok I will try to implement in tf.data.Dataset

@Ankur3107
Copy link
Contributor Author

Hi @fchollet
According to your feedback, I have updated the file.

  1. Used tf.data.Dataset
  2. Refactored variable names
  3. Refactored package import format

I have one doubt, What is the good way to add '[MASK]' token (or list of the special token)into the vectorizer layer.

Copy link
Contributor

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the update -- the code looks good to me!

vectorize_layer.adapt(texts)

vocab = vectorize_layer.get_vocabulary()
vocab = vocab[2 : vocab_size - len(special_tokens)] + ["[MASK]"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it. The mask token is normally index 0. So you could just insert it at index 0, which is what I was saying. Note that there is a mask token already, by default it is the empty string. You could simply replace [MASK] with the empty string in your code and you wouldn't need to configure it (and its index would be 0).

Note, the purpose of the default mask token is to handle padding, so your data will already contain 0s for sections that are padded (sequences shorter than MAX_LEN). Not sure if that forces you to use a separate mask token...

I agree with you that it should be configurable in the constructor of TextVectorization. We will add it (alongside a couple other arguments). Internally TextVectorization uses a StringLookup layer which allows this customization. However if we configure it in this way it would still be used for padding as well.

data = pd.read_csv("IMDb_Reviews.csv")

"""
## Dataset Preparation
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For each subsection, please add a text paragraph explaining what you're doing:

  • How do you prepare the dataset? what do your samples end up looking like?
  • What does the masked language model do (what are its inputs are outputs)?
  • What does the classification model do?
  • What does the end to end model do?
  • How to interpret evaluation metrics?

"""

"""shell
wget https://github.com/LawrenceDuan/IMDb-Review-Analysis/raw/master/IMDb_Reviews.csv
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"""
"""
## Introduction
Masked language modeling is a fill-in-the-blank task, where a model uses the context words surrounding a [MASK] token to try to predict what the [MASK] word should be. We will use IMDB Reviews raw text to pretrain and then use pretrained model weights to fine-tune sentiment classification.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please expand this description, provide concrete examples

@Ankur3107
Copy link
Contributor Author

Hi @fchollet

Please review my latest update.

Copy link
Contributor

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the update! I have applied extensive copyediting, please pull the changes first.

Note that I have changed the fine-tuning logic. When doing fine-tuning, you should:

  • Freeze the pretrained part
  • Train the model
  • Unfreeze the pretrained part
  • Train again

If you don't do this, then you're not doing fine-tuning, and you're not significantly better off compared to just training a randomly initialized model.

@fchollet
Copy link
Contributor

Please reapply black formatting, and trying running the code again. You may want to adjust the number of epochs used for the fine-tuning part, as well as the learning rate of the optimizer.

What classification accuracy do you get?

@Ankur3107
Copy link
Contributor Author

I have done the following update:

  • Added custom_standardization for vectorizer
  • Implemented little Generic BERT Model using Keras Subclass API
  • Implemented Custom Callback to check mask model using sample input

The following accuracy I got:

  • Masked Language Model

Input: "i have watched this [mask] and it was awesome"
top_k= 5 #top 5 prediction

Epoch 1/5
1563/1563 [==============================] - ETA: 0s - loss: 7.2558
{'input_text': 'i have watched this [mask] and it was awesome',
 'predicted mask token': 'this',
 'prediction': 'i have watched this this and it was awesome',
 'probability': 0.04289173}

{'input_text': 'i have watched this [mask] and it was awesome',
 'predicted mask token': 'movie',
 'prediction': 'i have watched this movie and it was awesome',
 'probability': 0.04171709}

{'input_text': 'i have watched this [mask] and it was awesome',
 'predicted mask token': 'i',
 'prediction': 'i have watched this i and it was awesome',
 'probability': 0.034618173}

{'input_text': 'i have watched this [mask] and it was awesome',
 'predicted mask token': 'a',
 'prediction': 'i have watched this a and it was awesome',
 'probability': 0.03268459}

{'input_text': 'i have watched this [mask] and it was awesome',
 'predicted mask token': 'to',
 'prediction': 'i have watched this to and it was awesome',
 'probability': 0.031339437}
1563/1563 [==============================] - 152s 97ms/step - loss: 7.2556
Epoch 2/5
1563/1563 [==============================] - ETA: 0s - loss: 6.5399
{'input_text': 'i have watched this [mask] and it was awesome',
 'predicted mask token': 'movie',
 'prediction': 'i have watched this movie and it was awesome',
 'probability': 0.36367047}

{'input_text': 'i have watched this [mask] and it was awesome',
 'predicted mask token': 'film',
 'prediction': 'i have watched this film and it was awesome',
 'probability': 0.101376995}

{'input_text': 'i have watched this [mask] and it was awesome',
 'predicted mask token': 'is',
 'prediction': 'i have watched this is and it was awesome',
 'probability': 0.054904718}

{'input_text': 'i have watched this [mask] and it was awesome',
 'predicted mask token': 'one',
 'prediction': 'i have watched this one and it was awesome',
 'probability': 0.047664806}

{'input_text': 'i have watched this [mask] and it was awesome',
 'predicted mask token': 'was',
 'prediction': 'i have watched this was and it was awesome',
 'probability': 0.03694917}
1563/1563 [==============================] - 151s 96ms/step - loss: 6.5398
Epoch 3/5
1563/1563 [==============================] - ETA: 0s - loss: 5.8020
{'input_text': 'i have watched this [mask] and it was awesome',
 'predicted mask token': 'movie',
 'prediction': 'i have watched this movie and it was awesome',
 'probability': 0.27628118}

{'input_text': 'i have watched this [mask] and it was awesome',
 'predicted mask token': 'film',
 'prediction': 'i have watched this film and it was awesome',
 'probability': 0.1621771}

{'input_text': 'i have watched this [mask] and it was awesome',
 'predicted mask token': 'time',
 'prediction': 'i have watched this time and it was awesome',
 'probability': 0.023138938}

{'input_text': 'i have watched this [mask] and it was awesome',
 'predicted mask token': 'one',
 'prediction': 'i have watched this one and it was awesome',
 'probability': 0.018522203}

{'input_text': 'i have watched this [mask] and it was awesome',
 'predicted mask token': 'review',
 'prediction': 'i have watched this review and it was awesome',
 'probability': 0.015409218}
1563/1563 [==============================] - 151s 97ms/step - loss: 5.8019
Epoch 4/5
1563/1563 [==============================] - ETA: 0s - loss: 5.1151
{'input_text': 'i have watched this [mask] and it was awesome',
 'predicted mask token': 'movie',
 'prediction': 'i have watched this movie and it was awesome',
 'probability': 0.41562507}

{'input_text': 'i have watched this [mask] and it was awesome',
 'predicted mask token': 'film',
 'prediction': 'i have watched this film and it was awesome',
 'probability': 0.19526063}

{'input_text': 'i have watched this [mask] and it was awesome',
 'predicted mask token': 'one',
 'prediction': 'i have watched this one and it was awesome',
 'probability': 0.035126403}

{'input_text': 'i have watched this [mask] and it was awesome',
 'predicted mask token': 'night',
 'prediction': 'i have watched this night and it was awesome',
 'probability': 0.025516002}

{'input_text': 'i have watched this [mask] and it was awesome',
 'predicted mask token': 'time',
 'prediction': 'i have watched this time and it was awesome',
 'probability': 0.024052398}
1563/1563 [==============================] - 151s 97ms/step - loss: 5.1150
Epoch 5/5
1563/1563 [==============================] - ETA: 0s - loss: 4.7227
{'input_text': 'i have watched this [mask] and it was awesome',
 'predicted mask token': 'movie',
 'prediction': 'i have watched this movie and it was awesome',
 'probability': 0.46062306}

{'input_text': 'i have watched this [mask] and it was awesome',
 'predicted mask token': 'film',
 'prediction': 'i have watched this film and it was awesome',
 'probability': 0.12770538}

{'input_text': 'i have watched this [mask] and it was awesome',
 'predicted mask token': 'one',
 'prediction': 'i have watched this one and it was awesome',
 'probability': 0.051738814}

{'input_text': 'i have watched this [mask] and it was awesome',
 'predicted mask token': 'miniseries',
 'prediction': 'i have watched this miniseries and it was awesome',
 'probability': 0.025324458}

{'input_text': 'i have watched this [mask] and it was awesome',
 'predicted mask token': 'show',
 'prediction': 'i have watched this show and it was awesome',
 'probability': 0.015968936}
1563/1563 [==============================] - 152s 97ms/step - loss: 4.7226

Our mask model looks good to me.

  • Classification Model: I am getting classification accuracy b/w 84-86%

@Ankur3107
Copy link
Contributor Author

Hi @fchollet

Sorry for the delayed response, This is the first time, I am implementing from scratch in Keras (preprocessing->model building->callback->custom training-> end_to_end). I have learned a lot and still learning and trying to improve my code.

Please review and suggest me if there is anything I missed.

@Ankur3107 Ankur3107 requested a review from fchollet September 22, 2020 03:58
Copy link
Contributor

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the update!

return sequence_output


class Encoder(layers.Layer):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would be far more readable as a Sequential or a Functional model. It's just a list of layers applied in sequence, but this class makes it a puzzle to decode.

"""


class BERTLayer(layers.Layer):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could specify this in 8-9 lines in the Functional API.

def bert_module(query, key, value):
   x = layers.MultiHeadAttention(...)(query, key, value)
   x = layers.Dropout(...)(x)
   x = layers.LayerNormalization(epsilon=1e-6)(x)
   x  = layers.Dense(...)(x)
   x  = layers.Dense(...)(x)
   x = layers.Dropout(...)(x)
   x = layers.LayerNormalization(epsilon=1e-6)(x)
   return x

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed



# Loss metric to track masked language model loss
loss_tracker = keras.metrics.Mean(name="mlm_loss")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can create this in the constructor. It should be managed by the model.

Copy link
Contributor Author

@Ankur3107 Ankur3107 Sep 23, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had included it in the constructor but I was facing an issue in saving (.save('model.h5)) because it included the Mean Layer(loss tracker mean metric).

classifer_model.compile(
optimizer=optimizer, loss="binary_crossentropy", metrics=["accuracy"]
)
print(classifer_model.summary())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need for the print statement here.

return id2token[id]

def on_epoch_end(self, epoch, logs=None):
prediction, _ = self.model.predict(self.sample_tokens)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why doesn't predict already return a softmax?

mask_prediction = prediction.numpy()[0][masked_index]

topk = tf.math.top_k(mask_prediction, k=self.k)
values, predictions = topk.values.numpy(), topk.indices.numpy()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Prefer doing everything in NumPy, for consistency

}

print()
pprint(result)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use string formatting

"""

# Load pretrained bert model
BERT_MODEL = keras.models.load_model("mlm_imdb_bert.h5py")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use lowercase variable name

@Ankur3107
Copy link
Contributor Author

Ankur3107 commented Sep 23, 2020

Hi @fchollet

Thanks for your great feedback.

I have updated it according to your guidance.

  • Refactored Code i.e used Functional API
  • Refactored Callback with NumPy function
  • Used softmax inside the model
  • Refactored variable names

I have tried to use the loss tracker in the custom Model Class constructor but I was facing an issue in saving (.save('model.h5)) because the model added the Mean Layer. So for simplicity I keet it outside. Please guide me if anything I need to refactor.

@Ankur3107 Ankur3107 requested a review from fchollet September 23, 2020 03:53
@fchollet
Copy link
Contributor

I was facing an issue in saving (.save('model.h5)) because the model added the Mean Layer.

What was the issue?

@Ankur3107
Copy link
Contributor Author

I made a dummy notebook. https://colab.research.google.com/drive/1GH-WxghmkhgLSO2vVcSKgQ4UI3WX6GKC?usp=sharing

Sorry problem is not with not saving (i misspelled). It's in the loading of the model.

model.save('model.h5')

new_model = keras.models.load_model('model.h5', custom_objects={'CustomModel': CustomModel})

usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/utils/generic_utils.py in class_and_config_for_serialized_keras_object(config, module_objects, custom_objects, printable_module_name)
294 cls = get_registered_object(class_name, custom_objects, module_objects)
295 if cls is None:
--> 296 raise ValueError('Unknown ' + printable_module_name + ': ' + class_name)
297
298 cls_config = config['config']

ValueError: Unknown layer: Mean

new_model = keras.models.load_model('model.h5', custom_objects={'CustomModel': CustomModel, 'Mean':keras.metrics.Mean})

/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/saving/hdf5_format.py in load_weights_from_hdf5_group(f, layers)
684 'containing ' + str(len(layer_names)) +
685 ' layers into a model with ' + str(len(filtered_layers)) +
--> 686 ' layers.')
687
688 # We batch weight value assignments in a single backend call

ValueError: You are trying to load a weight file containing 2 layers into a model with 3 layers.

@fchollet
Copy link
Contributor

Thanks for the report. That's a bug, which we will fix.

Copy link
Contributor

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me! I've pushed some copyedits, and I've refactored the loss function (the way it was written, it was instantiating a SparseCategoricalCrossentropy object every time it was called, but it's actually the same to just instantiate it once and reuse it).

Please add the generated files, and we can merge the example.

Copy link
Contributor

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thank you for the great contribution! 👍

@fchollet fchollet merged commit f553517 into keras-team:master Sep 23, 2020
@Ankur3107
Copy link
Contributor Author

Hi @fchollet

Thank you very much for your guidance. I have added generated files. I am very happy that I have contributed to Keras. I am using Keras for the past 3years and it is a very awesome package and you are a great mentor(saying from bottom of my heart) :)

Regards,
Ankur

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants