Skip to content

Commit

Permalink
Merge FunctionAnalysisError into FunctionError
Browse files Browse the repository at this point in the history
  • Loading branch information
kvark committed Mar 22, 2021
1 parent 1b3e729 commit 1d3f2bb
Show file tree
Hide file tree
Showing 6 changed files with 88 additions and 111 deletions.
4 changes: 3 additions & 1 deletion src/back/msl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1516,7 +1516,9 @@ fn test_stack_size() {
});
let _ = module.functions.append(fun);
// analyse the module
let info = ModuleInfo::new(&module, ValidationFlags::empty()).unwrap();
let info = crate::valid::Validator::new(ValidationFlags::empty())
.validate(&module)
.unwrap();
// process the module
let mut writer = Writer::new(String::new());
writer.write(&module, &info, &Default::default()).unwrap();
Expand Down
99 changes: 22 additions & 77 deletions src/valid/analyzer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ Figures out the following properties:
- expression reference counts
!*/

use super::ValidationFlags;
use super::{CallError, FunctionError, ModuleInfo, ValidationFlags};
use crate::arena::{Arena, Handle};
use std::ops;

Expand Down Expand Up @@ -216,32 +216,6 @@ pub enum UniformityDisruptor {
Discard,
}

#[derive(Clone, Debug, thiserror::Error)]
#[cfg_attr(test, derive(PartialEq))]
pub enum FunctionAnalysisError {
#[error("Expression {0:?} is not a global variable!")]
ExpectedGlobalVariable(crate::Expression),
#[error("Called function {0:?} that hasn't been declared in the IR yet")]
ForwardCall(Handle<crate::Function>),
#[error(
"Required uniformity of control flow for {0:?} in {1:?} is not fulfilled because of {2:?}"
)]
NonUniformControlFlow(
UniformityRequirements,
Handle<crate::Expression>,
UniformityDisruptor,
),
}

