Skip to content

Commit

Permalink
Added BertSelfAttention
Browse files Browse the repository at this point in the history
  • Loading branch information
guillaume-be committed Feb 16, 2020
1 parent 20ec99e commit 8526a6b
Show file tree
Hide file tree
Showing 4 changed files with 122 additions and 1 deletion.
10 changes: 10 additions & 0 deletions examples/bert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use rust_tokenizers::{BertTokenizer, TruncationStrategy, MultiThreadedTokenizer}
use rust_bert::bert::bert::BertConfig;
use rust_bert::bert::embeddings::BertEmbeddings;
use rust_bert::common::config::Config;
use rust_bert::bert::attention::BertSelfAttention;

fn main() -> failure::Fallible<()> {
// Resources paths
Expand Down Expand Up @@ -43,10 +44,19 @@ fn main() -> failure::Fallible<()> {
// Forward pass

let embeddings = BertEmbeddings::new(&vs.root(), &config);
let bert_self_attention = BertSelfAttention::new(vs.root(), &config);

let output = no_grad(|| {
embeddings
.forward_t(Some(input_tensor), None, None, None, false)
.unwrap()
});

println!("{:?}", output);

let output = no_grad(|| {
bert_self_attention
.forward_t(&output, &None, &None, &None, false)
});

println!("{:?}", output);
Expand Down
109 changes: 109 additions & 0 deletions src/bert/attention.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
// Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc.
// Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
// Copyright 2019 Guillaume Becquin
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use crate::common::dropout::Dropout;
use tch::{nn, Tensor};
use crate::BertConfig;
use tch::kind::Kind::Float;

#[derive(Debug)]
pub struct BertSelfAttention {
num_attention_heads: i64,
attention_head_size: i64,
all_head_size: i64,
dropout: Dropout,
output_attentions: bool,
query: nn::Linear,
key: nn::Linear,
value: nn::Linear,
}

impl BertSelfAttention {
pub fn new(p: nn::Path, config: &BertConfig) -> BertSelfAttention {
assert_eq!(config.hidden_size % config.num_attention_heads, 0, "Hidden size not a multiple of the number of attention heads");
let attention_head_size = config.hidden_size / config.num_attention_heads;
let all_head_size = config.num_attention_heads * attention_head_size;

let query = nn::linear(&p / "query", config.hidden_size, all_head_size, Default::default());
let key = nn::linear(&p / "key", config.hidden_size, all_head_size, Default::default());
let value = nn::linear(&p / "value", config.hidden_size, all_head_size, Default::default());

let dropout = Dropout::new(config.attention_probs_dropout_prob);
let attention_head_size = config.hidden_size / config.num_attention_heads;
let output_attentions = match config.output_attentions {
Some(value) => value,
None => false
};

BertSelfAttention {
num_attention_heads: config.num_attention_heads,
attention_head_size,
all_head_size,
dropout,
output_attentions,
query,
key,
value,
}
}

fn split_heads(&self, x: Tensor, bs: i64, dim_per_head: i64) -> Tensor {
x.view((bs, -1, self.num_attention_heads, dim_per_head)).transpose(1, 2)
}

fn flatten(&self, x: Tensor, bs: i64, dim_per_head: i64) -> Tensor {
x.transpose(1, 2).contiguous().view((bs, -1, &self.num_attention_heads * dim_per_head))
}

pub fn forward_t(&self,
hidden_states: &Tensor,
mask: &Option<Tensor>,
encoder_hidden_states: &Option<Tensor>,
encoder_mask: &Option<Tensor>,
train: bool) -> (Tensor, Option<Tensor>) {
let (key_layer, value_layer, mask) = match encoder_hidden_states {
Some(encoder_hidden_state_values) => {
(encoder_hidden_state_values.apply(&self.key),
encoder_hidden_state_values.apply(&self.value),
encoder_mask)
}
None => {
(hidden_states.apply(&self.key),
hidden_states.apply(&self.value),
mask)
}
};

let bs = hidden_states.size()[0];

let query_layer = self.split_heads(hidden_states.apply(&self.query), bs, self.attention_head_size);
let key_layer = self.split_heads(key_layer, bs, self.attention_head_size);
let value_layer = self.split_heads(value_layer, bs, self.attention_head_size);
let query_layer: Tensor = query_layer / (self.attention_head_size as f64).sqrt();

let scores = if let Some(mask) = mask {
query_layer.matmul(&key_layer.transpose(-1, -2)) + mask
} else {
query_layer.matmul(&key_layer.transpose(-1, -2))
};

let weights = scores.softmax(-1, Float).apply_t(&self.dropout, train);
let context = self.flatten(weights.matmul(&value_layer), bs, self.attention_head_size);

if !self.output_attentions {
(context, None)
} else {
(context, Some(weights))
}
}
}
1 change: 1 addition & 0 deletions src/bert/bert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ pub struct BertConfig {
pub num_hidden_layers: i64,
pub type_vocab_size: i64,
pub vocab_size: i64,
pub output_attentions: Option<bool>,
}

impl Config<BertConfig> for BertConfig {}
Expand Down
3 changes: 2 additions & 1 deletion src/bert/mod.rs
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
pub mod bert;
pub mod embeddings;
pub mod embeddings;
pub mod attention;

0 comments on commit 8526a6b

Please sign in to comment.