Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Metal encoder & pass timestamp support #4008

Merged
merged 13 commits into from
Sep 16, 2023
Prev Previous commit
Next Next commit
Handle write_timestamp for when writing via sample_counters_in_buffer…
… is not possible
  • Loading branch information
Wumpf committed Sep 16, 2023
commit 7daae1c9c5e92e110a723fb8180c4d1ad8059bd8
186 changes: 157 additions & 29 deletions wgpu-hal/src/metal/command.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ impl Default for super::CommandState {
storage_buffer_length_map: Default::default(),
work_group_memory_sizes: Vec::new(),
push_constants: Vec::new(),
pending_timer_queries: Vec::new(),
}
}
}
Expand All @@ -26,10 +27,85 @@ impl super::CommandEncoder {
fn enter_blit(&mut self) -> &metal::BlitCommandEncoderRef {
if self.state.blit.is_none() {
debug_assert!(self.state.render.is_none() && self.state.compute.is_none());
let cmd_buf = self.raw_cmd_buf.as_ref().unwrap();

// Take care of pending timer queries.
// If we can't use `sample_counters_in_buffer` we have to create a dummy blit encoder!
//
// There is a known bug in Metal where blit encoders won't write timestamps if they don't have a blit operation.
// See https://github.com/gpuweb/gpuweb/issues/2046#issuecomment-1205793680 & https://source.chromium.org/chromium/chromium/src/+/006c4eb70c96229834bbaf271290f40418144cd3:third_party/dawn/src/dawn/native/metal/BackendMTL.mm;l=350
//
// To things worse:
Wumpf marked this conversation as resolved.
Show resolved Hide resolved
// * what counts as a blit operation is a bit unclear, experimenting seemed to indicate that resolve_counters doesn't count.
// * in some cases (when?) using `set_start_of_encoder_sample_index` doesn't work, so we have to use `set_end_of_encoder_sample_index` instead
//
// All this means that pretty much the only *reliable* thing as of writing is to:
// * create a dummy blit encoder using set_end_of_encoder_sample_index
// * do a dummy write that is known to be not optimized out.
// * close the encoder since we used set_end_of_encoder_sample_index and don't want to get any extra stuff in there.
// * create another encoder for whatever we actually had in mind.
let supports_sample_counters_in_buffer = self
.shared
.private_caps
.timestamp_query_support
.contains(TimestampQuerySupport::ON_BLIT_ENCODER);

if !self.state.pending_timer_queries.is_empty() && !supports_sample_counters_in_buffer {
objc::rc::autoreleasepool(|| {
let descriptor = metal::BlitPassDescriptor::new();
let mut last_query = None;
for (i, (set, index)) in self.state.pending_timer_queries.drain(..).enumerate()
{
let sba_descriptor = descriptor
.sample_buffer_attachments()
.object_at(i as _)
.unwrap();
sba_descriptor
.set_sample_buffer(set.counter_sample_buffer.as_ref().unwrap());

// Here be dragons:
// As mentioned above, for some reasons using the start of the encoder won't yield any results sometimes!
sba_descriptor
.set_start_of_encoder_sample_index(metal::COUNTER_DONT_SAMPLE);
sba_descriptor.set_end_of_encoder_sample_index(index as _);

last_query = Some((set, index));
}
let encoder = cmd_buf.blit_command_encoder_with_descriptor(descriptor);

// As explained above, we need to do some write:
// Conveniently, we have a buffer with every query set, that we can use for this for a dummy write,
// since we know that it is going to be overwritten again on timer resolve and HAL doesn't define its state before that.
let raw_range = metal::NSRange {
location: last_query.as_ref().unwrap().1 as u64 * crate::QUERY_SIZE,
length: 1,
};
encoder.fill_buffer(
&last_query.as_ref().unwrap().0.raw_buffer,
raw_range,
255, // Don't write 0, so it's easier to identify if something went wrong.
);

encoder.end_encoding();
});
}

objc::rc::autoreleasepool(|| {
let cmd_buf = self.raw_cmd_buf.as_ref().unwrap();
self.state.blit = Some(cmd_buf.new_blit_command_encoder().to_owned());
});

let encoder = self.state.blit.as_ref().unwrap();

// UNTESTED:
// If the above described issue with empty blit encoder applies to `sample_counters_in_buffer` as well, we should use the same workaround instead!
for (set, index) in self.state.pending_timer_queries.drain(..) {
debug_assert!(supports_sample_counters_in_buffer);
encoder.sample_counters_in_buffer(
set.counter_sample_buffer.as_ref().unwrap(),
index as _,
true,
)
}
}
self.state.blit.as_ref().unwrap()
}
Expand All @@ -40,7 +116,7 @@ impl super::CommandEncoder {
}
}

