-
Notifications
You must be signed in to change notification settings - Fork 480
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix/constant tensors #984
Fix/constant tensors #984
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good. I had one comment. Please see below.
@@ -4,22 +4,36 @@ use crate::tensor::{ | |||
backend::{AutodiffBackend, Backend}, | |||
Tensor, | |||
}; | |||
use alloc::vec::Vec; | |||
use burn_tensor::{Bool, Int}; | |||
|
|||
impl<B: Backend, const D: usize> From<Tensor<B, D>> for Param<Tensor<B, D>> { | |||
fn from(value: Tensor<B, D>) -> Self { | |||
Param::new(ParamId::new(), value.require_grad()) | |||
} | |||
} | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to explain why we do not require grad?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added why we add require grad, other tensor kinds (bool and int) can't have gradients
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice work! 🔥 🚀
The devices()
needing devices looks a bit odd to me. Especially when it sometimes leads being used as self.devices(Vec::new())
but I think I'm maybe missing something.
OK. We got a couple approvals and I need this PR for my onnx changes, hence I am merging. |
Maybe we could rename the method, but we collect all the devices in the module tree in the provided vector. The alternative would imply a lot more allocations. Maybe we could rename the method to |
In this effort to support constant tensors of any kind (
Int
andBool
, not justFloat
), I had to fix how we handle devices within modules.Fix #461 #665 #982