Skip to content

Commit

Permalink
initial commit for ProphetNet
Browse files Browse the repository at this point in the history
  • Loading branch information
guillaume-be committed Jan 7, 2021
1 parent 7890d2d commit f8ed507
Show file tree
Hide file tree
Showing 8 changed files with 189 additions and 2 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ all-tests = []
features = ["doc-only"]

[dependencies]
rust_tokenizers = "~6.1.0"
rust_tokenizers = {version = "~6.1.1", path="E:/Coding/backup-rust/rust-tokenizers/main"}
tch = "~0.3.0"
serde_json = "1.0.59"
serde = { version = "1.0.117", features = ["derive"] }
Expand Down
51 changes: 51 additions & 0 deletions examples/prophetnet.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
// Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc.
// Copyright 2019 Guillaume Becquin
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

extern crate anyhow;

use rust_bert::bart::{
BartConfig, BartConfigResources, BartMergesResources, BartModel, BartModelResources,
BartVocabResources,
};
use rust_bert::prophetnet::{
ProphetNetConfig, ProphetNetConfigResources, ProphetNetModelResources, ProphetNetVocabResources,
};
use rust_bert::resources::{RemoteResource, Resource};
use rust_bert::Config;
use rust_tokenizers::tokenizer::{
ProphetNetTokenizer, RobertaTokenizer, Tokenizer, TruncationStrategy,
};
use tch::{nn, no_grad, Device, Tensor};

fn main() -> anyhow::Result<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
ProphetNetConfigResources::PROPHETNET_LARGE_UNCASED,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
ProphetNetVocabResources::PROPHETNET_LARGE_UNCASED,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
ProphetNetModelResources::PROPHETNET_LARGE_UNCASED,
));
let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?;
let _weights_path = weights_resource.get_local_path()?;

// Set-up masked LM model
// let device = Device::cuda_if_available();
// let mut vs = nn::VarStore::new(device);
let _tokenizer = ProphetNetTokenizer::from_file(vocab_path.to_str().unwrap(), true, true)?;
let _config = ProphetNetConfig::from_file(config_path);

Ok(())
}
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ pub mod marian;
pub mod mobilebert;
pub mod openai_gpt;
pub mod pipelines;
pub mod prophetnet;
pub mod reformer;
pub mod roberta;
pub mod t5;
Expand Down
11 changes: 11 additions & 0 deletions src/prophetnet/attention.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
// Copyright 2020 The Microsoft Authors and The HuggingFace Inc. team.
// Copyright 2020 Guillaume Becquin
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
29 changes: 29 additions & 0 deletions src/prophetnet/decoder.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// Copyright 2020 The Microsoft Authors and The HuggingFace Inc. team.
// Copyright 2020 Guillaume Becquin
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use tch::{Device, Kind, Tensor};

fn ngram_attention_bias(sequence_length: i64, ngram: i64, device: Device) -> Tensor {
let left_block = Tensor::ones(
&[ngram, sequence_length, sequence_length],
(Kind::Float, device),
) * f64::NEG_INFINITY;
let right_block = left_block.copy();
for stream_idx in 0..ngram {
let _ = right_block.get(stream_idx).fill_diagonal_(0, false);
let _ = left_block.get(stream_idx).triu_(-stream_idx + 1);
}
let _ = left_block
.slice(2, 0, *left_block.size().last().unwrap(), 1)
.fill_(0);
Tensor::cat(&[left_block, right_block], 2)
}
7 changes: 7 additions & 0 deletions src/prophetnet/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
mod attention;
mod decoder;
mod prophetnet_model;

pub use prophetnet_model::{
ProphetNetConfig, ProphetNetConfigResources, ProphetNetModelResources, ProphetNetVocabResources,
};
88 changes: 88 additions & 0 deletions src/prophetnet/prophetnet_model.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
// Copyright 2020 The Microsoft Authors and The HuggingFace Inc. team.
// Copyright 2020 Guillaume Becquin
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use crate::{Activation, Config};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;

/// # ProphetNet Pretrained model weight files
pub struct ProphetNetModelResources;

/// # ProphetNet Pretrained model config files
pub struct ProphetNetConfigResources;

/// # ProphetNet Pretrained model vocab files
pub struct ProphetNetVocabResources;

impl ProphetNetModelResources {
/// Shared under MIT license by the Microsoft team at https://github.com/microsoft/ProphetNet. Modified with conversion to C-array format.
pub const PROPHETNET_LARGE_UNCASED: (&'static str, &'static str) = (
"prophetnet-large-uncased/model",
"https://huggingface.co/microsoft/prophetnet-large-uncased/resolve/main/rust_model.ot",
);
}

impl ProphetNetConfigResources {
/// Shared under MIT license by the Microsoft team at https://github.com/microsoft/ProphetNet. Modified with conversion to C-array format.
pub const PROPHETNET_LARGE_UNCASED: (&'static str, &'static str) = (
"prophetnet-large-uncased/config",
"https://huggingface.co/microsoft/prophetnet-large-uncased/resolve/main/config.json",
);
}

impl ProphetNetVocabResources {
/// Shared under MIT license by the Microsoft team at https://github.com/microsoft/ProphetNet. Modified with conversion to C-array format.
pub const PROPHETNET_LARGE_UNCASED: (&'static str, &'static str) = (
"prophetnet-large-uncased/vocab",
"https://huggingface.co/microsoft/prophetnet-large-uncased/resolve/main/prophetnet.tokenizer",
);
}

#[derive(Debug, Serialize, Deserialize)]
/// # ProphetNet model configuration
/// Defines the ProphetNet model architecture (e.g. number of layers, hidden layer size, label mapping...)
pub struct ProphetNetConfig {
pub activation_function: Activation,
pub activation_dropout: f64,
pub attention_dropout: f64,
pub decoder_ffn_dim: i64,
pub decoder_layerdrop: f64,
pub decoder_max_position_embeddings: i64,
pub decoder_start_token_id: i64,
pub disable_ngram_loss: bool,
pub dropout: f64,
pub encoder_ffn_dim: i64,
pub encoder_layerdrop: f64,
pub encoder_max_position_embeddings: i64,
pub eps: f64,
pub hidden_size: i64,
pub init_std: f64,
pub is_encoder_decoder: bool,
pub max_position_embeddings: i64,
pub bos_token_id: i64,
pub eos_token_id: i64,
pub ngram: i64,
pub id2label: Option<HashMap<i64, String>>,
pub label2id: Option<HashMap<String, i64>>,
pub num_buckets: i64,
pub num_decoder_attention_heads: i64,
pub num_decoder_layers: i64,
pub num_encoder_attention_heads: i64,
pub num_encoder_layers: i64,
pub output_past: Option<bool>,
pub pad_token_id: i64,
pub relative_max_distance: i64,
pub vocab_size: i64,
pub output_attentions: Option<bool>,
pub output_hidden_states: Option<bool>,
}

impl Config<ProphetNetConfig> for ProphetNetConfig {}
2 changes: 1 addition & 1 deletion utils/convert_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
source_file = Path("path/to/pytorch_model.bin")
target_folder = source_file.parent

weights = torch.load(source_file, map_location='cpu')
weights = torch.load(str(source_file), map_location='cpu')

nps = {}
for k, v in weights.items():
Expand Down

0 comments on commit f8ed507

Please sign in to comment.