#[derive(Clone, Debug, thiserror::Error)]
#[cfg_attr(test, derive(PartialEq))]
pub enum AnalysisError {
#[error("Function {0:?} analysis failed")]
Function(Handle<crate::Function>, #[source] FunctionAnalysisError),
#[error("Entry point {0:?}/'{1}' function analysis failed")]
EntryPoint(crate::ShaderStage, String, #[source] FunctionAnalysisError),
}

impl FunctionInfo {
/// Adds a value-type reference to an expression.
#[must_use]
Expand Down Expand Up @@ -314,7 +288,7 @@ impl FunctionInfo {
arguments: &[crate::FunctionArgument],
global_var_arena: &Arena<crate::GlobalVariable>,
other_functions: &[FunctionInfo],
) -> Result<(), FunctionAnalysisError> {
) -> Result<(), FunctionError> {
use crate::{Expression as E, SampleLevel as Sl};

let mut assignable_global = None;
Expand Down Expand Up @@ -404,17 +378,13 @@ impl FunctionInfo {
image: match expression_arena[image] {
crate::Expression::GlobalVariable(var) => var,
ref other => {
return Err(FunctionAnalysisError::ExpectedGlobalVariable(
other.clone(),
))
return Err(FunctionError::ExpectedGlobalVariable(other.clone()))
}
},
sampler: match expression_arena[sampler] {
crate::Expression::GlobalVariable(var) => var,
ref other => {
return Err(FunctionAnalysisError::ExpectedGlobalVariable(
other.clone(),
))
return Err(FunctionError::ExpectedGlobalVariable(other.clone()))
}
},
});
Expand Down Expand Up @@ -512,9 +482,13 @@ impl FunctionInfo {
requirements: UniformityRequirements::empty(),
},
E::Call(function) => {
let fun = other_functions
.get(function.index())
.ok_or(FunctionAnalysisError::ForwardCall(function))?;
let fun =
other_functions
.get(function.index())
.ok_or(FunctionError::InvalidCall {
function,
error: CallError::ForwardDeclaredFunction,
})?;
self.process_call(fun).result
}
E::ArrayLength(expr) => Uniformity {
Expand Down Expand Up @@ -546,7 +520,7 @@ impl FunctionInfo {
statements: &[crate::Statement],
other_functions: &[FunctionInfo],
mut disruptor: Option<UniformityDisruptor>,
) -> Result<FunctionUniformity, FunctionAnalysisError> {
) -> Result<FunctionUniformity, FunctionError> {
use crate::Statement as S;

let mut combined_uniformity = FunctionUniformity::new();
Expand All @@ -562,9 +536,7 @@ impl FunctionInfo {
&& !req.is_empty()
{
if let Some(cause) = disruptor {
return Err(FunctionAnalysisError::NonUniformControlFlow(
req, expr, cause,
));
return Err(FunctionError::NonUniformControlFlow(req, expr, cause));
}
}
requirements |= req;
Expand Down Expand Up @@ -670,9 +642,12 @@ impl FunctionInfo {
for &argument in arguments {
let _ = self.add_ref(argument);
}
let info = other_functions
.get(function.index())
.ok_or(FunctionAnalysisError::ForwardCall(function))?;
let info = other_functions.get(function.index()).ok_or(
FunctionError::InvalidCall {
function,
error: CallError::ForwardDeclaredFunction,
},
)?;
self.process_call(info)
}
};
Expand All @@ -684,22 +659,15 @@ impl FunctionInfo {
}
}

#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
pub struct ModuleInfo {
functions: Vec<FunctionInfo>,
entry_points: Vec<FunctionInfo>,
}

impl ModuleInfo {
/// Builds the `FunctionInfo` based on the function, and validates the
/// uniform control flow if required by the expressions of this function.
fn process_function(
pub(super) fn process_function(
&self,
fun: &crate::Function,
global_var_arena: &Arena<crate::GlobalVariable>,
flags: ValidationFlags,
) -> Result<FunctionInfo, FunctionAnalysisError> {
) -> Result<FunctionInfo, FunctionError> {
let mut info = FunctionInfo {
flags,
uniformity: Uniformity::new(),
Expand All @@ -726,29 +694,6 @@ impl ModuleInfo {
Ok(info)
}

/// Analyze a module and return the `ModuleInfo`, if successful.
pub fn new(module: &crate::Module, flags: ValidationFlags) -> Result<Self, AnalysisError> {
let mut this = ModuleInfo {
functions: Vec::with_capacity(module.functions.len()),
entry_points: Vec::with_capacity(module.entry_points.len()),
};
for (fun_handle, fun) in module.functions.iter() {
let info = this
.process_function(fun, &module.global_variables, flags)
.map_err(|source| AnalysisError::Function(fun_handle, source))?;
this.functions.push(info);
}

for ep in module.entry_points.iter() {
let info = this
.process_function(&ep.function, &module.global_variables, flags)
.map_err(|source| AnalysisError::EntryPoint(ep.stage, ep.name.clone(), source))?;
this.entry_points.push(info);
}

Ok(this)
}

pub fn get_entry_point(&self, index: usize) -> &FunctionInfo {
&self.entry_points[index]
}
Expand Down Expand Up @@ -880,7 +825,7 @@ fn uniform_control_flow() {
};
assert_eq!(
info.process_block(&[stmt_emit2, stmt_if_non_uniform], &[], None),
Err(FunctionAnalysisError::NonUniformControlFlow(
Err(FunctionError::NonUniformControlFlow(
UniformityRequirements::DERIVATIVE,
derivative_expr,
UniformityDisruptor::Expression(non_uniform_global_expr)
Expand Down
1 change: 1 addition & 0 deletions src/valid/expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use crate::{
};

#[derive(Clone, Debug, thiserror::Error)]
#[cfg_attr(test, derive(PartialEq))]
pub enum ExpressionError {
#[error("Doesn't exist")]
DoesntExist,
Expand Down
28 changes: 22 additions & 6 deletions src/valid/function.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
use super::{analyzer::FunctionInfo, ExpressionError, TypeFlags, ValidationFlags};
use super::{
analyzer::{FunctionInfo, UniformityDisruptor, UniformityRequirements},
ExpressionError, ModuleInfo, TypeFlags, ValidationFlags,
};
use crate::{
arena::{Arena, Handle},
proc::{ResolveContext, TypifyError},
};

#[derive(Clone, Debug, thiserror::Error)]
#[cfg_attr(test, derive(PartialEq))]
pub enum CallError {
#[error("Bad function")]
InvalidFunction,
Expand Down Expand Up @@ -36,12 +40,14 @@ pub enum CallError {
}

#[derive(Clone, Debug, thiserror::Error)]
#[cfg_attr(test, derive(PartialEq))]
pub enum LocalVariableError {
#[error("Initializer doesn't match the variable type")]
InitializerType,
}

#[derive(Clone, Debug, thiserror::Error)]
#[cfg_attr(test, derive(PartialEq))]
pub enum FunctionError {
#[error(transparent)]
Resolve(#[from] TypifyError),
Expand Down Expand Up @@ -97,6 +103,16 @@ pub enum FunctionError {
#[source]
error: CallError,
},
#[error("Expression {0:?} is not a global variable!")]
ExpectedGlobalVariable(crate::Expression),
#[error(
"Required uniformity of control flow for {0:?} in {1:?} is not fulfilled because of {2:?}"
)]
NonUniformControlFlow(
UniformityRequirements,
Handle<crate::Expression>,
UniformityDisruptor,
),
}

bitflags::bitflags! {
Expand Down Expand Up @@ -464,9 +480,9 @@ impl super::Validator {
pub(super) fn validate_function(
&mut self,
fun: &crate::Function,
_info: &FunctionInfo,
module: &crate::Module,
) -> Result<(), FunctionError> {
mod_info: &ModuleInfo,
) -> Result<FunctionInfo, FunctionError> {
let resolve_ctx = ResolveContext {
constants: &module.constants,
global_vars: &module.global_variables,
Expand All @@ -476,6 +492,7 @@ impl super::Validator {
};
self.typifier
.resolve_all(&fun.expressions, &module.types, &resolve_ctx)?;
let info = mod_info.process_function(fun, &module.global_variables, self.flags)?;

for (var_handle, var) in fun.local_variables.iter() {
self.validate_local_var(var, &module.types, &module.constants)
Expand Down Expand Up @@ -511,9 +528,8 @@ impl super::Validator {
}

if self.flags.contains(ValidationFlags::BLOCKS) {
self.validate_block(&fun.body, &BlockContext::new(fun, module))
} else {
Ok(())
self.validate_block(&fun.body, &BlockContext::new(fun, module))?;
}
Ok(info)
}
}
11 changes: 6 additions & 5 deletions src/valid/interface.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use super::{
analyzer::{FunctionInfo, GlobalUse},
Disalignment, FunctionError, TypeFlags,
Disalignment, FunctionError, ModuleInfo, TypeFlags,
};
use crate::arena::{Arena, Handle};

Expand Down Expand Up @@ -352,9 +352,9 @@ impl super::Validator {
pub(super) fn validate_entry_point(
&mut self,
ep: &crate::EntryPoint,
info: &FunctionInfo,
module: &crate::Module,
) -> Result<(), EntryPointError> {
mod_info: &ModuleInfo,
) -> Result<FunctionInfo, EntryPointError> {
if ep.early_depth_test.is_some() && ep.stage != crate::ShaderStage::Fragment {
return Err(EntryPointError::UnexpectedEarlyDepthTest);
}
Expand All @@ -370,6 +370,8 @@ impl super::Validator {
return Err(EntryPointError::UnexpectedWorkgroupSize);
}

let info = self.validate_function(&ep.function, module, &mod_info)?;

self.location_mask.clear();
for (index, fa) in ep.function.arguments.iter().enumerate() {
let ctx = VaryingContext {
Expand Down Expand Up @@ -439,7 +441,6 @@ impl super::Validator {
}
}

self.validate_function(&ep.function, info, module)?;
Ok(())
Ok(info)
}
}
Loading

0 comments on commit 1d3f2bb

Please sign in to comment.