From aa7f093175b1bd7cddc697d993f366a1cd22c588 Mon Sep 17 00:00:00 2001 From: hupe1980 Date: Wed, 21 Jun 2023 20:44:01 +0200 Subject: [PATCH] Refactor --- _examples/zero_shot_react_description/main.go | 2 ++ agent/agent.go | 17 ++++++++++- agent/conversational.go | 2 +- agent/mrkl.go | 2 +- chain/llm.go | 30 +++++++++---------- chain/llm_bash.go | 6 ++-- chain/llm_math.go | 6 ++-- chain/refine_documents.go | 6 ++-- chain/retrieval_qa.go | 2 +- chain/stuff_documents.go | 4 +-- chain/summarization.go | 6 ++-- evaluation/context_qa_eval_chain.go | 4 +-- evaluation/qa_eval_chain.go | 4 +-- 13 files changed, 54 insertions(+), 37 deletions(-) diff --git a/_examples/zero_shot_react_description/main.go b/_examples/zero_shot_react_description/main.go index 5a4a5cb..9eb6937 100644 --- a/_examples/zero_shot_react_description/main.go +++ b/_examples/zero_shot_react_description/main.go @@ -15,6 +15,8 @@ import ( ) func main() { + golc.Verbose = true + openai, err := llm.NewOpenAI(os.Getenv("OPENAI_API_KEY")) if err != nil { log.Fatal(err) diff --git a/agent/agent.go b/agent/agent.go index 1139d10..5a5b964 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -4,6 +4,7 @@ import ( "fmt" "strings" + "github.com/hupe1980/golc" "github.com/hupe1980/golc/schema" ) @@ -15,7 +16,21 @@ const ( ConversationalReactDescriptionAgentType AgentType = "conversational-react-description" ) -func New(llm schema.LLM, tools []schema.Tool, aType AgentType) (*Executor, error) { +type Options struct { + *schema.CallbackOptions +} + +func New(llm schema.LLM, tools []schema.Tool, aType AgentType, optFns ...func(o *Options)) (*Executor, error) { + opts := Options{ + CallbackOptions: &schema.CallbackOptions{ + Verbose: golc.Verbose, + }, + } + + for _, fn := range optFns { + fn(&opts) + } + var ( agent schema.Agent err error diff --git a/agent/conversational.go b/agent/conversational.go index f885b9b..8ddbdbd 100644 --- a/agent/conversational.go +++ b/agent/conversational.go @@ -75,7 +75,7 @@ func NewConversationalReactDescription(llm schema.LLM, tools []schema.Tool) (*Co return nil, err } - llmChain, err := chain.NewLLMChain(llm, prompt, func(o *chain.LLMChainOptions) { + llmChain, err := chain.NewLLM(llm, prompt, func(o *chain.LLMOptions) { o.Memory = memory.NewConversationBuffer() }) if err != nil { diff --git a/agent/mrkl.go b/agent/mrkl.go index 65b9d8b..662a0df 100644 --- a/agent/mrkl.go +++ b/agent/mrkl.go @@ -64,7 +64,7 @@ func NewZeroShotReactDescription(llm schema.LLM, tools []schema.Tool) (*ZeroShot return nil, err } - llmChain, err := chain.NewLLMChain(llm, prompt) + llmChain, err := chain.NewLLM(llm, prompt) if err != nil { return nil, err } diff --git a/chain/llm.go b/chain/llm.go index 712a2f0..e7fa527 100644 --- a/chain/llm.go +++ b/chain/llm.go @@ -10,21 +10,21 @@ import ( "github.com/hupe1980/golc/schema" ) -type LLMChainOptions struct { +type LLMOptions struct { *schema.CallbackOptions Memory schema.Memory OutputKey string OutputParser schema.OutputParser[any] } -type LLMChain struct { +type LLM struct { llm schema.LLM prompt *prompt.Template - opts LLMChainOptions + opts LLMOptions } -func NewLLMChain(llm schema.LLM, prompt *prompt.Template, optFns ...func(o *LLMChainOptions)) (*LLMChain, error) { - opts := LLMChainOptions{ +func NewLLM(llm schema.LLM, prompt *prompt.Template, optFns ...func(o *LLMOptions)) (*LLM, error) { + opts := LLMOptions{ OutputKey: "text", CallbackOptions: &schema.CallbackOptions{ Verbose: golc.Verbose, @@ -35,14 +35,14 @@ func NewLLMChain(llm schema.LLM, prompt *prompt.Template, optFns ...func(o *LLMC fn(&opts) } - return &LLMChain{ + return &LLM{ prompt: prompt, llm: llm, opts: opts, }, nil } -func (c *LLMChain) Call(ctx context.Context, inputs schema.ChainValues) (schema.ChainValues, error) { +func (c *LLM) Call(ctx context.Context, inputs schema.ChainValues) (schema.ChainValues, error) { promptValue, err := c.prompt.FormatPrompt(inputs) if err != nil { return nil, err @@ -60,37 +60,37 @@ func (c *LLMChain) Call(ctx context.Context, inputs schema.ChainValues) (schema. }, nil } -func (c *LLMChain) Prompt() *prompt.Template { +func (c *LLM) Prompt() *prompt.Template { return c.prompt } -func (c *LLMChain) Memory() schema.Memory { +func (c *LLM) Memory() schema.Memory { return c.opts.Memory } -func (c *LLMChain) Type() string { +func (c *LLM) Type() string { return "LLM" } -func (c *LLMChain) Verbose() bool { +func (c *LLM) Verbose() bool { return c.opts.CallbackOptions.Verbose } -func (c *LLMChain) Callbacks() []schema.Callback { +func (c *LLM) Callbacks() []schema.Callback { return c.opts.CallbackOptions.Callbacks } // InputKeys returns the expected input keys. -func (c *LLMChain) InputKeys() []string { +func (c *LLM) InputKeys() []string { return c.prompt.InputVariables() } // OutputKeys returns the output keys the chain will return. -func (c *LLMChain) OutputKeys() []string { +func (c *LLM) OutputKeys() []string { return []string{c.opts.OutputKey} } -func (c *LLMChain) getFinalOutput(generations [][]*schema.Generation) string { +func (c *LLM) getFinalOutput(generations [][]*schema.Generation) string { output := []string{} for _, generation := range generations { // Get the text of the top generated string. diff --git a/chain/llm_bash.go b/chain/llm_bash.go index fdab7a1..4133fe6 100644 --- a/chain/llm_bash.go +++ b/chain/llm_bash.go @@ -36,12 +36,12 @@ type LLMBashOptions struct { } type LLMBash struct { - llmChain *LLMChain + llmChain *LLM bashProcess *integration.BashProcess opts LLMBashOptions } -func NewLLMBash(llmChain *LLMChain, optFns ...func(o *LLMBashOptions)) (*LLMBash, error) { +func NewLLMBash(llmChain *LLM, optFns ...func(o *LLMBashOptions)) (*LLMBash, error) { opts := LLMBashOptions{ InputKey: "question", OutputKey: "answer", @@ -74,7 +74,7 @@ func NewLLMBashFromLLM(llm schema.LLM) (*LLMBash, error) { return nil, err } - llmChain, err := NewLLMChain(llm, prompt) + llmChain, err := NewLLM(llm, prompt) if err != nil { return nil, err } diff --git a/chain/llm_math.go b/chain/llm_math.go index 8d11977..efa4ca2 100644 --- a/chain/llm_math.go +++ b/chain/llm_math.go @@ -43,11 +43,11 @@ type LLMMathOptions struct { } type LLMMath struct { - llmChain *LLMChain + llmChain *LLM opts LLMMathOptions } -func NewLLMMath(llmChain *LLMChain, optFns ...func(o *LLMMathOptions)) (*LLMMath, error) { +func NewLLMMath(llmChain *LLM, optFns ...func(o *LLMMathOptions)) (*LLMMath, error) { opts := LLMMathOptions{ InputKey: "question", OutputKey: "answer", @@ -74,7 +74,7 @@ func NewLLMMathFromLLM(llm schema.LLM) (*LLMMath, error) { return nil, err } - llmChain, err := NewLLMChain(llm, prompt) + llmChain, err := NewLLM(llm, prompt) if err != nil { return nil, err } diff --git a/chain/refine_documents.go b/chain/refine_documents.go index d3f62e1..e9da64f 100644 --- a/chain/refine_documents.go +++ b/chain/refine_documents.go @@ -21,12 +21,12 @@ type RefineDocumentsOptions struct { } type RefineDocuments struct { - llmChain *LLMChain - refineLLMChain *LLMChain + llmChain *LLM + refineLLMChain *LLM opts RefineDocumentsOptions } -func NewRefineDocuments(llmChain *LLMChain, refineLLMChain *LLMChain, optFns ...func(o *RefineDocumentsOptions)) (*RefineDocuments, error) { +func NewRefineDocuments(llmChain *LLM, refineLLMChain *LLM, optFns ...func(o *RefineDocumentsOptions)) (*RefineDocuments, error) { opts := RefineDocumentsOptions{ InputKey: "inputDocuments", DocumentVariableName: "context", diff --git a/chain/retrieval_qa.go b/chain/retrieval_qa.go index 66e1ce6..1d54727 100644 --- a/chain/retrieval_qa.go +++ b/chain/retrieval_qa.go @@ -47,7 +47,7 @@ func NewRetrievalQAFromLLM(llm schema.LLM, retriever schema.Retriever) (*Retriev return nil, err } - llmChain, err := NewLLMChain(llm, stuffPrompt) + llmChain, err := NewLLM(llm, stuffPrompt) if err != nil { return nil, err } diff --git a/chain/stuff_documents.go b/chain/stuff_documents.go index 6425fb4..06d37a3 100644 --- a/chain/stuff_documents.go +++ b/chain/stuff_documents.go @@ -18,11 +18,11 @@ type StuffDocumentsOptions struct { } type StuffDocuments struct { - llmChain *LLMChain + llmChain *LLM opts StuffDocumentsOptions } -func NewStuffDocuments(llmChain *LLMChain, optFns ...func(o *StuffDocumentsOptions)) (*StuffDocuments, error) { +func NewStuffDocuments(llmChain *LLM, optFns ...func(o *StuffDocumentsOptions)) (*StuffDocuments, error) { opts := StuffDocumentsOptions{ InputKey: "inputDocuments", DocumentVariableName: "context", diff --git a/chain/summarization.go b/chain/summarization.go index c5a9138..a4fa6f1 100644 --- a/chain/summarization.go +++ b/chain/summarization.go @@ -34,7 +34,7 @@ func NewStuffSummarization(llm schema.LLM, optFns ...func(o *StuffSummarizationO return nil, err } - llmChain, err := NewLLMChain(llm, stuffPrompt, func(o *LLMChainOptions) { + llmChain, err := NewLLM(llm, stuffPrompt, func(o *LLMOptions) { o.CallbackOptions = opts.CallbackOptions }) if err != nil { @@ -79,7 +79,7 @@ func NewRefineSummarization(llm schema.LLM, optFns ...func(o *RefineSummarizatio return nil, err } - llmChain, err := NewLLMChain(llm, stuffPrompt, func(o *LLMChainOptions) { + llmChain, err := NewLLM(llm, stuffPrompt, func(o *LLMOptions) { o.CallbackOptions = opts.CallbackOptions }) if err != nil { @@ -91,7 +91,7 @@ func NewRefineSummarization(llm schema.LLM, optFns ...func(o *RefineSummarizatio return nil, err } - refineLLMChain, err := NewLLMChain(llm, refinePrompt, func(o *LLMChainOptions) { + refineLLMChain, err := NewLLM(llm, refinePrompt, func(o *LLMOptions) { o.CallbackOptions = opts.CallbackOptions }) if err != nil { diff --git a/evaluation/context_qa_eval_chain.go b/evaluation/context_qa_eval_chain.go index d51e4fd..6ef1410 100644 --- a/evaluation/context_qa_eval_chain.go +++ b/evaluation/context_qa_eval_chain.go @@ -38,7 +38,7 @@ type ContextQAEvalChainOptions struct { // ConetxtQAEvalChain is a LLM Chain specifically for evaluating QA w/o GT based on context. type ContextQAEvalChain struct { - llmChain *chain.LLMChain + llmChain *chain.LLM questionKey string contextKey string predictionKey string @@ -64,7 +64,7 @@ func NewContextQAEvalChain(llm schema.LLM, optFns ...func(o *ContextQAEvalChainO opts.Prompt = contextQAEvalPrompt } - llmChain, err := chain.NewLLMChain(llm, opts.Prompt) + llmChain, err := chain.NewLLM(llm, opts.Prompt) if err != nil { return nil, err } diff --git a/evaluation/qa_eval_chain.go b/evaluation/qa_eval_chain.go index 6b27560..502a0ca 100644 --- a/evaluation/qa_eval_chain.go +++ b/evaluation/qa_eval_chain.go @@ -34,7 +34,7 @@ type QAEvalChainOptions struct { // QAEvalChain is a LLM Chain specifically for evaluating question answering. type QAEvalChain struct { - llmChain *chain.LLMChain + llmChain *chain.LLM questionKey string answerKey string predictionKey string @@ -57,7 +57,7 @@ func NewQAEvalChain(llm schema.LLM, optFns ...func(o *QAEvalChainOptions)) (*QAE fn(&opts) } - llmChain, err := chain.NewLLMChain(llm, opts.Prompt) + llmChain, err := chain.NewLLM(llm, opts.Prompt) if err != nil { return nil, err }