-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathgolc.go
188 lines (152 loc) · 4.29 KB
/
golc.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
// Package golc provides functions for executing chains.
package golc
import (
"context"
"errors"
"fmt"
"github.com/hupe1980/golc/callback"
"github.com/hupe1980/golc/schema"
"golang.org/x/sync/errgroup"
)
var (
// Verbose controls the verbosity of the chain execution.
Verbose = false
)
type CallOptions struct {
Callbacks []schema.Callback
ParentRunID string
IncludeRunInfo bool
Stop []string
}
// Call executes a chain with multiple inputs.
// It returns the outputs of the chain or an error, if any.
func Call(ctx context.Context, chain schema.Chain, inputs schema.ChainValues, optFns ...func(*CallOptions)) (schema.ChainValues, error) {
opts := CallOptions{
IncludeRunInfo: false,
}
for _, fn := range optFns {
fn(&opts)
}
cm := callback.NewManager(opts.Callbacks, chain.Callbacks(), chain.Verbose(), func(mo *callback.ManagerOptions) {
mo.ParentRunID = opts.ParentRunID
})
rm, err := cm.OnChainStart(ctx, &schema.ChainStartManagerInput{
ChainType: chain.Type(),
Inputs: inputs,
})
if err != nil {
return nil, err
}
if chain.Memory() != nil {
vars, _ := chain.Memory().LoadMemoryVariables(ctx, inputs)
for k, v := range vars {
inputs[k] = v
}
}
outputs, err := chain.Call(ctx, inputs, func(o *schema.CallOptions) {
o.CallbackManger = rm
o.Stop = opts.Stop
})
if err != nil {
if cbErr := rm.OnChainError(ctx, &schema.ChainErrorManagerInput{
Error: err,
}); cbErr != nil {
return nil, cbErr
}
return nil, err
}
if chain.Memory() != nil {
if err := chain.Memory().SaveContext(ctx, inputs, outputs); err != nil {
return nil, err
}
}
if err := rm.OnChainEnd(ctx, &schema.ChainEndManagerInput{
Outputs: outputs,
}); err != nil {
return nil, err
}
if opts.IncludeRunInfo {
outputs["runInfo"] = cm.RunID()
}
return outputs, nil
}
type SimpleCallOptions struct {
Callbacks []schema.Callback
ParentRunID string
Stop []string
}
// SimpleCall executes a chain with a single input and a single output.
// It returns the output value as a string or an error, if any.
func SimpleCall(ctx context.Context, chain schema.Chain, input any, optFns ...func(*SimpleCallOptions)) (string, error) {
opts := SimpleCallOptions{}
for _, fn := range optFns {
fn(&opts)
}
var cv schema.ChainValues
switch v := input.(type) {
case string:
if len(chain.InputKeys()) != 1 {
return "", fmt.Errorf("invalid arguments: number of input keys must be 1, got %d", len(chain.InputKeys()))
}
cv = schema.ChainValues{
chain.InputKeys()[0]: input,
}
case schema.ChainValues:
cv, _ = input.(schema.ChainValues)
default:
return "", fmt.Errorf("unspported input type: %s", v)
}
if len(chain.OutputKeys()) != 1 {
return "", fmt.Errorf("invalid arguments: number of output keys must be 1, got %d", len(chain.OutputKeys()))
}
outputValues, err := Call(ctx, chain, cv, func(o *CallOptions) {
o.Callbacks = opts.Callbacks
o.ParentRunID = opts.ParentRunID
o.Stop = opts.Stop
})
if err != nil {
return "", err
}
outputValue, ok := outputValues[chain.OutputKeys()[0]].(string)
if !ok {
return "", errors.New("chain with non string return type")
}
return outputValue, nil
}
type BatchCallOptions struct {
Callbacks []schema.Callback
ParentRunID string
IncludeRunInfo bool
Stop []string
}
// BatchCall executes multiple calls to the chain.Call function concurrently and collects
// the results in the same order as the inputs. It utilizes the errgroup package to manage
// the concurrent execution and handle any errors that may occur.
func BatchCall(ctx context.Context, chain schema.Chain, inputs []schema.ChainValues, optFns ...func(*BatchCallOptions)) ([]schema.ChainValues, error) {
opts := BatchCallOptions{}
for _, fn := range optFns {
fn(&opts)
}
errs, errctx := errgroup.WithContext(ctx)
chainValues := make([]schema.ChainValues, len(inputs))
for i, input := range inputs {
i, input := i, input
errs.Go(func() error {
vals, err := Call(errctx, chain, input, func(o *CallOptions) {
o.Callbacks = opts.Callbacks
o.ParentRunID = opts.ParentRunID
o.IncludeRunInfo = opts.IncludeRunInfo
o.Stop = opts.Stop
})
if err != nil {
return err
}
chainValues[i] = vals
return nil
})
}
if err := errs.Wait(); err != nil {
return nil, err
}
return chainValues, nil
}