From 297173124f4061cf501f77a62d4b52b8968957b8 Mon Sep 17 00:00:00 2001 From: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com> Date: Wed, 31 Jul 2024 12:08:26 -0500 Subject: [PATCH] Add 1d and 2d modules for interpolate with scaling (also fix ONNX Resize op) (#2081) * Add interpolate module * Update module.md * Add interpolate 1d and 2d modules * Consolidated InterpolateMode for 1d and 2d * Remove CoordinateTransformationMode * Add 1d tests for interpolate * Refactor and fixes of ONNX Resize OP * Fix clippy * Fix docs * Fix no_std --- burn-book/src/building-blocks/module.md | 32 +-- .../src/nn/interpolate/interpolate1d.rs | 248 +++++++++++++++++ .../src/nn/interpolate/interpolate2d.rs | 251 ++++++++++++++++++ crates/burn-core/src/nn/interpolate/mod.rs | 46 ++++ crates/burn-core/src/nn/mod.rs | 3 + crates/burn-import/onnx-tests/build.rs | 69 ++--- .../onnx-tests/tests/onnx_tests.rs | 200 ++++++++++++-- .../tests/resize/resize_1d_linear_scale.onnx | Bin 0 -> 370 bytes .../tests/resize/resize_1d_nearest_scale.onnx | Bin 0 -> 368 bytes .../tests/resize/resize_2d_bicubic_scale.onnx | Bin 0 -> 381 bytes .../resize/resize_2d_bilinear_scale.onnx | Bin 0 -> 382 bytes .../tests/resize/resize_2d_nearest_scale.onnx | Bin 0 -> 380 bytes .../onnx-tests/tests/resize/resize_scale.py | 71 +++++ .../{resize.onnx => resize_with_sizes.onnx} | Bin 200 -> 198 bytes .../{resize.py => resize_with_sizes.py} | 16 +- crates/burn-import/src/burn/codegen.rs | 7 + crates/burn-import/src/burn/node/resize.rs | 250 ++++++++++------- .../burn-import/src/onnx/op_configuration.rs | 132 ++++++++- crates/burn-import/src/onnx/to_burn.rs | 7 +- .../src/tensor/ops/modules/base.rs | 2 +- .../src/tests/module/bicubic_interpolate.rs | 41 +++ .../src/tests/module/bilinear_interpolate.rs | 49 ++++ .../src/tests/module/nearest_interpolate.rs | 39 +++ crates/onnx-ir/src/dim_inference.rs | 31 +-- crates/onnx-ir/src/ir.rs | 103 ++++--- 25 files changed, 1340 insertions(+), 257 deletions(-) create mode 100644 crates/burn-core/src/nn/interpolate/interpolate1d.rs create mode 100644 crates/burn-core/src/nn/interpolate/interpolate2d.rs create mode 100644 crates/burn-core/src/nn/interpolate/mod.rs create mode 100644 crates/burn-import/onnx-tests/tests/resize/resize_1d_linear_scale.onnx create mode 100644 crates/burn-import/onnx-tests/tests/resize/resize_1d_nearest_scale.onnx create mode 100644 crates/burn-import/onnx-tests/tests/resize/resize_2d_bicubic_scale.onnx create mode 100644 crates/burn-import/onnx-tests/tests/resize/resize_2d_bilinear_scale.onnx create mode 100644 crates/burn-import/onnx-tests/tests/resize/resize_2d_nearest_scale.onnx create mode 100755 crates/burn-import/onnx-tests/tests/resize/resize_scale.py rename crates/burn-import/onnx-tests/tests/resize/{resize.onnx => resize_with_sizes.onnx} (63%) rename crates/burn-import/onnx-tests/tests/resize/{resize.py => resize_with_sizes.py} (65%) mode change 100644 => 100755 diff --git a/burn-book/src/building-blocks/module.md b/burn-book/src/building-blocks/module.md index f280e94bcd..3d7b738b95 100644 --- a/burn-book/src/building-blocks/module.md +++ b/burn-book/src/building-blocks/module.md @@ -161,21 +161,23 @@ Burn comes with built-in modules that you can use to build your own modules. ### General -| Burn API | PyTorch Equivalent | -| -------------- | --------------------------------------------- | -| `BatchNorm` | `nn.BatchNorm1d`, `nn.BatchNorm2d` etc. | -| `Dropout` | `nn.Dropout` | -| `Embedding` | `nn.Embedding` | -| `Gelu` | `nn.Gelu` | -| `GroupNorm` | `nn.GroupNorm` | -| `InstanceNorm` | `nn.InstanceNorm1d`, `nn.InstanceNorm2d` etc. | -| `LayerNorm` | `nn.LayerNorm` | -| `LeakyRelu` | `nn.LeakyReLU` | -| `Linear` | `nn.Linear` | -| `Prelu` | `nn.PReLu` | -| `Relu` | `nn.ReLU` | -| `RmsNorm` | _No direct equivalent_ | -| `SwiGlu` | _No direct equivalent_ | +| Burn API | PyTorch Equivalent | +| --------------- | --------------------------------------------- | +| `BatchNorm` | `nn.BatchNorm1d`, `nn.BatchNorm2d` etc. | +| `Dropout` | `nn.Dropout` | +| `Embedding` | `nn.Embedding` | +| `Gelu` | `nn.Gelu` | +| `GroupNorm` | `nn.GroupNorm` | +| `InstanceNorm` | `nn.InstanceNorm1d`, `nn.InstanceNorm2d` etc. | +| `LayerNorm` | `nn.LayerNorm` | +| `LeakyRelu` | `nn.LeakyReLU` | +| `Linear` | `nn.Linear` | +| `Prelu` | `nn.PReLu` | +| `Relu` | `nn.ReLU` | +| `RmsNorm` | _No direct equivalent_ | +| `SwiGlu` | _No direct equivalent_ | +| `Interpolate1d` | _No direct equivalent_ | +| `Interpolate2d` | _No direct equivalent_ | ### Convolutions diff --git a/crates/burn-core/src/nn/interpolate/interpolate1d.rs b/crates/burn-core/src/nn/interpolate/interpolate1d.rs new file mode 100644 index 0000000000..192527b55f --- /dev/null +++ b/crates/burn-core/src/nn/interpolate/interpolate1d.rs @@ -0,0 +1,248 @@ +use alloc::format; + +use burn_tensor::module::interpolate; + +use crate as burn; + +use crate::config::Config; +use crate::module::{Content, DisplaySettings, Ignored, Module, ModuleDisplay}; +use crate::tensor::backend::Backend; +use crate::tensor::ops::InterpolateOptions; +use crate::tensor::Tensor; + +use super::InterpolateMode; + +/// Configuration for the 1D interpolation module. +/// +/// This struct defines the configuration options for the 1D interpolation operation. +/// It allows specifying the output size, scale factor, and interpolation mode. +#[derive(Config, Debug)] +pub struct Interpolate1dConfig { + /// Output size of the interpolated tensor. + /// If specified, this takes precedence over `scale_factor`. + #[config(default = "None")] + pub output_size: Option, + + /// Scale factor for resizing the input tensor. + /// This is used when `output_size` is not specified. + #[config(default = "None")] + pub scale_factor: Option, + + /// Interpolation mode to use for resizing. + /// Determines how the output values are calculated. + #[config(default = "InterpolateMode::Nearest")] + pub mode: InterpolateMode, +} + +/// Interpolate module for resizing 1D tensors with shape [N, C, L]. +/// +/// This struct represents a 1D interpolation module that can resize tensors +/// using various interpolation methods. It provides flexibility in specifying +/// either an output size or a scale factor for resizing, along with options +/// for the interpolation mode. +/// +/// The module can be used to upsample or downsample 1D tensors, preserving the +/// number of channels and batch size while adjusting the length dimension. +/// +/// The module can be created using the [Interpolate1dConfig] struct and the +/// `init` method, which returns an instance of the [Interpolate1d] struct. +#[derive(Module, Clone, Debug)] +#[module(custom_display)] +pub struct Interpolate1d { + /// Output size of the interpolated tensor + pub output_size: Option, + + /// Scale factor for resizing the input tensor + pub scale_factor: Option, + + /// Interpolation mode used for resizing + pub mode: Ignored, +} + +impl Interpolate1dConfig { + /// Initialize the interpolation module + pub fn init(self) -> Interpolate1d { + Interpolate1d { + output_size: self.output_size, + scale_factor: self.scale_factor, + mode: Ignored(self.mode), + } + } +} + +impl Interpolate1d { + /// Performs the forward pass of the 1D interpolation module + /// + /// # Arguments + /// + /// * `input` - Input tensor with shape [N, C, L] + /// + /// # Returns + /// + /// Resized tensor with shape [N, C, L'], where L' is determined by + /// the output_size or scale_factor specified in the module configuration + /// + /// # Example + /// + /// ```ignore + /// let input = Tensor::::random([1, 3, 64], Distribution::Uniform(0.0, 1.0), &device); + /// let interpolate = Interpolate1dConfig::new() + /// .with_output_size(Some(128)) + /// .init(); + /// let output = interpolate.forward(input); + /// assert_eq!(output.dims(), [1, 3, 128]); + /// ``` + pub fn forward(&self, input: Tensor) -> Tensor { + let output_size = calculate_output_size(input.dims(), self.output_size, self.scale_factor); + + // Use the interpolate operation to resize the temporal input tensor + // by adding a new dimension for the interpolation axis + let input = input.unsqueeze_dim(2); + + let result = interpolate( + input, + [1, output_size], + InterpolateOptions::new(self.mode.0.clone().into()), + ); + + result.squeeze_dims(&[2]) + } +} + +/// Calculate output size based on input dimensions, output size, and scale factor +/// +/// # Arguments +/// +/// * `input_dims` - Input dimensions of the tensor +/// * `output_size` - Output size for the interpolated tensor +/// * `scale_factor` - Scale factor for resizing the tensor +/// +/// # Returns +/// +/// Output size for the interpolated tensor +/// +/// # Panics +/// +/// Panics if neither output_size nor scale_factor is provided +/// or if the scale factor is too large +fn calculate_output_size( + input_dims: [usize; 3], + output_size: Option, + scale_factor: Option, +) -> usize { + match (output_size, scale_factor) { + (Some(output_size), None) => { + // Use provided + output_size + } + (None, Some(scale_factor)) => { + // Calculate output size based on scale factor + let [_, _, l] = input_dims; + + let new_dim = (l as f64) * (scale_factor as f64); + + if new_dim > usize::MAX as f64 { + panic!("Scale factor is too large"); + } + + new_dim as usize + } + _ => panic!("Either output_size or scale_factor must be provided"), + } +} + +impl ModuleDisplay for Interpolate1d { + fn custom_settings(&self) -> Option { + DisplaySettings::new() + .with_new_line_after_attribute(false) + .optional() + } + + fn custom_content(&self, content: Content) -> Option { + content + .add("mode", &self.mode) + .add("output_size", &format!("{:?}", self.output_size)) + .add("scale_factor", &self.scale_factor) + .optional() + } +} + +#[cfg(test)] +mod tests { + + use burn_tensor::Distribution; + + use super::*; + use crate::TestBackend; + #[test] + fn test_calculate_output_size() { + let input_dims = [1, 1, 4]; + + let output_size = calculate_output_size(input_dims, Some(2), None); + assert_eq!(output_size, 2); + + let output_size = calculate_output_size(input_dims, None, Some(2.0)); + assert_eq!(output_size, 8); + + let output_size = calculate_output_size(input_dims, None, Some(0.5)); + assert_eq!(output_size, 2); + + let output_size = calculate_output_size(input_dims, None, Some(1.5)); + assert_eq!(output_size, 6); + } + + #[test] + #[should_panic(expected = "Either output_size or scale_factor must be provided")] + fn test_panic() { + let input_dims = [1, 1, 4]; + calculate_output_size(input_dims, None, None); + } + + #[test] + #[should_panic(expected = "Scale factor is too large")] + fn test_large_scale_factor() { + let input_dims = [1, 1, usize::MAX - 1]; + calculate_output_size(input_dims, None, Some(2.0)); + } + + #[test] + fn test_module() { + let input = Tensor::::random( + [2, 3, 4], + Distribution::Uniform(0.0, 1.0), + &Default::default(), + ); + + // Test with output_size + let config = Interpolate1dConfig::new().with_output_size(Some(8)); + let interpolate = config.init(); + let output = interpolate.forward(input.clone()); + assert_eq!(output.dims(), [2, 3, 8]); + + // Test with scale_factor + let config = Interpolate1dConfig::new().with_scale_factor(Some(0.5)); + let interpolate = config.init(); + let output = interpolate.forward(input.clone()); + assert_eq!(output.dims(), [2, 3, 2]); + + // Test with different interpolation mode + let config = Interpolate1dConfig::new() + .with_output_size(Some(6)) + .with_mode(InterpolateMode::Linear); + let interpolate = config.init(); + let output = interpolate.forward(input); + assert_eq!(output.dims(), [2, 3, 6]); + } + + #[test] + fn display() { + let config = Interpolate1dConfig::new().with_output_size(Some(20)); + let layer = config.init(); + + assert_eq!( + alloc::format!("{}", layer), + "Interpolate1d {mode: Nearest, output_size: Some(20), \ + scale_factor: None}" + ); + } +} diff --git a/crates/burn-core/src/nn/interpolate/interpolate2d.rs b/crates/burn-core/src/nn/interpolate/interpolate2d.rs new file mode 100644 index 0000000000..9154ceb1c0 --- /dev/null +++ b/crates/burn-core/src/nn/interpolate/interpolate2d.rs @@ -0,0 +1,251 @@ +use alloc::format; + +use burn_tensor::module::interpolate; + +use crate as burn; + +use crate::config::Config; +use crate::module::{Content, DisplaySettings, Ignored, Module, ModuleDisplay}; +use crate::tensor::backend::Backend; +use crate::tensor::ops::InterpolateOptions; +use crate::tensor::Tensor; + +use super::InterpolateMode; + +/// Configuration for the 2D interpolation module. +/// +/// This struct defines the configuration options for the 2D interpolation operation. +/// It allows specifying the output size, scale factor, and interpolation mode. +#[derive(Config, Debug)] +pub struct Interpolate2dConfig { + /// Output size of the interpolated tensor. + /// If specified, this takes precedence over `scale_factor`. + #[config(default = "None")] + pub output_size: Option<[usize; 2]>, + + /// Scale factor for resizing the input tensor. + /// This is used when `output_size` is not specified. + #[config(default = "None")] + pub scale_factor: Option<[f32; 2]>, + + /// Interpolation mode to use for resizing. + /// Determines how the output values are calculated. + #[config(default = "InterpolateMode::Nearest")] + pub mode: InterpolateMode, +} + +/// Interpolate module for resizing tensors with shape [N, C, H, W]. +/// +/// This struct represents an interpolation module that can resize tensors +/// using various interpolation methods. It provides flexibility in specifying +/// either an output size or a scale factor for resizing, along with options +/// for the interpolation mode. +/// +/// The module can be used to upsample or downsample tensors, preserving the +/// number of channels and batch size while adjusting the height and width +/// dimensions. +/// +/// The module can be created using the [Interpolate2dConfig] struct and the +/// `init` method, which returns an instance of the [Interpolate2d] struct. +#[derive(Module, Clone, Debug)] +#[module(custom_display)] +pub struct Interpolate2d { + /// Output size of the interpolated tensor + pub output_size: Option<[usize; 2]>, + + /// Scale factor for resizing the input tensor + pub scale_factor: Option<[f32; 2]>, + + /// Interpolation mode used for resizing + pub mode: Ignored, +} + +impl Interpolate2dConfig { + /// Initialize the interpolation module + pub fn init(self) -> Interpolate2d { + Interpolate2d { + output_size: self.output_size, + scale_factor: self.scale_factor, + mode: Ignored(self.mode), + } + } +} +impl Interpolate2d { + /// Performs the forward pass of the interpolation module + /// + /// # Arguments + /// + /// * `input` - Input tensor with shape [N, C, H, W] + /// + /// # Returns + /// + /// Resized tensor with shape [N, C, H', W'], where H' and W' are determined by + /// the output_size or scale_factor specified in the module configuration + /// + /// # Example + /// + /// ```ignore + /// let input = Tensor::::random([1, 3, 64, 64], Distribution::Uniform(0.0, 1.0), &device); + /// let interpolate = Interpolate2dConfig::new() + /// .with_output_size(Some([128, 128])) + /// .init(); + /// let output = interpolate.forward(input); + /// assert_eq!(output.dims(), [1, 3, 128, 128]); + /// ``` + pub fn forward(&self, input: Tensor) -> Tensor { + let output_size = calculate_output_size(input.dims(), self.output_size, self.scale_factor); + interpolate( + input, + output_size, + InterpolateOptions::new(self.mode.0.clone().into()), + ) + } +} + +/// Calculates the output size for tensor interpolation. +/// +/// # Arguments +/// +/// * `input_dims` - The dimensions of the input tensor [N, C, H, W]. +/// * `output_size` - Optional desired output size [H', W']. +/// * `scale_factor` - Optional scale factor for height and width [scale_h, scale_w]. +/// +/// # Returns +/// +/// A tuple [H', W'] representing the calculated output size. +/// +/// # Panics +/// +/// Panics if neither `output_size` nor `scale_factor` is provided, +/// or if the scale factor results in dimensions exceeding usize::MAX. +fn calculate_output_size( + input_dims: [usize; 4], + output_size: Option<[usize; 2]>, + scale_factor: Option<[f32; 2]>, +) -> [usize; 2] { + match (output_size, scale_factor) { + (Some(output_size), None) => { + // Use provided + output_size + } + (None, Some(scale_factor)) => { + // Calculate output size based on scale factor + let [_, _, h, w] = input_dims; + + let new_dim_h = (h as f64) * (scale_factor[0] as f64); + + if new_dim_h > usize::MAX as f64 { + panic!("Scale factor for height is too large"); + } + + let new_dim_w = (w as f64) * (scale_factor[1] as f64); + + if new_dim_w > usize::MAX as f64 { + panic!("Scale factor for width is too large"); + } + + [new_dim_h as usize, new_dim_w as usize] + } + _ => panic!("Either output_size or scale_factor must be provided"), + } +} + +impl ModuleDisplay for Interpolate2d { + fn custom_settings(&self) -> Option { + DisplaySettings::new() + .with_new_line_after_attribute(false) + .optional() + } + + fn custom_content(&self, content: Content) -> Option { + content + .add("mode", &self.mode) + .add("output_size", &format!("{:?}", self.output_size)) + .add("scale_factor", &self.scale_factor) + .optional() + } +} +#[cfg(test)] +mod tests { + use burn_tensor::Distribution; + + use crate::TestBackend; + + use super::*; + + #[test] + fn test_calculate_output_size() { + let input_dims = [1, 1, 4, 4]; + + let output_size = calculate_output_size(input_dims, Some([2, 2]), None); + assert_eq!(output_size, [2, 2]); + + let output_size = calculate_output_size(input_dims, None, Some([2.0, 2.0])); + assert_eq!(output_size, [8, 8]); + + let output_size = calculate_output_size([1, 1, 4, 4], None, Some([0.5, 0.5])); + assert_eq!(output_size, [2, 2]); + + let output_size = calculate_output_size([1, 1, 4, 4], None, Some([2.0, 1.5])); + assert_eq!(output_size, [8, 6]); + } + + #[test] + #[should_panic(expected = "Either output_size or scale_factor must be provided")] + fn test_missing_params() { + calculate_output_size([1, 1, 4, 4], None, None); + } + + #[test] + #[should_panic(expected = "Scale factor for height is too large")] + fn test_infinite_height() { + calculate_output_size([1, 1, usize::MAX - 1, 4], None, Some([2.0, 1.0])); + } + + #[test] + #[should_panic(expected = "Scale factor for width is too large")] + fn test_infinite_width() { + calculate_output_size([1, 1, 4, usize::MAX - 1], None, Some([1.0, 2.0])); + } + + #[test] + fn test_module() { + let input = Tensor::::random( + [2, 3, 4, 4], + Distribution::Uniform(0.0, 1.0), + &Default::default(), + ); + + // Test with output_size + let config = Interpolate2dConfig::new().with_output_size(Some([8, 8])); + let interpolate = config.init(); + let output = interpolate.forward(input.clone()); + assert_eq!(output.dims(), [2, 3, 8, 8]); + + // Test with scale_factor + let config = Interpolate2dConfig::new().with_scale_factor(Some([0.5, 0.5])); + let interpolate = config.init(); + let output = interpolate.forward(input.clone()); + assert_eq!(output.dims(), [2, 3, 2, 2]); + + // Test with different interpolation mode + let config = Interpolate2dConfig::new() + .with_output_size(Some([6, 6])) + .with_mode(InterpolateMode::Linear); + let interpolate = config.init(); + let output = interpolate.forward(input); + assert_eq!(output.dims(), [2, 3, 6, 6]); + } + + #[test] + fn display() { + let config = Interpolate2dConfig::new().with_output_size(Some([20, 20])); + let layer = config.init(); + + assert_eq!( + alloc::format!("{}", layer), + "Interpolate2d {mode: Nearest, output_size: Some([20, 20]), \ + scale_factor: None}" + ); + } +} diff --git a/crates/burn-core/src/nn/interpolate/mod.rs b/crates/burn-core/src/nn/interpolate/mod.rs new file mode 100644 index 0000000000..e3b2ce47ee --- /dev/null +++ b/crates/burn-core/src/nn/interpolate/mod.rs @@ -0,0 +1,46 @@ +mod interpolate1d; +mod interpolate2d; + +pub use interpolate1d::*; +pub use interpolate2d::*; + +use crate::tensor::ops::InterpolateMode as OpsInterpolateMode; + +/// Algorithm used for downsampling and upsampling +/// +/// This enum defines different interpolation modes for resampling data. +#[derive(new, Debug, Clone, serde::Deserialize, serde::Serialize)] +pub enum InterpolateMode { + /// Nearest-neighbor interpolation + /// + /// This mode selects the value of the nearest sample point for each output pixel. + /// It is applicable for both temporal and spatial data. + Nearest, + + /// Linear interpolation + /// + /// This mode calculates the output value using linear + /// interpolation between nearby sample points. + /// + /// It is applicable for both temporal and spatial data. + Linear, + + /// Cubic interpolation + /// + /// This mode uses cubic interpolation to calculate the output value + /// based on surrounding sample points. + /// + /// It is applicable for both temporal and spatial data and generally + /// provides smoother results than linear interpolation. + Cubic, +} + +impl From for OpsInterpolateMode { + fn from(mode: InterpolateMode) -> Self { + match mode { + InterpolateMode::Nearest => OpsInterpolateMode::Nearest, + InterpolateMode::Linear => OpsInterpolateMode::Bilinear, + InterpolateMode::Cubic => OpsInterpolateMode::Bicubic, + } + } +} diff --git a/crates/burn-core/src/nn/mod.rs b/crates/burn-core/src/nn/mod.rs index 80fb02c8be..170d11877b 100644 --- a/crates/burn-core/src/nn/mod.rs +++ b/crates/burn-core/src/nn/mod.rs @@ -16,6 +16,9 @@ pub mod pool; /// Transformer module pub mod transformer; +/// Interpolate module +pub mod interpolate; + mod dropout; mod embedding; mod gelu; diff --git a/crates/burn-import/onnx-tests/build.rs b/crates/burn-import/onnx-tests/build.rs index 2ccca2e55b..67e2f091a3 100644 --- a/crates/burn-import/onnx-tests/build.rs +++ b/crates/burn-import/onnx-tests/build.rs @@ -6,8 +6,8 @@ fn main() { // Add onnx models. ModelGen::new() - .input("tests/add/add_int.onnx") .input("tests/add/add.onnx") + .input("tests/add/add_int.onnx") .input("tests/argmax/argmax.onnx") .input("tests/avg_pool1d/avg_pool1d.onnx") .input("tests/avg_pool2d/avg_pool2d.onnx") @@ -16,9 +16,13 @@ fn main() { .input("tests/clip/clip_opset16.onnx") .input("tests/clip/clip_opset7.onnx") .input("tests/concat/concat.onnx") + .input("tests/constant_of_shape/constant_of_shape.onnx") + .input("tests/constant_of_shape/constant_of_shape_full_like.onnx") .input("tests/conv1d/conv1d.onnx") .input("tests/conv2d/conv2d.onnx") .input("tests/conv3d/conv3d.onnx") + .input("tests/conv_transpose2d/conv_transpose2d.onnx") + .input("tests/conv_transpose3d/conv_transpose3d.onnx") .input("tests/cos/cos.onnx") .input("tests/div/div.onnx") .input("tests/dropout/dropout_opset16.onnx") @@ -26,70 +30,71 @@ fn main() { .input("tests/equal/equal.onnx") .input("tests/erf/erf.onnx") .input("tests/exp/exp.onnx") + .input("tests/expand/expand.onnx") .input("tests/flatten/flatten.onnx") .input("tests/gather/gather.onnx") .input("tests/gather_elements/gather_elements.onnx") .input("tests/gelu/gelu.onnx") .input("tests/global_avr_pool/global_avr_pool.onnx") + .input("tests/greater/greater.onnx") + .input("tests/greater_or_equal/greater_or_equal.onnx") .input("tests/layer_norm/layer_norm.onnx") + .input("tests/leaky_relu/leaky_relu.onnx") + .input("tests/less/less.onnx") + .input("tests/less_or_equal/less_or_equal.onnx") .input("tests/linear/linear.onnx") - .input("tests/log_softmax/log_softmax.onnx") .input("tests/log/log.onnx") + .input("tests/log_softmax/log_softmax.onnx") + .input("tests/mask_where/mask_where.onnx") .input("tests/matmul/matmul.onnx") - .input("tests/min/min.onnx") .input("tests/max/max.onnx") .input("tests/maxpool1d/maxpool1d.onnx") .input("tests/maxpool2d/maxpool2d.onnx") + .input("tests/min/min.onnx") .input("tests/mul/mul.onnx") .input("tests/neg/neg.onnx") .input("tests/not/not.onnx") .input("tests/pad/pad.onnx") - .input("tests/expand/expand.onnx") - .input("tests/greater/greater.onnx") - .input("tests/greater_or_equal/greater_or_equal.onnx") - .input("tests/less/less.onnx") - .input("tests/less_or_equal/less_or_equal.onnx") - .input("tests/recip/recip.onnx") - .input("tests/relu/relu.onnx") - .input("tests/leaky_relu/leaky_relu.onnx") + .input("tests/pow/pow.onnx") + .input("tests/pow/pow_int.onnx") .input("tests/prelu/prelu.onnx") + .input("tests/random_normal/random_normal.onnx") + .input("tests/random_uniform/random_uniform.onnx") + .input("tests/range/range.onnx") + .input("tests/recip/recip.onnx") .input("tests/reduce_max/reduce_max.onnx") - .input("tests/reduce_min/reduce_min.onnx") .input("tests/reduce_mean/reduce_mean.onnx") + .input("tests/reduce_min/reduce_min.onnx") .input("tests/reduce_prod/reduce_prod.onnx") - .input("tests/reduce_sum/reduce_sum_opset13.onnx") .input("tests/reduce_sum/reduce_sum_opset11.onnx") + .input("tests/reduce_sum/reduce_sum_opset13.onnx") + .input("tests/relu/relu.onnx") .input("tests/reshape/reshape.onnx") - .input("tests/resize/resize.onnx") + .input("tests/resize/resize_with_sizes.onnx") + .input("tests/resize/resize_1d_linear_scale.onnx") + .input("tests/resize/resize_1d_nearest_scale.onnx") + .input("tests/resize/resize_2d_bicubic_scale.onnx") + .input("tests/resize/resize_2d_bilinear_scale.onnx") + .input("tests/resize/resize_2d_nearest_scale.onnx") .input("tests/shape/shape.onnx") .input("tests/sigmoid/sigmoid.onnx") .input("tests/sign/sign.onnx") .input("tests/sin/sin.onnx") + .input("tests/slice/slice.onnx") .input("tests/softmax/softmax.onnx") .input("tests/sqrt/sqrt.onnx") - .input("tests/sub/sub_int.onnx") + .input("tests/squeeze/squeeze_multiple.onnx") + .input("tests/squeeze/squeeze_opset13.onnx") + .input("tests/squeeze/squeeze_opset16.onnx") .input("tests/sub/sub.onnx") - .input("tests/tanh/tanh.onnx") - .input("tests/transpose/transpose.onnx") - .input("tests/conv_transpose2d/conv_transpose2d.onnx") - .input("tests/conv_transpose3d/conv_transpose3d.onnx") - .input("tests/pow/pow.onnx") - .input("tests/pow/pow_int.onnx") - .input("tests/slice/slice.onnx") + .input("tests/sub/sub_int.onnx") .input("tests/sum/sum.onnx") .input("tests/sum/sum_int.onnx") + .input("tests/tanh/tanh.onnx") + .input("tests/transpose/transpose.onnx") .input("tests/unsqueeze/unsqueeze.onnx") - .input("tests/unsqueeze/unsqueeze_opset16.onnx") .input("tests/unsqueeze/unsqueeze_opset11.onnx") - .input("tests/mask_where/mask_where.onnx") - .input("tests/squeeze/squeeze_opset16.onnx") - .input("tests/squeeze/squeeze_opset13.onnx") - .input("tests/squeeze/squeeze_multiple.onnx") - .input("tests/random_uniform/random_uniform.onnx") - .input("tests/random_normal/random_normal.onnx") - .input("tests/constant_of_shape/constant_of_shape.onnx") - .input("tests/constant_of_shape/constant_of_shape_full_like.onnx") - .input("tests/range/range.onnx") + .input("tests/unsqueeze/unsqueeze_opset16.onnx") .out_dir("model/") .run_from_script(); diff --git a/crates/burn-import/onnx-tests/tests/onnx_tests.rs b/crates/burn-import/onnx-tests/tests/onnx_tests.rs index b6cefac8c1..2fc5ec6758 100644 --- a/crates/burn-import/onnx-tests/tests/onnx_tests.rs +++ b/crates/burn-import/onnx-tests/tests/onnx_tests.rs @@ -15,19 +15,23 @@ macro_rules! include_models { // ATTENTION: Modify this macro to include all models in the `model` directory. include_models!( - add_int, add, + add_int, argmax, - avg_pool2d, avg_pool1d, + avg_pool2d, batch_norm, cast, clip_opset16, clip_opset7, concat, + constant_of_shape, + constant_of_shape_full_like, conv1d, conv2d, conv3d, + conv_transpose2d, + conv_transpose3d, cos, div, dropout_opset16, @@ -41,37 +45,46 @@ include_models!( gather_elements, gelu, global_avr_pool, + greater, + greater_or_equal, layer_norm, leaky_relu, + less, + less_or_equal, linear, - log_softmax, log, + log_softmax, mask_where, matmul, - min, max, maxpool1d, maxpool2d, + min, mul, neg, not, pad, - greater, - greater_or_equal, - less, - less_or_equal, + pow, + pow_int, prelu, + random_normal, + random_uniform, range, recip, reduce_max, - reduce_min, reduce_mean, + reduce_min, reduce_prod, - reduce_sum_opset13, reduce_sum_opset11, + reduce_sum_opset13, relu, reshape, - resize, + resize_with_sizes, + resize_1d_linear_scale, + resize_1d_nearest_scale, + resize_2d_bicubic_scale, + resize_2d_bilinear_scale, + resize_2d_nearest_scale, shape, sigmoid, sign, @@ -79,26 +92,18 @@ include_models!( slice, softmax, sqrt, - sub_int, + squeeze_multiple, + squeeze_opset13, + squeeze_opset16, sub, + sub_int, sum, sum_int, tanh, transpose, - conv_transpose2d, - conv_transpose3d, - pow, - pow_int, unsqueeze, - unsqueeze_opset16, unsqueeze_opset11, - squeeze_opset16, - squeeze_opset13, - squeeze_multiple, - random_uniform, - random_normal, - constant_of_shape, - constant_of_shape_full_like + unsqueeze_opset16 ); #[cfg(test)] @@ -865,10 +870,10 @@ mod tests { } #[test] - fn resize() { + fn resize_with_sizes() { // Initialize the model without weights (because the exported file does not contain them) let device = Default::default(); - let model: resize::Model = resize::Model::new(&device); + let model: resize_with_sizes::Model = resize_with_sizes::Model::new(&device); // Run the model let input = Tensor::::from_floats( @@ -880,14 +885,153 @@ mod tests { ]]], &device, ); - let size = Tensor::::from_ints([1, 1, 2, 3], &device); - let output = model.forward(input, size); + // The sizes are [1, 1, 2, 3] + let output = model.forward(input); let expected = TensorData::from([[[[0.0f32, 1.5, 3.0], [12.0, 13.5, 15.0]]]]); output.to_data().assert_eq(&expected, true); } + #[test] + #[ignore = "https://github.com/tracel-ai/burn/issues/2080"] + fn resize_with_scales_1d_linear() { + // Initialize the model without weights (because the exported file does not contain them) + let device = Default::default(); + let model: resize_1d_linear_scale::Model = + resize_1d_linear_scale::Model::new(&device); + + // Run the model + let input = Tensor::::from_floats( + [[[1.5410, -0.2934, -2.1788, 0.5684, -1.0845, -1.3986]]], + &device, + ); + + // The scales are 1.5 + let output = model.forward(input); + + let output_sum = output.sum().into_scalar(); + let expected_sum = -4.568_224; // from pytorch + + assert!(expected_sum.approx_eq(output_sum, (1.0e-4, 2))); + } + + #[test] + fn resize_with_scales_2d_bilinear() { + // Initialize the model without weights (because the exported file does not contain them) + let device = Default::default(); + let model: resize_2d_bilinear_scale::Model = + resize_2d_bilinear_scale::Model::new(&device); + + // Run the model + let input = Tensor::::from_floats( + [[[ + [-1.1258, -1.1524, -0.2506, -0.4339, 0.8487, 0.6920], + [-0.3160, -2.1152, 0.3223, -1.2633, 0.3500, 0.3081], + [0.1198, 1.2377, 1.1168, -0.2473, -1.3527, -1.6959], + [0.5667, 0.7935, 0.4397, 0.1124, 0.6408, 0.4412], + [-0.2159, -0.7425, 0.5627, 0.2596, 0.5229, 2.3022], + [-1.4689, -1.5867, 1.2032, 0.0845, -1.2001, -0.0048], + ]]], + &device, + ); + + // The scales are 1.5, 1.5 + let output = model.forward(input); + + let output_sum = output.sum().into_scalar(); + let expected_sum = -3.401_126_6; // from pytorch + + assert!(expected_sum.approx_eq(output_sum, (1.0e-4, 2))); + } + + #[test] + fn resize_with_scales_2d_nearest() { + // Initialize the model without weights (because the exported file does not contain them) + let device = Default::default(); + let model: resize_2d_nearest_scale::Model = + resize_2d_nearest_scale::Model::::new(&device); + + // Run the model + let input = Tensor::::from_floats( + [[[ + [-1.1258, -1.1524, -0.2506, -0.4339, 0.8487, 0.6920], + [-0.3160, -2.1152, 0.3223, -1.2633, 0.3500, 0.3081], + [0.1198, 1.2377, 1.1168, -0.2473, -1.3527, -1.6959], + [0.5667, 0.7935, 0.4397, 0.1124, 0.6408, 0.4412], + [-0.2159, -0.7425, 0.5627, 0.2596, 0.5229, 2.3022], + [-1.4689, -1.5867, 1.2032, 0.0845, -1.2001, -0.0048], + ]]], + &device, + ); + + // The scales are 1.5, 1.5 + let output = model.forward(input); + + assert_eq!(output.dims(), [1, 1, 9, 9]); + + let output_sum = output.sum().into_scalar(); + let expected_sum = -0.812_227_7; // from pytorch + + assert!(expected_sum.approx_eq(output_sum, (1.0e-4, 2))); + } + + #[test] + fn resize_with_scales_1d_nearest() { + // Initialize the model without weights (because the exported file does not contain them) + let device = Default::default(); + let model: resize_1d_nearest_scale::Model = + resize_1d_nearest_scale::Model::::new(&device); + + // Run the model + let input = Tensor::::from_floats( + [[[1.5410, -0.2934, -2.1788, 0.5684, -1.0845, -1.3986]]], + &device, + ); + + // The scales are 1.5, 1.5 + let output = model.forward(input); + + assert_eq!(output.dims(), [1, 1, 9]); + + let output_sum = output.sum().into_scalar(); + let expected_sum = -4.568_224; // from pytorch + + assert!(expected_sum.approx_eq(output_sum, (1.0e-4, 2))); + } + + #[test] + fn resize_with_scales_2d_bicubic() { + // Initialize the model without weights (because the exported file does not contain them) + let device = Default::default(); + let model: resize_2d_bicubic_scale::Model = + resize_2d_bicubic_scale::Model::::new(&device); + + // Run the model + let input = Tensor::::from_floats( + [[[ + [-1.1258, -1.1524, -0.2506, -0.4339, 0.8487, 0.6920], + [-0.3160, -2.1152, 0.3223, -1.2633, 0.3500, 0.3081], + [0.1198, 1.2377, 1.1168, -0.2473, -1.3527, -1.6959], + [0.5667, 0.7935, 0.4397, 0.1124, 0.6408, 0.4412], + [-0.2159, -0.7425, 0.5627, 0.2596, 0.5229, 2.3022], + [-1.4689, -1.5867, 1.2032, 0.0845, -1.2001, -0.0048], + ]]], + &device, + ); + + // The scales are 1.5, 1.5 + let output = model.forward(input); + + assert_eq!(output.dims(), [1, 1, 9, 9]); + + let output_sum = output.sum().into_scalar(); + + let expected_sum = -3.515_921; // from pytorch + + assert!(expected_sum.approx_eq(output_sum, (1.0e-3, 2))); + } + #[test] fn shape() { let device = Default::default(); diff --git a/crates/burn-import/onnx-tests/tests/resize/resize_1d_linear_scale.onnx b/crates/burn-import/onnx-tests/tests/resize/resize_1d_linear_scale.onnx new file mode 100644 index 0000000000000000000000000000000000000000..b98596a4165261b4f8baa746694437bce10f184b GIT binary patch literal 370 zcmaJ+Jx{|h5RKC&bx*0(TH-?y>JSNr3@v12DS`<`Mi$F)oyJmPM~;I8Q-4Z+0{;O1 z2LFbH6eh&*bfrWT%30dE1Ra>Sk)s3whOE1IeFBjpT zX}kpQnH7zQ2_DVSL*O|3&47}&o00;gsCH`O^PS%g2maii6hZ-la8(3od;#gOX&wLo literal 0 HcmV?d00001 diff --git a/crates/burn-import/onnx-tests/tests/resize/resize_1d_nearest_scale.onnx b/crates/burn-import/onnx-tests/tests/resize/resize_1d_nearest_scale.onnx new file mode 100644 index 0000000000000000000000000000000000000000..dd22c30a80054676e2f4b8552668d018e17ad946 GIT binary patch literal 368 zcmaKnzfQw25XR%QN!(MAiY5LiLLDN(kfDK&EUjRIk&(r6T!&a{?8tGD%G5{66YvV~ z7Q9YV5||Lf>2&vf`rUUJlc3sKZFmuRi}_-HbNl7O10m_XR<&iyQr+09vGh8c{O)4> zqb4)(UfHG*3BjW&dJG+Bx9sb4*`hJLBk-i^0N}tMP~!LaC`g|~E#E}!_dQ8r#O+>IIE8qp@hKr+aD;w zjHvCwd=B9-^(-&nbjLb{K*^+{FNRgc^W~78TtI{oIfalw#w;)BAu6~Sb>(MQFnsZ^ JZ*u2OkN4!Xj5n zZ-iuqQ=^$|3#Dsjgp#yY8y?SDCAO00N=vTW9vvlTFweV8H0u~sN2xbH_X7Z8_x55M?3x7bC Qz~PmDvw_V9ca!k=8#!lc7ytkO literal 0 HcmV?d00001 diff --git a/crates/burn-import/onnx-tests/tests/resize/resize_2d_bilinear_scale.onnx b/crates/burn-import/onnx-tests/tests/resize/resize_2d_bilinear_scale.onnx new file mode 100644 index 0000000000000000000000000000000000000000..d60fd8ca9537b5f38f1be519e2a8294411c6a339 GIT binary patch literal 382 zcmaJ+Jx{|h5RF5UI;U1jE%BiY70O127CN#N!G^@b$YMFE(^yLE$Z?Qh;-};%@DJ#p z;eX(e!h{%}-s$e$dv_RrWRR2wB}J*%~GdRh?-%Loeg-x0m3b zX_|ol%&Ly3C)kV7edIWAn-P1<`%RBLct_w1X-&X^Ka5E*M#PiVBX7kkPlA!tH7s(a z^j1h_I5nEdwotleMkq;ZwdKi-RbnS;uC(O3?a@(s3bVY+L~eV$C@4F2oSTmxp>z)3 z_exL+$(bIsXD}T~-kQ&ZDRFd{Kb31rOVB^{0h7bA;c SK$gPcmw&aP)rPljH2(s|Dr%7c literal 0 HcmV?d00001 diff --git a/crates/burn-import/onnx-tests/tests/resize/resize_2d_nearest_scale.onnx b/crates/burn-import/onnx-tests/tests/resize/resize_2d_nearest_scale.onnx new file mode 100644 index 0000000000000000000000000000000000000000..42e55cd4b3ee580db918d2deaadeaa1459852980 GIT binary patch literal 380 zcmaJ+J5Izf6pWXApB5=t;-eHRlnxgg*pg-ubVxLmG?wFah^37kIS$fNag>~ZE3kLr zA|xbGA&SxHy&28C7lt^fcUBu-Cf;hfT3%g$&fo##>`tr3GG(c5ZQWXWnMA+4H2hIX z3f>Fbwt}35ZjA0@$9Y?i*!R3&cgTfz1fEoV2srRZF!slaB*-2`BVR@8k4!RHaIMWo zDrN<>hN-5|reaoVMJv4#DX?Z&Rl*v{J9I!!;gGkvfsN_2~3~- N>y3JEbUTZWz5u-9Y1jY& literal 0 HcmV?d00001 diff --git a/crates/burn-import/onnx-tests/tests/resize/resize_scale.py b/crates/burn-import/onnx-tests/tests/resize/resize_scale.py new file mode 100755 index 0000000000..5e3861ec6d --- /dev/null +++ b/crates/burn-import/onnx-tests/tests/resize/resize_scale.py @@ -0,0 +1,71 @@ +#!/usr/bin/env python3 + +import torch +import torch.nn as nn +import torch.onnx + +class InterpolateModel(nn.Module): + def __init__(self, scale_factor=None, size=None, mode='nearest', align_corners=None): + super(InterpolateModel, self).__init__() + self.scale_factor = scale_factor + self.size = size + self.mode = mode + self.align_corners = align_corners + + def forward(self, x): + return nn.functional.interpolate(x, scale_factor=self.scale_factor, size=self.size, + mode=self.mode, align_corners=self.align_corners) + +def export_interpolate_onnx(filename, batch_size=1, channels=1, height=6, width=6, + scale_factor=None, size=None, mode='nearest', dim=2, align_corners=None): + model = InterpolateModel(scale_factor, size, mode, align_corners) + model.eval() + + # Add seed for reproducibility + torch.manual_seed(0) + + # Create a dummy input + if dim == 1: + dummy_input = torch.randn(batch_size, channels, width) + elif dim == 2: + dummy_input = torch.randn(batch_size, channels, height, width) + else: + raise ValueError("Unsupported dimension. Use 1 for temporal or 2 for spatial.") + + # Export the model + torch.onnx.export(model, dummy_input, filename, + input_names=['input'], output_names=['output'], + dynamic_axes={'input': {0: 'batch_size'}, + 'output': {0: 'batch_size'}}, + opset_version=17) + + output = model(dummy_input) + print(f"Input shape: {dummy_input.shape}") + print(f"Output shape: {output.shape}") + + # Print sum data + print(f"Input sum: {dummy_input.sum()}") + print(f"Output sum: {output.sum()}") + + print(f"Input: {dummy_input}") + print(f"Output: {output}") + + print(f"Model exported to {filename}") + + print() + +# Usage examples: +if __name__ == "__main__": + + + # 1D (temporal) examples + export_interpolate_onnx("resize_1d_nearest_scale.onnx", scale_factor=1.5, mode='nearest', dim=1) + export_interpolate_onnx("resize_1d_linear_scale.onnx", scale_factor=1.5, mode='linear', dim=1, align_corners=True) + + # Cubic interpolation is not supported for 1D tensors + # export_interpolate_onnx("resize_1d_cubic_scale.onnx", scale_factor=1.5, mode='cubic', dim=1) + + # 2D (spatial) examples + export_interpolate_onnx("resize_2d_nearest_scale.onnx", scale_factor=1.5, mode='nearest', dim=2) + export_interpolate_onnx("resize_2d_bilinear_scale.onnx", scale_factor=1.5, mode='bilinear', dim=2, align_corners=True) + export_interpolate_onnx("resize_2d_bicubic_scale.onnx", scale_factor=1.5, mode='bicubic', dim=2, align_corners=True) diff --git a/crates/burn-import/onnx-tests/tests/resize/resize.onnx b/crates/burn-import/onnx-tests/tests/resize/resize_with_sizes.onnx similarity index 63% rename from crates/burn-import/onnx-tests/tests/resize/resize.onnx rename to crates/burn-import/onnx-tests/tests/resize/resize_with_sizes.onnx index 3067216282a1e671c3f324a57f0d1250526a2ef0..dca2f3769cb5be46736118ecfe7e0d364cfb9cc6 100644 GIT binary patch delta 48 zcmX@Xc#M&UgG-35D784VD%EQ1M4m`tEkO None: input_tensor = helper.make_tensor_value_info("input_tensor", TensorProto.FLOAT, [1, 1, 4, 4]) - sizes_tensor = helper.make_tensor_value_info("sizes", TensorProto.INT64, [4]) + + # Create sizes as a constant tensor + sizes = np.array([1, 1, 2, 3], dtype=np.int64) + sizes_tensor = helper.make_tensor( + name="sizes", + data_type=TensorProto.INT64, + dims=sizes.shape, + vals=sizes.flatten().tolist(), + ) resize_node = helper.make_node( "Resize", @@ -20,15 +29,16 @@ def main() -> None: graph_def = helper.make_graph( nodes=[resize_node], name="ResizeGraph", - inputs=[input_tensor, sizes_tensor], + inputs=[input_tensor], outputs=[ helper.make_tensor_value_info("output", TensorProto.FLOAT, [1, 1, 2, 2]) ], + initializer=[sizes_tensor], ) model_def = helper.make_model(graph_def, producer_name="resize") - onnx.save(model_def, "resize.onnx") + onnx.save(model_def, "resize_with_sizes.onnx") if __name__ == "__main__": diff --git a/crates/burn-import/src/burn/codegen.rs b/crates/burn-import/src/burn/codegen.rs index 811f63b526..798636c323 100644 --- a/crates/burn-import/src/burn/codegen.rs +++ b/crates/burn-import/src/burn/codegen.rs @@ -64,6 +64,13 @@ impl ToTokens for f64 { } } +/// Prettier output for `f32` +impl ToTokens for f32 { + fn to_tokens(&self) -> TokenStream { + convert_primitive(self) + } +} + /// Padding configuration impl ToTokens for PaddingConfig1d { fn to_tokens(&self) -> TokenStream { diff --git a/crates/burn-import/src/burn/node/resize.rs b/crates/burn-import/src/burn/node/resize.rs index 72f71bfb22..59afcfb607 100644 --- a/crates/burn-import/src/burn/node/resize.rs +++ b/crates/burn-import/src/burn/node/resize.rs @@ -1,29 +1,17 @@ use super::{Node, NodeCodegen}; -use crate::burn::{OtherType, Scope, TensorType, Type}; -use burn::module::Module; +use crate::burn::{OtherType, Scope, TensorType, ToTokens, Type}; use burn::record::PrecisionSettings; use proc_macro2::TokenStream; use quote::quote; -#[derive(Module, Debug, Clone)] -pub enum ResizeMode { - Nearest, - Linear, - Cubic, -} - -#[derive(new, Module, Debug, Clone)] -pub struct ResizeOptions { - pub mode: ResizeMode, -} - #[derive(Debug, Clone)] pub struct ResizeNode { pub field: OtherType, pub input: TensorType, pub output: TensorType, - pub output_size: TensorType, - pub config: ResizeOptions, + mode: String, + scales: Vec, + sizes: Vec, } impl ResizeNode { @@ -31,20 +19,29 @@ impl ResizeNode { name: S, input: TensorType, output: TensorType, - output_size: TensorType, - config: ResizeOptions, + mode: String, + scales: Vec, + sizes: Vec, ) -> Self { + let ty = if input.dim == 3 { + quote! { + Interpolate1d + } + } else if input.dim == 4 { + quote! { + Interpolate2d + } + } else { + panic!("Unsupported input dimension for resize node"); + }; + Self { - field: OtherType::new( - name, - quote! { - burn::module::Ignored - }, - ), + field: OtherType::new(name, ty), input, output, - output_size, - config, + mode, + scales, + sizes, } } } @@ -55,10 +52,7 @@ impl NodeCodegen for ResizeNode { } fn input_types(&self) -> Vec { - vec![ - Type::Tensor(self.input.clone()), - Type::Tensor(self.output_size.clone()), - ] + vec![Type::Tensor(self.input.clone())] } fn field_type(&self) -> Option { @@ -68,59 +62,96 @@ impl NodeCodegen for ResizeNode { fn field_init(&self) -> Option { let name = &self.field.name; - let mode = match self.config.mode { - ResizeMode::Linear => quote! { InterpolateMode::Bilinear }, - ResizeMode::Nearest => quote! { InterpolateMode::Nearest }, - ResizeMode::Cubic => quote! { InterpolateMode::Bicubic }, + let mode = match self.mode.as_str() { + "nearest" => quote! { InterpolateMode::Nearest }, + "linear" => quote! { InterpolateMode::Linear }, + "cubic" => quote! { InterpolateMode::Cubic }, + _ => panic!("Unsupported mode for resize node"), }; - let tokens = quote! { - let #name = InterpolateOptions { - mode: #mode, + let tokens = if self.input.dim == 3 { + let size = if let Some(size) = self.sizes.first() { + let size = size.to_tokens(); + quote! { Some(#size) } + } else { + quote! { None } }; - let #name = burn::module::Ignored(#name); + + let scale_factor = if let Some(scale) = self.scales.first() { + let scale = scale.to_tokens(); + quote! { Some(#scale) } + } else { + quote! { None } + }; + + quote! { + let #name = Interpolate1dConfig::new() + .with_output_size(#size) + .with_scale_factor(#scale_factor) + .with_mode(#mode) + .init(); + } + } else if self.input.dim == 4 { + let size = if self.sizes.len() == 2 { + let h = self.sizes[0].to_tokens(); + let w = self.sizes[1].to_tokens(); + quote! { Some([#h, #w]) } + } else { + quote! { None } + }; + + let scale_factor = if self.scales.len() == 2 { + let h = self.scales[0].to_tokens(); + let w = self.scales[1].to_tokens(); + quote! { Some([#h, #w]) } + } else { + quote! { None } + }; + + quote! { + let #name = Interpolate2dConfig::new() + .with_output_size(#size) + .with_scale_factor(#scale_factor) + .with_mode(#mode) + .init(); + } + } else { + panic!("Unsupported input dimension for resize node"); }; Some(tokens) } + fn register_imports(&self, imports: &mut crate::burn::BurnImports) { + imports.register("burn::nn::interpolate::InterpolateMode"); + if self.input.dim == 3 { + imports.register("burn::nn::interpolate::Interpolate1dConfig"); + imports.register("burn::nn::interpolate::Interpolate1d"); + } else if self.input.dim == 4 { + imports.register("burn::nn::interpolate::Interpolate2dConfig"); + imports.register("burn::nn::interpolate::Interpolate2d"); + } else { + panic!("Unsupported input dimension for resize node"); + } + } + fn field_serialize(&self, serializer: S) -> Result { S::serialize_none(serializer) } fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream { let input = scope.tensor_use_owned(&self.input, node_position); - let output_size = scope.tensor_use_owned(&self.output_size, node_position); let output = &self.output.name; - let field = &self.field.name; quote! { - let output_size_data = #output_size.to_data(); - let mut output_size = [0usize; 2]; - - for (i, &x) in output_size_data.as_slice::().unwrap().iter().rev().take(2).rev().enumerate() { - output_size[i] = x.elem::() as usize; - } - - let #output = interpolate( - #input, - output_size, - self.#field.0.clone(), - ); + let #output = self.#field.forward(#input); } } fn into_node(self) -> Node { Node::Resize(self) } - - fn register_imports(&self, imports: &mut crate::burn::BurnImports) { - imports.register("burn::tensor::ElementConversion"); - imports.register("burn::tensor::module::interpolate"); - imports.register("burn::tensor::ops::InterpolateMode"); - imports.register("burn::tensor::ops::InterpolateOptions"); - } } #[cfg(test)] @@ -135,47 +166,42 @@ mod tests { }; #[test] - fn test_codegen_nodes() { + fn test_codegen_nodes_2d() { let mut graph = BurnGraph::::default(); graph.register(ResizeNode::new( "resize", TensorType::new_float("tensor1", 4), TensorType::new_float("tensor2", 4), - TensorType::new_int("output_size", 1), - ResizeOptions::new(ResizeMode::Linear), + "nearest".to_string(), + vec![0.5, 0.5], + vec![], )); - graph.register_input_output( - vec!["tensor1".to_string(), "output_size".to_string()], - vec!["tensor2".to_string()], - ); + graph.register_input_output(vec!["tensor1".to_string()], vec!["tensor2".to_string()]); let expected = quote! { - use burn::tensor::module::interpolate; - use burn::tensor::ops::InterpolateMode; - use burn::tensor::ops::InterpolateOptions; - use burn::tensor::ElementConversion; - use burn::tensor::Int; + use burn::nn::interpolate::Interpolate2d; + use burn::nn::interpolate::Interpolate2dConfig; + use burn::nn::interpolate::InterpolateMode; use burn::{ module::Module, tensor::{backend::Backend, Tensor}, }; - #[derive(Module, Debug)] pub struct Model { - resize: burn::module::Ignored, + resize: Interpolate2d, phantom: core::marker::PhantomData, device: burn::module::Ignored, } - - impl Model { + impl Model { #[allow(unused_variables)] pub fn new(device: &B::Device) -> Self { - let resize = InterpolateOptions { - mode: InterpolateMode::Bilinear, - }; - let resize = burn::module::Ignored(resize); + let resize = Interpolate2dConfig::new() + .with_output_size(None) + .with_scale_factor(Some([0.5, 0.5])) + .with_mode(InterpolateMode::Nearest) + .init(); Self { resize, phantom: core::marker::PhantomData, @@ -183,20 +209,62 @@ mod tests { } } #[allow(clippy::let_and_return, clippy::approx_constant)] - pub fn forward( - &self, - tensor1: Tensor, - output_size: Tensor - ) -> Tensor { - let output_size_data = output_size.to_data(); - let mut output_size = [0usize; 2]; - - for (i, &x) in output_size_data.as_slice::().unwrap().iter().rev().take(2).rev().enumerate() { - output_size[i] = x.elem::() as usize; - } + pub fn forward(&self, tensor1: Tensor) -> Tensor { + let tensor2 = self.resize.forward(tensor1); + tensor2 + } + } + }; + + assert_tokens(graph.codegen(), expected); + } + + #[test] + fn test_codegen_nodes_1d() { + let mut graph = BurnGraph::::default(); + + graph.register(ResizeNode::new( + "resize", + TensorType::new_float("tensor1", 3), + TensorType::new_float("tensor2", 3), + "cubic".to_string(), + vec![], + vec![20], + )); - let tensor2 = interpolate(tensor1, output_size, self.resize.0.clone()); + graph.register_input_output(vec!["tensor1".to_string()], vec!["tensor2".to_string()]); + let expected = quote! { + use burn::nn::interpolate::Interpolate1d; + use burn::nn::interpolate::Interpolate1dConfig; + use burn::nn::interpolate::InterpolateMode; + use burn::{ + module::Module, + tensor::{backend::Backend, Tensor}, + }; + #[derive(Module, Debug)] + pub struct Model { + resize: Interpolate1d, + phantom: core::marker::PhantomData, + device: burn::module::Ignored, + } + impl Model { + #[allow(unused_variables)] + pub fn new(device: &B::Device) -> Self { + let resize = Interpolate1dConfig::new() + .with_output_size(Some(20)) + .with_scale_factor(None) + .with_mode(InterpolateMode::Cubic) + .init(); + Self { + resize, + phantom: core::marker::PhantomData, + device: burn::module::Ignored(device.clone()), + } + } + #[allow(clippy::let_and_return, clippy::approx_constant)] + pub fn forward(&self, tensor1: Tensor) -> Tensor { + let tensor2 = self.resize.forward(tensor1); tensor2 } } diff --git a/crates/burn-import/src/onnx/op_configuration.rs b/crates/burn-import/src/onnx/op_configuration.rs index 8ba031e04d..919d74b399 100644 --- a/crates/burn-import/src/onnx/op_configuration.rs +++ b/crates/burn-import/src/onnx/op_configuration.rs @@ -7,7 +7,7 @@ use burn::nn::{ PaddingConfig2d, PaddingConfig3d, }; -use crate::burn::node::{pad::PadConfig, resize::ResizeMode}; +use crate::burn::node::pad::PadConfig; use onnx_ir::ir::{ArgType, AttributeValue, Data, Node}; /// Create a Conv1dConfig from the attributes of the node @@ -976,26 +976,132 @@ pub fn reshape_config(node: &Node) -> Vec { } } -pub fn resize_config(node: &Node) -> ResizeMode { +pub fn resize_config(node: &Node) -> (String, Vec, Vec) { let mut mode: String = "".to_string(); + + let mut scales: Vec; + let mut sizes: Vec; + + let input = if let ArgType::Tensor(tensor) = &node + .inputs + .first() + .expect("Resize: Input tensor must be present") + .ty + { + tensor + } else { + panic!("Resize: input must be a tensor") + }; + + // Note: we are ignoring some attributes because results are approximately the same + // and we are not supporting all the attributes of the Resize operator. + // However, some attributes are important to be checked and we are checking + // against the default values of the attributes. + // TODO revisit this when we have more Resize operators in the model for (key, value) in node.attrs.iter() { match key.as_str() { - "coordinate_transformation_mode" => {} - "cubic_coeff_a" => {} - "mode" => mode = value.clone().into_string(), - "nearest_mode" => {} + "antialias" => assert_eq!( + value.clone().into_i32(), + 0, + "Resize: antialias other than 0 is not supported" + ), + "axes" => panic!("Resize: custom axes attribute is not supported"), + "coordinate_transformation_mode" => { + log::warn!("Resize: coordinate_transformation_mode is ignored") + } + + "cubic_coeff_a" => log::warn!("Resize: cubic_coeff_a is ignored"), + "exclude_outside" => assert_eq!( + value.clone().into_i32(), + 0, + "Resize: exclude_outside other than 0 is not supported" + ), + "extrapolation_value" => assert_eq!( + value.clone().into_f32(), + 0.0, + "Resize: extrapolation_value other than 0.0 is not supported" + ), + "keep_aspect_ratio_policy" => { + assert_eq!( + value.clone().into_string().to_lowercase(), + "stretch", + "Resize: keep_aspect_ratio_policy other than 'stretch' is not supported" + ) + } + "mode" => mode = value.clone().into_string().to_lowercase(), + "nearest_mode" => log::warn!("Resize: nearest_mode is ignored"), + _ => {} } } - let mode = match mode.as_str() { - "nearest" => ResizeMode::Nearest, - "linear" => ResizeMode::Linear, - "cubic" => ResizeMode::Cubic, - _ => panic!("Resize: invalid mode string, must be 'nearest', 'linear', or 'cubic'"), - }; + let roi: Vec = node + .inputs + .get(1) + .map(|input| { + if let Some(data) = &input.value { + data.clone().into_f32s() + } else { + vec![] + } + }) + .unwrap_or_default(); + + scales = node + .inputs + .get(2) + .map(|input| { + if let Some(data) = &input.value { + data.clone().into_f32s() + } else { + vec![] + } + }) + .unwrap_or_default(); + + sizes = node + .inputs + .get(3) + .map(|input| { + if let Some(data) = &input.value { + data.clone() + .into_i64s() + .iter() + .map(|&x| x as usize) + .collect() + } else { + vec![] + } + }) + .unwrap_or_default(); + + if mode.is_empty() { + panic!("Resize: mode attribute is required") + } + + if !roi.is_empty() { + panic!("Resize: roi input is not supported") + } + + if scales.is_empty() && sizes.is_empty() { + panic!("Resize: either scales or sizes input is required") + } + + if !scales.is_empty() { + assert!(scales.len() == input.dim); + // ignore the fist two items from scales + // because they are the batch and channel dimensions + scales = scales.iter().skip(2).cloned().collect(); + } + + if !sizes.is_empty() { + assert!(sizes.len() == input.dim); + // ignore the fist two items from sizes + // because they are the batch and channel dimensions + sizes = sizes.iter().skip(2).cloned().collect(); + } - mode + (mode, scales, sizes) } //Note this function should only execute if the second input is a constant diff --git a/crates/burn-import/src/onnx/to_burn.rs b/crates/burn-import/src/onnx/to_burn.rs index ad839aed4f..f8b0e0ad8f 100644 --- a/crates/burn-import/src/onnx/to_burn.rs +++ b/crates/burn-import/src/onnx/to_burn.rs @@ -46,7 +46,7 @@ use crate::{ random_uniform::RandomUniformNode, range::RangeNode, reshape::ReshapeNode, - resize::{ResizeNode, ResizeOptions}, + resize::ResizeNode, slice::SliceNode, squeeze::SqueezeNode, sum::SumNode, @@ -646,13 +646,12 @@ impl ParsedOnnxGraph { let name = &node.name; let input = TensorType::from(&node.inputs[0]); - let output_size = TensorType::from(&node.inputs[3]); let output = TensorType::from(node.outputs.first().unwrap()); - let mode = resize_config(&node); + let (mode, scales, sizes) = resize_config(&node); - ResizeNode::new(name, input, output, output_size, ResizeOptions { mode }) + ResizeNode::new(name, input, output, mode, scales, sizes) } fn min_conversion(node: Node) -> BinaryNode { diff --git a/crates/burn-tensor/src/tensor/ops/modules/base.rs b/crates/burn-tensor/src/tensor/ops/modules/base.rs index 02a56ce1a5..1f76b3d840 100644 --- a/crates/burn-tensor/src/tensor/ops/modules/base.rs +++ b/crates/burn-tensor/src/tensor/ops/modules/base.rs @@ -128,7 +128,7 @@ pub struct UnfoldOptions { } /// Algorithm used for upsampling. -#[derive(new, Debug, Clone)] +#[derive(new, Debug, Clone, serde::Deserialize, serde::Serialize)] pub enum InterpolateMode { /// Nearest-neighbor interpolation. /// diff --git a/crates/burn-tensor/src/tests/module/bicubic_interpolate.rs b/crates/burn-tensor/src/tests/module/bicubic_interpolate.rs index 564f41bf88..28ed3b57a6 100644 --- a/crates/burn-tensor/src/tests/module/bicubic_interpolate.rs +++ b/crates/burn-tensor/src/tests/module/bicubic_interpolate.rs @@ -85,6 +85,47 @@ mod tests { ]]])); } + #[test] + #[ignore = "https://github.com/tracel-ai/burn/issues/2080"] + fn test_1d_bicubic() { + // Initialize the model without weights (because the exported file does not contain them) + let device = Default::default(); + + // Run the model + let input = TestTensor::<3>::from_floats( + [[[1.5410, -0.2934, -2.1788, 0.5684, -1.0845, -1.3986]]], + &device, + ); + + let input = input.unsqueeze_dim(2); + + let output = interpolate( + input, + [1, 9], + InterpolateOptions::new(InterpolateMode::Bicubic), + ); + + assert_eq!(output.dims(), [1, 1, 1, 9]); + + // assert output data does not contain NaN + assert!( + !output + .clone() + .to_data() + .as_slice::() + .unwrap() + .iter() + .any(|&x| x.is_nan()), + "interpolate output contains NaN" + ); + + TestTensor::<4>::from([[[[ + 1.541, 0.5747652, -1.010614, -2.197787, -0.8269969, 0.59609234, -0.5803058, -1.3792794, + -1.3986, + ]]]]) + .to_data() + .assert_approx_eq(&output.into_data(), 3); + } struct InterpolateTestCase { batch_size: usize, channels: usize, diff --git a/crates/burn-tensor/src/tests/module/bilinear_interpolate.rs b/crates/burn-tensor/src/tests/module/bilinear_interpolate.rs index 32d8b236b2..634e95756f 100644 --- a/crates/burn-tensor/src/tests/module/bilinear_interpolate.rs +++ b/crates/burn-tensor/src/tests/module/bilinear_interpolate.rs @@ -85,6 +85,55 @@ mod tests { ]]])); } + #[test] + #[ignore = "https://github.com/tracel-ai/burn/issues/2080"] + fn test_1d_bilinear() { + // Initialize the model without weights (because the exported file does not contain them) + let device = Default::default(); + + // Run the model + let input = TestTensor::<3>::from_floats( + [[[1.5410, -0.2934, -2.1788, 0.5684, -1.0845, -1.3986]]], + &device, + ); + + let input = input.unsqueeze_dim(2); + + let output = interpolate( + input, + [1, 9], + InterpolateOptions::new(InterpolateMode::Bilinear), + ); + + assert_eq!(output.dims(), [1, 1, 1, 9]); + + // assert output data does not contain NaN + assert!( + !output + .clone() + .to_data() + .as_slice::() + .unwrap() + .iter() + .any(|&x| x.is_nan()), + "interpolate output contains NaN" + ); + + TestTensor::<4>::from([[[[ + 1.541f32, + 0.39450002, + -0.76475, + -1.943125, + -0.80520004, + 0.36178753, + -0.671275, + -1.2022874, + -1.3986, + ]]]]) + .to_data() + .assert_approx_eq(&output.into_data(), 3); + } + struct InterpolateTestCase { batch_size: usize, channels: usize, diff --git a/crates/burn-tensor/src/tests/module/nearest_interpolate.rs b/crates/burn-tensor/src/tests/module/nearest_interpolate.rs index 36d127db19..be0e7b12fd 100644 --- a/crates/burn-tensor/src/tests/module/nearest_interpolate.rs +++ b/crates/burn-tensor/src/tests/module/nearest_interpolate.rs @@ -59,6 +59,45 @@ mod tests { ]]])); } + #[test] + fn test_1d_nearest() { + // Initialize the model without weights (because the exported file does not contain them) + let device = Default::default(); + + // Run the model + let input = TestTensor::<3>::from_floats( + [[[1.5410, -0.2934, -2.1788, 0.5684, -1.0845, -1.3986]]], + &device, + ); + + let input = input.unsqueeze_dim(2); + + let output = interpolate( + input, + [1, 9], + InterpolateOptions::new(InterpolateMode::Nearest), + ); + assert_eq!(output.dims(), [1, 1, 1, 9]); + + // assert output data does not contain NaN + assert!( + !output + .clone() + .to_data() + .as_slice::() + .unwrap() + .iter() + .any(|&x| x.is_nan()), + "interpolate output contains NaN" + ); + + TestTensor::<4>::from([[[[ + 1.541, 1.541, -0.2934, -2.1788, -2.1788, 0.5684, -1.0845, -1.0845, -1.3986, + ]]]]) + .to_data() + .assert_approx_eq(&output.into_data(), 3); + } + struct InterpolateTestCase { batch_size: usize, channels: usize, diff --git a/crates/onnx-ir/src/dim_inference.rs b/crates/onnx-ir/src/dim_inference.rs index 8b0ea3029f..769225e770 100644 --- a/crates/onnx-ir/src/dim_inference.rs +++ b/crates/onnx-ir/src/dim_inference.rs @@ -62,7 +62,7 @@ pub fn dim_inference(node: &mut Node) { NodeType::ReduceSum => reduce_sum_update_outputs(node), NodeType::Relu => same_as_input(node), NodeType::Reshape => reshape_update_outputs(node), - NodeType::Resize => resize_update_outputs(node), + NodeType::Resize => same_as_input(node), NodeType::Shape => shape_update_outputs(node), NodeType::Sigmoid => same_as_input(node), NodeType::Sign => same_as_input(node), @@ -318,33 +318,6 @@ fn reshape_update_outputs(node: &mut Node) { } } -fn resize_update_outputs(node: &mut Node) { - let input = match &node.inputs[0].ty { - ArgType::Tensor(tensor) => tensor.clone(), - _ => panic!("Resize: invalid input type"), - }; - - let output = match &node.outputs[0].ty { - ArgType::Tensor(tensor) => tensor.clone(), - _ => panic!("Resize: invalid output type"), - }; - - let output_size = match &node.inputs[3].ty { - ArgType::Tensor(output_size) => output_size.clone(), - _ => panic!("Resize: invalid output_size type"), - }; - - if output_size.dim != 1 { - panic!("Resize: output_size must be 1D"); - } - - node.outputs[0].ty = ArgType::Tensor(TensorType { - dim: input.dim, - shape: None, // shape is calculated at runtime - ..output - }); -} - fn greater_update_outputs(node: &mut Node) { match &node.inputs[0].ty { ArgType::Tensor(tensor) => { @@ -838,7 +811,7 @@ fn gather_update_outputs(node: &mut Node) { let input_tensor = match &node.inputs[0].ty { ArgType::Tensor(tensor) => tensor, - _ => panic!("Only tensor input is valid"), + ty => panic!("Only tensor input is valid but received: {:?}", ty), }; let indices_tensor = match &node.inputs[1].ty { diff --git a/crates/onnx-ir/src/ir.rs b/crates/onnx-ir/src/ir.rs index 8dd94d68dc..52f9ee21e3 100644 --- a/crates/onnx-ir/src/ir.rs +++ b/crates/onnx-ir/src/ir.rs @@ -523,44 +523,54 @@ impl Data { _ => self, } } - pub fn into_f16(self) -> f16 { - if let Data::Float16(elem) = self { - elem - } else { - panic!("Expected Float16, got {:?}", self); + match self { + Data::Float16(elem) => elem, + Data::Float32(elem) => f16::from_f32(elem), + Data::Float64(elem) => f16::from_f64(elem), + _ => panic!("Cannot convert {:?} to f16", self), } } pub fn into_f32(self) -> f32 { - if let Data::Float32(elem) = self { - elem - } else { - panic!("Expected Float32, got {:?}", self); + match self { + Data::Float16(elem) => elem.to_f32(), + Data::Float32(elem) => elem, + Data::Float64(elem) => elem as f32, + Data::Int32(elem) => elem as f32, + Data::Int64(elem) => elem as f32, + _ => panic!("Cannot convert {:?} to f32", self), } } pub fn into_f64(self) -> f64 { - if let Data::Float64(elem) = self { - elem - } else { - panic!("Expected Float64, got {:?}", self); + match self { + Data::Float16(elem) => elem.to_f64(), + Data::Float32(elem) => elem as f64, + Data::Float64(elem) => elem, + Data::Int32(elem) => elem as f64, + Data::Int64(elem) => elem as f64, + _ => panic!("Cannot convert {:?} to f64", self), } } pub fn into_i32(self) -> i32 { - if let Data::Int32(elem) = self { - elem - } else { - panic!("Expected Int32, got {:?}", self); + match self { + Data::Int32(elem) => elem, + Data::Int64(elem) => elem as i32, + Data::Float32(elem) => elem as i32, + Data::Float64(elem) => elem as i32, + _ => panic!("Cannot convert {:?} to i32", self), } } pub fn into_i64(self) -> i64 { - if let Data::Int64(elem) = self { - elem - } else { - panic!("Expected Int64, got {:?}", self); + match self { + Data::Int32(elem) => elem as i64, + Data::Int64(elem) => elem, + Data::Float32(elem) => elem as i64, + Data::Float64(elem) => elem as i64, + _ => panic!("Cannot convert {:?} to i64", self), } } @@ -581,42 +591,53 @@ impl Data { } pub fn into_f16s(self) -> Vec { - if let Data::Float16s(elem) = self { - elem - } else { - panic!("Expected Float16s, got {:?}", self); + match self { + Data::Float16s(elem) => elem, + Data::Float32s(elem) => elem.into_iter().map(f16::from_f32).collect(), + Data::Float64s(elem) => elem.into_iter().map(f16::from_f64).collect(), + _ => panic!("Cannot convert {:?} to Vec", self), } } pub fn into_f32s(self) -> Vec { - if let Data::Float32s(elem) = self { - elem - } else { - panic!("Expected Float32s, got {:?}", self); + match self { + Data::Float16s(elem) => elem.into_iter().map(|x| x.to_f32()).collect(), + Data::Float32s(elem) => elem, + Data::Float64s(elem) => elem.into_iter().map(|x| x as f32).collect(), + Data::Int32s(elem) => elem.into_iter().map(|x| x as f32).collect(), + Data::Int64s(elem) => elem.into_iter().map(|x| x as f32).collect(), + _ => panic!("Cannot convert {:?} to Vec", self), } } pub fn into_f64s(self) -> Vec { - if let Data::Float64s(elem) = self { - elem - } else { - panic!("Expected Float64s, got {:?}", self); + match self { + Data::Float16s(elem) => elem.into_iter().map(|x| x.to_f64()).collect(), + Data::Float32s(elem) => elem.into_iter().map(|x| x as f64).collect(), + Data::Float64s(elem) => elem, + Data::Int32s(elem) => elem.into_iter().map(|x| x as f64).collect(), + Data::Int64s(elem) => elem.into_iter().map(|x| x as f64).collect(), + _ => panic!("Cannot convert {:?} to Vec", self), } } pub fn into_i32s(self) -> Vec { - if let Data::Int32s(elem) = self { - elem - } else { - panic!("Expected Int32s, got {:?}", self); + match self { + Data::Int32s(elem) => elem, + Data::Int64s(elem) => elem.into_iter().map(|x| x as i32).collect(), + Data::Float32s(elem) => elem.into_iter().map(|x| x as i32).collect(), + Data::Float64s(elem) => elem.into_iter().map(|x| x as i32).collect(), + _ => panic!("Cannot convert {:?} to Vec", self), } } pub fn into_i64s(self) -> Vec { - if let Data::Int64s(elem) = self { - elem - } else { - panic!("Expected Int64s, got {:?}", self); + match self { + Data::Int32s(elem) => elem.into_iter().map(|x| x as i64).collect(), + Data::Int64s(elem) => elem, + Data::Float32s(elem) => elem.into_iter().map(|x| x as i64).collect(), + Data::Float64s(elem) => elem.into_iter().map(|x| x as i64).collect(), + _ => panic!("Cannot convert {:?} to Vec", self), } }