diff --git a/experimental/checkpoint_test.go b/experimental/checkpoint_test.go index 14bb4629fc5..868933b1a51 100644 --- a/experimental/checkpoint_test.go +++ b/experimental/checkpoint_test.go @@ -112,7 +112,11 @@ func TestSnapshotMultipleWasmInvocations(t *testing.T) { // snapshot returned zero require.Equal(t, uint64(0), res[0]) - res, err = mod.ExportedFunction("restore").Call(ctx, snapshotPtr) - // Fails, snapshot and restore are called from different wasm invocations. - require.Error(t, err) + // Fails, snapshot and restore are called from different wasm invocations. Currently, this + // results in a panic. + err = require.CapturePanic(func() { + _, _ = mod.ExportedFunction("restore").Call(ctx, snapshotPtr) + }) + require.EqualError(t, err, "unhandled snapshot restore, this generally indicates restore was called from a different "+ + "exported function invocation than snapshot") } diff --git a/internal/engine/compiler/engine.go b/internal/engine/compiler/engine.go index ef42db74252..e29bd098988 100644 --- a/internal/engine/compiler/engine.go +++ b/internal/engine/compiler/engine.go @@ -1202,7 +1202,6 @@ func (s *snapshot) Restore(ret []uint64) { panic(s) } -// Restore implements the same method as documented on experimental.Snapshot. func (s *snapshot) doRestore() { ce := s.ce ce.stackContext.stackPointer = s.stackPointer @@ -1212,6 +1211,7 @@ func (s *snapshot) doRestore() { copy(ce.stack[s.hostBase:], s.ret) } +// Error implements the same method on error. func (s *snapshot) Error() string { return "unhandled snapshot restore, this generally indicates restore was called from a different " + "exported function invocation than snapshot" diff --git a/internal/engine/interpreter/interpreter.go b/internal/engine/interpreter/interpreter.go index d47e7c3129b..cddca1038c4 100644 --- a/internal/engine/interpreter/interpreter.go +++ b/internal/engine/interpreter/interpreter.go @@ -216,6 +216,53 @@ func functionFromUintptr(ptr uintptr) *function { return *(**function)(unsafe.Pointer(wrapped)) } +type snapshot struct { + stack []uint64 + frames []*callFrame + pc uint64 + + ret []uint64 + + ce *callEngine +} + +// Snapshot implements the same method as documented on experimental.Snapshotter. +func (ce *callEngine) Snapshot() experimental.Snapshot { + stack := make([]uint64, len(ce.stack)) + copy(stack, ce.stack) + + frames := make([]*callFrame, len(ce.frames)) + copy(frames, ce.frames) + + return &snapshot{ + stack: stack, + frames: frames, + ce: ce, + } +} + +// Restore implements the same method as documented on experimental.Snapshot. +func (s *snapshot) Restore(ret []uint64) { + s.ret = ret + panic(s) +} + +func (s *snapshot) doRestore() { + ce := s.ce + + ce.stack = s.stack + ce.frames = s.frames + ce.frames[len(ce.frames)-1].pc = s.pc + + copy(ce.stack[len(ce.stack)-len(s.ret):], s.ret) +} + +// Error implements the same method on error. +func (s *snapshot) Error() string { + return "unhandled snapshot restore, this generally indicates restore was called from a different " + + "exported function invocation than snapshot" +} + // stackIterator implements experimental.StackIterator. type stackIterator struct { stack []uint64 @@ -512,6 +559,10 @@ func (ce *callEngine) call(ctx context.Context, params, results []uint64) (_ []u } } + if ctx.Value(experimental.EnableSnapshotterKey{}) != nil { + ctx = context.WithValue(ctx, experimental.SnapshotterKey{}, ce) + } + defer func() { // If the module closed during the call, and the call didn't err for another reason, set an ExitError. if err == nil { @@ -555,6 +606,12 @@ type functionListenerInvocation struct { // with the call frame stack traces. Also, reset the state of callEngine // so that it can be used for the subsequent calls. func (ce *callEngine) recoverOnCall(ctx context.Context, m *wasm.ModuleInstance, v interface{}) (err error) { + if s, ok := v.(*snapshot); ok { + // A snapshot that wasn't handled was created by a different call engine possibly from a nested wasm invocation, + // let it propagate up to be handled by the caller. + panic(s) + } + builder := wasmdebug.NewErrorBuilder() frameCount := len(ce.frames) functionListeners := make([]functionListenerInvocation, 0, 16) @@ -669,7 +726,25 @@ func (ce *callEngine) callNativeFunc(ctx context.Context, m *wasm.ModuleInstance ce.drop(op.Us[v+1]) frame.pc = op.Us[v] case wazeroir.OperationKindCall: - ce.callFunction(ctx, f.moduleInstance, &functions[op.U1]) + func() { + defer func() { + if r := recover(); r != nil { + if s, ok := r.(*snapshot); ok { + if s.ce == ce { + s.doRestore() + frame = ce.frames[len(ce.frames)-1] + body = frame.f.parent.body + bodyLen = uint64(len(body)) + } else { + panic(r) + } + } else { + panic(r) + } + } + }() + ce.callFunction(ctx, f.moduleInstance, &functions[op.U1]) + }() frame.pc++ case wazeroir.OperationKindCallIndirect: offset := ce.popValue()