From f8ed5079c13080e947f99548f723015bd144e526 Mon Sep 17 00:00:00 2001 From: Guillaume B Date: Thu, 7 Jan 2021 10:41:50 +0100 Subject: [PATCH] initial commit for ProphetNet --- Cargo.toml | 2 +- examples/prophetnet.rs | 51 +++++++++++++++++ src/lib.rs | 1 + src/prophetnet/attention.rs | 11 ++++ src/prophetnet/decoder.rs | 29 ++++++++++ src/prophetnet/mod.rs | 7 +++ src/prophetnet/prophetnet_model.rs | 88 ++++++++++++++++++++++++++++++ utils/convert_model.py | 2 +- 8 files changed, 189 insertions(+), 2 deletions(-) create mode 100644 examples/prophetnet.rs create mode 100644 src/prophetnet/attention.rs create mode 100644 src/prophetnet/decoder.rs create mode 100644 src/prophetnet/mod.rs create mode 100644 src/prophetnet/prophetnet_model.rs diff --git a/Cargo.toml b/Cargo.toml index 27e987986..065136ccf 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -57,7 +57,7 @@ all-tests = [] features = ["doc-only"] [dependencies] -rust_tokenizers = "~6.1.0" +rust_tokenizers = {version = "~6.1.1", path="E:/Coding/backup-rust/rust-tokenizers/main"} tch = "~0.3.0" serde_json = "1.0.59" serde = { version = "1.0.117", features = ["derive"] } diff --git a/examples/prophetnet.rs b/examples/prophetnet.rs new file mode 100644 index 000000000..7a2b5e0d8 --- /dev/null +++ b/examples/prophetnet.rs @@ -0,0 +1,51 @@ +// Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc. +// Copyright 2019 Guillaume Becquin +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +extern crate anyhow; + +use rust_bert::bart::{ + BartConfig, BartConfigResources, BartMergesResources, BartModel, BartModelResources, + BartVocabResources, +}; +use rust_bert::prophetnet::{ + ProphetNetConfig, ProphetNetConfigResources, ProphetNetModelResources, ProphetNetVocabResources, +}; +use rust_bert::resources::{RemoteResource, Resource}; +use rust_bert::Config; +use rust_tokenizers::tokenizer::{ + ProphetNetTokenizer, RobertaTokenizer, Tokenizer, TruncationStrategy, +}; +use tch::{nn, no_grad, Device, Tensor}; + +fn main() -> anyhow::Result<()> { + // Resources paths + let config_resource = Resource::Remote(RemoteResource::from_pretrained( + ProphetNetConfigResources::PROPHETNET_LARGE_UNCASED, + )); + let vocab_resource = Resource::Remote(RemoteResource::from_pretrained( + ProphetNetVocabResources::PROPHETNET_LARGE_UNCASED, + )); + let weights_resource = Resource::Remote(RemoteResource::from_pretrained( + ProphetNetModelResources::PROPHETNET_LARGE_UNCASED, + )); + let config_path = config_resource.get_local_path()?; + let vocab_path = vocab_resource.get_local_path()?; + let _weights_path = weights_resource.get_local_path()?; + + // Set-up masked LM model + // let device = Device::cuda_if_available(); + // let mut vs = nn::VarStore::new(device); + let _tokenizer = ProphetNetTokenizer::from_file(vocab_path.to_str().unwrap(), true, true)?; + let _config = ProphetNetConfig::from_file(config_path); + + Ok(()) +} diff --git a/src/lib.rs b/src/lib.rs index 9dedfffc9..9fc41527c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -73,6 +73,7 @@ pub mod marian; pub mod mobilebert; pub mod openai_gpt; pub mod pipelines; +pub mod prophetnet; pub mod reformer; pub mod roberta; pub mod t5; diff --git a/src/prophetnet/attention.rs b/src/prophetnet/attention.rs new file mode 100644 index 000000000..eeba4f34f --- /dev/null +++ b/src/prophetnet/attention.rs @@ -0,0 +1,11 @@ +// Copyright 2020 The Microsoft Authors and The HuggingFace Inc. team. +// Copyright 2020 Guillaume Becquin +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. diff --git a/src/prophetnet/decoder.rs b/src/prophetnet/decoder.rs new file mode 100644 index 000000000..6f999bc20 --- /dev/null +++ b/src/prophetnet/decoder.rs @@ -0,0 +1,29 @@ +// Copyright 2020 The Microsoft Authors and The HuggingFace Inc. team. +// Copyright 2020 Guillaume Becquin +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use tch::{Device, Kind, Tensor}; + +fn ngram_attention_bias(sequence_length: i64, ngram: i64, device: Device) -> Tensor { + let left_block = Tensor::ones( + &[ngram, sequence_length, sequence_length], + (Kind::Float, device), + ) * f64::NEG_INFINITY; + let right_block = left_block.copy(); + for stream_idx in 0..ngram { + let _ = right_block.get(stream_idx).fill_diagonal_(0, false); + let _ = left_block.get(stream_idx).triu_(-stream_idx + 1); + } + let _ = left_block + .slice(2, 0, *left_block.size().last().unwrap(), 1) + .fill_(0); + Tensor::cat(&[left_block, right_block], 2) +} diff --git a/src/prophetnet/mod.rs b/src/prophetnet/mod.rs new file mode 100644 index 000000000..e6b9de8a8 --- /dev/null +++ b/src/prophetnet/mod.rs @@ -0,0 +1,7 @@ +mod attention; +mod decoder; +mod prophetnet_model; + +pub use prophetnet_model::{ + ProphetNetConfig, ProphetNetConfigResources, ProphetNetModelResources, ProphetNetVocabResources, +}; diff --git a/src/prophetnet/prophetnet_model.rs b/src/prophetnet/prophetnet_model.rs new file mode 100644 index 000000000..b45952c1d --- /dev/null +++ b/src/prophetnet/prophetnet_model.rs @@ -0,0 +1,88 @@ +// Copyright 2020 The Microsoft Authors and The HuggingFace Inc. team. +// Copyright 2020 Guillaume Becquin +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +use crate::{Activation, Config}; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; + +/// # ProphetNet Pretrained model weight files +pub struct ProphetNetModelResources; + +/// # ProphetNet Pretrained model config files +pub struct ProphetNetConfigResources; + +/// # ProphetNet Pretrained model vocab files +pub struct ProphetNetVocabResources; + +impl ProphetNetModelResources { + /// Shared under MIT license by the Microsoft team at https://github.com/microsoft/ProphetNet. Modified with conversion to C-array format. + pub const PROPHETNET_LARGE_UNCASED: (&'static str, &'static str) = ( + "prophetnet-large-uncased/model", + "https://huggingface.co/microsoft/prophetnet-large-uncased/resolve/main/rust_model.ot", + ); +} + +impl ProphetNetConfigResources { + /// Shared under MIT license by the Microsoft team at https://github.com/microsoft/ProphetNet. Modified with conversion to C-array format. + pub const PROPHETNET_LARGE_UNCASED: (&'static str, &'static str) = ( + "prophetnet-large-uncased/config", + "https://huggingface.co/microsoft/prophetnet-large-uncased/resolve/main/config.json", + ); +} + +impl ProphetNetVocabResources { + /// Shared under MIT license by the Microsoft team at https://github.com/microsoft/ProphetNet. Modified with conversion to C-array format. + pub const PROPHETNET_LARGE_UNCASED: (&'static str, &'static str) = ( + "prophetnet-large-uncased/vocab", + "https://huggingface.co/microsoft/prophetnet-large-uncased/resolve/main/prophetnet.tokenizer", + ); +} + +#[derive(Debug, Serialize, Deserialize)] +/// # ProphetNet model configuration +/// Defines the ProphetNet model architecture (e.g. number of layers, hidden layer size, label mapping...) +pub struct ProphetNetConfig { + pub activation_function: Activation, + pub activation_dropout: f64, + pub attention_dropout: f64, + pub decoder_ffn_dim: i64, + pub decoder_layerdrop: f64, + pub decoder_max_position_embeddings: i64, + pub decoder_start_token_id: i64, + pub disable_ngram_loss: bool, + pub dropout: f64, + pub encoder_ffn_dim: i64, + pub encoder_layerdrop: f64, + pub encoder_max_position_embeddings: i64, + pub eps: f64, + pub hidden_size: i64, + pub init_std: f64, + pub is_encoder_decoder: bool, + pub max_position_embeddings: i64, + pub bos_token_id: i64, + pub eos_token_id: i64, + pub ngram: i64, + pub id2label: Option>, + pub label2id: Option>, + pub num_buckets: i64, + pub num_decoder_attention_heads: i64, + pub num_decoder_layers: i64, + pub num_encoder_attention_heads: i64, + pub num_encoder_layers: i64, + pub output_past: Option, + pub pad_token_id: i64, + pub relative_max_distance: i64, + pub vocab_size: i64, + pub output_attentions: Option, + pub output_hidden_states: Option, +} + +impl Config for ProphetNetConfig {} diff --git a/utils/convert_model.py b/utils/convert_model.py index 272c482f1..5f87a45f6 100644 --- a/utils/convert_model.py +++ b/utils/convert_model.py @@ -7,7 +7,7 @@ source_file = Path("path/to/pytorch_model.bin") target_folder = source_file.parent -weights = torch.load(source_file, map_location='cpu') +weights = torch.load(str(source_file), map_location='cpu') nps = {} for k, v in weights.items():