Skip to content

Commit

Permalink
Download DialoGPT dependencies
Browse files Browse the repository at this point in the history
  • Loading branch information
guillaume-be committed Jun 24, 2020
1 parent c98695c commit 30528ca
Show file tree
Hide file tree
Showing 4 changed files with 274 additions and 0 deletions.
20 changes: 20 additions & 0 deletions src/gpt2/gpt2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,11 @@ impl Gpt2ModelResources {
"distilgpt2/model.ot",
"https://cdn.huggingface.co/distilgpt2-rust_model.ot",
);
/// Shared under MIT license by the Microsoft team at https://huggingface.co/microsoft/DialoGPT-medium. Modified with conversion to C-array format.
pub const DIALOGPT_MEDIUM: (&'static str, &'static str) = (
"dialogpt-medium/model.ot",
"https://cdn.huggingface.co/microsoft/DialoGPT-medium/rust_model.ot",
);
}

impl Gpt2ConfigResources {
Expand Down Expand Up @@ -89,6 +94,11 @@ impl Gpt2ConfigResources {
"distilgpt2/config.json",
"https://cdn.huggingface.co/distilgpt2-config.json",
);
/// Shared under MIT license by the Microsoft team at https://huggingface.co/microsoft/DialoGPT-medium. Modified with conversion to C-array format.
pub const DIALOGPT_MEDIUM: (&'static str, &'static str) = (
"dialogpt-medium/config.json",
"https://cdn.huggingface.co/microsoft/DialoGPT-medium/config.json",
);
}

impl Gpt2VocabResources {
Expand Down Expand Up @@ -117,6 +127,11 @@ impl Gpt2VocabResources {
"distilgpt2/vocab.txt",
"https://cdn.huggingface.co/distilgpt2-vocab.json",
);
/// Shared under MIT license by the Microsoft team at https://huggingface.co/microsoft/DialoGPT-medium. Modified with conversion to C-array format.
pub const DIALOGPT_MEDIUM: (&'static str, &'static str) = (
"dialogpt-medium/vocab.txt",
"https://cdn.huggingface.co/microsoft/DialoGPT-medium/vocab.json",
);
}

impl Gpt2MergesResources {
Expand Down Expand Up @@ -145,6 +160,11 @@ impl Gpt2MergesResources {
"distilgpt2/merges.txt",
"https://cdn.huggingface.co/distilgpt2-merges.txt",
);
/// Shared under MIT license by the Microsoft team at https://huggingface.co/microsoft/DialoGPT-medium. Modified with conversion to C-array format.
pub const DIALOGPT_MEDIUM: (&'static str, &'static str) = (
"dialogpt-medium/merges.txt",
"https://cdn.huggingface.co/microsoft/DialoGPT-medium/merges.txt",
);
}

