From 6f16ea460ab437173e14d2f5f3584ca7e1c9841d Mon Sep 17 00:00:00 2001 From: teoxoy <28601907+teoxoy@users.noreply.github.com> Date: Fri, 12 Jul 2024 14:38:38 +0200 Subject: [PATCH] make the `StagingBuffer` implementation more robust --- wgpu-core/src/device/global.rs | 11 ++--- wgpu-core/src/device/queue.rs | 49 ++++++++------------ wgpu-core/src/device/resource.rs | 9 ++-- wgpu-core/src/resource.rs | 76 ++++++++++++++++++++++++-------- wgpu-hal/src/lib.rs | 6 ++- 5 files changed, 89 insertions(+), 62 deletions(-) diff --git a/wgpu-core/src/device/global.rs b/wgpu-core/src/device/global.rs index 94b59ad6cb..e5643a3da9 100644 --- a/wgpu-core/src/device/global.rs +++ b/wgpu-core/src/device/global.rs @@ -2519,7 +2519,7 @@ impl Global { } let map_state = &*buffer.map_state.lock(); match *map_state { - resource::BufferMapState::Init { ref ptr, .. } => { + resource::BufferMapState::Init { ref staging_buffer } => { // offset (u64) can not be < 0, so no need to validate the lower bound if offset + range_size > buffer.size { return Err(BufferAccessError::OutOfBoundsOverrun { @@ -2527,12 +2527,9 @@ impl Global { max: buffer.size, }); } - unsafe { - Ok(( - NonNull::new_unchecked(ptr.as_ptr().offset(offset as isize)), - range_size, - )) - } + let ptr = unsafe { staging_buffer.ptr() }; + let ptr = unsafe { NonNull::new_unchecked(ptr.as_ptr().offset(offset as isize)) }; + Ok((ptr, range_size)) } resource::BufferMapState::Active { ref ptr, ref range, .. diff --git a/wgpu-core/src/device/queue.rs b/wgpu-core/src/device/queue.rs index 05f58f2078..52edb528a3 100644 --- a/wgpu-core/src/device/queue.rs +++ b/wgpu-core/src/device/queue.rs @@ -31,7 +31,7 @@ use smallvec::SmallVec; use std::{ iter, mem::{self}, - ptr::{self, NonNull}, + ptr::NonNull, sync::{atomic::Ordering, Arc}, }; use thiserror::Error; @@ -405,17 +405,13 @@ impl Global { // Platform validation requires that the staging buffer always be // freed, even if an error occurs. All paths from here must call // `device.pending_writes.consume`. - let (staging_buffer, staging_buffer_ptr) = StagingBuffer::new(device, data_size)?; + let mut staging_buffer = StagingBuffer::new(device, data_size)?; let mut pending_writes = device.pending_writes.lock(); let pending_writes = pending_writes.as_mut().unwrap(); - let staging_buffer = unsafe { + let staging_buffer = { profiling::scope!("copy"); - ptr::copy_nonoverlapping( - data.as_ptr(), - staging_buffer_ptr.as_ptr(), - data_size.get() as usize, - ); + staging_buffer.write(data); staging_buffer.flush() }; @@ -448,13 +444,14 @@ impl Global { let device = &queue.device; - let (staging_buffer, staging_buffer_ptr) = StagingBuffer::new(device, buffer_size)?; + let staging_buffer = StagingBuffer::new(device, buffer_size)?; + let ptr = unsafe { staging_buffer.ptr() }; let fid = hub.staging_buffers.prepare(id_in); let id = fid.assign(Arc::new(staging_buffer)); resource_log!("Queue::create_staging_buffer {id:?}"); - Ok((id, staging_buffer_ptr)) + Ok((id, ptr)) } pub fn queue_write_staging_buffer( @@ -487,7 +484,7 @@ impl Global { // user. Platform validation requires that the staging buffer always // be freed, even if an error occurs. All paths from here must call // `device.pending_writes.consume`. - let staging_buffer = unsafe { staging_buffer.flush() }; + let staging_buffer = staging_buffer.flush(); let result = self.queue_write_staging_buffer_impl( &queue, @@ -779,42 +776,34 @@ impl Global { // Platform validation requires that the staging buffer always be // freed, even if an error occurs. All paths from here must call // `device.pending_writes.consume`. - let (staging_buffer, staging_buffer_ptr) = StagingBuffer::new(device, stage_size)?; + let mut staging_buffer = StagingBuffer::new(device, stage_size)?; if stage_bytes_per_row == bytes_per_row { profiling::scope!("copy aligned"); // Fast path if the data is already being aligned optimally. - unsafe { - ptr::copy_nonoverlapping( - data.as_ptr().offset(data_layout.offset as isize), - staging_buffer_ptr.as_ptr(), - stage_size.get() as usize, - ); - } + staging_buffer.write(&data[data_layout.offset as usize..]); } else { profiling::scope!("copy chunked"); // Copy row by row into the optimal alignment. let copy_bytes_per_row = stage_bytes_per_row.min(bytes_per_row) as usize; for layer in 0..size.depth_or_array_layers { let rows_offset = layer * block_rows_per_image; - for row in 0..height_blocks { + for row in rows_offset..rows_offset + height_blocks { + let src_offset = data_layout.offset as u32 + row * bytes_per_row; + let dst_offset = row * stage_bytes_per_row; unsafe { - ptr::copy_nonoverlapping( - data.as_ptr().offset( - data_layout.offset as isize - + (rows_offset + row) as isize * bytes_per_row as isize, - ), - staging_buffer_ptr.as_ptr().offset( - (rows_offset + row) as isize * stage_bytes_per_row as isize, - ), + staging_buffer.write_with_offset( + data, + src_offset as isize, + dst_offset as isize, copy_bytes_per_row, - ); + ) } } } } - let staging_buffer = unsafe { staging_buffer.flush() }; + let staging_buffer = staging_buffer.flush(); let regions = (0..array_layer_count).map(|rel_array_layer| { let mut texture_base = dst_base.clone(); diff --git a/wgpu-core/src/device/resource.rs b/wgpu-core/src/device/resource.rs index 3e3e5f9049..2ec3a3e9eb 100644 --- a/wgpu-core/src/device/resource.rs +++ b/wgpu-core/src/device/resource.rs @@ -591,18 +591,15 @@ impl Device { }; hal::BufferUses::MAP_WRITE } else { - let (staging_buffer, staging_buffer_ptr) = + let mut staging_buffer = StagingBuffer::new(self, wgt::BufferSize::new(aligned_size).unwrap())?; // Zero initialize memory and then mark the buffer as initialized // (it's guaranteed that this is the case by the time the buffer is usable) - unsafe { std::ptr::write_bytes(staging_buffer_ptr.as_ptr(), 0, aligned_size as usize) }; + staging_buffer.write_zeros(); buffer.initialization_status.write().drain(0..aligned_size); - *buffer.map_state.lock() = resource::BufferMapState::Init { - staging_buffer, - ptr: staging_buffer_ptr, - }; + *buffer.map_state.lock() = resource::BufferMapState::Init { staging_buffer }; hal::BufferUses::COPY_DST }; diff --git a/wgpu-core/src/resource.rs b/wgpu-core/src/resource.rs index 612d68ff61..ced9edbb56 100644 --- a/wgpu-core/src/resource.rs +++ b/wgpu-core/src/resource.rs @@ -256,10 +256,7 @@ pub enum BufferMapAsyncStatus { #[derive(Debug)] pub(crate) enum BufferMapState { /// Mapped at creation. - Init { - staging_buffer: StagingBuffer, - ptr: NonNull, - }, + Init { staging_buffer: StagingBuffer }, /// Waiting for GPU to be done before mapping Waiting(BufferPendingMapping), /// Mapped @@ -651,15 +648,10 @@ impl Buffer { let raw_buf = self.try_raw(&snatch_guard)?; log::debug!("{} map state -> Idle", self.error_ident()); match mem::replace(&mut *self.map_state.lock(), BufferMapState::Idle) { - BufferMapState::Init { - staging_buffer, - ptr, - } => { + BufferMapState::Init { staging_buffer } => { #[cfg(feature = "trace")] if let Some(ref mut trace) = *device.trace.lock() { - let data = trace.make_binary("bin", unsafe { - std::slice::from_raw_parts(ptr.as_ptr(), self.size as usize) - }); + let data = trace.make_binary("bin", staging_buffer.get_data()); trace.add(trace::Action::WriteBuffer { id: buffer_id, data, @@ -667,12 +659,11 @@ impl Buffer { queued: true, }); } - let _ = ptr; let mut pending_writes = device.pending_writes.lock(); let pending_writes = pending_writes.as_mut().unwrap(); - let staging_buffer = unsafe { staging_buffer.flush() }; + let staging_buffer = staging_buffer.flush(); self.use_at(device.active_submission_index.load(Ordering::Relaxed) + 1); let region = wgt::BufferSize::new(self.size).map(|size| hal::BufferCopy { @@ -832,6 +823,11 @@ impl Drop for DestroyedBuffer { } } +#[cfg(send_sync)] +unsafe impl Send for StagingBuffer {} +#[cfg(send_sync)] +unsafe impl Sync for StagingBuffer {} + /// A temporary buffer, consumed by the command that uses it. /// /// A [`StagingBuffer`] is designed for one-shot uploads of data to the GPU. It @@ -857,13 +853,11 @@ pub struct StagingBuffer { device: Arc>, pub(crate) size: wgt::BufferSize, is_coherent: bool, + ptr: NonNull, } impl StagingBuffer { - pub(crate) fn new( - device: &Arc>, - size: wgt::BufferSize, - ) -> Result<(Self, NonNull), DeviceError> { + pub(crate) fn new(device: &Arc>, size: wgt::BufferSize) -> Result { use hal::Device; profiling::scope!("StagingBuffer::new"); let stage_desc = hal::BufferDescriptor { @@ -881,9 +875,55 @@ impl StagingBuffer { device: device.clone(), size, is_coherent: mapping.is_coherent, + ptr: mapping.ptr, }; - Ok((staging_buffer, mapping.ptr)) + Ok(staging_buffer) + } + + /// SAFETY: You must not call any functions of `self` + /// until you stopped using the returned pointer. + pub(crate) unsafe fn ptr(&self) -> NonNull { + self.ptr + } + + #[cfg(feature = "trace")] + pub(crate) fn get_data(&self) -> &[u8] { + unsafe { std::slice::from_raw_parts(self.ptr.as_ptr(), self.size.get() as usize) } + } + + pub(crate) fn write_zeros(&mut self) { + unsafe { core::ptr::write_bytes(self.ptr.as_ptr(), 0, self.size.get() as usize) }; + } + + pub(crate) fn write(&mut self, data: &[u8]) { + assert!(data.len() >= self.size.get() as usize); + // SAFETY: With the assert above, all of `copy_nonoverlapping`'s + // requirements are satisfied. + unsafe { + core::ptr::copy_nonoverlapping( + data.as_ptr(), + self.ptr.as_ptr(), + self.size.get() as usize, + ); + } + } + + /// SAFETY: The offsets and size must be in-bounds. + pub(crate) unsafe fn write_with_offset( + &mut self, + data: &[u8], + src_offset: isize, + dst_offset: isize, + size: usize, + ) { + unsafe { + core::ptr::copy_nonoverlapping( + data.as_ptr().offset(src_offset), + self.ptr.as_ptr().offset(dst_offset), + size, + ); + } } pub(crate) fn flush(self) -> FlushedStagingBuffer { diff --git a/wgpu-hal/src/lib.rs b/wgpu-hal/src/lib.rs index 36dc9b0689..6f470f4ddc 100644 --- a/wgpu-hal/src/lib.rs +++ b/wgpu-hal/src/lib.rs @@ -714,9 +714,13 @@ pub trait Device: WasmNotSendSync { /// be ordered, so it is meaningful to talk about what must occur /// "between" them. /// + /// - Zero-sized mappings are not allowed. + /// + /// - The returned [`BufferMapping::ptr`] must not be used after a call to + /// [`Device::unmap_buffer`]. + /// /// [`MAP_READ`]: BufferUses::MAP_READ /// [`MAP_WRITE`]: BufferUses::MAP_WRITE - //TODO: clarify if zero-sized mapping is allowed unsafe fn map_buffer( &self, buffer: &::Buffer,