fn enter_any(&mut self) -> Option<&metal::CommandEncoderRef> {
fn active_encoder(&mut self) -> Option<&metal::CommandEncoderRef> {
if let Some(ref encoder) = self.state.render {
Some(encoder)
} else if let Some(ref encoder) = self.state.compute {
Expand Down Expand Up @@ -127,9 +203,17 @@ impl crate::CommandEncoder<super::Api> for super::CommandEncoder {
}

unsafe fn end_encoding(&mut self) -> Result<super::CommandBuffer, crate::DeviceError> {
// Handle pending timer query if any.
if !self.state.pending_timer_queries.is_empty() {
self.leave_blit();
self.enter_blit();
}

self.leave_blit();
debug_assert!(self.state.render.is_none());
debug_assert!(self.state.compute.is_none());
debug_assert!(self.state.pending_timer_queries.is_empty());

Ok(super::CommandBuffer {
raw: self.raw_cmd_buf.take().unwrap(),
})
Expand Down Expand Up @@ -337,26 +421,27 @@ impl crate::CommandEncoder<super::Api> for super::CommandEncoder {
support.contains(TimestampQuerySupport::ON_BLIT_ENCODER),
&self.state.blit,
) {
encoder.sample_counters_in_buffer(sample_buffer, index as u64, with_barrier);
encoder.sample_counters_in_buffer(sample_buffer, index as _, with_barrier);
} else if let (true, Some(ref encoder)) = (
support.contains(TimestampQuerySupport::ON_RENDER_ENCODER),
&self.state.render,
) {
encoder.sample_counters_in_buffer(sample_buffer, index as u64, with_barrier);
encoder.sample_counters_in_buffer(sample_buffer, index as _, with_barrier);
} else if let (true, Some(ref encoder)) = (
support.contains(TimestampQuerySupport::ON_COMPUTE_ENCODER),
&self.state.compute,
) {
encoder.sample_counters_in_buffer(sample_buffer, index as u64, with_barrier);
encoder.sample_counters_in_buffer(sample_buffer, index as _, with_barrier);
} else {
// Otherwise, we need to create a new encoder with a descriptor that inserts the timestamps.
// We create a blit encoder since this is the simplest and most fitting encoder for operations on
// a wgpu encoder.
// If we're here it means we either have no encoder open, or it's not supported to sample within them.
// If this happens with render/compute open, this is an invalid usage!
debug_assert!(self.state.render.is_none() && self.state.compute.is_none());

// Note that as of writing creating a new encoder is not exposed by the metal crate.
// https://developer.apple.com/documentation/metal/mtlcommandbuffer/3564431-makeblitcommandencoder
// But otherwise it means we'll put defer this to the next created encoder.
self.state.pending_timer_queries.push((set.clone(), index));

// TODO:
// Ensure we didn't already have a blit open.
self.leave_blit();
};
}

Expand Down Expand Up @@ -481,8 +566,29 @@ impl crate::CommandEncoder<super::Api> for super::CommandEncoder {
}
}

let mut sba_index = 0;
let mut next_sba_descriptor = || {
let sba_descriptor = descriptor
.sample_buffer_attachments()
.object_at(sba_index)
.unwrap();

sba_descriptor.set_end_of_vertex_sample_index(metal::COUNTER_DONT_SAMPLE);
sba_descriptor.set_start_of_fragment_sample_index(metal::COUNTER_DONT_SAMPLE);

sba_index += 1;
sba_descriptor
};

for (set, index) in self.state.pending_timer_queries.drain(..) {
let sba_descriptor = next_sba_descriptor();
sba_descriptor.set_sample_buffer(set.counter_sample_buffer.as_ref().unwrap());
sba_descriptor.set_start_of_vertex_sample_index(index as _);
sba_descriptor.set_end_of_fragment_sample_index(metal::COUNTER_DONT_SAMPLE);
}

if let Some(ref timestamp_writes) = desc.timestamp_writes {
let sba_descriptor = descriptor.sample_buffer_attachments().object_at(0).unwrap();
let sba_descriptor = next_sba_descriptor();
sba_descriptor.set_sample_buffer(
timestamp_writes
.query_set
Expand All @@ -491,12 +597,16 @@ impl crate::CommandEncoder<super::Api> for super::CommandEncoder {
.unwrap(),
);

if let Some(start_index) = timestamp_writes.beginning_of_pass_write_index {
sba_descriptor.set_start_of_vertex_sample_index(start_index as _);
}
if let Some(end_index) = timestamp_writes.end_of_pass_write_index {
sba_descriptor.set_end_of_fragment_sample_index(end_index as _);
}
sba_descriptor.set_start_of_vertex_sample_index(
timestamp_writes
.beginning_of_pass_write_index
.map_or(metal::COUNTER_DONT_SAMPLE, |i| i as _),
);
sba_descriptor.set_end_of_fragment_sample_index(
timestamp_writes
.end_of_pass_write_index
.map_or(metal::COUNTER_DONT_SAMPLE, |i| i as _),
);
}

if let Some(occlusion_query_set) = desc.occlusion_query_set {
Expand Down Expand Up @@ -724,19 +834,19 @@ impl crate::CommandEncoder<super::Api> for super::CommandEncoder {
}

unsafe fn insert_debug_marker(&mut self, label: &str) {
if let Some(encoder) = self.enter_any() {
if let Some(encoder) = self.active_encoder() {
encoder.insert_debug_signpost(label);
}
}
unsafe fn begin_debug_marker(&mut self, group_label: &str) {
if let Some(encoder) = self.enter_any() {
if let Some(encoder) = self.active_encoder() {
encoder.push_debug_group(group_label);
} else if let Some(ref buf) = self.raw_cmd_buf {
buf.push_debug_group(group_label);
}
}
unsafe fn end_debug_marker(&mut self) {
if let Some(encoder) = self.enter_any() {
if let Some(encoder) = self.active_encoder() {
encoder.pop_debug_group();
} else if let Some(ref buf) = self.raw_cmd_buf {
buf.pop_debug_group();
Expand Down Expand Up @@ -996,11 +1106,25 @@ impl crate::CommandEncoder<super::Api> for super::CommandEncoder {
objc::rc::autoreleasepool(|| {
let descriptor = metal::ComputePassDescriptor::new();

if let Some(timestamp_writes) = desc.timestamp_writes.as_ref() {
let mut sba_index = 0;
let mut next_sba_descriptor = || {
let sba_descriptor = descriptor
.sample_buffer_attachments()
.object_at(0 as _)
.object_at(sba_index)
.unwrap();
sba_index += 1;
sba_descriptor
};

for (set, index) in self.state.pending_timer_queries.drain(..) {
let sba_descriptor = next_sba_descriptor();
sba_descriptor.set_sample_buffer(set.counter_sample_buffer.as_ref().unwrap());
sba_descriptor.set_start_of_encoder_sample_index(index as _);
sba_descriptor.set_end_of_encoder_sample_index(metal::COUNTER_DONT_SAMPLE);
}

if let Some(timestamp_writes) = desc.timestamp_writes.as_ref() {
let sba_descriptor = next_sba_descriptor();
sba_descriptor.set_sample_buffer(
timestamp_writes
.query_set
Expand All @@ -1009,12 +1133,16 @@ impl crate::CommandEncoder<super::Api> for super::CommandEncoder {
.unwrap(),
);

if let Some(start_index) = timestamp_writes.beginning_of_pass_write_index {
sba_descriptor.set_start_of_encoder_sample_index(start_index as _);
}
if let Some(end_index) = timestamp_writes.end_of_pass_write_index {
sba_descriptor.set_end_of_encoder_sample_index(end_index as _);
}
sba_descriptor.set_start_of_encoder_sample_index(
timestamp_writes
.beginning_of_pass_write_index
.map_or(metal::COUNTER_DONT_SAMPLE, |i| i as _),
);
sba_descriptor.set_end_of_encoder_sample_index(
timestamp_writes
.end_of_pass_write_index
.map_or(metal::COUNTER_DONT_SAMPLE, |i| i as _),
);
}

let encoder = raw.compute_command_encoder_with_descriptor(descriptor);
Expand Down
5 changes: 4 additions & 1 deletion wgpu-hal/src/metal/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -722,7 +722,7 @@ pub struct ComputePipeline {
unsafe impl Send for ComputePipeline {}
unsafe impl Sync for ComputePipeline {}

#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct QuerySet {
raw_buffer: metal::Buffer,
//Metal has a custom buffer for counters.
Expand Down Expand Up @@ -805,6 +805,9 @@ struct CommandState {

work_group_memory_sizes: Vec<u32>,
push_constants: Vec<u32>,

/// Timer query that should be executed when the next pass starts.
pending_timer_queries: Vec<(QuerySet, u32)>,
}

pub struct CommandEncoder {
Expand Down