Skip to content

Commit

Permalink
Fix one_hot implementation for Int Tensors (tracel-ai#2500) (tracel-a…
Browse files Browse the repository at this point in the history
…i#2501)

* Fix one_hot implementation for Int Tensors (tracel-ai#2500)

* Correct dimensions in one_hot implementation

Co-authored-by: Tiago Sanona <40792244+tsanona@users.noreply.github.com>

---------

Co-authored-by: Tiago Sanona <40792244+tsanona@users.noreply.github.com>
  • Loading branch information
maun and tsanona authored Nov 18, 2024
1 parent f57c145 commit ab317e7
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions crates/burn-tensor/src/tensor/api/int.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,9 @@ where
pub fn one_hot(self, num_classes: usize) -> Tensor<B, 2, Int> {
check!(TensorCheck::one_hot_tensor(self.clone(), num_classes));
let [num_samples] = self.dims();
let indices = self.unsqueeze();
let indices = self.unsqueeze_dim(1);
let values = indices.ones_like();
Tensor::zeros([num_samples, num_samples], &indices.device()).scatter(1, indices, values)
Tensor::zeros([num_samples, num_classes], &indices.device()).scatter(1, indices, values)
}
}

Expand Down

0 comments on commit ab317e7

Please sign in to comment.