Skip to content

Commit

Permalink
Refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
hupe1980 committed Jul 30, 2023
1 parent 7151d1b commit 9978c54
Show file tree
Hide file tree
Showing 13 changed files with 311 additions and 91 deletions.
73 changes: 65 additions & 8 deletions agent/agent_test.go
Original file line number Diff line number Diff line change
@@ -1,33 +1,34 @@
package agent

import (
"context"
"reflect"
"testing"

"github.com/hupe1980/golc/schema"
"github.com/hupe1980/golc/tool"
"github.com/stretchr/testify/assert"
)

func TestToolNames(t *testing.T) {
tools := []schema.Tool{
tool.NewSleep(),
tool.NewHuman(),
&mockTool{ToolName: "Tool1"},
&mockTool{ToolName: "Tool2"},
}

expected := "Sleep, Human"
expected := "Tool1, Tool2"
result := toolNames(tools)

assert.Equal(t, expected, result)
}

func TestToolDescriptions(t *testing.T) {
tools := []schema.Tool{
tool.NewSleep(),
tool.NewHuman(),
&mockTool{ToolName: "Tool1", ToolDescription: "Description1."},
&mockTool{ToolName: "Tool2", ToolDescription: "Description2."},
}

expected := `- Sleep: Make agent sleep for a specified number of seconds.
- Human: You can ask a human for guidance when you think you got stuck or you are not sure what to do next. The input should be a question for the human.`
expected := `- Tool1: Description1.
- Tool2: Description2.`

result := toolDescriptions(tools)

Expand All @@ -54,3 +55,59 @@ func TestInputsToString(t *testing.T) {
_, err = inputsToString(inputValues)
assert.Error(t, err)
}

// Compile time check to ensure mockTool satisfies the Tool interface.
var _ schema.Tool = (*mockTool)(nil)

type mockTool struct {
ToolName string
ToolDescription string
ToolArgsType any
ToolRunFunc func(ctx context.Context, input any) (string, error)
}

// Name returns the name of the tool.
func (t *mockTool) Name() string {
if t.ToolName != "" {
return t.ToolName
}

return "Mock"
}

// Description returns the description of the tool.
func (t *mockTool) Description() string {
if t.ToolDescription != "" {
return t.ToolDescription
}

return "Mock"
}

// ArgsType returns the type of the input argument expected by the tool.
func (t *mockTool) ArgsType() reflect.Type {
if t.ToolArgsType != nil {
return reflect.TypeOf(t.ToolArgsType)
}

return reflect.TypeOf("") // string
}

// Run executes the tool with the given input and returns the output.
func (t *mockTool) Run(ctx context.Context, input any) (string, error) {
if t.ToolRunFunc != nil {
return t.ToolRunFunc(ctx, input)
}

return "Mock", nil
}

// Verbose returns the verbosity setting of the tool.
func (t *mockTool) Verbose() bool {
return false
}

// Callbacks returns the registered callbacks of the tool.
func (t *mockTool) Callbacks() []schema.Callback {
return nil
}
5 changes: 0 additions & 5 deletions agent/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,6 @@ func (e Executor) Call(ctx context.Context, inputs schema.ChainValues, optFns ..
fn(&opts)
}

// strInputs, err := inputsToString(inputs)
// if err != nil {
// return nil, err
// }

steps := []schema.AgentStep{}

for i := 0; i <= e.opts.MaxIterations; i++ {
Expand Down
18 changes: 13 additions & 5 deletions agent/openai_functions.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"fmt"

"github.com/hupe1980/golc/model"
"github.com/hupe1980/golc/model/chatmodel"
"github.com/hupe1980/golc/prompt"
"github.com/hupe1980/golc/schema"
"github.com/hupe1980/golc/tool"
Expand All @@ -15,23 +14,27 @@ import (
// Compile time check to ensure OpenAIFunctions satisfies the agent interface.
var _ schema.Agent = (*OpenAIFunctions)(nil)

// OpenAIFunctionsOptions represents the configuration options for the OpenAIFunctions agent.
type OpenAIFunctionsOptions struct {
// OutputKey is the key to store the output of the agent in the ChainValues.
OutputKey string
}

// OpenAIFunctions is an agent that uses OpenAI chatModels and schema.Tools to perform actions.
type OpenAIFunctions struct {
model schema.ChatModel
functions []schema.FunctionDefinition
opts OpenAIFunctionsOptions
}

func NewOpenAIFunctions(model schema.Model, tools []schema.Tool) (*Executor, error) {
// NewOpenAIFunctions creates a new instance of the OpenAIFunctions agent with the given model and tools.
// It returns an error if the model is not an OpenAI chatModel or fails to convert tools to function definitions.
func NewOpenAIFunctions(model schema.ChatModel, tools []schema.Tool) (*Executor, error) {
opts := OpenAIFunctionsOptions{
OutputKey: "output",
}

chatModel, ok := model.(*chatmodel.OpenAI)
if !ok {
if model.Type() != "chatmodel.OpenAI" {
return nil, errors.New("agent only supports OpenAI chatModels")
}

Expand All @@ -47,14 +50,16 @@ func NewOpenAIFunctions(model schema.Model, tools []schema.Tool) (*Executor, err
}

agent := &OpenAIFunctions{
model: chatModel,
model: model,
functions: functions,
opts: opts,
}

return NewExecutor(agent, tools)
}

// Plan executes the agent with the given context, intermediate steps, and inputs.
// It returns the agent actions, agent finish, or an error, if any.
func (a *OpenAIFunctions) Plan(ctx context.Context, intermediateSteps []schema.AgentStep, inputs schema.ChainValues) ([]*schema.AgentAction, *schema.AgentFinish, error) {
inputs["agentScratchpad"] = a.constructScratchPad(intermediateSteps)

Expand Down Expand Up @@ -111,14 +116,17 @@ func (a *OpenAIFunctions) Plan(ctx context.Context, intermediateSteps []schema.A
}, nil
}

// InputKeys returns the expected input keys for the agent.
func (a *OpenAIFunctions) InputKeys() []string {
return []string{"input"}
}

// OutputKeys returns the output keys that the agent will return.
func (a *OpenAIFunctions) OutputKeys() []string {
return []string{a.opts.OutputKey}
}

// constructScratchPad constructs the scratch pad from the given intermediate steps.
func (a *OpenAIFunctions) constructScratchPad(steps []schema.AgentStep) schema.ChatMessages {
messages := schema.ChatMessages{}

Expand Down
82 changes: 82 additions & 0 deletions agent/openai_functions_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
package agent

import (
"context"
"testing"

"github.com/hupe1980/golc/model/chatmodel"
"github.com/hupe1980/golc/schema"
"github.com/stretchr/testify/require"
)

func TestOpenAIFunctions(t *testing.T) {
t.Run("TestPlan", func(t *testing.T) {
agent, err := NewOpenAIFunctions(chatmodel.NewFake(func(ctx context.Context, messages schema.ChatMessages) (*schema.ModelResult, error) {
var generation schema.Generation
if len(messages) == 2 {
generation = schema.Generation{
Text: "text",
Message: schema.NewAIChatMessage("text", func(o *schema.ChatMessageExtension) {
o.FunctionCall = &schema.FunctionCall{
Name: "Mock",
Arguments: `{"__arg1": "tool input"}`,
}
}),
}
} else {
require.Len(t, messages, 4)
require.Equal(t, "tool output", messages[3].Content())

generation = schema.Generation{
Text: "finish text",
Message: schema.NewAIChatMessage("finish text"),
}
}

return &schema.ModelResult{
Generations: []schema.Generation{generation},
LLMOutput: map[string]any{},
}, nil
}, func(o *chatmodel.FakeOptions) {
o.ChatModelType = "chatmodel.OpenAI"
}), []schema.Tool{
&mockTool{
ToolRunFunc: func(ctx context.Context, input any) (string, error) {
require.Equal(t, "tool input", input.(string))
return "tool output", nil
},
},
})
require.NoError(t, err)

// Create the inputs for the agent
inputs := schema.ChainValues{
"input": "User Input",
}

// Execute the agent's Plan method
output, err := agent.Call(context.Background(), inputs)
require.NoError(t, err)
require.Equal(t, "finish text", output[agent.OutputKeys()[0]])
})

t.Run("TestPlanInvalidModel", func(t *testing.T) {
_, err := NewOpenAIFunctions(chatmodel.NewSimpleFake("foo"), []schema.Tool{
&mockTool{},
})
require.Error(t, err)
require.EqualError(t, err, "agent only supports OpenAI chatModels")
})

t.Run("TestPlanInvalidTool", func(t *testing.T) {
_, err := NewOpenAIFunctions(chatmodel.NewSimpleFake("foo", func(o *chatmodel.FakeOptions) {
o.ChatModelType = "chatmodel.OpenAI"
}), []schema.Tool{
&mockTool{ToolArgsType: struct {
Channel chan int `json:"channel"` // chan cannot converted to json
}{}},
})
require.Error(t, err)
require.EqualError(t, err, "unsupported type chan from chan int")
})
}
32 changes: 23 additions & 9 deletions chain/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,16 @@ import (

func TestAPI(t *testing.T) {
t.Run("Valid Input", func(t *testing.T) {
fake := llm.NewFake(func(prompt string) string {
fake := llm.NewFake(func(ctx context.Context, prompt string) (*schema.ModelResult, error) {
text := "42"
if strings.HasSuffix(prompt, "API url:") {
return "https://galaxy.org"
text = "https://galaxy.org"
}

return "42"
return &schema.ModelResult{
Generations: []schema.Generation{{Text: text}},
LLMOutput: map[string]any{},
}, nil
})

api, err := NewAPI(fake, "doc", func(o *APIOptions) {
Expand All @@ -40,11 +44,16 @@ func TestAPI(t *testing.T) {
})

t.Run("Invalid Input Key", func(t *testing.T) {
fake := llm.NewFake(func(prompt string) string {
fake := llm.NewFake(func(ctx context.Context, prompt string) (*schema.ModelResult, error) {
text := "42"
if strings.HasSuffix(prompt, "API url:") {
return "https://galaxy.org"
text = "https://galaxy.org"
}
return "42"

return &schema.ModelResult{
Generations: []schema.Generation{{Text: text}},
LLMOutput: map[string]any{},
}, nil
})

api, err := NewAPI(fake, "doc", func(o *APIOptions) {
Expand All @@ -62,11 +71,16 @@ func TestAPI(t *testing.T) {
})

t.Run("Invalid API URL", func(t *testing.T) {
fake := llm.NewFake(func(prompt string) string {
fake := llm.NewFake(func(ctx context.Context, prompt string) (*schema.ModelResult, error) {
text := "42"
if strings.HasSuffix(prompt, "API url:") {
return "https://galaxy.org"
text = "https://galaxy.org"
}
return "42"

return &schema.ModelResult{
Generations: []schema.Generation{{Text: text}},
LLMOutput: map[string]any{},
}, nil
})

api, err := NewAPI(fake, "doc", func(o *APIOptions) {
Expand Down
4 changes: 1 addition & 3 deletions chain/llm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,7 @@ import (

func TestLLM(t *testing.T) {
t.Run("Valid Question", func(t *testing.T) {
fake := llm.NewFake(func(prompt string) string {
return "This is a valid question."
})
fake := llm.NewSimpleFake("This is a valid question.")

llmChain, err := NewLLM(fake, prompt.NewTemplate("{{.input}}"))
require.NoError(t, err)
Expand Down
8 changes: 2 additions & 6 deletions chain/math_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,7 @@ import (

func TestMath(t *testing.T) {
t.Run("Valid Question", func(t *testing.T) {
fake := llm.NewFake(func(prompt string) string {
return "```text\n3 * 3\n```"
})
fake := llm.NewSimpleFake("```text\n3 * 3\n```")

mathChain, err := NewMath(fake)
require.NoError(t, err)
Expand All @@ -28,9 +26,7 @@ func TestMath(t *testing.T) {
})

t.Run("Invalid Input Key", func(t *testing.T) {
fake := llm.NewFake(func(prompt string) string {
return "```text\n3 * 3\n```"
})
fake := llm.NewSimpleFake("```text\n3 * 3\n```")

mathChain, err := NewMath(fake)
require.NoError(t, err)
Expand Down
Loading

0 comments on commit 9978c54

Please sign in to comment.