Skip to content

Commit

Permalink
feat: cross entropy loss (tracel-ai#130)
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanielsimard authored Dec 25, 2022
1 parent 1a1d86d commit 3a9dfe6
Show file tree
Hide file tree
Showing 14 changed files with 131 additions and 55 deletions.
3 changes: 2 additions & 1 deletion burn-tensor/src/tensor/data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,8 @@ impl<P: std::fmt::Debug, const D: usize> Data<P, D>
where
P: Zeros + Default,
{
pub fn zeros(shape: Shape<D>) -> Data<P, D> {
pub fn zeros<S: Into<Shape<D>>>(shape: S) -> Data<P, D> {
let shape = shape.into();
let elem = P::default();
let num_elements = shape.num_elements();
let mut data = Vec::with_capacity(num_elements);
Expand Down
90 changes: 90 additions & 0 deletions burn/src/nn/loss/cross_entropy.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
use burn_tensor::{backend::Backend, loss::cross_entropy_with_logits, Tensor};

/// Calculate the cross entropy loss from the input logits and the targets.
pub struct CrossEntropyLoss<B: Backend> {
num_targets: usize,
pad_index: Option<usize>,
_b: B,
}

impl<B: Backend> CrossEntropyLoss<B> {
/// Create the criterion.
///
/// # Notes
///
/// The number of targets must be specified, this correspond to the number of classes in a
/// classification task. A padding index can also be specified.
pub fn new(num_targets: usize, pad_index: Option<usize>) -> Self {
Self {
num_targets,
pad_index,
_b: B::default(),
}
}

/// Compute the criterion on the input tensor.
///
/// # Shapes
///
/// - logits: [batch_size, num_targets]
/// - targets: [batch_size]
pub fn forward(
&self,
logits: &Tensor<B, 2>,
targets: &Tensor<B::IntegerBackend, 1>,
) -> Tensor<B, 1> {
let device = logits.device();
let [batch_size] = targets.dims();
let indexes = targets.to_data();

let mut targets_logits =
Tensor::<B, 2>::zeros_device([batch_size, self.num_targets], device);

for b in 0..batch_size {
let index = indexes.value[b] as usize;
if let Some(pad_index) = self.pad_index {
if index == pad_index {
continue;
}
}

targets_logits = targets_logits.index_assign(
[b..b + 1, index..index + 1],
&Tensor::ones_device([1, 1], device),
);
}

cross_entropy_with_logits(logits, &targets_logits.detach())
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::TestBackend;
use burn_tensor::{Data, Distribution};

#[test]
fn test_cross_entropy_loss() {
let [batch_size, num_targets] = [4, 5];
let logits = Tensor::<TestBackend, 2>::random(
[batch_size, num_targets],
Distribution::Normal(0., 1.0),
);
let targets =
Tensor::<<TestBackend as Backend>::IntegerBackend, 1>::from_data(Data::from([
2, 0, 4, 1_i64,
]));
let targets_logits = Tensor::<TestBackend, 2>::from_data(Data::from([
[0.0, 0.0, 1.0, 0.0, 0.0],
[1.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 1.0],
[0.0, 1.0, 0.0, 0.0, 0.0],
]));

let loss_1 = CrossEntropyLoss::new(5, None).forward(&logits, &targets);
let loss_2 = cross_entropy_with_logits(&logits, &targets_logits);

loss_1.into_data().assert_approx_eq(&loss_2.into_data(), 3);
}
}
3 changes: 3 additions & 0 deletions burn/src/nn/loss/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
mod cross_entropy;

pub use cross_entropy::*;
1 change: 1 addition & 0 deletions burn/src/nn/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
pub mod attention;
pub mod cache;
pub mod loss;
pub mod transformer;

mod dropout;
Expand Down
4 changes: 2 additions & 2 deletions burn/src/train/learner/classification.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use burn_tensor::Tensor;
pub struct ClassificationOutput<B: Backend> {
pub loss: Tensor<B, 1>,
pub output: Tensor<B, 2>,
pub targets: Tensor<B, 2>,
pub targets: Tensor<B::IntegerBackend, 1>,
}

impl<B: Backend> metric::Metric<ClassificationOutput<B>> for metric::LossMetric {
Expand All @@ -24,6 +24,6 @@ impl<B: Backend> metric::Metric<ClassificationOutput<B>> for metric::AccuracyMet
}

fn clear(&mut self) {
<metric::AccuracyMetric as metric::Metric<(Tensor<B, 2>, Tensor<B, 2>)>>::clear(self);
<metric::AccuracyMetric as metric::Metric<(Tensor<B, 2>, Tensor<B::IntegerBackend, 1>)>>::clear(self);
}
}
21 changes: 10 additions & 11 deletions burn/src/train/metric/acc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,19 +31,18 @@ impl Numeric for AccuracyMetric {
}
}

impl<B: Backend> Metric<(Tensor<B, 2>, Tensor<B, 2>)> for AccuracyMetric {
fn update(&mut self, batch: &(Tensor<B, 2>, Tensor<B, 2>)) -> MetricStateDyn {
impl<B: Backend> Metric<(Tensor<B, 2>, Tensor<B::IntegerBackend, 1>)> for AccuracyMetric {
fn update(&mut self, batch: &(Tensor<B, 2>, Tensor<B::IntegerBackend, 1>)) -> MetricStateDyn {
let (outputs, targets) = batch;
let logits_outputs = outputs.argmax(1).to_device(B::Device::default());
let logits_targets = targets.argmax(1).to_device(B::Device::default());
let count_current = logits_targets.shape().dims[0];
let count_current = outputs.dims()[0];

let total_current = logits_outputs
.equal(&logits_targets)
.to_int()
.sum()
.to_data()
.value[0] as usize;
let targets = targets.to_device(B::Device::default());
let outputs = outputs
.argmax(1)
.to_device(B::Device::default())
.reshape([count_current]);

let total_current = outputs.equal(&targets).to_int().sum().to_data().value[0] as usize;

self.count += count_current;
self.total += total_current;
Expand Down
4 changes: 2 additions & 2 deletions examples/mnist/src/data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ pub struct MNISTBatcher<B: Backend> {
#[derive(Clone, Debug)]
pub struct MNISTBatch<B: Backend> {
pub images: Tensor<B, 2>,
pub targets: Tensor<B, 2>,
pub targets: Tensor<B::IntegerBackend, 1>,
}

impl<B: Backend> MNISTBatcher<B> {
Expand All @@ -31,7 +31,7 @@ impl<B: Backend> Batcher<MNISTItem, MNISTBatch<B>> for MNISTBatcher<B> {

let targets = items
.iter()
.map(|item| Tensor::<B, 2>::one_hot(item.label, 10))
.map(|item| Tensor::<B::IntegerBackend, 1>::from_data(Data::from([item.label as i64])))
.collect();

let images = Tensor::cat(images, 0).to_device(self.device).detach();
Expand Down
8 changes: 5 additions & 3 deletions examples/mnist/src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,10 @@ use crate::{
use burn::{
config::Config,
module::{Module, Param},
nn,
nn::{self, loss::CrossEntropyLoss},
optim::SgdConfig,
tensor::{
backend::{ADBackend, Backend},
loss::cross_entropy_with_logits,
Tensor,
},
train::{ClassificationOutput, TrainOutput, TrainStep, ValidStep},
Expand All @@ -34,6 +33,7 @@ pub struct Model<B: Backend> {
mlp: Param<Mlp<B>>,
input: Param<nn::Linear<B>>,
output: Param<nn::Linear<B>>,
num_classes: usize,
}

impl<B: Backend> Model<B> {
Expand All @@ -46,6 +46,7 @@ impl<B: Backend> Model<B> {
mlp: Param::new(mlp),
output: Param::new(output),
input: Param::new(input),
num_classes,
}
}

Expand All @@ -62,7 +63,8 @@ impl<B: Backend> Model<B> {
pub fn forward_classification(&self, item: MNISTBatch<B>) -> ClassificationOutput<B> {
let targets = item.targets;
let output = self.forward(item.images);
let loss = cross_entropy_with_logits(&output, &targets);
let loss = CrossEntropyLoss::new(self.num_classes, None);
let loss = loss.forward(&output, &targets);

ClassificationOutput {
loss,
Expand Down
7 changes: 3 additions & 4 deletions examples/text-classification/src/data/batcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,21 @@ use super::{dataset::TextClassificationItem, tokenizer::Tokenizer};
use burn::{
data::dataloader::batcher::Batcher,
nn::attention::generate_padding_mask,
tensor::{backend::Backend, BoolTensor, Tensor},
tensor::{backend::Backend, BoolTensor, Data, Tensor},
};
use std::sync::Arc;

#[derive(new)]
pub struct TextClassificationBatcher<B: Backend> {
tokenizer: Arc<dyn Tokenizer>,
num_classes: usize,
device: B::Device,
max_seq_lenght: usize,
}

#[derive(Debug, Clone, new)]
pub struct TextClassificationBatch<B: Backend> {
pub tokens: Tensor<B::IntegerBackend, 2>,
pub labels: Tensor<B, 2>,
pub labels: Tensor<B::IntegerBackend, 1>,
pub mask_pad: BoolTensor<B, 2>,
}

Expand All @@ -30,7 +29,7 @@ impl<B: Backend> Batcher<TextClassificationItem, TextClassificationBatch<B>>

for item in items {
tokens_list.push(self.tokenizer.encode(&item.text));
labels_list.push(Tensor::one_hot(item.label, self.num_classes));
labels_list.push(Tensor::from_data(Data::from([item.label as i64])));
}

let mask = generate_padding_mask(
Expand Down
6 changes: 4 additions & 2 deletions examples/text-classification/src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@ use burn::{
config::Config,
module::{Module, Param},
nn::{
loss::CrossEntropyLoss,
transformer::{TransformerEncoder, TransformerEncoderConfig, TransformerEncoderInput},
Embedding, EmbeddingConfig, Linear, LinearConfig,
},
tensor::backend::{ADBackend, Backend},
tensor::{loss::cross_entropy_with_logits, Tensor},
tensor::Tensor,
train::{ClassificationOutput, TrainOutput, TrainStep, ValidStep},
};

Expand Down Expand Up @@ -77,7 +78,8 @@ impl<B: Backend> TextClassificationModel<B> {
.index([0..batch_size, 0..1])
.reshape([batch_size, self.n_classes]);

let loss = cross_entropy_with_logits(&output_classification, &labels.clone().detach());
let loss = CrossEntropyLoss::new(self.n_classes, None);
let loss = loss.forward(&output_classification, &labels);

ClassificationOutput {
loss,
Expand Down
2 changes: 0 additions & 2 deletions examples/text-classification/src/training.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,11 @@ pub fn train<B: ADBackend, D: TextClassificationDataset + 'static>(
let tokenizer = Arc::new(BertCasedTokenizer::default());
let batcher_train = Arc::new(TextClassificationBatcher::<B>::new(
tokenizer.clone(),
n_classes,
device,
config.max_seq_length,
));
let batcher_test = Arc::new(TextClassificationBatcher::<B::InnerBackend>::new(
tokenizer.clone(),
n_classes,
device,
config.max_seq_length,
));
Expand Down
19 changes: 1 addition & 18 deletions examples/text-generation/src/data/batcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@ use std::sync::Arc;
#[derive(new)]
pub struct TextGenerationBatcher {
tokenizer: Arc<dyn Tokenizer>,
vocab_size: usize,
pad_token: usize,
max_seq_lenght: usize,
}

Expand All @@ -23,7 +21,7 @@ pub struct TextGenerationBatch<B: Backend> {
#[derive(Debug, Clone, new)]
pub struct TrainingTextGenerationBatch<B: Backend> {
pub tokens_inputs: Tensor<B::IntegerBackend, 2>,
pub targets: Tensor<B, 2>,
pub targets: Tensor<B::IntegerBackend, 2>,
pub mask_pad: BoolTensor<B, 2>,
}

Expand Down Expand Up @@ -60,21 +58,6 @@ impl<B: Backend> Batcher<TextGenerationItem, TrainingTextGenerationBatch<B>>
let targets = item.tokens.index([0..batch_size, 1..seq_length]);
let mask_pad = item.mask_pad.index([0..batch_size, 0..seq_length - 1]);

let seq_length = seq_length - 1;

let targets = targets
.reshape([batch_size * seq_length])
.to_data()
.value
.iter()
.map(|index| match *index as usize == self.pad_token {
true => Tensor::<B, 2>::zeros([1, self.vocab_size]),
false => Tensor::<B, 2>::one_hot(*index as usize, self.vocab_size),
})
.collect();

let targets = Tensor::cat(targets, 0);

TrainingTextGenerationBatch::new(inputs, targets, mask_pad)
}
}
14 changes: 8 additions & 6 deletions examples/text-generation/src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@ use burn::{
module::{Module, Param},
nn::{
attention::generate_autoregressive_mask,
loss::CrossEntropyLoss,
transformer::{TransformerEncoder, TransformerEncoderConfig, TransformerEncoderInput},
Embedding, EmbeddingConfig, Linear, LinearConfig,
},
tensor::backend::{ADBackend, Backend},
tensor::{loss::cross_entropy_with_logits, Tensor},
tensor::Tensor,
train::{ClassificationOutput, TrainOutput, TrainStep, ValidStep},
};

Expand Down Expand Up @@ -83,15 +84,16 @@ impl<B: Backend> TextClassificationModel<B> {
);

let output = self.output.forward(encoded);
let output_classification = output.reshape([batch_size * seq_length, self.vocab_size]);
let targets = item.targets.to_device(device).detach();
let output_flatten = output.reshape([batch_size * seq_length, self.vocab_size]);
let targets_flatten = item.targets.reshape([batch_size * seq_length]);

let loss = cross_entropy_with_logits(&output_classification, &targets);
let loss = CrossEntropyLoss::new(self.vocab_size, Some(self.pad_token));
let loss = loss.forward(&output_flatten, &targets_flatten);

ClassificationOutput {
loss,
output: output_classification,
targets,
output: output_flatten,
targets: targets_flatten,
}
}
}
Expand Down
4 changes: 0 additions & 4 deletions examples/text-generation/src/training.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,10 @@ pub fn train<B: ADBackend, D: Dataset<TextGenerationItem> + 'static>(
let tokenizer = Arc::new(Gpt2Tokenizer::default());
let batcher_train = Arc::new(TextGenerationBatcher::new(
tokenizer.clone(),
tokenizer.vocab_size(),
tokenizer.pad_token(),
config.max_seq_length,
));
let batcher_test = Arc::new(TextGenerationBatcher::new(
tokenizer.clone(),
tokenizer.vocab_size(),
tokenizer.pad_token(),
config.max_seq_length,
));

Expand Down

0 comments on commit 3a9dfe6

Please sign in to comment.