Skip to content

Commit

Permalink
[spv-in] add support for specialization constants
Browse files Browse the repository at this point in the history
  • Loading branch information
teoxoy authored and jimblandy committed Apr 11, 2024
1 parent b7519bb commit 9df6819
Show file tree
Hide file tree
Showing 10 changed files with 1,610 additions and 93 deletions.
6 changes: 4 additions & 2 deletions naga/src/front/spv/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,8 @@ pub enum Error {
ControlFlowGraphCycle(crate::front::spv::BlockId),
#[error("recursive function call %{0}")]
FunctionCallCycle(spirv::Word),
#[error("invalid array size {0:?}")]
InvalidArraySize(Handle<crate::Constant>),
#[error("invalid array size %{0}")]
InvalidArraySize(spirv::Word),
#[error("invalid barrier scope %{0}")]
InvalidBarrierScope(spirv::Word),
#[error("invalid barrier memory semantics %{0}")]
Expand All @@ -130,6 +130,8 @@ pub enum Error {
come from a binding)"
)]
NonBindingArrayOfImageOrSamplers,
#[error("naga only supports specialization constant IDs up to 65535 but was given {0}")]
SpecIdTooHigh(u32),
}

impl Error {
Expand Down
7 changes: 5 additions & 2 deletions naga/src/front/spv/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,11 @@ impl<I: Iterator<Item = u32>> super::Frontend<I> {
})
},
local_variables: Arena::new(),
expressions: self
.make_expression_storage(&module.global_variables, &module.constants),
expressions: self.make_expression_storage(
&module.global_variables,
&module.constants,
&module.overrides,
),
named_expressions: crate::NamedExpressions::default(),
body: crate::Block::new(),
}
Expand Down
13 changes: 8 additions & 5 deletions naga/src/front/spv/image.rs
Original file line number Diff line number Diff line change
Expand Up @@ -507,11 +507,14 @@ impl<I: Iterator<Item = u32>> super::Frontend<I> {
}
spirv::ImageOperands::CONST_OFFSET => {
let offset_constant = self.next()?;
let offset_handle = self.lookup_constant.lookup(offset_constant)?.handle;
let offset_handle = ctx.global_expressions.append(
crate::Expression::Constant(offset_handle),
Default::default(),
);
let offset_expr = self
.lookup_constant
.lookup(offset_constant)?
.inner
.to_expr();
let offset_handle = ctx
.global_expressions
.append(offset_expr, Default::default());
offset = Some(offset_handle);
words_left -= 1;
}
Expand Down
172 changes: 88 additions & 84 deletions naga/src/front/spv/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ struct Decoration {
location: Option<spirv::Word>,
desc_set: Option<spirv::Word>,
desc_index: Option<spirv::Word>,
specialization_constant_id: Option<spirv::Word>,
storage_buffer: bool,
offset: Option<spirv::Word>,
array_stride: Option<NonZeroU32>,
Expand Down Expand Up @@ -277,9 +278,24 @@ struct LookupType {
base_id: Option<spirv::Word>,
}

#[derive(Debug)]
enum Constant {
Constant(Handle<crate::Constant>),
Override(Handle<crate::Override>),
}

impl Constant {
const fn to_expr(&self) -> crate::Expression {
match *self {
Self::Constant(c) => crate::Expression::Constant(c),
Self::Override(o) => crate::Expression::Override(o),
}
}
}

#[derive(Debug)]
struct LookupConstant {
handle: Handle<crate::Constant>,
inner: Constant,
type_id: spirv::Word,
}

Expand Down Expand Up @@ -751,6 +767,9 @@ impl<I: Iterator<Item = u32>> Frontend<I> {
spirv::Decoration::RowMajor => {
dec.matrix_major = Some(Majority::Row);
}
spirv::Decoration::SpecId => {
dec.specialization_constant_id = Some(self.next()?);
}
other => {
log::warn!("Unknown decoration {:?}", other);
for _ in base_words + 1..inst.wc {
Expand Down Expand Up @@ -1385,10 +1404,7 @@ impl<I: Iterator<Item = u32>> Frontend<I> {
inst.expect(5)?;
let init_id = self.next()?;
let lconst = self.lookup_constant.lookup(init_id)?;
Some(
ctx.expressions
.append(crate::Expression::Constant(lconst.handle), span),
)
Some(ctx.expressions.append(lconst.inner.to_expr(), span))
} else {
None
};
Expand Down Expand Up @@ -3642,9 +3658,9 @@ impl<I: Iterator<Item = u32>> Frontend<I> {
let exec_scope_const = self.lookup_constant.lookup(exec_scope_id)?;
let semantics_const = self.lookup_constant.lookup(semantics_id)?;

let exec_scope = resolve_constant(ctx.gctx(), exec_scope_const.handle)
let exec_scope = resolve_constant(ctx.gctx(), &exec_scope_const.inner)
.ok_or(Error::InvalidBarrierScope(exec_scope_id))?;
let semantics = resolve_constant(ctx.gctx(), semantics_const.handle)
let semantics = resolve_constant(ctx.gctx(), &semantics_const.inner)
.ok_or(Error::InvalidBarrierMemorySemantics(semantics_id))?;

if exec_scope == spirv::Scope::Workgroup as u32 {
Expand Down Expand Up @@ -3705,6 +3721,7 @@ impl<I: Iterator<Item = u32>> Frontend<I> {
&mut self,
globals: &Arena<crate::GlobalVariable>,
constants: &Arena<crate::Constant>,
overrides: &Arena<crate::Override>,
) -> Arena<crate::Expression> {
let mut expressions = Arena::new();
#[allow(clippy::panic)]
Expand All @@ -3729,8 +3746,11 @@ impl<I: Iterator<Item = u32>> Frontend<I> {
}
// register constants
for (&id, con) in self.lookup_constant.iter() {
let span = constants.get_span(con.handle);
let handle = expressions.append(crate::Expression::Constant(con.handle), span);
let (expr, span) = match con.inner {
Constant::Constant(c) => (crate::Expression::Constant(c), constants.get_span(c)),
Constant::Override(o) => (crate::Expression::Override(o), overrides.get_span(o)),
};
let handle = expressions.append(expr, span);
self.lookup_expression.insert(
id,
LookupExpression {
Expand Down Expand Up @@ -3935,11 +3955,17 @@ impl<I: Iterator<Item = u32>> Frontend<I> {
Op::TypeImage => self.parse_type_image(inst, &mut module),
Op::TypeSampledImage => self.parse_type_sampled_image(inst),
Op::TypeSampler => self.parse_type_sampler(inst, &mut module),
Op::Constant => self.parse_constant(inst, &mut module),
Op::ConstantComposite => self.parse_composite_constant(inst, &mut module),
Op::Constant | Op::SpecConstant => self.parse_constant(inst, &mut module),
Op::ConstantComposite | Op::SpecConstantComposite => {
self.parse_composite_constant(inst, &mut module)
}
Op::ConstantNull | Op::Undef => self.parse_null_constant(inst, &mut module),
Op::ConstantTrue => self.parse_bool_constant(inst, true, &mut module),
Op::ConstantFalse => self.parse_bool_constant(inst, false, &mut module),
Op::ConstantTrue | Op::SpecConstantTrue => {
self.parse_bool_constant(inst, true, &mut module)
}
Op::ConstantFalse | Op::SpecConstantFalse => {
self.parse_bool_constant(inst, false, &mut module)
}
Op::Variable => self.parse_global_variable(inst, &mut module),
Op::Function => {
self.switch(ModuleState::Function, inst.op)?;
Expand Down Expand Up @@ -4496,9 +4522,9 @@ impl<I: Iterator<Item = u32>> Frontend<I> {
let length_id = self.next()?;
let length_const = self.lookup_constant.lookup(length_id)?;

let size = resolve_constant(module.to_ctx(), length_const.handle)
let size = resolve_constant(module.to_ctx(), &length_const.inner)
.and_then(NonZeroU32::new)
.ok_or(Error::InvalidArraySize(length_const.handle))?;
.ok_or(Error::InvalidArraySize(length_id))?;

let decor = self.future_decor.remove(&id).unwrap_or_default();
let base = self.lookup_type.lookup(type_id)?.handle;
Expand Down Expand Up @@ -4911,28 +4937,13 @@ impl<I: Iterator<Item = u32>> Frontend<I> {
_ => return Err(Error::UnsupportedType(type_lookup.handle)),
};

let decor = self.future_decor.remove(&id).unwrap_or_default();

let span = self.span_from_with_op(start);

let init = module
.global_expressions
.append(crate::Expression::Literal(literal), span);
self.lookup_constant.insert(
id,
LookupConstant {
handle: module.constants.append(
crate::Constant {
name: decor.name,
ty,
init,
},
span,
),
type_id,
},
);
Ok(())

self.insert_parsed_constant(module, id, type_id, ty, init, span)
}

fn parse_composite_constant(
Expand All @@ -4957,32 +4968,17 @@ impl<I: Iterator<Item = u32>> Frontend<I> {
let constant = self.lookup_constant.lookup(component_id)?;
let expr = module
.global_expressions
.append(crate::Expression::Constant(constant.handle), span);
.append(constant.inner.to_expr(), span);
components.push(expr);
}

let decor = self.future_decor.remove(&id).unwrap_or_default();

let span = self.span_from_with_op(start);

let init = module
.global_expressions
.append(crate::Expression::Compose { ty, components }, span);
self.lookup_constant.insert(
id,
LookupConstant {
handle: module.constants.append(
crate::Constant {
name: decor.name,
ty,
init,
},
span,
),
type_id,
},
);
Ok(())

self.insert_parsed_constant(module, id, type_id, ty, init, span)
}

fn parse_null_constant(
Expand All @@ -5000,22 +4996,11 @@ impl<I: Iterator<Item = u32>> Frontend<I> {
let type_lookup = self.lookup_type.lookup(type_id)?;
let ty = type_lookup.handle;

let decor = self.future_decor.remove(&id).unwrap_or_default();

let init = module
.global_expressions
.append(crate::Expression::ZeroValue(ty), span);
let handle = module.constants.append(
crate::Constant {
name: decor.name,
ty,
init,
},
span,
);
self.lookup_constant
.insert(id, LookupConstant { handle, type_id });
Ok(())

self.insert_parsed_constant(module, id, type_id, ty, init, span)
}

fn parse_bool_constant(
Expand All @@ -5034,26 +5019,44 @@ impl<I: Iterator<Item = u32>> Frontend<I> {
let type_lookup = self.lookup_type.lookup(type_id)?;
let ty = type_lookup.handle;

let decor = self.future_decor.remove(&id).unwrap_or_default();

let init = module.global_expressions.append(
crate::Expression::Literal(crate::Literal::Bool(value)),
span,
);
self.lookup_constant.insert(
id,
LookupConstant {
handle: module.constants.append(
crate::Constant {
name: decor.name,
ty,
init,
},
span,
),
type_id,
},
);

self.insert_parsed_constant(module, id, type_id, ty, init, span)
}

fn insert_parsed_constant(
&mut self,
module: &mut crate::Module,
id: u32,
type_id: u32,
ty: Handle<crate::Type>,
init: Handle<crate::Expression>,
span: crate::Span,
) -> Result<(), Error> {
let decor = self.future_decor.remove(&id).unwrap_or_default();

let inner = if let Some(id) = decor.specialization_constant_id {
let o = crate::Override {
name: decor.name,
id: Some(id.try_into().map_err(|_| Error::SpecIdTooHigh(id))?),
ty,
init: Some(init),
};
Constant::Override(module.overrides.append(o, span))
} else {
let c = crate::Constant {
name: decor.name,
ty,
init,
};
Constant::Constant(module.constants.append(c, span))
};

self.lookup_constant
.insert(id, LookupConstant { inner, type_id });
Ok(())
}

Expand All @@ -5076,7 +5079,7 @@ impl<I: Iterator<Item = u32>> Frontend<I> {
let lconst = self.lookup_constant.lookup(init_id)?;
let expr = module
.global_expressions
.append(crate::Expression::Constant(lconst.handle), span);
.append(lconst.inner.to_expr(), span);
Some(expr)
} else {
None
Expand Down Expand Up @@ -5291,10 +5294,11 @@ fn make_index_literal(
Ok(expr)
}

fn resolve_constant(
gctx: crate::proc::GlobalCtx,
constant: Handle<crate::Constant>,
) -> Option<u32> {
fn resolve_constant(gctx: crate::proc::GlobalCtx, constant: &Constant) -> Option<u32> {
let constant = match *constant {
Constant::Constant(constant) => constant,
Constant::Override(_) => return None,
};
match gctx.global_expressions[gctx.constants[constant].init] {
crate::Expression::Literal(crate::Literal::U32(id)) => Some(id),
crate::Expression::Literal(crate::Literal::I32(id)) => Some(id as u32),
Expand Down
Binary file added naga/tests/in/spv/spec-constants.spv
Binary file not shown.
Loading

0 comments on commit 9df6819

Please sign in to comment.