#[allow(non_camel_case_types)]
Expand Down
200 changes: 200 additions & 0 deletions src/pipelines/conversation.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
// // Copyright 2019-present Microsoft
// // Copyright 2020-present, the HuggingFace Inc. team.
// // Copyright 2020 Guillaume Becquin
// // Licensed under the Apache License, Version 2.0 (the "License");
// // you may not use this file except in compliance with the License.
// // You may obtain a copy of the License at
// // http://www.apache.org/licenses/LICENSE-2.0
// // Unless required by applicable law or agreed to in writing, software
// // distributed under the License is distributed on an "AS IS" BASIS,
// // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// // See the License for the specific language governing permissions and
// // limitations under the License.
//
// /// # Disclaimer
// /// This repository aims to facilitate research in large-scale pre-training for conversational data.
// /// This toolkit contains only part of the modeling machinery needed to actually produce a model
// /// weight file in a running dialog. On its own, this model provides only information about the
// /// weights of various text spans; in order for a researcher to actually use it, they will need
// /// to bring conversational data of their own and decode the response generation from the pretrained
// /// system. Neither the author of this repository or Microsoft are responsible for any generation
// /// from the 3rd party utilization of the pretrained system.
// ///
// ///
// ///
// ///
// use crate::bart::{
// BartConfigResources, BartMergesResources, BartModelResources, BartVocabResources,
// };
// use crate::common::resources::{RemoteResource, Resource};
// use crate::pipelines::generation::{BartGenerator, GenerateConfig, LanguageGenerator};
// use tch::Device;
//
// /// # Configuration for multi-turn classification
// /// Contains information regarding the model to load, mirrors the GenerationConfig, with a
// /// different set of default parameters and sets the device to place the model on.
// pub struct ConversationConfig {
// /// Model weights resource (default: DialoGPT-medium)
// pub model_resource: Resource,
// /// Config resource (default: DialoGPT-medium)
// pub config_resource: Resource,
// /// Vocab resource (default: DialoGPT-medium)
// pub vocab_resource: Resource,
// /// Merges resource (default: DialoGPT-medium)
// pub merges_resource: Resource,
// /// Minimum sequence length (default: 0)
// pub min_length: u64,
// /// Maximum sequence length (default: 20)
// pub max_length: u64,
// /// Sampling flag. If true, will perform top-k and/or nucleus sampling on generated tokens, otherwise greedy (deterministic) decoding (default: true)
// pub do_sample: bool,
// /// Early stopping flag indicating if the beam search should stop as soon as `num_beam` hypotheses have been generated (default: false)
// pub early_stopping: bool,
// /// Number of beams for beam search (default: 5)
// pub num_beams: u64,
// /// Temperature setting. Values higher than 1 will improve originality at the risk of reducing relevance (default: 1.0)
// pub temperature: f64,
// /// Top_k values for sampling tokens. Value higher than 0 will enable the feature (default: 0)
// pub top_k: u64,
// /// Top_p value for [Nucleus sampling, Holtzman et al.](http://arxiv.org/abs/1904.09751). Keep top tokens until cumulative probability reaches top_p (default: 0.9)
// pub top_p: f64,
// /// Repetition penalty (mostly useful for CTRL decoders). Values higher than 1 will penalize tokens that have been already generated. (default: 1.0)
// pub repetition_penalty: f64,
// /// Exponential penalty based on the length of the hypotheses generated (default: 1.0)
// pub length_penalty: f64,
// /// Number of allowed repetitions of n-grams. Values higher than 0 turn on this feature (default: 3)
// pub no_repeat_ngram_size: u64,
// /// Number of sequences to return for each prompt text (default: 1)
// pub num_return_sequences: u64,
// /// Device to place the model on (default: CUDA/GPU when available)
// pub device: Device,
// }
//
// impl Default for ConversationConfig {
// fn default() -> ConversationConfig {
// ConversationConfig {
// model_resource: Resource::Remote(RemoteResource::from_pretrained(
// BartModelResources::BART_CNN,
// )),
// config_resource: Resource::Remote(RemoteResource::from_pretrained(
// BartConfigResources::BART_CNN,
// )),
// vocab_resource: Resource::Remote(RemoteResource::from_pretrained(
// BartVocabResources::BART_CNN,
// )),
// merges_resource: Resource::Remote(RemoteResource::from_pretrained(
// BartMergesResources::BART_CNN,
// )),
// min_length: 56,
// max_length: 142,
// do_sample: false,
// early_stopping: false,
// num_beams: 3,
// temperature: 1.0,
// top_k: 50,
// top_p: 1.0,
// repetition_penalty: 1.0,
// length_penalty: 1.0,
// no_repeat_ngram_size: 3,
// num_return_sequences: 1,
// device: Device::cuda_if_available(),
// }
// }
// }
//
// /// # SummarizationModel to perform summarization
// pub struct SummarizationModel {
// model: BartGenerator,
// }
//
// impl SummarizationModel {
// /// Build a new `SummarizationModel`
// ///
// /// # Arguments
// ///
// /// * `summarization_config` - `SummarizationConfig` object containing the resource references (model, vocabulary, configuration), summarization options and device placement (CPU/GPU)
// ///
// /// # Example
// ///
// /// ```no_run
// /// # fn main() -> failure::Fallible<()> {
// /// use rust_bert::pipelines::summarization::SummarizationModel;
// ///
// /// let mut summarization_model = SummarizationModel::new(Default::default())?;
// /// # Ok(())
// /// # }
// /// ```
// pub fn new(summarization_config: SummarizationConfig) -> failure::Fallible<SummarizationModel> {
// let generate_config = GenerateConfig {
// model_resource: summarization_config.model_resource,
// config_resource: summarization_config.config_resource,
// merges_resource: summarization_config.merges_resource,
// vocab_resource: summarization_config.vocab_resource,
// min_length: summarization_config.min_length,
// max_length: summarization_config.max_length,
// do_sample: summarization_config.do_sample,
// early_stopping: summarization_config.early_stopping,
// num_beams: summarization_config.num_beams,
// temperature: summarization_config.temperature,
// top_k: summarization_config.top_k,
// top_p: summarization_config.top_p,
// repetition_penalty: summarization_config.repetition_penalty,
// length_penalty: summarization_config.length_penalty,
// no_repeat_ngram_size: summarization_config.no_repeat_ngram_size,
// num_return_sequences: summarization_config.num_return_sequences,
// device: summarization_config.device,
// };
//
// let model = BartGenerator::new(generate_config)?;
//
// Ok(SummarizationModel { model })
// }
//
// /// Summarize texts provided
// ///
// /// # Arguments
// ///
// /// * `input` - `&[&str]` Array of texts to summarize.
// ///
// /// # Returns
// /// * `Vec<String>` Summarized texts
// ///
// /// # Example
// ///
// /// ```no_run
// /// # fn main() -> failure::Fallible<()> {
// /// use rust_bert::pipelines::generation::LanguageGenerator;
// /// use rust_bert::pipelines::summarization::SummarizationModel;
// /// let model = SummarizationModel::new(Default::default())?;
// ///
// /// let input = ["In findings published Tuesday in Cornell University's arXiv by a team of scientists
// /// from the University of Montreal and a separate report published Wednesday in Nature Astronomy by a team
// /// from University College London (UCL), the presence of water vapour was confirmed in the atmosphere of K2-18b,
// /// a planet circling a star in the constellation Leo. This is the first such discovery in a planet in its star's
// /// habitable zone — not too hot and not too cold for liquid water to exist. The Montreal team, led by Björn Benneke,
// /// used data from the NASA's Hubble telescope to assess changes in the light coming from K2-18b's star as the planet
// /// passed between it and Earth. They found that certain wavelengths of light, which are usually absorbed by water,
// /// weakened when the planet was in the way, indicating not only does K2-18b have an atmosphere, but the atmosphere
// /// contains water in vapour form. The team from UCL then analyzed the Montreal team's data using their own software
// /// and confirmed their conclusion. This was not the first time scientists have found signs of water on an exoplanet,
// /// but previous discoveries were made on planets with high temperatures or other pronounced differences from Earth.
// /// \"This is the first potentially habitable planet where the temperature is right and where we now know there is water,\"
// /// said UCL astronomer Angelos Tsiaras. \"It's the best candidate for habitability right now.\" \"It's a good sign\",
// /// said Ryan Cloutier of the Harvard–Smithsonian Center for Astrophysics, who was not one of either study's authors.
// /// \"Overall,\" he continued, \"the presence of water in its atmosphere certainly improves the prospect of K2-18b being
// /// a potentially habitable planet, but further observations will be required to say for sure. \"
// /// K2-18b was first identified in 2015 by the Kepler space telescope. It is about 110 light-years from Earth and larger
// /// but less dense. Its star, a red dwarf, is cooler than the Sun, but the planet's orbit is much closer, such that a year
// /// on K2-18b lasts 33 Earth days. According to The Guardian, astronomers were optimistic that NASA's James Webb space
// /// telescope — scheduled for launch in 2021 — and the European Space Agency's 2028 ARIEL program, could reveal more
// /// about exoplanets like K2-18b."];
// ///
// /// let output = model.summarize(&input);
// /// # Ok(())
// /// # }
// /// ```
// /// (New sample credits: [WikiNews](https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b))
// pub fn summarize(&self, texts: &[&str]) -> Vec<String> {
// self.model.generate(Some(texts.to_vec()), None)
// }
// }
1 change: 1 addition & 0 deletions src/pipelines/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@
//!
pub mod common;
// pub mod conversation;
pub mod generation;
pub mod ner;
pub mod question_answering;
Expand Down
53 changes: 53 additions & 0 deletions utils/download-dependencies_dialogpt-medium.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from transformers.file_utils import get_from_cache, S3_BUCKET_PREFIX
from pathlib import Path
import shutil
import os
import numpy as np
import torch
import subprocess

