Skip to content

Commit

Permalink
Example/text classification (tracel-ai#123)
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanielsimard authored Dec 2, 2022
1 parent 7c38a98 commit eee90a5
Show file tree
Hide file tree
Showing 13 changed files with 578 additions and 3 deletions.
3 changes: 1 addition & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,8 @@ This may also be a good idea to take a look the main [components](#components) o

### Examples

For now there is only one example, but more to come 💪..

* [MNIST](https://github.com/burn-rs/burn/tree/main/examples/mnist) train a model on CPU/GPU using different backends.
* [Text Classification](https://github.com/burn-rs/burn/tree/main/examples/text-classification) train a transformer encoder from scratch on GPU.

### Components

Expand Down
2 changes: 1 addition & 1 deletion burn-derive/src/config/analyzer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,6 @@ fn parse_asm(ast: &syn::DeriveInput) -> ConfigType {
ConfigType::Struct(struct_data.fields.clone().into_iter().collect())
}
syn::Data::Enum(enum_data) => ConfigType::Enum(enum_data.clone()),
syn::Data::Union(_) => panic!("Only struct cna be derived"),
syn::Data::Union(_) => panic!("Only struct and enum can be derived"),
}
}
23 changes: 23 additions & 0 deletions examples/text-classification/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
[package]
name = "text-classification"
version = "0.1.0"
authors = ["nathanielsimard <nathaniel.simard.42@gmail.com>"]
license = "MIT/Apache-2.0"
edition = "2021"
publish = false

[features]
default = []

[dependencies]
# Burn
burn = { path = "../../burn" }
burn-autodiff = { path = "../../burn-autodiff" }
burn-tch = { path = "../../burn-tch" }

# Tokenizer
tokenizers = { version = "0.13", default-features = false, features = ["onig", "http"] }

# Utils
derive-new = "0.5"
serde = { version = "1.0", features = ["derive"] }
12 changes: 12 additions & 0 deletions examples/text-classification/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# Text Classification

The example can be run like so:

```bash
git clone https://github.com/burn-rs/burn.git
cd burn
# Use the --release flag to really speed up training.
export TORCH_CUDA_VERSION=cu113 # Set the cuda version
cargo run --example text-classification-ag-news --release # Train on the ag news dataset
cargo run --example text-classification-db-pedia --release # Train on the db pedia dataset
```
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
use burn::optim::{decay::WeightDecayConfig, momentum::MomentumConfig};
use text_classification::{training::ExperimentConfig, AgNewsDataset};

type Backend = burn_autodiff::ADBackendDecorator<burn_tch::TchBackend<burn::tensor::f16>>;

fn main() {
let config = ExperimentConfig::new(
burn::nn::transformer::TransformerEncoderConfig::new(256, 512, 4, 4),
burn::optim::SgdConfig::new()
.with_learning_rate(5.0e-3)
.with_momentum(None)
.with_weight_decay(Some(WeightDecayConfig::new(5e-4)))
.with_momentum(Some(MomentumConfig::new().with_nesterov(true))),
);

text_classification::training::train::<Backend, AgNewsDataset>(
burn_tch::TchDevice::Cuda(0),
AgNewsDataset::train(),
AgNewsDataset::test(),
config,
"/tmp/text-classification-ag-news",
);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
use burn::optim::{decay::WeightDecayConfig, momentum::MomentumConfig};
use text_classification::{training::ExperimentConfig, DbPediaDataset};

type Backend = burn_autodiff::ADBackendDecorator<burn_tch::TchBackend<burn::tensor::f16>>;

fn main() {
let config = ExperimentConfig::new(
burn::nn::transformer::TransformerEncoderConfig::new(256, 512, 4, 4),
burn::optim::SgdConfig::new()
.with_learning_rate(5.0e-3)
.with_momentum(Some(MomentumConfig::new().with_nesterov(true)))
.with_weight_decay(Some(WeightDecayConfig::new(5e-4))),
);

text_classification::training::train::<Backend, DbPediaDataset>(
burn_tch::TchDevice::Cuda(0),
DbPediaDataset::train(),
DbPediaDataset::test(),
config,
"/tmp/text-classification-db-pedia",
);
}
86 changes: 86 additions & 0 deletions examples/text-classification/src/data/batcher.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
use super::{dataset::TextClassificationItem, tokenizer::Tokenizer};
use burn::{
data::dataloader::batcher::Batcher,
tensor::{backend::Backend, BoolTensor, Data, Shape, Tensor},
};
use std::sync::Arc;

#[derive(new)]
pub struct TextClassificationBatcher<B: Backend> {
tokenizer: Arc<dyn Tokenizer>,
num_classes: usize,
device: B::Device,
max_seq_lenght: usize,
}

#[derive(Debug, Clone, new)]
pub struct TextClassificationBatch<B: Backend> {
pub tokens: Tensor<B::IntegerBackend, 2>,
pub labels: Tensor<B, 2>,
pub mask_pad: BoolTensor<B, 2>,
}

impl<B: Backend> Batcher<TextClassificationItem, TextClassificationBatch<B>>
for TextClassificationBatcher<B>
{
fn batch(&self, items: Vec<TextClassificationItem>) -> TextClassificationBatch<B> {
let mut tokens_list = Vec::with_capacity(items.len());
let mut labels_list = Vec::with_capacity(items.len());

for item in items {
tokens_list.push(self.tokenizer.encode(&item.text));
labels_list.push(Tensor::one_hot(item.label, self.num_classes));
}

let (tokens, mask_pad) =
pad_tokens::<B>(self.tokenizer.pad_token(), tokens_list, self.max_seq_lenght);

TextClassificationBatch {
tokens: tokens.to_device(self.device).detach(),
labels: Tensor::cat(labels_list, 0).to_device(self.device).detach(),
mask_pad: mask_pad.to_device(self.device),
}
}
}

pub fn pad_tokens<B: Backend>(
pad_token: usize,
tokens_list: Vec<Vec<usize>>,
max_seq_lenght: usize,
) -> (Tensor<B::IntegerBackend, 2>, BoolTensor<B, 2>) {
let mut max_size = 0;
let batch_size = tokens_list.len();

for tokens in tokens_list.iter() {
if tokens.len() > max_size {
max_size = tokens.len();
}
if tokens.len() >= max_seq_lenght {
max_size = max_seq_lenght;
break;
}
}

let mut tensor = Tensor::zeros([batch_size, max_size]);
tensor = tensor.add_scalar(pad_token as i64);

for (index, tokens) in tokens_list.into_iter().enumerate() {
let mut seq_length = tokens.len();
let mut tokens = tokens;
if seq_length > max_seq_lenght {
seq_length = max_seq_lenght;
let _ = tokens.split_off(seq_length);
}
tensor = tensor.index_assign(
[index..index + 1, 0..tokens.len()],
&Tensor::from_data(Data::new(
tokens.into_iter().map(|e| e as i64).collect(),
Shape::new([1, seq_length]),
)),
);
}

let mask_pad = BoolTensor::from_int_backend(tensor.equal_scalar(pad_token as i64));

(tensor, mask_pad)
}
150 changes: 150 additions & 0 deletions examples/text-classification/src/data/dataset.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
use burn::data::dataset::{
source::huggingface::downloader::HuggingfaceDatasetLoader, Dataset, InMemDataset,
};

#[derive(new, Clone, Debug)]
pub struct TextClassificationItem {
pub text: String,
pub label: usize,
}

pub trait TextClassificationDataset: Dataset<TextClassificationItem> {
fn num_classes() -> usize;
fn class_name(label: usize) -> String;
}

#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
pub struct AgNewsItem {
pub text: String,
pub label: usize,
}

pub struct AgNewsDataset {
dataset: InMemDataset<AgNewsItem>,
}

impl Dataset<TextClassificationItem> for AgNewsDataset {
fn get(&self, index: usize) -> Option<TextClassificationItem> {
self.dataset
.get(index)
.map(|item| TextClassificationItem::new(item.text, item.label))
}

fn len(&self) -> usize {
self.dataset.len()
}
}

impl AgNewsDataset {
pub fn train() -> Self {
let dataset: InMemDataset<AgNewsItem> = HuggingfaceDatasetLoader::new("ag_news", "train")
.extract_string("text")
.extract_number("label")
.load_in_memory()
.unwrap();
Self { dataset }
}

pub fn test() -> Self {
let dataset: InMemDataset<AgNewsItem> = HuggingfaceDatasetLoader::new("ag_news", "test")
.extract_string("text")
.extract_number("label")
.load_in_memory()
.unwrap();
Self { dataset }
}
}

impl TextClassificationDataset for AgNewsDataset {
fn num_classes() -> usize {
4
}

fn class_name(label: usize) -> String {
match label {
0 => "World",
1 => "Sports",
2 => "Business",
3 => "Technology",
_ => panic!("invalid class"),
}
.to_string()
}
}

#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
pub struct DbPediaItem {
pub title: String,
pub content: String,
pub label: usize,
}

pub struct DbPediaDataset {
dataset: InMemDataset<DbPediaItem>,
}

impl Dataset<TextClassificationItem> for DbPediaDataset {
fn get(&self, index: usize) -> Option<TextClassificationItem> {
self.dataset.get(index).map(|item| {
TextClassificationItem::new(
format!("Title: {} - Content: {}", item.title, item.content),
item.label,
)
})
}

fn len(&self) -> usize {
self.dataset.len()
}
}

impl DbPediaDataset {
pub fn train() -> Self {
let dataset: InMemDataset<DbPediaItem> =
HuggingfaceDatasetLoader::new("dbpedia_14", "train")
.extract_string("title")
.extract_string("content")
.extract_number("label")
.load_in_memory()
.unwrap();
Self { dataset }
}

pub fn test() -> Self {
let dataset: InMemDataset<DbPediaItem> =
HuggingfaceDatasetLoader::new("dbpedia_14", "test")
.extract_string("title")
.extract_string("content")
.extract_number("label")
.load_in_memory()
.unwrap();
Self { dataset }
}
}

impl TextClassificationDataset for DbPediaDataset {
fn num_classes() -> usize {
14
}

fn class_name(label: usize) -> String {
match label {
0 => "Company",
1 => "EducationalInstitution",
2 => "Artist",
3 => "Athlete",
4 => "OfficeHolder",
5 => "MeanOfTransportation",
6 => "Building",
7 => "NaturalPlace",
8 => "Village",
9 => "Animal",
10 => "Plant",
11 => "Album",
12 => "Film",
13 => "WrittenWork",
_ => panic!("invalid class"),
}
.to_string()
}
}
7 changes: 7 additions & 0 deletions examples/text-classification/src/data/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
mod batcher;
mod dataset;
mod tokenizer;

pub use batcher::*;
pub use dataset::*;
pub use tokenizer::*;
42 changes: 42 additions & 0 deletions examples/text-classification/src/data/tokenizer.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
pub trait Tokenizer: Send + Sync {
fn encode(&self, value: &str) -> Vec<usize>;
fn decode(&self, tokens: &[usize]) -> String;
fn vocab_size(&self) -> usize;
fn pad_token(&self) -> usize;
fn pad_token_value(&self) -> String {
self.decode(&[self.pad_token()])
}
}

pub struct BertCasedTokenizer {
tokenizer: tokenizers::Tokenizer,
}

impl Default for BertCasedTokenizer {
fn default() -> Self {
Self {
tokenizer: tokenizers::Tokenizer::from_pretrained("bert-base-cased", None).unwrap(),
}
}
}

impl Tokenizer for BertCasedTokenizer {
fn encode(&self, value: &str) -> Vec<usize> {
let tokens = self.tokenizer.encode(value, true).unwrap();
tokens.get_ids().iter().map(|t| *t as usize).collect()
}

fn decode(&self, tokens: &[usize]) -> String {
self.tokenizer
.decode(tokens.iter().map(|t| *t as u32).collect(), false)
.unwrap()
}

fn vocab_size(&self) -> usize {
self.tokenizer.get_vocab_size(true)
}

fn pad_token(&self) -> usize {
self.tokenizer.token_to_id("[PAD]").unwrap() as usize
}
}
Loading

0 comments on commit eee90a5

Please sign in to comment.