forked from tracel-ai/burn
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Example/text classification (tracel-ai#123)
- Loading branch information
1 parent
7c38a98
commit eee90a5
Showing
13 changed files
with
578 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
[package] | ||
name = "text-classification" | ||
version = "0.1.0" | ||
authors = ["nathanielsimard <nathaniel.simard.42@gmail.com>"] | ||
license = "MIT/Apache-2.0" | ||
edition = "2021" | ||
publish = false | ||
|
||
[features] | ||
default = [] | ||
|
||
[dependencies] | ||
# Burn | ||
burn = { path = "../../burn" } | ||
burn-autodiff = { path = "../../burn-autodiff" } | ||
burn-tch = { path = "../../burn-tch" } | ||
|
||
# Tokenizer | ||
tokenizers = { version = "0.13", default-features = false, features = ["onig", "http"] } | ||
|
||
# Utils | ||
derive-new = "0.5" | ||
serde = { version = "1.0", features = ["derive"] } |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
# Text Classification | ||
|
||
The example can be run like so: | ||
|
||
```bash | ||
git clone https://github.com/burn-rs/burn.git | ||
cd burn | ||
# Use the --release flag to really speed up training. | ||
export TORCH_CUDA_VERSION=cu113 # Set the cuda version | ||
cargo run --example text-classification-ag-news --release # Train on the ag news dataset | ||
cargo run --example text-classification-db-pedia --release # Train on the db pedia dataset | ||
``` |
23 changes: 23 additions & 0 deletions
23
examples/text-classification/examples/text-classification-ag-news.rs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
use burn::optim::{decay::WeightDecayConfig, momentum::MomentumConfig}; | ||
use text_classification::{training::ExperimentConfig, AgNewsDataset}; | ||
|
||
type Backend = burn_autodiff::ADBackendDecorator<burn_tch::TchBackend<burn::tensor::f16>>; | ||
|
||
fn main() { | ||
let config = ExperimentConfig::new( | ||
burn::nn::transformer::TransformerEncoderConfig::new(256, 512, 4, 4), | ||
burn::optim::SgdConfig::new() | ||
.with_learning_rate(5.0e-3) | ||
.with_momentum(None) | ||
.with_weight_decay(Some(WeightDecayConfig::new(5e-4))) | ||
.with_momentum(Some(MomentumConfig::new().with_nesterov(true))), | ||
); | ||
|
||
text_classification::training::train::<Backend, AgNewsDataset>( | ||
burn_tch::TchDevice::Cuda(0), | ||
AgNewsDataset::train(), | ||
AgNewsDataset::test(), | ||
config, | ||
"/tmp/text-classification-ag-news", | ||
); | ||
} |
22 changes: 22 additions & 0 deletions
22
examples/text-classification/examples/text-classification-db-pedia.rs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
use burn::optim::{decay::WeightDecayConfig, momentum::MomentumConfig}; | ||
use text_classification::{training::ExperimentConfig, DbPediaDataset}; | ||
|
||
type Backend = burn_autodiff::ADBackendDecorator<burn_tch::TchBackend<burn::tensor::f16>>; | ||
|
||
fn main() { | ||
let config = ExperimentConfig::new( | ||
burn::nn::transformer::TransformerEncoderConfig::new(256, 512, 4, 4), | ||
burn::optim::SgdConfig::new() | ||
.with_learning_rate(5.0e-3) | ||
.with_momentum(Some(MomentumConfig::new().with_nesterov(true))) | ||
.with_weight_decay(Some(WeightDecayConfig::new(5e-4))), | ||
); | ||
|
||
text_classification::training::train::<Backend, DbPediaDataset>( | ||
burn_tch::TchDevice::Cuda(0), | ||
DbPediaDataset::train(), | ||
DbPediaDataset::test(), | ||
config, | ||
"/tmp/text-classification-db-pedia", | ||
); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
use super::{dataset::TextClassificationItem, tokenizer::Tokenizer}; | ||
use burn::{ | ||
data::dataloader::batcher::Batcher, | ||
tensor::{backend::Backend, BoolTensor, Data, Shape, Tensor}, | ||
}; | ||
use std::sync::Arc; | ||
|
||
#[derive(new)] | ||
pub struct TextClassificationBatcher<B: Backend> { | ||
tokenizer: Arc<dyn Tokenizer>, | ||
num_classes: usize, | ||
device: B::Device, | ||
max_seq_lenght: usize, | ||
} | ||
|
||
#[derive(Debug, Clone, new)] | ||
pub struct TextClassificationBatch<B: Backend> { | ||
pub tokens: Tensor<B::IntegerBackend, 2>, | ||
pub labels: Tensor<B, 2>, | ||
pub mask_pad: BoolTensor<B, 2>, | ||
} | ||
|
||
impl<B: Backend> Batcher<TextClassificationItem, TextClassificationBatch<B>> | ||
for TextClassificationBatcher<B> | ||
{ | ||
fn batch(&self, items: Vec<TextClassificationItem>) -> TextClassificationBatch<B> { | ||
let mut tokens_list = Vec::with_capacity(items.len()); | ||
let mut labels_list = Vec::with_capacity(items.len()); | ||
|
||
for item in items { | ||
tokens_list.push(self.tokenizer.encode(&item.text)); | ||
labels_list.push(Tensor::one_hot(item.label, self.num_classes)); | ||
} | ||
|
||
let (tokens, mask_pad) = | ||
pad_tokens::<B>(self.tokenizer.pad_token(), tokens_list, self.max_seq_lenght); | ||
|
||
TextClassificationBatch { | ||
tokens: tokens.to_device(self.device).detach(), | ||
labels: Tensor::cat(labels_list, 0).to_device(self.device).detach(), | ||
mask_pad: mask_pad.to_device(self.device), | ||
} | ||
} | ||
} | ||
|
||
pub fn pad_tokens<B: Backend>( | ||
pad_token: usize, | ||
tokens_list: Vec<Vec<usize>>, | ||
max_seq_lenght: usize, | ||
) -> (Tensor<B::IntegerBackend, 2>, BoolTensor<B, 2>) { | ||
let mut max_size = 0; | ||
let batch_size = tokens_list.len(); | ||
|
||
for tokens in tokens_list.iter() { | ||
if tokens.len() > max_size { | ||
max_size = tokens.len(); | ||
} | ||
if tokens.len() >= max_seq_lenght { | ||
max_size = max_seq_lenght; | ||
break; | ||
} | ||
} | ||
|
||
let mut tensor = Tensor::zeros([batch_size, max_size]); | ||
tensor = tensor.add_scalar(pad_token as i64); | ||
|
||
for (index, tokens) in tokens_list.into_iter().enumerate() { | ||
let mut seq_length = tokens.len(); | ||
let mut tokens = tokens; | ||
if seq_length > max_seq_lenght { | ||
seq_length = max_seq_lenght; | ||
let _ = tokens.split_off(seq_length); | ||
} | ||
tensor = tensor.index_assign( | ||
[index..index + 1, 0..tokens.len()], | ||
&Tensor::from_data(Data::new( | ||
tokens.into_iter().map(|e| e as i64).collect(), | ||
Shape::new([1, seq_length]), | ||
)), | ||
); | ||
} | ||
|
||
let mask_pad = BoolTensor::from_int_backend(tensor.equal_scalar(pad_token as i64)); | ||
|
||
(tensor, mask_pad) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,150 @@ | ||
use burn::data::dataset::{ | ||
source::huggingface::downloader::HuggingfaceDatasetLoader, Dataset, InMemDataset, | ||
}; | ||
|
||
#[derive(new, Clone, Debug)] | ||
pub struct TextClassificationItem { | ||
pub text: String, | ||
pub label: usize, | ||
} | ||
|
||
pub trait TextClassificationDataset: Dataset<TextClassificationItem> { | ||
fn num_classes() -> usize; | ||
fn class_name(label: usize) -> String; | ||
} | ||
|
||
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] | ||
pub struct AgNewsItem { | ||
pub text: String, | ||
pub label: usize, | ||
} | ||
|
||
pub struct AgNewsDataset { | ||
dataset: InMemDataset<AgNewsItem>, | ||
} | ||
|
||
impl Dataset<TextClassificationItem> for AgNewsDataset { | ||
fn get(&self, index: usize) -> Option<TextClassificationItem> { | ||
self.dataset | ||
.get(index) | ||
.map(|item| TextClassificationItem::new(item.text, item.label)) | ||
} | ||
|
||
fn len(&self) -> usize { | ||
self.dataset.len() | ||
} | ||
} | ||
|
||
impl AgNewsDataset { | ||
pub fn train() -> Self { | ||
let dataset: InMemDataset<AgNewsItem> = HuggingfaceDatasetLoader::new("ag_news", "train") | ||
.extract_string("text") | ||
.extract_number("label") | ||
.load_in_memory() | ||
.unwrap(); | ||
Self { dataset } | ||
} | ||
|
||
pub fn test() -> Self { | ||
let dataset: InMemDataset<AgNewsItem> = HuggingfaceDatasetLoader::new("ag_news", "test") | ||
.extract_string("text") | ||
.extract_number("label") | ||
.load_in_memory() | ||
.unwrap(); | ||
Self { dataset } | ||
} | ||
} | ||
|
||
impl TextClassificationDataset for AgNewsDataset { | ||
fn num_classes() -> usize { | ||
4 | ||
} | ||
|
||
fn class_name(label: usize) -> String { | ||
match label { | ||
0 => "World", | ||
1 => "Sports", | ||
2 => "Business", | ||
3 => "Technology", | ||
_ => panic!("invalid class"), | ||
} | ||
.to_string() | ||
} | ||
} | ||
|
||
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] | ||
pub struct DbPediaItem { | ||
pub title: String, | ||
pub content: String, | ||
pub label: usize, | ||
} | ||
|
||
pub struct DbPediaDataset { | ||
dataset: InMemDataset<DbPediaItem>, | ||
} | ||
|
||
impl Dataset<TextClassificationItem> for DbPediaDataset { | ||
fn get(&self, index: usize) -> Option<TextClassificationItem> { | ||
self.dataset.get(index).map(|item| { | ||
TextClassificationItem::new( | ||
format!("Title: {} - Content: {}", item.title, item.content), | ||
item.label, | ||
) | ||
}) | ||
} | ||
|
||
fn len(&self) -> usize { | ||
self.dataset.len() | ||
} | ||
} | ||
|
||
impl DbPediaDataset { | ||
pub fn train() -> Self { | ||
let dataset: InMemDataset<DbPediaItem> = | ||
HuggingfaceDatasetLoader::new("dbpedia_14", "train") | ||
.extract_string("title") | ||
.extract_string("content") | ||
.extract_number("label") | ||
.load_in_memory() | ||
.unwrap(); | ||
Self { dataset } | ||
} | ||
|
||
pub fn test() -> Self { | ||
let dataset: InMemDataset<DbPediaItem> = | ||
HuggingfaceDatasetLoader::new("dbpedia_14", "test") | ||
.extract_string("title") | ||
.extract_string("content") | ||
.extract_number("label") | ||
.load_in_memory() | ||
.unwrap(); | ||
Self { dataset } | ||
} | ||
} | ||
|
||
impl TextClassificationDataset for DbPediaDataset { | ||
fn num_classes() -> usize { | ||
14 | ||
} | ||
|
||
fn class_name(label: usize) -> String { | ||
match label { | ||
0 => "Company", | ||
1 => "EducationalInstitution", | ||
2 => "Artist", | ||
3 => "Athlete", | ||
4 => "OfficeHolder", | ||
5 => "MeanOfTransportation", | ||
6 => "Building", | ||
7 => "NaturalPlace", | ||
8 => "Village", | ||
9 => "Animal", | ||
10 => "Plant", | ||
11 => "Album", | ||
12 => "Film", | ||
13 => "WrittenWork", | ||
_ => panic!("invalid class"), | ||
} | ||
.to_string() | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
mod batcher; | ||
mod dataset; | ||
mod tokenizer; | ||
|
||
pub use batcher::*; | ||
pub use dataset::*; | ||
pub use tokenizer::*; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
pub trait Tokenizer: Send + Sync { | ||
fn encode(&self, value: &str) -> Vec<usize>; | ||
fn decode(&self, tokens: &[usize]) -> String; | ||
fn vocab_size(&self) -> usize; | ||
fn pad_token(&self) -> usize; | ||
fn pad_token_value(&self) -> String { | ||
self.decode(&[self.pad_token()]) | ||
} | ||
} | ||
|
||
pub struct BertCasedTokenizer { | ||
tokenizer: tokenizers::Tokenizer, | ||
} | ||
|
||
impl Default for BertCasedTokenizer { | ||
fn default() -> Self { | ||
Self { | ||
tokenizer: tokenizers::Tokenizer::from_pretrained("bert-base-cased", None).unwrap(), | ||
} | ||
} | ||
} | ||
|
||
impl Tokenizer for BertCasedTokenizer { | ||
fn encode(&self, value: &str) -> Vec<usize> { | ||
let tokens = self.tokenizer.encode(value, true).unwrap(); | ||
tokens.get_ids().iter().map(|t| *t as usize).collect() | ||
} | ||
|
||
fn decode(&self, tokens: &[usize]) -> String { | ||
self.tokenizer | ||
.decode(tokens.iter().map(|t| *t as u32).collect(), false) | ||
.unwrap() | ||
} | ||
|
||
fn vocab_size(&self) -> usize { | ||
self.tokenizer.get_vocab_size(true) | ||
} | ||
|
||
fn pad_token(&self) -> usize { | ||
self.tokenizer.token_to_id("[PAD]").unwrap() as usize | ||
} | ||
} |
Oops, something went wrong.