Skip to content

Commit

Permalink
feat/tensor: add resize to SharedTensor
Browse files Browse the repository at this point in the history
  • Loading branch information
hobofan committed Feb 7, 2016
1 parent e4d6d4f commit f3bb3b4
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 10 deletions.
41 changes: 31 additions & 10 deletions src/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -227,17 +227,8 @@ impl<T> SharedTensor<T> {
/// [1]: ../memory/index.html
pub fn new<D: IntoTensorDesc>(dev: &DeviceType, desc: &D) -> Result<SharedTensor<T>, Error> {
let copies = LinearMap::<DeviceType, MemoryType>::new();
let copy = try!(Self::alloc_on_device(dev, desc));
let tensor_desc: TensorDesc = desc.into();
let copy: MemoryType;
let alloc_size = Self::mem_size(tensor_desc.size());
match *dev {
#[cfg(feature = "native")]
DeviceType::Native(ref cpu) => copy = MemoryType::Native(try!(cpu.alloc_memory(alloc_size))),
#[cfg(feature = "opencl")]
DeviceType::OpenCL(ref context) => copy = MemoryType::OpenCL(try!(context.alloc_memory(alloc_size))),
#[cfg(feature = "cuda")]
DeviceType::Cuda(ref context) => copy = MemoryType::Cuda(try!(context.alloc_memory(alloc_size))),
}
Ok(SharedTensor {
desc: tensor_desc,
latest_location: dev.clone(),
Expand All @@ -250,6 +241,7 @@ impl<T> SharedTensor<T> {
/// Change the shape of the Tensor.
///
/// Will return an Error if size of new shape is not equal to the old shape.
/// If you want to change the shape to one of a different size, use `resize`.
pub fn reshape<D: IntoTensorDesc>(&mut self, desc: &D) -> Result<(), Error> {
let new_desc: TensorDesc = desc.into();
if new_desc.size() == self.desc().size() {
Expand All @@ -260,6 +252,35 @@ impl<T> SharedTensor<T> {
}
}

/// Change the size and shape of the Tensor.
///
/// **Caution**: Drops all copies which are not on the current device.
///
/// 'reshape' is preffered over this method if the size of the old and new shape
/// are identical because it will not reallocate memory.
pub fn resize<D: IntoTensorDesc>(&mut self, desc: &D) -> Result<(), Error> {
self.copies.clear();
self.latest_copy = try!(Self::alloc_on_device(self.latest_device(), desc));
let new_desc: TensorDesc = desc.into();
self.desc = new_desc;
Ok(())
}

/// Allocate memory on the provided DeviceType.
fn alloc_on_device<D: IntoTensorDesc>(dev: &DeviceType, desc: &D) -> Result<MemoryType, Error> {
let tensor_desc: TensorDesc = desc.into();
let alloc_size = Self::mem_size(tensor_desc.size());
let copy = match *dev {
#[cfg(feature = "native")]
DeviceType::Native(ref cpu) => MemoryType::Native(try!(cpu.alloc_memory(alloc_size))),
#[cfg(feature = "opencl")]
DeviceType::OpenCL(ref context) => MemoryType::OpenCL(try!(context.alloc_memory(alloc_size))),
#[cfg(feature = "cuda")]
DeviceType::Cuda(ref context) => MemoryType::Cuda(try!(context.alloc_memory(alloc_size))),
};
Ok(copy)
}

/// Synchronize memory from latest location to `destination`.
pub fn sync(&mut self, destination: &DeviceType) -> Result<(), Error> {
if &self.latest_location != destination {
Expand Down
9 changes: 9 additions & 0 deletions tests/tensor_specs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,13 @@ mod tensor_spec {
let tensor_desc_r0_into = <() as IntoTensorDesc>::into(&());
assert_eq!(1, tensor_desc_r0_into.size());
}

#[test]
fn it_resizes_tensor() {
let native = Backend::<Native>::default().unwrap();
let mut tensor = SharedTensor::<f32>::new(native.device(), &(10, 20, 30)).unwrap();
assert_eq!(tensor.desc(), &[10, 20, 30]);
tensor.resize(&(2, 3, 4, 5)).unwrap();
assert_eq!(tensor.desc(), &[2, 3, 4, 5]);
}
}

0 comments on commit f3bb3b4

Please sign in to comment.