Skip to content

Commit

Permalink
docs: add numeric tensor examples (#2514)
Browse files Browse the repository at this point in the history
* docs: add-numeric-tensor-examples

* fix: doc-tests

* Fix examples print statements

* Fix formatting

* Fix chunks fmt

---------

Co-authored-by: Guillaume Lagrange <lagrange.guillaume.1@gmail.com>
  • Loading branch information
quinton11 and laggui authored Nov 20, 2024
1 parent a0e8e4d commit 2132d47
Show file tree
Hide file tree
Showing 2 changed files with 1,265 additions and 35 deletions.
58 changes: 29 additions & 29 deletions crates/burn-tensor/src/tensor/api/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,20 +51,20 @@ use crate::{DType, Element, TensorPrimitive};
/// // [[2.0, 1.9, 3.0], [6.0, 1.5, 7.0]]
/// // The resulting tensor will have dimensions [2, 3].
/// let slice = tensor.clone().slice([1..3]);
/// println!("{slice:?}");
/// println!("{slice}");
///
/// // Slice the tensor to get the first two rows and the first 2 columns:
/// // [[3.0, 4.9], [2.0, 1.9]]
/// // The resulting tensor will have dimensions [2, 2].
/// let slice = tensor.clone().slice([0..2, 0..2]);
/// println!("{slice:?}");
/// println!("{slice}");
///
/// // Index the tensor along the dimension 1 to get the elements 0 and 2:
/// // [[3.0, 2.0], [2.0, 3.0], [6.0, 7.0], [3.0, 9.0]]
/// // The resulting tensor will have dimensions [4, 2]
/// let indices = Tensor::<B, 1, Int>::from_data([0, 2], &device);
/// let indexed = tensor.select(1, indices);
/// println!("{indexed:?}");
/// println!("{indexed}");
/// }
/// ```
#[derive(new, Clone, Debug)]
Expand Down Expand Up @@ -193,7 +193,7 @@ where
/// let tensor = Tensor::<B, 3>::ones([2, 3, 4], &device);
/// // Reshape it to [2, 12], where 12 is inferred from the number of elements.
/// let reshaped = tensor.reshape([2, -1]);
/// println!("{reshaped:?}");
/// println!("{reshaped}");
/// }
/// ```
pub fn reshape<const D2: usize, S: ReshapeArgs<D2>>(self, shape: S) -> Tensor<B, D2, K> {
Expand Down Expand Up @@ -227,7 +227,7 @@ where
/// // [[1.0, 5.0], [-2.0, 9.0], [3.0, 6.0]]
/// // The resulting tensor will have dimensions [3, 2].
/// let transposed = tensor.transpose();
/// println!("{transposed:?}");
/// println!("{transposed}");
/// }
/// ```
pub fn transpose(self) -> Tensor<B, D, K> {
Expand Down Expand Up @@ -261,7 +261,7 @@ where
/// // [[1.0, 5.0], [-2.0, 9.0], [3.0, 6.0]]
/// // The resulting tensor will have dimensions [3, 2].
/// let swapped = tensor.swap_dims(0, 1);
/// println!("{swapped:?}");
/// println!("{swapped}");
/// }
/// ```
pub fn swap_dims(self, dim1: usize, dim2: usize) -> Tensor<B, D, K> {
Expand Down Expand Up @@ -297,7 +297,7 @@ where
/// // [[1.0, -2.0, 3.0], [5.0, 9.0, 6.0]]
/// // The resulting tensor will have dimensions [3, 2].
/// let permuted = tensor.permute([1, 0]);
/// println!("{permuted:?}");
/// println!("{permuted}");
/// }
/// ```
pub fn permute(self, axes: [isize; D]) -> Tensor<B, D, K> {
Expand Down Expand Up @@ -354,7 +354,7 @@ where
/// // [[[1.0], [-2.0], [3.0]], [[5.0], [9.0], [6.0]]]
/// // The resulting tensor will have dimensions [2, 3, 1].
/// let moved = tensor.movedim(1, 0);
/// println!("{moved:?}");
/// println!("{moved}");
/// }
/// ```
// This is a syntactic sugar for `permute`. It is used widely enough, so we define a separate Op
Expand Down Expand Up @@ -427,7 +427,7 @@ where
/// // [2.0, 4.9, 3.0]]
/// // The resulting tensor will have dimensions [4, 3].
/// let flipped = tensor.flip([0, 1]);
/// println!("{flipped:?}");
/// println!("{flipped}");
/// }
/// ```
pub fn flip<const N: usize>(self, axes: [isize; N]) -> Tensor<B, D, K> {
Expand Down Expand Up @@ -480,7 +480,7 @@ where
/// // Flatten the tensor from dimensions 1 to 2 (inclusive).
/// // The resulting tensor will have dimensions [2, 12]
/// let flattened: Tensor<B, 2> = tensor.flatten(1, 2);
/// println!("{flattened:?}");
/// println!("{flattened}");
/// }
/// ```
pub fn flatten<const D2: usize>(self, start_dim: usize, end_dim: usize) -> Tensor<B, D2, K> {
Expand Down Expand Up @@ -534,7 +534,7 @@ where
/// // Squeeze the dimension 1.
/// // The resulting tensor will have dimensions [3, 3].
/// let squeezed = tensor.squeeze::<2>(1);
/// println!("{squeezed:?}");
/// println!("{squeezed}");
/// }
/// ```
pub fn squeeze<const D2: usize>(self, dim: usize) -> Tensor<B, D2, K> {
Expand Down Expand Up @@ -583,7 +583,7 @@ where
/// // Squeeze the dimensions 1 and 3.
/// // The resulting tensor will have dimensions [2, 4].
/// let squeezed: Tensor<B, 2> = tensor.squeeze_dims(&[1, 3]);
/// println!("{squeezed:?}");
/// println!("{squeezed}");
/// }
/// ```
pub fn squeeze_dims<const D2: usize>(self, dims: &[isize]) -> Tensor<B, D2, K> {
Expand Down Expand Up @@ -655,7 +655,7 @@ where
/// // Unsqueeze the tensor up to 4 dimensions.
/// // The resulting tensor will have dimensions [1, 1, 3, 3].
/// let unsqueezed = tensor.unsqueeze::<4>();
/// println!("{unsqueezed:?}");
/// println!("{unsqueezed}");
/// }
/// ```
pub fn unsqueeze<const D2: usize>(self) -> Tensor<B, D2, K> {
Expand Down Expand Up @@ -686,7 +686,7 @@ where
/// // Unsqueeze the dimension 1.
/// // The resulting tensor will have dimensions [3, 1, 3].
/// let unsqueezed: Tensor<B, 3> = tensor.unsqueeze_dim(1);
/// println!("{unsqueezed:?}");
/// println!("{unsqueezed}");
/// }
/// ```
pub fn unsqueeze_dim<const D2: usize>(self, dim: usize) -> Tensor<B, D2, K> {
Expand Down Expand Up @@ -725,7 +725,7 @@ where
/// // Unsqueeze the leading dimension (0) once and the trailing dimension (-1) twice.
/// // The resulting tensor will have dimensions [1, 3, 4, 5, 1, 1].
/// let unsqueezed: Tensor<B, 6> = tensor.unsqueeze_dims(&[0, -1, -1]);
/// println!("{unsqueezed:?}");
/// println!("{unsqueezed}");
/// }
/// ```
pub fn unsqueeze_dims<const D2: usize>(self, axes: &[isize]) -> Tensor<B, D2, K> {
Expand Down Expand Up @@ -961,7 +961,7 @@ where
/// // [[3.0, 4.9], [2.0, 1.9], [4.0, 5.9], [3.0, 4.9], [2.0, 1.9], [4.0, 5.9]]
/// // The resulting tensor will have dimensions [6, 2].
/// let repeated = tensor.repeat_dim(0, 2);
/// println!("{repeated:?}");
/// println!("{repeated}");
/// }
/// ```
pub fn repeat_dim(self, dim: usize, times: usize) -> Self {
Expand Down Expand Up @@ -1023,7 +1023,7 @@ where
/// // Compare the elements of the two 2D tensors with dimensions [3, 2].
/// // [[false, true], [true, true], [true, true]]
/// let equal = t1.equal(t2);
/// println!("{equal:?}");
/// println!("{equal}");
/// }
/// ```
pub fn equal(self, other: Self) -> Tensor<B, D, Bool> {
Expand All @@ -1050,7 +1050,7 @@ where
/// // Compare the elements of the two 2D tensors for inequality.
/// // [[true, false], [false, false], [false, false]]
/// let not_equal = t1.not_equal(t2);
/// println!("{not_equal:?}");
/// println!("{not_equal}");
/// }
/// ```
pub fn not_equal(self, other: Self) -> Tensor<B, D, Bool> {
Expand Down Expand Up @@ -1079,7 +1079,7 @@ where
/// // [[3.0, 4.9, 2.0, 4.0, 5.9, 8.0], [2.0, 1.9, 3.0, 1.4, 5.8, 6.0]]
/// // The resulting tensor will have shape [2, 6].
/// let concat = Tensor::cat(vec![t1, t2], 1);
/// println!("{concat:?}");
/// println!("{concat}");
/// }
/// ```
pub fn cat(tensors: Vec<Self>, dim: usize) -> Self {
Expand Down Expand Up @@ -1116,7 +1116,7 @@ where
/// // [[4.0, 5.9, 8.0], [1.4, 5.8, 6.0]]]
/// // The resulting tensor will have shape [3, 2, 3].
/// let stacked= Tensor::stack::<3>(vec![t1, t2, t3], 0);
/// println!("{stacked:?}");
/// println!("{stacked}");
/// }
/// ```
pub fn stack<const D2: usize>(tensors: Vec<Tensor<B, D, K>>, dim: usize) -> Tensor<B, D2, K> {
Expand Down Expand Up @@ -1146,7 +1146,7 @@ where
/// // Given a 2D tensor with dimensions (2, 3), iterate over slices of tensors along the dimension 0.
/// let iter = tensor.iter_dim(0);
/// for (i,tensor) in iter.enumerate() {
/// println!("Tensor {}: {:?}", i, tensor);
/// println!("Tensor {}: {}", i, tensor);
/// // Tensor 0: Tensor { data: [[3.0, 4.9, 2.0]], ... }
/// // Tensor 1: Tensor { data: [[2.0, 1.9, 3.0]], ... }
/// }
Expand Down Expand Up @@ -1190,7 +1190,7 @@ where
/// // [[2.0, 1.9, 3.0], [6.0, 1.5, 7.0], [3.0, 4.9, 9.0]]
/// // The resulting tensor will have dimensions [3, 3].
/// let narrowed = tensor.narrow(0, 1, 3);
/// println!("{narrowed:?}");
/// println!("{narrowed}");
/// }
/// ```
pub fn narrow(self, dim: usize, start: usize, length: usize) -> Self {
Expand Down Expand Up @@ -1271,12 +1271,12 @@ where
///
/// // Given a 2D tensor with dimensions (2, 3), test if any element in the tensor evaluates to True.
/// let any_tensor = tensor.any();
/// println!("{:?}", any_tensor);
/// println!("{}", any_tensor);
/// // Tensor { data: [true], ... }
///
/// // Given a 2D tensor with dimensions (2, 3), test if any element in the tensor evaluates to True.
/// let any_tensor_two = tensor_two.any();
/// println!("{:?}", any_tensor_two);
/// println!("{}", any_tensor_two);
/// // Tensor { data: [false], ... }
/// }
/// ```
Expand Down Expand Up @@ -1310,7 +1310,7 @@ where
/// // Check if any element in the tensor evaluates to True along the dimension 1.
/// // [[true], [true]],
/// let any_dim = tensor.clone().any_dim(1);
/// println!("{any_dim:?}");
/// println!("{any_dim}");
/// }
/// ```
pub fn any_dim(self, dim: usize) -> Tensor<B, D, Bool> {
Expand Down Expand Up @@ -1341,7 +1341,7 @@ where
/// // Check if all elements in the tensor evaluate to True (which is not the case).
/// // [false]
/// let all = tensor.all();
/// println!("{all:?}");
/// println!("{all}");
/// }
/// ```
pub fn all(self) -> Tensor<B, 1, Bool> {
Expand Down Expand Up @@ -1374,7 +1374,7 @@ where
/// // Check if all elements in the tensor evaluate to True along the dimension 1.
/// // [[true, true, false]]
/// let all_dim = tensor.clone().all_dim(0);
/// println!("{all_dim:?}");
/// println!("{all_dim}");
/// }
/// ```
pub fn all_dim(self, dim: usize) -> Tensor<B, D, Bool> {
Expand Down Expand Up @@ -1403,7 +1403,7 @@ where
/// let tensor = Tensor::<B, 2>::from_data([[3.0]], &device);
/// // Convert the tensor with a single element into a scalar.
/// let scalar = tensor.into_scalar();
/// println!("{scalar:?}");
/// println!("{scalar}");
/// }
/// ```
pub fn into_scalar(self) -> K::Elem {
Expand Down Expand Up @@ -1454,7 +1454,7 @@ where
/// // Expand the tensor to a new shape [3, 4]
/// // [[1.0, 1.0, 1.0, 1.0], [2.0, 2.0, 2.0, 2.0], [3.0, 3.0, 3.0, 3.0]]
/// let expanded = tensor.expand([3, 4]);
/// println!("{:?}", expanded);
/// println!("{}", expanded);
/// }
/// ```
pub fn expand<const D2: usize, S: BroadcastArgs<D, D2>>(self, shape: S) -> Tensor<B, D2, K> {
Expand Down
Loading

0 comments on commit 2132d47

Please sign in to comment.