Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Fiber primitive for stack switching #249

Merged
merged 7 commits into from
Mar 3, 2025
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
clippy & fmt
  • Loading branch information
JonasKruckenberg committed Jan 21, 2025
commit 0b1ca5c3f5f8d6079c14e4b47d47ade2dc894dc0
5 changes: 4 additions & 1 deletion kernel/src/arch/riscv64/trap_handler.rs
Original file line number Diff line number Diff line change
@@ -252,6 +252,9 @@ fn default_trap_handler(
a6: usize,
a7: usize,
) -> *mut TrapFrame {
// Safety: `default_trap_entry` has to correctly set up the stack frame
let frame = unsafe { &*raw_frame };

let cause = scause::read().cause();
log::trace!("trap_handler cause {cause:?}, a1 {a1:#x} a2 {a2:#x} a3 {a3:#x} a4 {a4:#x} a5 {a5:#x} a6 {a6:#x} a7 {a7:#x}");

@@ -294,7 +297,7 @@ fn default_trap_handler(

crate::trap_handler::begin_trap(crate::trap_handler::Trap {
pc: VirtualAddress::new(epc).unwrap(),
fp: VirtualAddress::new(unsafe { (&*raw_frame).gp[8] }).unwrap(),
fp: VirtualAddress::new(frame.gp[8]).unwrap(),
faulting_address: VirtualAddress::new(tval).unwrap(),
reason,
});
53 changes: 39 additions & 14 deletions kernel/src/fiber.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
use crate::arch;
use crate::vm::{AddressSpace, UserMmap, KERNEL_ASPACE};
use crate::vm::{AddressSpace, UserMmap};
use alloc::boxed::Box;
use alloc::rc::Rc;
use core::arch::naked_asm;
use core::cell::Cell;
use core::marker::PhantomData;
use core::panic::AssertUnwindSafe;
use core::ptr;
use core::ptr::addr_of_mut;
use core::range::Range;

pub struct FiberStack(UserMmap);
@@ -19,6 +20,7 @@ impl FiberStack {
}

pub fn top(&self) -> *mut u8 {
// Safety: UserMmap guarantees the base pointer and length are valid
unsafe { self.0.as_ptr().cast_mut().byte_add(self.0.len()) }
}
}
@@ -48,33 +50,41 @@ impl<'a, Resume, Yield, Return> Fiber<'a, Resume, Yield, Return> {
/// This function returns a `Fiber` which, when resumed, will execute `func`
/// to completion. When desired the `func` can suspend itself via
/// `Fiber::suspend`.
pub fn new<F>(stack: FiberStack, func: F) -> crate::wasm::Result<Self>
pub fn new<F>(stack: FiberStack, mut f: F) -> Self
where
F: FnOnce(Resume, &mut Suspend<Resume, Yield, Return>) -> Return + 'a,
{
extern "C" fn fiber_start<F, A, B, C>(arg0: *mut u8, top_of_stack: *mut u8)
where
F: FnOnce(A, &mut Suspend<A, B, C>) -> C,
extern "C" fn fiber_start<F, Resume, Yield, Return>(
closure_ptr: *mut u8,
top_of_stack: *mut u8,
) where
F: FnOnce(Resume, &mut Suspend<Resume, Yield, Return>) -> Return,
{
let mut suspend = Suspend {
top_of_stack,
_phantom: PhantomData,
};
suspend.execute(unsafe { Box::from_raw(arg0.cast::<F>()) });

// Safety: code below & generics ensure the ptr is a valid `F` ptr
suspend.execute(unsafe { closure_ptr.cast::<F>().read() });
}

unsafe {
let data = Box::into_raw(Box::new(func)).cast();
let closure_ptr = addr_of_mut!(f);

fiber_init(stack.top(), fiber_start::<F, Resume, Yield, Return>, data);
// Safety: TODO
unsafe {
fiber_init(
stack.top(),
fiber_start::<F, Resume, Yield, Return>,
closure_ptr.cast(),
);
}

#[expect(tail_expr_drop_order, reason = "")]
Ok(Self {
Self {
stack: Some(stack),
done: Cell::new(false),
_phantom: PhantomData,
})
}
}

/// Resumes execution of this fiber.
@@ -96,19 +106,26 @@ impl<'a, Resume, Yield, Return> Fiber<'a, Resume, Yield, Return> {
assert!(!self.done.replace(true), "cannot resume a finished fiber");
let result = Cell::new(RunResult::Resuming(val));

// Safety: TODO
unsafe {
debug_assert!(
self.stack.as_ref().unwrap().top().addr() % 16 == 0,
"stack needs to be 16-byte aligned"
);

// Store where our result is going at the very tip-top of the
// stack, otherwise known as our reserved slot for this information.
//
// In the diagram above this is updating address 0xAff8
#[expect(clippy::cast_ptr_alignment, reason = "checked above")]
let addr = self
.stack
.as_ref()
.unwrap()
.top()
.cast::<usize>()
.offset(-1);
addr.write(&result as *const _ as usize);
addr.write(ptr::from_ref(&result) as usize);

fiber_switch(self.stack.as_ref().unwrap().top());

@@ -163,6 +180,7 @@ impl<Resume, Yield, Return> Suspend<Resume, Yield, Return> {
}

fn switch(&mut self, result: RunResult<Resume, Yield, Return>) -> Resume {
// Safety: TODO
unsafe {
// Calculate 0xAff8 and then write to it
(*self.result_location()).set(result);
@@ -173,6 +191,7 @@ impl<Resume, Yield, Return> Suspend<Resume, Yield, Return> {
}

unsafe fn take_resume(&self) -> Resume {
// Safety: TODO
let prev = unsafe { (*self.result_location()).replace(RunResult::Executing) };
match prev {
RunResult::Resuming(val) => val,
@@ -181,6 +200,8 @@ impl<Resume, Yield, Return> Suspend<Resume, Yield, Return> {
}

unsafe fn result_location(&self) -> *const Cell<RunResult<Resume, Yield, Return>> {
#[expect(clippy::cast_ptr_alignment, reason = "checked above")]
// Safety: TODO
let ret = unsafe { self.top_of_stack.cast::<*const u8>().offset(-1).read() };
assert!(!ret.is_null());
ret.cast()
@@ -190,6 +211,7 @@ impl<Resume, Yield, Return> Suspend<Resume, Yield, Return> {
where
F: FnOnce(Resume, &mut Suspend<Resume, Yield, Return>) -> Return,
{
// Safety: TODO
let initial = unsafe { self.take_resume() };

let result = crate::panic::catch_unwind(AssertUnwindSafe(|| (func)(initial, self)));
@@ -208,6 +230,7 @@ impl<A, B, C> Drop for Fiber<'_, A, B, C> {

#[naked]
unsafe extern "C" fn fiber_switch(top_of_stack: *mut u8) {
// Safety: inline assembly
unsafe {
naked_asm! {
// We're switching to arbitrary code somewhere else, so pessimistically
@@ -286,6 +309,7 @@ unsafe extern "C" fn fiber_init(
entry: extern "C" fn(*mut u8, *mut u8),
entry_arg0: *mut u8,
) {
// Safety: inline assembly
unsafe {
naked_asm! {
"lla t0, {fiber_start}",
@@ -305,6 +329,7 @@ unsafe extern "C" fn fiber_init(

#[naked]
unsafe extern "C" fn fiber_start() {
// Safety: inline assembly
unsafe {
naked_asm! {
"
2 changes: 1 addition & 1 deletion kernel/src/main.rs
Original file line number Diff line number Diff line change
@@ -23,6 +23,7 @@ extern crate alloc;
mod allocator;
mod arch;
mod error;
mod fiber;
mod logger;
mod machine_info;
mod metrics;
@@ -32,7 +33,6 @@ mod time;
mod trap_handler;
mod vm;
mod wasm;
mod fiber;

use crate::error::Error;
use crate::machine_info::{HartLocalMachineInfo, MachineInfo};
2 changes: 2 additions & 0 deletions kernel/src/panic.rs
Original file line number Diff line number Diff line change
@@ -81,6 +81,7 @@ pub fn resume_unwind(payload: Box<dyn Any + Send>) -> ! {

struct RewrapBox(Box<dyn Any + Send>);

// Safety: TODO
unsafe impl PanicPayload for RewrapBox {
fn take_box(&mut self) -> *mut (dyn Any + Send) {
Box::into_raw(mem::replace(&mut self.0, Box::new(())))
@@ -98,6 +99,7 @@ pub fn resume_unwind(payload: Box<dyn Any + Send>) -> ! {
}

#[expect(tail_expr_drop_order, reason = "")]
// Safety: take_box returns an unwrapped box
rust_panic(unsafe { Box::from_raw(RewrapBox(payload).take_box()) })
}

3 changes: 2 additions & 1 deletion kernel/src/wasm/instance_allocator.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use crate::fiber::FiberStack;
use crate::vm::AddressSpace;
use crate::wasm::indices::{DefinedMemoryIndex, DefinedTableIndex};
use crate::wasm::runtime::{FiberStack, InstanceAllocator, Memory, Table};
use crate::wasm::runtime::{InstanceAllocator, Memory, Table};
use crate::wasm::runtime::{OwnedVMContext, VMOffsets};
use crate::wasm::translate::{MemoryDesc, TableDesc, TranslatedModule};
use core::fmt;
5 changes: 4 additions & 1 deletion kernel/src/wasm/store.rs
Original file line number Diff line number Diff line change
@@ -146,7 +146,7 @@ impl Store {
slot = Some(f(this));

Ok(())
})?;
});

// Once we have the fiber representing our synchronous computation, we
// wrap that in a custom future implementation which does the
@@ -162,6 +162,7 @@ impl Store {
let stack = future.fiber.take().map(|f| f.into_stack());
drop(future);
if let Some(stack) = stack {
// Safety: we're deallocating the stack in the same store it was allocated in
unsafe {
self.alloc.deallocate_fiber_stack(stack);
}
@@ -190,6 +191,7 @@ impl Store {
current_poll_cx: *mut PollContext,
}

// Safety: TODO
unsafe impl Send for FiberFuture<'_> {}

impl FiberFuture<'_> {
@@ -229,6 +231,7 @@ impl Store {
// .stack()
// .guard_range()
// .unwrap_or(core::ptr::null_mut()..core::ptr::null_mut());
// Safety: TODO
unsafe {
// let _reset = Reset(self.current_poll_cx, *self.current_poll_cx);
*self.current_poll_cx = PollContext {
Loading