Skip to content

Commit

Permalink
Adding CTRL (squashed commit)
Browse files Browse the repository at this point in the history
adding conversion script

adding first draft of modeling & tokenization

adding placeholder for test files

bunch of changes

registering the tokenizer/model/etc

tests

change link; something is very VERY wrong here

weird end-of-word thingy going on

i think the tokenization works now ; wrote the unit tests

overall structure works;load w next

the monster is alive!

works after some cleanup as well

adding emacs autosave to gitignore

currently only supporting the 48 layer one; seems to infer fine on my macbook

cleanup

fixing some documentation

fixing some documentation

tests passing?

now works on CUDA also

adding greedy?

adding greedy sampling

works well
  • Loading branch information
keskarnitish committed Oct 4, 2019
1 parent 2dc8cb8 commit dbed1c5
Show file tree
Hide file tree
Showing 12 changed files with 1,129 additions and 23 deletions.
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -131,4 +131,7 @@ examples/runs

# data
/data
serialization_dir
serialization_dir

# emacs
*.*~
18 changes: 15 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
<p>State-of-the-art Natural Language Processing for TensorFlow 2.0 and PyTorch
</h3>

🤗 Transformers (formerly known as `pytorch-transformers` and `pytorch-pretrained-bert`) provides state-of-the-art general-purpose architectures (BERT, GPT-2, RoBERTa, XLM, DistilBert, XLNet...) for Natural Language Understanding (NLU) and Natural Language Generation (NLG) with over 32+ pretrained models in 100+ languages and deep interoperability between TensorFlow 2.0 and PyTorch.
🤗 Transformers (formerly known as `pytorch-transformers` and `pytorch-pretrained-bert`) provides state-of-the-art general-purpose architectures (BERT, GPT-2, RoBERTa, XLM, DistilBert, XLNet, CTRL...) for Natural Language Understanding (NLU) and Natural Language Generation (NLG) with over 32+ pretrained models in 100+ languages and deep interoperability between TensorFlow 2.0 and PyTorch.

### Features