ROOT_PATH = S3_BUCKET_PREFIX + '/' + 'microsoft/DialoGPT-medium'

config_path = ROOT_PATH + '/config.json'
vocab_path = ROOT_PATH + '/vocab.json'
merges_path = ROOT_PATH + '/merges.txt'
weights_path = ROOT_PATH + '/pytorch_model.bin'

target_path = Path.home() / 'rustbert' / 'dialogpt-medium'

temp_config = get_from_cache(config_path)
temp_vocab = get_from_cache(vocab_path)
temp_merges = get_from_cache(merges_path)
temp_weights = get_from_cache(weights_path)

os.makedirs(str(target_path), exist_ok=True)

config_path = str(target_path / 'config.json')
vocab_path = str(target_path / 'vocab.json')
merges_path = str(target_path / 'merges.txt')
model_path = str(target_path / 'model.bin')

shutil.copy(temp_config, config_path)
shutil.copy(temp_vocab, vocab_path)
shutil.copy(temp_merges, merges_path)
shutil.copy(temp_weights, model_path)

weights = torch.load(temp_weights, map_location='cpu')
nps = {}
for k, v in weights.items():
nps['transformer.' + k] = np.ascontiguousarray(v.cpu().numpy()).astype(np.float32)
if k == 'wte.weight':
nps['lm_head.weight'] = np.ascontiguousarray(v.cpu().numpy()).astype(np.float32)

np.savez(target_path / 'model.npz', **nps)

source = str(target_path / 'model.npz')
target = str(target_path / 'model.ot')

toml_location = (Path(__file__).resolve() / '..' / '..' / 'Cargo.toml').resolve()

subprocess.call(
['cargo', 'run', '--bin=convert-tensor', '--manifest-path=%s' % toml_location, '--', source, target])

os.remove(str(target_path / 'model.bin'))
os.remove(str(target_path / 'model.npz'))

0 comments on commit 30528ca

Please sign in to comment.