diff --git a/examples/bert.rs b/examples/bert.rs index b4554ddbc..02b3d64ec 100644 --- a/examples/bert.rs +++ b/examples/bert.rs @@ -7,6 +7,7 @@ use rust_tokenizers::{BertTokenizer, TruncationStrategy, MultiThreadedTokenizer} use rust_bert::bert::bert::BertConfig; use rust_bert::bert::embeddings::BertEmbeddings; use rust_bert::common::config::Config; +use rust_bert::bert::attention::BertSelfAttention; fn main() -> failure::Fallible<()> { // Resources paths @@ -43,10 +44,19 @@ fn main() -> failure::Fallible<()> { // Forward pass let embeddings = BertEmbeddings::new(&vs.root(), &config); + let bert_self_attention = BertSelfAttention::new(vs.root(), &config); let output = no_grad(|| { embeddings .forward_t(Some(input_tensor), None, None, None, false) + .unwrap() + }); + + println!("{:?}", output); + + let output = no_grad(|| { + bert_self_attention + .forward_t(&output, &None, &None, &None, false) }); println!("{:?}", output); diff --git a/src/bert/attention.rs b/src/bert/attention.rs new file mode 100644 index 000000000..8df494152 --- /dev/null +++ b/src/bert/attention.rs @@ -0,0 +1,109 @@ +// Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc. +// Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +// Copyright 2019 Guillaume Becquin +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use crate::common::dropout::Dropout; +use tch::{nn, Tensor}; +use crate::BertConfig; +use tch::kind::Kind::Float; + +#[derive(Debug)] +pub struct BertSelfAttention { + num_attention_heads: i64, + attention_head_size: i64, + all_head_size: i64, + dropout: Dropout, + output_attentions: bool, + query: nn::Linear, + key: nn::Linear, + value: nn::Linear, +} + +impl BertSelfAttention { + pub fn new(p: nn::Path, config: &BertConfig) -> BertSelfAttention { + assert_eq!(config.hidden_size % config.num_attention_heads, 0, "Hidden size not a multiple of the number of attention heads"); + let attention_head_size = config.hidden_size / config.num_attention_heads; + let all_head_size = config.num_attention_heads * attention_head_size; + + let query = nn::linear(&p / "query", config.hidden_size, all_head_size, Default::default()); + let key = nn::linear(&p / "key", config.hidden_size, all_head_size, Default::default()); + let value = nn::linear(&p / "value", config.hidden_size, all_head_size, Default::default()); + + let dropout = Dropout::new(config.attention_probs_dropout_prob); + let attention_head_size = config.hidden_size / config.num_attention_heads; + let output_attentions = match config.output_attentions { + Some(value) => value, + None => false + }; + + BertSelfAttention { + num_attention_heads: config.num_attention_heads, + attention_head_size, + all_head_size, + dropout, + output_attentions, + query, + key, + value, + } + } + + fn split_heads(&self, x: Tensor, bs: i64, dim_per_head: i64) -> Tensor { + x.view((bs, -1, self.num_attention_heads, dim_per_head)).transpose(1, 2) + } + + fn flatten(&self, x: Tensor, bs: i64, dim_per_head: i64) -> Tensor { + x.transpose(1, 2).contiguous().view((bs, -1, &self.num_attention_heads * dim_per_head)) + } + + pub fn forward_t(&self, + hidden_states: &Tensor, + mask: &Option, + encoder_hidden_states: &Option, + encoder_mask: &Option, + train: bool) -> (Tensor, Option) { + let (key_layer, value_layer, mask) = match encoder_hidden_states { + Some(encoder_hidden_state_values) => { + (encoder_hidden_state_values.apply(&self.key), + encoder_hidden_state_values.apply(&self.value), + encoder_mask) + } + None => { + (hidden_states.apply(&self.key), + hidden_states.apply(&self.value), + mask) + } + }; + + let bs = hidden_states.size()[0]; + + let query_layer = self.split_heads(hidden_states.apply(&self.query), bs, self.attention_head_size); + let key_layer = self.split_heads(key_layer, bs, self.attention_head_size); + let value_layer = self.split_heads(value_layer, bs, self.attention_head_size); + let query_layer: Tensor = query_layer / (self.attention_head_size as f64).sqrt(); + + let scores = if let Some(mask) = mask { + query_layer.matmul(&key_layer.transpose(-1, -2)) + mask + } else { + query_layer.matmul(&key_layer.transpose(-1, -2)) + }; + + let weights = scores.softmax(-1, Float).apply_t(&self.dropout, train); + let context = self.flatten(weights.matmul(&value_layer), bs, self.attention_head_size); + + if !self.output_attentions { + (context, None) + } else { + (context, Some(weights)) + } + } +} diff --git a/src/bert/bert.rs b/src/bert/bert.rs index 4d6e1a08d..757b9af8d 100644 --- a/src/bert/bert.rs +++ b/src/bert/bert.rs @@ -36,6 +36,7 @@ pub struct BertConfig { pub num_hidden_layers: i64, pub type_vocab_size: i64, pub vocab_size: i64, + pub output_attentions: Option, } impl Config for BertConfig {} diff --git a/src/bert/mod.rs b/src/bert/mod.rs index 907c32b84..1eb29e20d 100644 --- a/src/bert/mod.rs +++ b/src/bert/mod.rs @@ -1,2 +1,3 @@ pub mod bert; -pub mod embeddings; \ No newline at end of file +pub mod embeddings; +pub mod attention; \ No newline at end of file