Expand Down Expand Up @@ -122,6 +122,7 @@ At some point in the future, you'll be able to seamlessly move from pre-training
7. **[RoBERTa](https://github.com/pytorch/fairseq/tree/master/examples/roberta)** (from Facebook), released together with the paper a [Robustly Optimized BERT Pretraining Approach](https://arxiv.org/abs/1907.11692) by Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du, Mandar Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer, Veselin Stoyanov.
8. **[DistilBERT](https://github.com/huggingface/transformers/tree/master/examples/distillation)** (from HuggingFace), released together with the blogpost [Smaller, faster, cheaper, lighter: Introducing DistilBERT, a distilled version of BERT](https://medium.com/huggingface/distilbert-8cf3380435b5
) by Victor Sanh, Lysandre Debut and Thomas Wolf.
9. **[CTRL](https://github.com/salesforce/ctrl/)** (from Salesforce) released with the paper [CTRL: A Conditional Transformer Language Model for Controllable Generation](https://arxiv.org/abs/1909.05858) by Nitish Shirish Keskar*, Bryan McCann*, Lav R. Varshney, Caiming Xiong and Richard Socher.

These implementations have been tested on several datasets (see the example scripts) and should match the performances of the original implementations (e.g. ~93 F1 on SQuAD for BERT Whole-Word-Masking, ~88 F1 on RocStories for OpenAI GPT, ~18.3 perplexity on WikiText 103 for Transformer-XL, ~0.916 Peason R coefficient on STS-B for XLNet). You can find more details on the performances in the Examples section of the [documentation](https://huggingface.co/transformers/examples.html).

Expand All @@ -148,6 +149,7 @@ from transformers import *
MODELS = [(BertModel, BertTokenizer, 'bert-base-uncased'),
(OpenAIGPTModel, OpenAIGPTTokenizer, 'openai-gpt'),
(GPT2Model, GPT2Tokenizer, 'gpt2'),
(CTRLModel, CTRLTokenizer, 'ctrl'),
(TransfoXLModel, TransfoXLTokenizer, 'transfo-xl-wt103'),
(XLNetModel, XLNetTokenizer, 'xlnet-base-cased'),
(XLMModel, XLMTokenizer, 'xlm-mlm-enfr-1024'),
Expand Down Expand Up @@ -253,7 +255,7 @@ The library comprises several example scripts with SOTA performances for NLU and

- `run_glue.py`: an example fine-tuning Bert, XLNet and XLM on nine different GLUE tasks (*sequence-level classification*)
- `run_squad.py`: an example fine-tuning Bert, XLNet and XLM on the question answering dataset SQuAD 2.0 (*token-level classification*)
- `run_generation.py`: an example using GPT, GPT-2, Transformer-XL and XLNet for conditional language generation
- `run_generation.py`: an example using GPT, GPT-2, CTRL, Transformer-XL and XLNet for conditional language generation
- other model-specific examples (see the documentation).

Here are three quick usage examples for these scripts:
Expand Down Expand Up @@ -391,7 +393,7 @@ python $SQUAD_DIR/evaluate-v1.1.py $SQUAD_DIR/dev-v1.1.json ../models/wwm_uncase

This is the model provided as `bert-large-uncased-whole-word-masking-finetuned-squad`.

### `run_generation.py`: Text generation with GPT, GPT-2, Transformer-XL and XLNet
### `run_generation.py`: Text generation with GPT, GPT-2, CTRL, Transformer-XL and XLNet

A conditional generation script is also included to generate text from a prompt.
The generation script includes the [tricks](https://github.com/rusiaaman/XLNet-gen#methodology) proposed by Aman Rusia to get high quality generation with memory models like Transformer-XL and XLNet (include a predefined text to make short inputs longer).
Expand All @@ -405,6 +407,16 @@ python ./examples/run_generation.py \
--model_name_or_path=gpt2 \
```

and from the Salesforce CTRL model:
```shell
python ./examples/run_generation.py \
--model_type=ctrl \
--length=20 \
--model_name_or_path=gpt2 \
--temperature=0 \
--repetition_penalty=1.2 \
```

## Migrating from pytorch-transformers to transformers

Here is a quick summary of what you should take care of when migrating from `pytorch-transformers` to `transformers`.
Expand Down
34 changes: 25 additions & 9 deletions examples/run_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Conditional text generation with the auto-regressive models of the library (GPT/GPT-2/Transformer-XL/XLNet)
""" Conditional text generation with the auto-regressive models of the library (GPT/GPT-2/CTRL/Transformer-XL/XLNet)
"""
from __future__ import absolute_import, division, print_function, unicode_literals

Expand All @@ -26,13 +26,13 @@
import torch.nn.functional as F
import numpy as np

from transformers import GPT2Config, OpenAIGPTConfig, XLNetConfig, TransfoXLConfig
from transformers import GPT2Config, OpenAIGPTConfig, XLNetConfig, TransfoXLConfig, CTRLConfig

from transformers import GPT2LMHeadModel, GPT2Tokenizer
from transformers import OpenAIGPTLMHeadModel, OpenAIGPTTokenizer
from transformers import XLNetLMHeadModel, XLNetTokenizer
from transformers import TransfoXLLMHeadModel, TransfoXLTokenizer

from transformers import CTRLLMHeadModel, CTRLTokenizer

logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt = '%m/%d/%Y %H:%M:%S',
Expand All @@ -41,10 +41,11 @@

MAX_LENGTH = int(10000) # Hardcoded max length to avoid infinite loop

ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (GPT2Config, OpenAIGPTConfig, XLNetConfig, TransfoXLConfig)), ())
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (GPT2Config, OpenAIGPTConfig, XLNetConfig, TransfoXLConfig, CTRLConfig)), ())

MODEL_CLASSES = {
'gpt2': (GPT2LMHeadModel, GPT2Tokenizer),
'ctrl': (CTRLLMHeadModel, CTRLTokenizer),
'openai-gpt': (OpenAIGPTLMHeadModel, OpenAIGPTTokenizer),
'xlnet': (XLNetLMHeadModel, XLNetTokenizer),
'transfo-xl': (TransfoXLLMHeadModel, TransfoXLTokenizer),
Expand Down Expand Up @@ -103,7 +104,7 @@ def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')
return logits


def sample_sequence(model, length, context, num_samples=1, temperature=1, top_k=0, top_p=0.0, is_xlnet=False, device='cpu'):
def sample_sequence(model, length, context, num_samples=1, temperature=1, top_k=0, top_p=0.0, repetition_penalty=1.0, is_xlnet=False, device='cpu'):
context = torch.tensor(context, dtype=torch.long, device=device)
context = context.unsqueeze(0).repeat(num_samples, 1)
generated = context
Expand All @@ -122,9 +123,17 @@ def sample_sequence(model, length, context, num_samples=1, temperature=1, top_k=
inputs = {'input_ids': input_ids, 'perm_mask': perm_mask, 'target_mapping': target_mapping}

outputs = model(**inputs) # Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet (cached hidden-states)
next_token_logits = outputs[0][0, -1, :] / temperature
next_token_logits = outputs[0][0, -1, :] / (temperature if temperature > 0 else 1.)

# reptition penalty from CTRL (https://arxiv.org/abs/1909.05858)
for _ in set(generated):
next_token_logits[_] /= repetition_penalty

filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
if temperature == 0: #greedy sampling:
next_token = torch.argmax(filtered_logits).unsqueeze(0)
else:
next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
generated = torch.cat((generated, next_token.unsqueeze(0)), dim=1)
return generated

Expand All @@ -138,15 +147,21 @@ def main():
parser.add_argument("--prompt", type=str, default="")
parser.add_argument("--padding_text", type=str, default="")
parser.add_argument("--length", type=int, default=20)
parser.add_argument("--temperature", type=float, default=1.0)
parser.add_argument("--temperature", type=float, default=1.0,
help="temperature of 0 implies greedy sampling")
parser.add_argument("--repetition_penalty", type=float, default=1.0,
help="primarily useful for CTRL model; in that case, use 1.2")
parser.add_argument("--top_k", type=int, default=0)
parser.add_argument("--top_p", type=float, default=0.9)
parser.add_argument("--no_cuda", action='store_true',
help="Avoid using CUDA when available")
parser.add_argument('--seed', type=int, default=42,
help="random seed for initialization")
args = parser.parse_args()

if args.model_type in ["ctrl"]:
if args.temperature > 0.7 :
print('CTRL typically works better with lower temperatures (and lower top_k).')

args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
args.n_gpu = torch.cuda.device_count()

Expand Down Expand Up @@ -180,6 +195,7 @@ def main():
temperature=args.temperature,
top_k=args.top_k,
top_p=args.top_p,
repetition_penalty=args.repetition_penalty,
device=args.device,
is_xlnet=bool(args.model_type == "xlnet"),
)
Expand Down
6 changes: 6 additions & 0 deletions transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from .tokenization_openai import OpenAIGPTTokenizer
from .tokenization_transfo_xl import (TransfoXLTokenizer, TransfoXLCorpus)
from .tokenization_gpt2 import GPT2Tokenizer
from .tokenization_ctrl import CTRLTokenizer
from .tokenization_xlnet import XLNetTokenizer, SPIECE_UNDERLINE
from .tokenization_xlm import XLMTokenizer
from .tokenization_roberta import RobertaTokenizer
Expand All @@ -49,7 +50,9 @@
from .configuration_openai import OpenAIGPTConfig, OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP
from .configuration_transfo_xl import TransfoXLConfig, TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP
from .configuration_gpt2 import GPT2Config, GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP
from .configuration_ctrl import CTRLConfig, CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP
from .configuration_xlnet import XLNetConfig, XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP
from .configuration_ctrl import CTRLConfig, CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP
from .configuration_xlm import XLMConfig, XLM_PRETRAINED_CONFIG_ARCHIVE_MAP
from .configuration_roberta import RobertaConfig, ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP
from .configuration_distilbert import DistilBertConfig, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
Expand All @@ -73,6 +76,9 @@
from .modeling_gpt2 import (GPT2PreTrainedModel, GPT2Model,
GPT2LMHeadModel, GPT2DoubleHeadsModel,
load_tf_weights_in_gpt2, GPT2_PRETRAINED_MODEL_ARCHIVE_MAP)
from .modeling_ctrl import (CTRLPreTrainedModel, CTRLModel,
CTRLLMHeadModel,
CTRL_PRETRAINED_MODEL_ARCHIVE_MAP)
from .modeling_xlnet import (XLNetPreTrainedModel, XLNetModel, XLNetLMHeadModel,
XLNetForSequenceClassification, XLNetForMultipleChoice,
XLNetForQuestionAnsweringSimple, XLNetForQuestionAnswering,
Expand Down
10 changes: 6 additions & 4 deletions transformers/configuration_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from .configuration_xlm import XLMConfig
from .configuration_roberta import RobertaConfig
from .configuration_distilbert import DistilBertConfig
from .configuration_ctrl import CTRLConfig

logger = logging.getLogger(__name__)

Expand All @@ -49,7 +50,7 @@ class method.
- contains `xlnet`: XLNetConfig (XLNet model)
- contains `xlm`: XLMConfig (XLM model)
- contains `roberta`: RobertaConfig (RoBERTa model)
- contains `ctrl` : CTRLConfig (CTRL model)
This class cannot be instantiated using `__init__()` (throw an error).
"""
def __init__(self):
Expand All @@ -71,7 +72,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
- contains `xlnet`: XLNetConfig (XLNet model)
- contains `xlm`: XLMConfig (XLM model)
- contains `roberta`: RobertaConfig (RoBERTa model)
- contains `ctrl` : CTRLConfig (CTRL model)
Params:
pretrained_model_name_or_path: either:
Expand Down Expand Up @@ -129,7 +130,8 @@ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
return XLNetConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
elif 'xlm' in pretrained_model_name_or_path:
return XLMConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)

elif 'ctrl' in pretrained_model_name_or_path:
return CTRLConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
raise ValueError("Unrecognized model identifier in {}. Should contains one of "
"'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
"'xlm', 'roberta'".format(pretrained_model_name_or_path))
"'xlm', 'roberta', 'ctrl'".format(pretrained_model_name_or_path))
144 changes: 144 additions & 0 deletions transformers/configuration_ctrl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
# coding=utf-8
# Copyright 2018 Salesforce and HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Salesforce CTRL configuration """

from __future__ import absolute_import, division, print_function, unicode_literals

import json
import logging
import sys
from io import open

from .configuration_utils import PretrainedConfig

logger = logging.getLogger(__name__)

CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP = {"ctrl": "https://storage.googleapis.com/sf-ctrl/pytorch/ctrl-config.json"}

class CTRLConfig(PretrainedConfig):
"""Configuration class to store the configuration of a `CTRLModel`.
Args:
vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `CTRLModel` or a configuration json file.
n_positions: Number of positional embeddings.
n_ctx: Size of the causal mask (usually same as n_positions).
dff: Size of the inner dimension of the FFN.
n_embd: Dimensionality of the embeddings and hidden states.
n_layer: Number of hidden layers in the Transformer encoder.
n_head: Number of attention heads for each attention layer in
the Transformer encoder.
layer_norm_epsilon: epsilon to use in the layer norm layers
resid_pdrop: The dropout probabilitiy for all fully connected
layers in the embeddings, encoder, and pooler.
attn_pdrop: The dropout ratio for the attention
probabilities.
embd_pdrop: The dropout ratio for the embeddings.
initializer_range: The sttdev of the truncated_normal_initializer for
initializing all weight matrices.
"""
pretrained_config_archive_map = CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP

def __init__(
self,
vocab_size_or_config_json_file=246534,
n_positions=50000,
n_ctx=512,
n_embd=1280,
dff=8192,
n_layer=48,
n_head=16,
resid_pdrop=0.1,
embd_pdrop=0.1,
attn_pdrop=0.1,
layer_norm_epsilon=1e-6,
initializer_range=0.02,

num_labels=1,
summary_type='cls_index',
summary_use_proj=True,
summary_activation=None,
summary_proj_to_labels=True,
summary_first_dropout=0.1,
**kwargs
):
"""Constructs CTRLConfig.
Args:
vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `CTRLModel` or a configuration json file.
n_positions: Number of positional embeddings.
n_ctx: Size of the causal mask (usually same as n_positions).
dff: Size of the inner dimension of the FFN.
n_embd: Dimensionality of the embeddings and hidden states.
n_layer: Number of hidden layers in the Transformer encoder.
n_head: Number of attention heads for each attention layer in
the Transformer encoder.
layer_norm_epsilon: epsilon to use in the layer norm layers
resid_pdrop: The dropout probabilitiy for all fully connected
layers in the embeddings, encoder, and pooler.
attn_pdrop: The dropout ratio for the attention
probabilities.
embd_pdrop: The dropout ratio for the embeddings.
initializer_range: The sttdev of the truncated_normal_initializer for
initializing all weight matrices.
"""
super(CTRLConfig, self).__init__(**kwargs)

if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2
and isinstance(vocab_size_or_config_json_file, unicode)):
with open(vocab_size_or_config_json_file, "r", encoding="utf-8") as reader:
json_config = json.loads(reader.read())
for key, value in json_config.items():
self.__dict__[key] = value
elif isinstance(vocab_size_or_config_json_file, int):
self.vocab_size = vocab_size_or_config_json_file
self.n_ctx = n_ctx
self.n_positions = n_positions
self.n_embd = n_embd
self.n_layer = n_layer
self.n_head = n_head
self.dff = dff
self.resid_pdrop = resid_pdrop
self.embd_pdrop = embd_pdrop
self.attn_pdrop = attn_pdrop
self.layer_norm_epsilon = layer_norm_epsilon
self.initializer_range = initializer_range

self.num_labels = num_labels
self.summary_type = summary_type
self.summary_use_proj = summary_use_proj
self.summary_activation = summary_activation
self.summary_first_dropout = summary_first_dropout
self.summary_proj_to_labels = summary_proj_to_labels
else:
raise ValueError(
"First argument must be either a vocabulary size (int)"
"or the path to a pretrained model config file (str)"
)

@property
def max_position_embeddings(self):
return self.n_positions

@property
def hidden_size(self):
return self.n_embd

@property
def num_attention_heads(self):
return self.n_head

@property
def num_hidden_layers(self):
return self.n_layer
Loading

0 comments on commit dbed1c5

Please sign in to comment.