Skip to content

Commit

Permalink
examples testing
Browse files Browse the repository at this point in the history
  • Loading branch information
thomwolf committed Feb 18, 2019
1 parent 5ff0c60 commit fbb248a
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 15 deletions.
15 changes: 11 additions & 4 deletions examples/run_gpt2_generate_unconditional_samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
import logging

import torch
import torch.nn.functional as F
import numpy as np
from tqdm import trange

from pytorch_pretrained_bert import GPT2LMHeadModel, GPT2Tokenizer

Expand All @@ -23,18 +25,20 @@ def top_k_logits(logits, k):
def sample_sequence(model, length, start_token=None, batch_size=None, context=None, temperature=1, top_k=0, device='cuda'):
if start_token is None:
assert context is not None, 'Specify exactly one of start_token and context!'
context = torch.tensor(context, device=device)
context = torch.tensor(context, device=device, dtype=torch.long)
else:
assert context is None, 'Specify exactly one of start_token and context!'
context = torch.full((batch_size, 1), start_token, device=device)
context = torch.full((batch_size, 1), start_token, device=device, dtype=torch.long)
prev = context
output = context
past = None
with torch.no_grad():
for i in range(length):
for i in trange(length):
logits, past = model(prev, past=past)
logits = logits[:, -1, :] / temperature
logits = top_k_logits(logits, k=top_k)
prev = torch.multinomial(logits, 1)
log_probs = F.softmax(logits, dim=-1)
prev = torch.multinomial(log_probs, num_samples=1)
output = torch.cat((output, prev), dim=1)
return output

Expand All @@ -57,6 +61,8 @@ def sample_model():

enc = GPT2Tokenizer.from_pretrained(args.model_name_or_path)
model = GPT2LMHeadModel.from_pretrained(args.model_name_or_path)
model.to(device)
model.eval()

if args.length == -1:
args.length = model.config.n_ctx
Expand All @@ -71,6 +77,7 @@ def sample_model():
batch_size=args.batch_size,
temperature=args.temperature, top_k=args.top_k, device=device
)
out = out.tolist()
for i in range(args.batch_size):
generated += args.batch_size
text = enc.decode(out[i])
Expand Down
18 changes: 12 additions & 6 deletions examples/run_gpt2_interactive_conditional_samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@

import argparse
import logging
from tqdm import trange

import torch
import torch.nn.functional as F
import numpy as np

from pytorch_pretrained_bert import GPT2LMHeadModel, GPT2Tokenizer
Expand All @@ -23,18 +25,20 @@ def top_k_logits(logits, k):
def sample_sequence(model, length, start_token=None, batch_size=None, context=None, temperature=1, top_k=0, device='cuda'):
if start_token is None:
assert context is not None, 'Specify exactly one of start_token and context!'
context = torch.tensor(context, device=device)
context = torch.tensor(context, device=device, dtype=torch.long).unsqueeze(0).repeat(batch_size, 1)
else:
assert context is None, 'Specify exactly one of start_token and context!'
context = torch.full((batch_size, 1), start_token, device=device)
context = torch.full((batch_size, 1), start_token, device=device, dtype=torch.long)
prev = context
output = context
past = None
with torch.no_grad():
for i in range(length):
for i in trange(length):
logits, past = model(prev, past=past)
logits = logits[:, -1, :] / temperature
logits = top_k_logits(logits, k=top_k)
prev = torch.multinomial(logits, 1)
log_probs = F.softmax(logits, dim=-1)
prev = torch.multinomial(log_probs, num_samples=1)
output = torch.cat((output, prev), dim=1)
return output

Expand All @@ -50,7 +54,7 @@ def interact_model():
args = parser.parse_args()
print(args)

if args.batch_size is None:
if args.batch_size == -1:
args.batch_size = 1
assert args.nsamples % args.batch_size == 0

Expand All @@ -61,6 +65,8 @@ def interact_model():

enc = GPT2Tokenizer.from_pretrained(args.model_name_or_path)
model = GPT2LMHeadModel.from_pretrained(args.model_name_or_path)
model.to(device)
model.eval()

if args.length == -1:
args.length = model.config.n_ctx // 2
Expand All @@ -81,7 +87,7 @@ def interact_model():
batch_size=args.batch_size,
temperature=args.temperature, top_k=args.top_k, device=device
)
out = out[:, len(context_tokens):]
out = out[:, len(context_tokens):].tolist()
for i in range(args.batch_size):
generated += 1
text = enc.decode(out[i])
Expand Down
10 changes: 5 additions & 5 deletions pytorch_pretrained_bert/modeling_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,10 +244,10 @@ def forward(self, x, layer_past=None):
key = self.split_heads(key, k=True)
value = self.split_heads(value)
if layer_past is not None:
past_key, past_value = layer_past[0], layer_past[1]
key = torch.cat((past_key, key), dim=-2)
past_key, past_value = layer_past[0].transpose(-2, -1), layer_past[1] # transpose to have same shapes
key = torch.cat((past_key, key), dim=-1)
value = torch.cat((past_value, value), dim=-2)
present = torch.stack((key, value))
present = torch.stack((key.transpose(-2, -1), value))
a = self._attn(query, key, value)
a = self.merge_heads(a)
a = self.c_proj(a)
Expand Down Expand Up @@ -278,7 +278,7 @@ def __init__(self, n_ctx, config, scale=False):
self.mlp = MLP(4 * nx, config)

def forward(self, x, layer_past=None):
a, present = self.attn(self.ln_1(x), layer_past=past)
a, present = self.attn(self.ln_1(x), layer_past=layer_past)
x = x + a
m = self.mlp(self.ln_2(x))
x = x + m
Expand Down Expand Up @@ -531,7 +531,7 @@ def forward(self, input_ids, position_ids=None, token_type_ids=None, past=None):
past_length = 0
past = [None] * len(self.h)
else:
past[0][0].size(-2)
past_length = past[0][0].size(-2)
if position_ids is None:
position_ids = torch.arange(past_length, input_ids.size(-1) + past_length, dtype=torch.long, device=input_ids.device)
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
Expand Down

0 comments on commit fbb248a

Please sign in to comment.