forked from tracel-ai/burn
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Refactor/wgpu/memco (tracel-ai#1340)
1 parent
bff4961
commit b256c04
Showing
35 changed files
with
1,657 additions
and
796 deletions.
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
use super::{Elem, Item, Scope, Variable}; | ||
use serde::{Deserialize, Serialize}; | ||
|
||
/// All branching types. | ||
#[derive(Debug, Clone, Serialize, Deserialize)] | ||
pub enum Branch { | ||
// An if statement. | ||
If(If), | ||
// An if else statement. | ||
IfElse(IfElse), | ||
// A range loop. | ||
RangeLoop(RangeLoop), | ||
// A return statement. | ||
Return, | ||
// A break statement. | ||
Break, | ||
} | ||
|
||
#[derive(Debug, Clone, Serialize, Deserialize)] | ||
pub struct If { | ||
pub cond: Variable, | ||
pub scope: Scope, | ||
} | ||
|
||
#[derive(Debug, Clone, Serialize, Deserialize)] | ||
pub struct IfElse { | ||
pub cond: Variable, | ||
pub scope_if: Scope, | ||
pub scope_else: Scope, | ||
} | ||
|
||
#[derive(Debug, Clone, Serialize, Deserialize)] | ||
pub struct RangeLoop { | ||
pub i: Variable, | ||
pub start: Variable, | ||
pub end: Variable, | ||
pub scope: Scope, | ||
} | ||
|
||
impl If { | ||
/// Registers an if statement to the given scope. | ||
pub fn register<F: Fn(&mut Scope)>(parent_scope: &mut Scope, cond: Variable, func: F) { | ||
let mut scope = parent_scope.child(); | ||
|
||
func(&mut scope); | ||
|
||
let op = Self { cond, scope }; | ||
parent_scope.register(Branch::If(op)); | ||
} | ||
} | ||
|
||
impl IfElse { | ||
/// Registers an if else statement to the given scope. | ||
pub fn register<IF, ELSE>( | ||
parent_scope: &mut Scope, | ||
cond: Variable, | ||
func_if: IF, | ||
func_else: ELSE, | ||
) where | ||
IF: Fn(&mut Scope), | ||
ELSE: Fn(&mut Scope), | ||
{ | ||
let mut scope_if = parent_scope.child(); | ||
let mut scope_else = parent_scope.child(); | ||
|
||
func_if(&mut scope_if); | ||
func_else(&mut scope_else); | ||
|
||
parent_scope.register(Branch::IfElse(Self { | ||
cond, | ||
scope_if, | ||
scope_else, | ||
})); | ||
} | ||
} | ||
|
||
impl RangeLoop { | ||
/// Registers a range loop to the given scope. | ||
pub fn register<F: Fn(Variable, &mut Scope)>( | ||
parent_scope: &mut Scope, | ||
start: Variable, | ||
end: Variable, | ||
func: F, | ||
) { | ||
let mut scope = parent_scope.child(); | ||
let index_ty = Item::Scalar(Elem::UInt); | ||
let i = scope.create_local_undeclared(index_ty); | ||
|
||
func(i, &mut scope); | ||
|
||
parent_scope.register(Branch::RangeLoop(Self { | ||
i, | ||
start, | ||
end, | ||
scope, | ||
})); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.