Skip to content

Commit

Permalink
fix: ValueRef(Mut) should clone ValueInner
Browse files Browse the repository at this point in the history
  • Loading branch information
decahedron1 committed Oct 1, 2024
1 parent 2502224 commit 1f6f436
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 60 deletions.
16 changes: 6 additions & 10 deletions src/value/impl_map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -277,22 +277,18 @@ impl<K: IntoTensorElementType + Debug + Clone + Hash + Eq, V: IntoTensorElementT
/// Converts from a strongly-typed [`Map<K, V>`] to a reference to a type-erased [`DynMap`].
#[inline]
pub fn upcast_ref(&self) -> DynMapRef {
DynMapRef::new(unsafe {
Value::from_ptr_nodrop(
NonNull::new_unchecked(self.ptr()),
if let ValueInner::CppOwned { _session, .. } = &*self.inner { _session.clone() } else { None }
)
DynMapRef::new(Value {
inner: Arc::clone(&self.inner),
_markers: PhantomData
})
}

/// Converts from a strongly-typed [`Map<K, V>`] to a mutable reference to a type-erased [`DynMap`].
#[inline]
pub fn upcast_mut(&mut self) -> DynMapRefMut {
DynMapRefMut::new(unsafe {
Value::from_ptr_nodrop(
NonNull::new_unchecked(self.ptr()),
if let ValueInner::CppOwned { _session, .. } = &*self.inner { _session.clone() } else { None }
)
DynMapRefMut::new(Value {
inner: Arc::clone(&self.inner),
_markers: PhantomData
})
}
}
21 changes: 7 additions & 14 deletions src/value/impl_sequence.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,7 @@ impl<Type: SequenceValueTypeMarker + Sized> Value<Type> {
let mut value_ptr = ptr::null_mut();
ortsys![unsafe GetValue(self.ptr(), i as _, allocator.ptr.as_ptr(), &mut value_ptr)?; nonNull(value_ptr)];

let value = ValueRef {
inner: unsafe { Value::from_ptr(NonNull::new_unchecked(value_ptr), None) },
lifetime: PhantomData
};
let value = ValueRef::new(unsafe { Value::from_ptr(NonNull::new_unchecked(value_ptr), None) });
let value_type = value.dtype();
if !OtherType::can_downcast(&value.dtype()) {
return Err(Error::new_with_code(
Expand Down Expand Up @@ -138,22 +135,18 @@ impl<T: ValueTypeMarker + DowncastableTarget + Debug + Sized> Value<SequenceValu
/// Converts from a strongly-typed [`Sequence<T>`] to a reference to a type-erased [`DynSequence`].
#[inline]
pub fn upcast_ref(&self) -> DynSequenceRef {
DynSequenceRef::new(unsafe {
Value::from_ptr_nodrop(
NonNull::new_unchecked(self.ptr()),
if let ValueInner::CppOwned { _session, .. } = &*self.inner { _session.clone() } else { None }
)
DynSequenceRef::new(Value {
inner: Arc::clone(&self.inner),
_markers: PhantomData
})
}

/// Converts from a strongly-typed [`Sequence<T>`] to a mutable reference to a type-erased [`DynSequence`].
#[inline]
pub fn upcast_mut(&mut self) -> DynSequenceRefMut {
DynSequenceRefMut::new(unsafe {
Value::from_ptr_nodrop(
NonNull::new_unchecked(self.ptr()),
if let ValueInner::CppOwned { _session, .. } = &*self.inner { _session.clone() } else { None }
)
DynSequenceRefMut::new(Value {
inner: Arc::clone(&self.inner),
_markers: PhantomData
})
}
}
21 changes: 9 additions & 12 deletions src/value/impl_tensor/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@ use std::{
fmt::Debug,
marker::PhantomData,
ops::{Index, IndexMut},
ptr::NonNull
ptr::NonNull,
sync::Arc
};

use super::{DowncastableTarget, DynValue, Value, ValueInner, ValueRef, ValueRefMut, ValueType, ValueTypeMarker};
use super::{DowncastableTarget, DynValue, Value, ValueRef, ValueRefMut, ValueType, ValueTypeMarker};
use crate::{error::Result, memory::MemoryInfo, ortsys, tensor::IntoTensorElementType};

pub trait TensorValueTypeMarker: ValueTypeMarker {
Expand Down Expand Up @@ -178,11 +179,9 @@ impl<T: IntoTensorElementType + Debug> Tensor<T> {
/// ```
#[inline]
pub fn upcast_ref(&self) -> DynTensorRef {
DynTensorRef::new(unsafe {
Value::from_ptr_nodrop(
NonNull::new_unchecked(self.ptr()),
if let ValueInner::CppOwned { _session, .. } = &*self.inner { _session.clone() } else { None }
)
DynTensorRef::new(Value {
inner: Arc::clone(&self.inner),
_markers: PhantomData
})
}

Expand All @@ -204,11 +203,9 @@ impl<T: IntoTensorElementType + Debug> Tensor<T> {
/// ```
#[inline]
pub fn upcast_mut(&mut self) -> DynTensorRefMut {
DynTensorRefMut::new(unsafe {
Value::from_ptr_nodrop(
NonNull::new_unchecked(self.ptr()),
if let ValueInner::CppOwned { _session, .. } = &*self.inner { _session.clone() } else { None }
)
DynTensorRefMut::new(Value {
inner: Arc::clone(&self.inner),
_markers: PhantomData
})
}
}
Expand Down
47 changes: 23 additions & 24 deletions src/value/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use std::{
any::Any,
fmt::{self, Debug},
marker::PhantomData,
mem::ManuallyDrop,
ops::{Deref, DerefMut},
ptr::NonNull,
sync::Arc
Expand Down Expand Up @@ -239,13 +240,16 @@ impl ValueInner {
/// A temporary version of a [`Value`] with a lifetime specifier.
#[derive(Debug)]
pub struct ValueRef<'v, Type: ValueTypeMarker + ?Sized = DynValueTypeMarker> {
inner: Value<Type>,
inner: ManuallyDrop<Value<Type>>,
lifetime: PhantomData<&'v ()>
}

impl<'v, Type: ValueTypeMarker + ?Sized> ValueRef<'v, Type> {
pub(crate) fn new(inner: Value<Type>) -> Self {
ValueRef { inner, lifetime: PhantomData }
ValueRef {
inner: ManuallyDrop::new(inner),
lifetime: PhantomData
}
}

/// Attempts to downcast a temporary dynamic value (like [`DynValue`] or [`DynTensor`]) to a more strongly typed
Expand Down Expand Up @@ -276,13 +280,16 @@ impl<'v, Type: ValueTypeMarker + ?Sized> Deref for ValueRef<'v, Type> {
/// A mutable temporary version of a [`Value`] with a lifetime specifier.
#[derive(Debug)]
pub struct ValueRefMut<'v, Type: ValueTypeMarker + ?Sized = DynValueTypeMarker> {
inner: Value<Type>,
inner: ManuallyDrop<Value<Type>>,
lifetime: PhantomData<&'v ()>
}

impl<'v, Type: ValueTypeMarker + ?Sized> ValueRefMut<'v, Type> {
pub(crate) fn new(inner: Value<Type>) -> Self {
ValueRefMut { inner, lifetime: PhantomData }
ValueRefMut {
inner: ManuallyDrop::new(inner),
lifetime: PhantomData
}
}

/// Attempts to downcast a temporary mutable dynamic value (like [`DynValue`] or [`DynTensor`]) to a more
Expand Down Expand Up @@ -462,21 +469,17 @@ impl<Type: ValueTypeMarker + ?Sized> Value<Type> {

/// Create a view of this value's data.
pub fn view(&self) -> ValueRef<'_, Type> {
ValueRef::new(unsafe {
Value::from_ptr_nodrop(
NonNull::new_unchecked(self.ptr()),
if let ValueInner::CppOwned { _session, .. } = &*self.inner { _session.clone() } else { None }
)
ValueRef::new(Value {
inner: Arc::clone(&self.inner),
_markers: PhantomData
})
}

/// Create a mutable view of this value's data.
pub fn view_mut(&mut self) -> ValueRefMut<'_, Type> {
ValueRefMut::new(unsafe {
Value::from_ptr_nodrop(
NonNull::new_unchecked(self.ptr()),
if let ValueInner::CppOwned { _session, .. } = &*self.inner { _session.clone() } else { None }
)
ValueRefMut::new(Value {
inner: Arc::clone(&self.inner),
_markers: PhantomData
})
}

Expand Down Expand Up @@ -522,11 +525,9 @@ impl Value<DynValueTypeMarker> {
pub fn downcast_ref<OtherType: ValueTypeMarker + DowncastableTarget + ?Sized>(&self) -> Result<ValueRef<'_, OtherType>> {
let dt = self.dtype();
if OtherType::can_downcast(&dt) {
Ok(ValueRef::new(unsafe {
Value::from_ptr_nodrop(
NonNull::new_unchecked(self.ptr()),
if let ValueInner::CppOwned { _session, .. } = &*self.inner { _session.clone() } else { None }
)
Ok(ValueRef::new(Value {
inner: Arc::clone(&self.inner),
_markers: PhantomData
}))
} else {
Err(Error::new_with_code(ErrorCode::InvalidArgument, format!("Cannot downcast &{dt} to &{}", OtherType::format())))
Expand All @@ -539,11 +540,9 @@ impl Value<DynValueTypeMarker> {
pub fn downcast_mut<OtherType: ValueTypeMarker + DowncastableTarget + ?Sized>(&mut self) -> Result<ValueRefMut<'_, OtherType>> {
let dt = self.dtype();
if OtherType::can_downcast(&dt) {
Ok(ValueRefMut::new(unsafe {
Value::from_ptr_nodrop(
NonNull::new_unchecked(self.ptr()),
if let ValueInner::CppOwned { _session, .. } = &*self.inner { _session.clone() } else { None }
)
Ok(ValueRefMut::new(Value {
inner: Arc::clone(&self.inner),
_markers: PhantomData
}))
} else {
Err(Error::new_with_code(ErrorCode::InvalidArgument, format!("Cannot downcast &mut {dt} to &mut {}", OtherType::format())))
Expand Down

0 comments on commit 1f6f436

Please sign in to comment.