Skip to content

Commit

Permalink
Update tch to 0.15 (guillaume-be#443)
Browse files Browse the repository at this point in the history
* Update tch to 0.15

It allows to run in system with libtorch v2.2.0, upgrading from v2.1.0

* remove torch-sys dependency from benches

* Updated readmes

---------

Co-authored-by: Guillaume Becquin <guillaume.becquin@gmail.com>
  • Loading branch information
blmarket and guillaume-be authored Feb 11, 2024
1 parent b68f7dc commit 29f9a7a
Show file tree
Hide file tree
Showing 11 changed files with 12 additions and 31 deletions.
7 changes: 5 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,17 @@ All notable changes to this project will be documented in this file. The format

## [Unreleased]

## Changed
- (BREAKING) Upgraded to `torch` 2.2 (via `tch` 0.15.0).

## [0.22.0] - 2024-01-20
## Added
- Addition of `new_with_tokenizer` constructor for `SentenceEmbeddingsModel` allowing passing custom tokenizers for sentence embeddings pipelines.
- Support for [Tokenizers](https://github.com/huggingface/tokenizers) in pipelines, allowing loading `tokenizer.json` and `special_token_map.json` tokenizer files.
- (BREAKING) Most model configuration can now take an optional `kind` parameter to specify the model weight precision. If not provided, will default to full precision on CPU, or the serialized weights precision otherwise.

## Fixed
- (BREAKING) Fixed the keyword extraction pipeline for n-gram sizes > 2. Add new configuration option `tokenizer_forbidden_ngram_chars` to specify characters that should be excluded from n-grams (allows filtering m-grams spanning multiple sentences).
- (BREAKING) Fixed the keyword extraction pipeline for n-gram sizes > 2. Add new configuration option `tokenizer_forbidden_ngram_chars` to specify characters that should be excluded from n-grams (allows filtering n-grams spanning multiple sentences).
- Improved MPS device compatibility setting the `sparse_grad` flag to false for `gather` operations
- Updated ONNX runtime backend version to 1.15.x
- Issue with incorrect results for QA models with a tokenizer not using segment ids
Expand Down Expand Up @@ -449,4 +452,4 @@ All notable changes to this project will be documented in this file. The format

- Tensor conversion tools from Pytorch to Libtorch format
- DistilBERT model architecture
- Ready-to-use `SentimentClassifier` using a DistilBERT model fine-tuned on SST2
- Ready-to-use `SentimentClassifier` using a DistilBERT model fine-tuned on SST2
3 changes: 1 addition & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ features = ["doc-only"]

[dependencies]
rust_tokenizers = "8.1.1"
tch = "0.14.0"
tch = "0.15.0"
serde_json = "1"
serde = { version = "1", features = ["derive"] }
ordered-float = "3"
Expand All @@ -97,7 +97,6 @@ anyhow = "1"
csv = "1"
criterion = "0.5"
tokio = { version = "1.35", features = ["sync", "rt-multi-thread", "macros"] }
torch-sys = "0.14.0"
tempfile = "3"
itertools = "0.12"
tracing-subscriber = { version = "0.3", default-features = false, features = [ "env-filter", "fmt" ] }
Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,8 @@ This cache location defaults to `~/.cache/.rustbert`, but can be changed by sett

### Manual installation (recommended)

1. Download `libtorch` from https://pytorch.org/get-started/locally/. This package requires `v2.1`: if this version is no longer available on the "get started" page,
the file should be accessible by modifying the target link, for example `https://download.pytorch.org/libtorch/cu118/libtorch-cxx11-abi-shared-with-deps-2.1.1%2Bcu118.zip` for a Linux version with CUDA11. **NOTE:** When using `rust-bert` as dependency from [crates.io](https://crates.io), please check the required `LIBTORCH` on the published package [readme](https://crates.io/crates/rust-bert) as it may differ from the version documented here (applying to the current repository version).
1. Download `libtorch` from https://pytorch.org/get-started/locally/. This package requires `v2.2`: if this version is no longer available on the "get started" page,
the file should be accessible by modifying the target link, for example `https://download.pytorch.org/libtorch/cu121/libtorch-cxx11-abi-shared-with-deps-2.2.0%2Bcu121.zip` for a Linux version with CUDA12. **NOTE:** When using `rust-bert` as dependency from [crates.io](https://crates.io), please check the required `LIBTORCH` on the published package [readme](https://crates.io/crates/rust-bert) as it may differ from the version documented here (applying to the current repository version).
2. Extract the library to a location of your choice
3. Set the following environment variables
##### Linux:
Expand Down
4 changes: 0 additions & 4 deletions benches/generation_benchmark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,6 @@ fn generation_forward_pass(iters: u64, model: &TextGenerationModel, data: &[&str
}

fn bench_generation(c: &mut Criterion) {
// Set-up summarization model
unsafe {
torch_sys::dummy_cuda_dependency();
}
let model = create_text_generation_model();

// Define input
Expand Down
4 changes: 1 addition & 3 deletions benches/squad_benchmark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,7 @@ fn qa_load_model(iters: u64) -> Duration {
fn bench_squad(c: &mut Criterion) {
// Set-up QA model
let model = create_qa_model();
unsafe {
torch_sys::dummy_cuda_dependency();
}

// Define input
let mut squad_path = PathBuf::from(env::var("squad_dataset")
.expect("Please set the \"squad_dataset\" environment variable pointing to the SQuAD dataset folder"));
Expand Down
4 changes: 1 addition & 3 deletions benches/sst2_benchmark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,7 @@ fn sst2_load_model(iters: u64) -> Duration {
fn bench_sst2(c: &mut Criterion) {
// Set-up classifier
let model = create_sentiment_model();
unsafe {
torch_sys::dummy_cuda_dependency();
}

// Define input
let mut sst2_path = PathBuf::from(env::var("SST2_PATH").expect(
"Please set the \"SST2_PATH\" environment variable pointing to the SST2 dataset folder",
Expand Down
3 changes: 0 additions & 3 deletions benches/summarization_benchmark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,6 @@ fn summarization_load_model(iters: u64) -> Duration {

fn bench_squad(c: &mut Criterion) {
// Set-up summarization model
unsafe {
torch_sys::dummy_cuda_dependency();
}
let model = create_summarization_model();

// Define input
Expand Down
4 changes: 0 additions & 4 deletions benches/tensor_operations_benchmark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,6 @@ fn matrix_multiply(iters: u64, input: &Tensor, weights: &Tensor) -> Duration {
}

fn bench_tensor_ops(c: &mut Criterion) {
// Set-up summarization model
unsafe {
torch_sys::dummy_cuda_dependency();
}
let input = Tensor::rand([32, 128, 512], (Kind::Float, Device::cuda_if_available()));
let weights = Tensor::rand([512, 512], (Kind::Float, Device::cuda_if_available()));

Expand Down
3 changes: 0 additions & 3 deletions benches/token_classification_benchmark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,6 @@ fn create_model() -> TokenClassificationModel {

fn bench_token_classification_predict(c: &mut Criterion) {
// Set-up model
unsafe {
torch_sys::dummy_cuda_dependency();
}
let model = create_model();

// Define input
Expand Down
3 changes: 0 additions & 3 deletions benches/translation_benchmark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,6 @@ fn translation_load_model(iters: u64) -> Duration {

fn bench_squad(c: &mut Criterion) {
// Set-up translation model
unsafe {
torch_sys::dummy_cuda_dependency();
}
let model = create_translation_model();

// Define input
Expand Down
4 changes: 2 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,8 @@
//!
//! ### Manual installation (recommended)
//!
//! 1. Download `libtorch` from <https://pytorch.org/get-started/locally/>. This package requires `v2.1`: if this version is no longer available on the "get started" page,
//! the file should be accessible by modifying the target link, for example `https://download.pytorch.org/libtorch/cu118/libtorch-cxx11-abi-shared-with-deps-2.1.1%2Bcu118.zip` for a Linux version with CUDA11.
//! 1. Download `libtorch` from <https://pytorch.org/get-started/locally/>. This package requires `v2.2`: if this version is no longer available on the "get started" page,
//! the file should be accessible by modifying the target link, for example `https://download.pytorch.org/libtorch/cu121/libtorch-cxx11-abi-shared-with-deps-2.2.0%2Bcu121.zip` for a Linux version with CUDA12.
//! 2. Extract the library to a location of your choice
//! 3. Set the following environment variables
//! ##### Linux:
Expand Down

0 comments on commit 29f9a7a

Please sign in to comment.