Skip to content

Commit

Permalink
Refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
hupe1980 committed Jun 19, 2023
1 parent 51b03eb commit d980435
Show file tree
Hide file tree
Showing 26 changed files with 456 additions and 142 deletions.
19 changes: 19 additions & 0 deletions agent/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ var _ schema.Chain = (*Executor)(nil)

type ExecutorOptions struct {
MaxIterations int
Memory schema.Memory
Callbacks []schema.Callback
Verbose bool
}

type Executor struct {
Expand Down Expand Up @@ -91,6 +94,22 @@ func (e Executor) Call(ctx context.Context, values schema.ChainValues) (schema.C
return nil, ErrNotFinished
}

func (e Executor) Memory() schema.Memory {
return e.opts.Memory
}

func (e Executor) Type() string {
return "Executor"
}

func (e Executor) Verbose() bool {
return e.opts.Verbose
}

func (e Executor) Callbacks() []schema.Callback {
return e.opts.Callbacks
}

func (e Executor) InputKeys() []string {
return e.agent.InputKeys()
}
Expand Down
62 changes: 60 additions & 2 deletions chain/chain.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func (bc *baseChain) Call(ctx context.Context, inputs schema.ChainValues) (schem
}

if bc.memory != nil {
vars, _ := bc.memory.LoadMemoryVariables(inputs)
vars, _ := bc.memory.LoadMemoryVariables(ctx, inputs)
for k, v := range vars {
inputs[k] = v
}
Expand All @@ -47,7 +47,7 @@ func (bc *baseChain) Call(ctx context.Context, inputs schema.ChainValues) (schem
}

if bc.memory != nil {
if err := bc.memory.SaveContext(inputs, outputs); err != nil {
if err := bc.memory.SaveContext(ctx, inputs, outputs); err != nil {
return nil, err
}
}
Expand Down Expand Up @@ -110,3 +110,61 @@ func (bc *baseChain) InputKeys() []string {
func (bc *baseChain) OutputKeys() []string {
return bc.outputKeys
}

func Call(ctx context.Context, chain schema.Chain, inputs schema.ChainValues) (schema.ChainValues, error) {
cm := callback.NewManager(chain.Callbacks(), chain.Verbose())

if err := cm.OnChainStart(chain.Type(), &inputs); err != nil {
return nil, err
}

if chain.Memory() != nil {
vars, _ := chain.Memory().LoadMemoryVariables(ctx, inputs)
for k, v := range vars {
inputs[k] = v
}
}

outputs, err := chain.Call(ctx, inputs)
if err != nil {
if cbError := cm.OnChainError(err); cbError != nil {
return nil, cbError
}

return nil, err
}

if chain.Memory() != nil {
if err := chain.Memory().SaveContext(ctx, inputs, outputs); err != nil {
return nil, err
}
}

if err := cm.OnChainEnd(&outputs); err != nil {
return nil, err
}

return outputs, nil
}

func Run(ctx context.Context, chain schema.Chain, input any) (string, error) {
if len(chain.InputKeys()) != 1 {
return "", ErrMultipleInputsInRun
}

if len(chain.OutputKeys()) != 1 {
return "", ErrMultipleOutputsInRun
}

outputValues, err := Call(ctx, chain, map[string]any{chain.InputKeys()[0]: input})
if err != nil {
return "", err
}

outputValue, ok := outputValues[chain.OutputKeys()[0]].(string)
if !ok {
return "", ErrWrongOutputTypeInRun
}

return outputValue, nil
}
26 changes: 26 additions & 0 deletions chain/conversation.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,32 @@ func (c *Conversation) Prompt() *prompt.Template {
return c.opts.Prompt
}

func (c *Conversation) Memory() schema.Memory {
return c.memory
}

func (c *Conversation) Type() string {
return "Conversation"
}

func (c *Conversation) Verbose() bool {
return c.opts.callbackOptions.Verbose
}

func (c *Conversation) Callbacks() []schema.Callback {
return c.opts.callbackOptions.Callbacks
}

// InputKeys returns the expected input keys.
func (c *Conversation) InputKeys() []string {
return []string{"input"}
}

// OutputKeys returns the output keys the chain will return.
func (c *Conversation) OutputKeys() []string {
return []string{c.opts.OutputKey}
}

func (c *Conversation) getFinalOutput(generations [][]*schema.Generation) string {
output := []string{}
for _, generation := range generations {
Expand Down
26 changes: 26 additions & 0 deletions chain/llm.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,32 @@ func (c *LLMChain) Prompt() *prompt.Template {
return c.prompt
}

func (c *LLMChain) Memory() schema.Memory {
return c.opts.Memory
}

func (c *LLMChain) Type() string {
return "LLM"
}

func (c *LLMChain) Verbose() bool {
return c.opts.callbackOptions.Verbose
}

func (c *LLMChain) Callbacks() []schema.Callback {
return c.opts.callbackOptions.Callbacks
}

// InputKeys returns the expected input keys.
func (c *LLMChain) InputKeys() []string {
return c.prompt.InputVariables()
}

// OutputKeys returns the output keys the chain will return.
func (c *LLMChain) OutputKeys() []string {
return []string{c.opts.OutputKey}
}

func (c *LLMChain) getFinalOutput(generations [][]*schema.Generation) string {
output := []string{}
for _, generation := range generations {
Expand Down
40 changes: 33 additions & 7 deletions chain/llm_bash.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,23 +93,23 @@ func NewLLMBashFromLLM(llm schema.LLM) (*LLMBash, error) {
return NewLLMBash(llmChain)
}

func (lc *LLMBash) call(ctx context.Context, values schema.ChainValues) (schema.ChainValues, error) {
input, ok := values[lc.opts.InputKey]
func (c *LLMBash) call(ctx context.Context, values schema.ChainValues) (schema.ChainValues, error) {
input, ok := values[c.opts.InputKey]
if !ok {
return nil, fmt.Errorf("%w: no value for inputKey %s", ErrInvalidInputValues, lc.opts.InputKey)
return nil, fmt.Errorf("%w: no value for inputKey %s", ErrInvalidInputValues, c.opts.InputKey)
}

question, ok := input.(string)
if !ok {
return nil, ErrInputValuesWrongType
}

t, err := lc.llmChain.Run(ctx, question)
t, err := c.llmChain.Run(ctx, question)
if err != nil {
return nil, err
}

outputParser, ok := lc.llmChain.Prompt().OutputParser()
outputParser, ok := c.llmChain.Prompt().OutputParser()
if !ok {
return nil, ErrNoOutputParser
}
Expand All @@ -119,12 +119,38 @@ func (lc *LLMBash) call(ctx context.Context, values schema.ChainValues) (schema.
return nil, err
}

output, err := lc.bashProcess.Run(ctx, commands.([]string))
output, err := c.bashProcess.Run(ctx, commands.([]string))
if err != nil {
return nil, err
}

return schema.ChainValues{
lc.opts.OutputKey: output,
c.opts.OutputKey: output,
}, nil
}

func (c *LLMBash) Memory() schema.Memory {
return nil
}

func (c *LLMBash) Type() string {
return "LLMBash"
}

func (c *LLMBash) Verbose() bool {
return c.opts.callbackOptions.Verbose
}

func (c *LLMBash) Callbacks() []schema.Callback {
return c.opts.callbackOptions.Callbacks
}

// InputKeys returns the expected input keys.
func (c *LLMBash) InputKeys() []string {
return []string{c.opts.InputKey}
}

// OutputKeys returns the output keys the chain will return.
func (c *LLMBash) OutputKeys() []string {
return []string{c.opts.OutputKey}
}
42 changes: 34 additions & 8 deletions chain/llm_math.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,23 +93,23 @@ func NewLLMMathFromLLM(llm schema.LLM) (*LLMMath, error) {
return NewLLMMath(llmChain)
}

func (lc *LLMMath) call(ctx context.Context, values schema.ChainValues) (schema.ChainValues, error) {
input, ok := values[lc.opts.InputKey]
func (c *LLMMath) call(ctx context.Context, values schema.ChainValues) (schema.ChainValues, error) {
input, ok := values[c.opts.InputKey]
if !ok {
return nil, fmt.Errorf("%w: no value for inputKey %s", ErrInvalidInputValues, lc.opts.InputKey)
return nil, fmt.Errorf("%w: no value for inputKey %s", ErrInvalidInputValues, c.opts.InputKey)
}

question, ok := input.(string)
if !ok {
return nil, ErrInputValuesWrongType
}

t, err := lc.llmChain.Run(ctx, question)
t, err := c.llmChain.Run(ctx, question)
if err != nil {
return nil, err
}

outputParser, ok := lc.llmChain.Prompt().OutputParser()
outputParser, ok := c.llmChain.Prompt().OutputParser()
if !ok {
return nil, ErrNoOutputParser
}
Expand All @@ -123,21 +123,47 @@ func (lc *LLMMath) call(ctx context.Context, values schema.ChainValues) (schema.
return nil, fmt.Errorf("unknown format from LLM: %s", t)
}

output, err := lc.evaluateExpression(parsed.([]string)[0])
output, err := c.evaluateExpression(parsed.([]string)[0])
if err != nil {
return nil, err
}

return schema.ChainValues{
lc.opts.OutputKey: output,
c.opts.OutputKey: output,
}, nil
}

func (lc *LLMMath) evaluateExpression(expression string) (string, error) {
func (c *LLMMath) evaluateExpression(expression string) (string, error) {
output, err := expr.Eval(expression, nil)
if err != nil {
return "", err
}

return fmt.Sprintf("%f", output), nil
}

func (c *LLMMath) Memory() schema.Memory {
return nil
}

func (c *LLMMath) Type() string {
return "LLMMath"
}

func (c *LLMMath) Verbose() bool {
return c.opts.callbackOptions.Verbose
}

func (c *LLMMath) Callbacks() []schema.Callback {
return c.opts.callbackOptions.Callbacks
}

// InputKeys returns the expected input keys.
func (c *LLMMath) InputKeys() []string {
return []string{c.opts.InputKey}
}

// OutputKeys returns the output keys the chain will return.
func (c *LLMMath) OutputKeys() []string {
return []string{c.opts.OutputKey}
}
Loading

0 comments on commit d980435

Please sign in to comment.