Skip to content

Commit

Permalink
Add attention scaling
Browse files Browse the repository at this point in the history
  • Loading branch information
duzx16 committed Apr 25, 2021
1 parent 4bc1c20 commit bf1f708
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 14 deletions.
2 changes: 1 addition & 1 deletion arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def add_fp16_config_args(parser):
help='Window over which to raise/lower dynamic scale')
group.add_argument('--min-scale', type=float, default=1,
help='Minimum loss scale for dynamic loss scale')

group.add_argument('--attention-scale', type=float, default=1.0)
return parser


Expand Down
3 changes: 2 additions & 1 deletion config/ds_blockta_1.25.sh
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,13 @@ gpt_options=" \
--num-attention-heads 18 \
--seq-length 512 \
--max-position-embeddings 1024 \
--attention-scale 8.0 \
--save /dataset/fd5061f6/english_data/checkpoints \
--load /dataset/fd5061f6/english_data/checkpoints/blocklm-roberta-1.25-blank04-22-14-01 \
--save-interval 2500 \
--train-iters 500000 \
--resume-dataloader \
--shuffle \
--filter-english \
--train-data wikibook cc-news openwebtext stories \
--loader-scatter 8 \
--no-lazy-loader \
Expand Down
4 changes: 3 additions & 1 deletion model/gpt2_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ def __init__(self,
block_position_encoding=False,
nonautoregressive=False,
output_predict=True,
spell_length=None
spell_length=None,
attention_scale=1.0,
):
super(GPT2Model, self).__init__()

Expand Down Expand Up @@ -94,6 +95,7 @@ def __init__(self,
output_dropout_prob,
checkpoint_activations,
checkpoint_num_layers,
attention_scale=attention_scale,
relative_encoding=relative_encoding,
block_position_encoding=block_position_encoding)
if spell_length is not None:
Expand Down
35 changes: 25 additions & 10 deletions mpu/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ class ParallelSelfAttention(torch.nn.Module):
def __init__(self, hidden_size, num_attention_heads,
attention_dropout_prob, output_dropout_prob,
init_method, output_layer_init_method=None, relative_encoding=False,
performer=False):
performer=False, attention_scale=1.0):
super(ParallelSelfAttention, self).__init__()
self.performer = performer
# Set output layer initialization if not provided.
Expand All @@ -195,6 +195,7 @@ def __init__(self, hidden_size, num_attention_heads,
self.num_attention_heads_per_partition = divide(num_attention_heads,
world_size)
self.relative_encoding = relative_encoding
self.attention_scale = attention_scale
# Strided linear layer.
self.query_key_value = ColumnParallelLinear(hidden_size, 3 * hidden_size,
stride=3,
Expand Down Expand Up @@ -286,14 +287,24 @@ def forward(self, hidden_states, ltor_mask, position_embeddings=None, r_w_bias=N
attention_scores = ac_score + bd_score
attention_scores = attention_scores / math.sqrt(self.hidden_size_per_attention_head)
else:
# Raw attention scores. [b, np, s, s]
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2) / math.sqrt(
self.hidden_size_per_attention_head))
if self.attention_scale > 1.0:
# Raw attention scores. [b, np, s, s]
attention_scores = torch.matmul(query_layer / math.sqrt(self.attention_scale),
key_layer.transpose(-1, -2) / math.sqrt(
self.hidden_size_per_attention_head * self.attention_scale))
else:
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2) / math.sqrt(
self.hidden_size_per_attention_head))

# Apply the left to right attention mask.
attention_scores = torch.mul(attention_scores, ltor_mask)
min_attention_scores = attention_scores.min().item()
attention_scores = attention_scores + (min_attention_scores - 1000.0) * (1.0 - ltor_mask)
if self.attention_scale > 1.0:
max_attention_scores = attention_scores.max(dim=-1, keepdim=True)[0]
attention_scores -= max_attention_scores
attention_scores *= self.attention_scale
# if torch.distributed.get_rank() == 0:
# print(min_attention_scores, attention_scores.max().item())
attention_scores = attention_scores + (-65504.0) * (1.0 - ltor_mask)
# Attention probabilities. [b, np, s, s]
attention_probs = torch.nn.Softmax(dim=-1)(attention_scores)
# This is actually dropping out entire tokens to attend to, which might
Expand Down Expand Up @@ -516,7 +527,8 @@ def __init__(self,
init_method,
output_layer_init_method=None,
relative_encoding=False,
performer=False):
performer=False,
attention_scale=1.0):
super(ParallelTransformerLayer, self).__init__()
# Set output layer initialization if not provided.
if output_layer_init_method is None:
Expand All @@ -534,7 +546,8 @@ def __init__(self,
init_method,
output_layer_init_method=output_layer_init_method,
relative_encoding=relative_encoding,
performer=performer)
performer=performer,
attention_scale=attention_scale)

# Layernorm on the input data.
self.post_attention_layernorm = LayerNorm(hidden_size,
Expand Down Expand Up @@ -639,7 +652,8 @@ def __init__(self,
relative_encoding=False,
block_position_encoding=False,
performer=False,
use_decoder_layer=False
use_decoder_layer=False,
attention_scale=1.0,
):
super(GPT2ParallelTransformer, self).__init__()
self.hidden_size = hidden_size
Expand Down Expand Up @@ -710,7 +724,8 @@ def get_layer():
unscaled_init_method(init_method_std),
output_layer_init_method=output_layer_init_method,
relative_encoding=relative_encoding,
performer=performer)
performer=performer,
attention_scale=attention_scale)

# Transformer layers.
self.layers = torch.nn.ModuleList(
Expand Down
3 changes: 2 additions & 1 deletion train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ def get_model(args, model_type=None, multi_token=True, num_labels=None, spell_le
block_position_encoding=args.block_lm and not args.masked_lm,
output_predict=output_predict,
spell_length=spell_length,
nonautoregressive=args.nonautoregressive)
nonautoregressive=args.nonautoregressive,
attention_scale=args.attention_scale)
if model_type is not None:
if model_type == 'multiple_choice':
if args.cloze_eval:
Expand Down

0 comments on commit bf1f708

Please sign in to comment.