Skip to content

Commit

Permalink
Refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
hupe1980 committed Jun 21, 2023
1 parent 5eb4f09 commit aa7f093
Show file tree
Hide file tree
Showing 13 changed files with 54 additions and 37 deletions.
2 changes: 2 additions & 0 deletions _examples/zero_shot_react_description/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ import (
)

func main() {
golc.Verbose = true

openai, err := llm.NewOpenAI(os.Getenv("OPENAI_API_KEY"))
if err != nil {
log.Fatal(err)
Expand Down
17 changes: 16 additions & 1 deletion agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"fmt"
"strings"

"github.com/hupe1980/golc"
"github.com/hupe1980/golc/schema"
)

Expand All @@ -15,7 +16,21 @@ const (
ConversationalReactDescriptionAgentType AgentType = "conversational-react-description"
)

func New(llm schema.LLM, tools []schema.Tool, aType AgentType) (*Executor, error) {
type Options struct {
*schema.CallbackOptions
}

func New(llm schema.LLM, tools []schema.Tool, aType AgentType, optFns ...func(o *Options)) (*Executor, error) {
opts := Options{
CallbackOptions: &schema.CallbackOptions{
Verbose: golc.Verbose,
},
}

for _, fn := range optFns {
fn(&opts)
}

var (
agent schema.Agent
err error
Expand Down
2 changes: 1 addition & 1 deletion agent/conversational.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ func NewConversationalReactDescription(llm schema.LLM, tools []schema.Tool) (*Co
return nil, err
}

llmChain, err := chain.NewLLMChain(llm, prompt, func(o *chain.LLMChainOptions) {
llmChain, err := chain.NewLLM(llm, prompt, func(o *chain.LLMOptions) {
o.Memory = memory.NewConversationBuffer()
})
if err != nil {
Expand Down
2 changes: 1 addition & 1 deletion agent/mrkl.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ func NewZeroShotReactDescription(llm schema.LLM, tools []schema.Tool) (*ZeroShot
return nil, err
}

llmChain, err := chain.NewLLMChain(llm, prompt)
llmChain, err := chain.NewLLM(llm, prompt)
if err != nil {
return nil, err
}
Expand Down
30 changes: 15 additions & 15 deletions chain/llm.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,21 @@ import (
"github.com/hupe1980/golc/schema"
)

type LLMChainOptions struct {
type LLMOptions struct {
*schema.CallbackOptions
Memory schema.Memory
OutputKey string
OutputParser schema.OutputParser[any]
}

type LLMChain struct {
type LLM struct {
llm schema.LLM
prompt *prompt.Template
opts LLMChainOptions
opts LLMOptions
}

func NewLLMChain(llm schema.LLM, prompt *prompt.Template, optFns ...func(o *LLMChainOptions)) (*LLMChain, error) {
opts := LLMChainOptions{
func NewLLM(llm schema.LLM, prompt *prompt.Template, optFns ...func(o *LLMOptions)) (*LLM, error) {
opts := LLMOptions{
OutputKey: "text",
CallbackOptions: &schema.CallbackOptions{
Verbose: golc.Verbose,
Expand All @@ -35,14 +35,14 @@ func NewLLMChain(llm schema.LLM, prompt *prompt.Template, optFns ...func(o *LLMC
fn(&opts)
}

return &LLMChain{
return &LLM{
prompt: prompt,
llm: llm,
opts: opts,
}, nil
}

func (c *LLMChain) Call(ctx context.Context, inputs schema.ChainValues) (schema.ChainValues, error) {
func (c *LLM) Call(ctx context.Context, inputs schema.ChainValues) (schema.ChainValues, error) {
promptValue, err := c.prompt.FormatPrompt(inputs)
if err != nil {
return nil, err
Expand All @@ -60,37 +60,37 @@ func (c *LLMChain) Call(ctx context.Context, inputs schema.ChainValues) (schema.
}, nil
}

func (c *LLMChain) Prompt() *prompt.Template {
func (c *LLM) Prompt() *prompt.Template {
return c.prompt
}

func (c *LLMChain) Memory() schema.Memory {
func (c *LLM) Memory() schema.Memory {
return c.opts.Memory
}

func (c *LLMChain) Type() string {
func (c *LLM) Type() string {
return "LLM"
}

func (c *LLMChain) Verbose() bool {
func (c *LLM) Verbose() bool {
return c.opts.CallbackOptions.Verbose
}

func (c *LLMChain) Callbacks() []schema.Callback {
func (c *LLM) Callbacks() []schema.Callback {
return c.opts.CallbackOptions.Callbacks
}

// InputKeys returns the expected input keys.
func (c *LLMChain) InputKeys() []string {
func (c *LLM) InputKeys() []string {
return c.prompt.InputVariables()
}

// OutputKeys returns the output keys the chain will return.
func (c *LLMChain) OutputKeys() []string {
func (c *LLM) OutputKeys() []string {
return []string{c.opts.OutputKey}
}

func (c *LLMChain) getFinalOutput(generations [][]*schema.Generation) string {
func (c *LLM) getFinalOutput(generations [][]*schema.Generation) string {
output := []string{}
for _, generation := range generations {
// Get the text of the top generated string.
Expand Down
6 changes: 3 additions & 3 deletions chain/llm_bash.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,12 @@ type LLMBashOptions struct {
}

type LLMBash struct {
llmChain *LLMChain
llmChain *LLM
bashProcess *integration.BashProcess
opts LLMBashOptions
}

func NewLLMBash(llmChain *LLMChain, optFns ...func(o *LLMBashOptions)) (*LLMBash, error) {
func NewLLMBash(llmChain *LLM, optFns ...func(o *LLMBashOptions)) (*LLMBash, error) {
opts := LLMBashOptions{
InputKey: "question",
OutputKey: "answer",
Expand Down Expand Up @@ -74,7 +74,7 @@ func NewLLMBashFromLLM(llm schema.LLM) (*LLMBash, error) {
return nil, err
}

llmChain, err := NewLLMChain(llm, prompt)
llmChain, err := NewLLM(llm, prompt)
if err != nil {
return nil, err
}
Expand Down
6 changes: 3 additions & 3 deletions chain/llm_math.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,11 @@ type LLMMathOptions struct {
}

type LLMMath struct {
llmChain *LLMChain
llmChain *LLM
opts LLMMathOptions
}

func NewLLMMath(llmChain *LLMChain, optFns ...func(o *LLMMathOptions)) (*LLMMath, error) {
func NewLLMMath(llmChain *LLM, optFns ...func(o *LLMMathOptions)) (*LLMMath, error) {
opts := LLMMathOptions{
InputKey: "question",
OutputKey: "answer",
Expand All @@ -74,7 +74,7 @@ func NewLLMMathFromLLM(llm schema.LLM) (*LLMMath, error) {
return nil, err
}

llmChain, err := NewLLMChain(llm, prompt)
llmChain, err := NewLLM(llm, prompt)
if err != nil {
return nil, err
}
Expand Down
6 changes: 3 additions & 3 deletions chain/refine_documents.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,12 @@ type RefineDocumentsOptions struct {
}

type RefineDocuments struct {
llmChain *LLMChain
refineLLMChain *LLMChain
llmChain *LLM
refineLLMChain *LLM
opts RefineDocumentsOptions
}

func NewRefineDocuments(llmChain *LLMChain, refineLLMChain *LLMChain, optFns ...func(o *RefineDocumentsOptions)) (*RefineDocuments, error) {
func NewRefineDocuments(llmChain *LLM, refineLLMChain *LLM, optFns ...func(o *RefineDocumentsOptions)) (*RefineDocuments, error) {
opts := RefineDocumentsOptions{
InputKey: "inputDocuments",
DocumentVariableName: "context",
Expand Down
2 changes: 1 addition & 1 deletion chain/retrieval_qa.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ func NewRetrievalQAFromLLM(llm schema.LLM, retriever schema.Retriever) (*Retriev
return nil, err
}

llmChain, err := NewLLMChain(llm, stuffPrompt)
llmChain, err := NewLLM(llm, stuffPrompt)
if err != nil {
return nil, err
}
Expand Down
4 changes: 2 additions & 2 deletions chain/stuff_documents.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@ type StuffDocumentsOptions struct {
}

type StuffDocuments struct {
llmChain *LLMChain
llmChain *LLM
opts StuffDocumentsOptions
}

func NewStuffDocuments(llmChain *LLMChain, optFns ...func(o *StuffDocumentsOptions)) (*StuffDocuments, error) {
func NewStuffDocuments(llmChain *LLM, optFns ...func(o *StuffDocumentsOptions)) (*StuffDocuments, error) {
opts := StuffDocumentsOptions{
InputKey: "inputDocuments",
DocumentVariableName: "context",
Expand Down
6 changes: 3 additions & 3 deletions chain/summarization.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ func NewStuffSummarization(llm schema.LLM, optFns ...func(o *StuffSummarizationO
return nil, err
}

llmChain, err := NewLLMChain(llm, stuffPrompt, func(o *LLMChainOptions) {
llmChain, err := NewLLM(llm, stuffPrompt, func(o *LLMOptions) {
o.CallbackOptions = opts.CallbackOptions
})
if err != nil {
Expand Down Expand Up @@ -79,7 +79,7 @@ func NewRefineSummarization(llm schema.LLM, optFns ...func(o *RefineSummarizatio
return nil, err
}

llmChain, err := NewLLMChain(llm, stuffPrompt, func(o *LLMChainOptions) {
llmChain, err := NewLLM(llm, stuffPrompt, func(o *LLMOptions) {
o.CallbackOptions = opts.CallbackOptions
})
if err != nil {
Expand All @@ -91,7 +91,7 @@ func NewRefineSummarization(llm schema.LLM, optFns ...func(o *RefineSummarizatio
return nil, err
}

refineLLMChain, err := NewLLMChain(llm, refinePrompt, func(o *LLMChainOptions) {
refineLLMChain, err := NewLLM(llm, refinePrompt, func(o *LLMOptions) {
o.CallbackOptions = opts.CallbackOptions
})
if err != nil {
Expand Down
4 changes: 2 additions & 2 deletions evaluation/context_qa_eval_chain.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ type ContextQAEvalChainOptions struct {

// ConetxtQAEvalChain is a LLM Chain specifically for evaluating QA w/o GT based on context.
type ContextQAEvalChain struct {
llmChain *chain.LLMChain
llmChain *chain.LLM
questionKey string
contextKey string
predictionKey string
Expand All @@ -64,7 +64,7 @@ func NewContextQAEvalChain(llm schema.LLM, optFns ...func(o *ContextQAEvalChainO
opts.Prompt = contextQAEvalPrompt
}

llmChain, err := chain.NewLLMChain(llm, opts.Prompt)
llmChain, err := chain.NewLLM(llm, opts.Prompt)
if err != nil {
return nil, err
}
Expand Down
4 changes: 2 additions & 2 deletions evaluation/qa_eval_chain.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ type QAEvalChainOptions struct {

// QAEvalChain is a LLM Chain specifically for evaluating question answering.
type QAEvalChain struct {
llmChain *chain.LLMChain
llmChain *chain.LLM
questionKey string
answerKey string
predictionKey string
Expand All @@ -57,7 +57,7 @@ func NewQAEvalChain(llm schema.LLM, optFns ...func(o *QAEvalChainOptions)) (*QAE
fn(&opts)
}

llmChain, err := chain.NewLLMChain(llm, opts.Prompt)
llmChain, err := chain.NewLLM(llm, opts.Prompt)
if err != nil {
return nil, err
}
Expand Down

0 comments on commit aa7f093

Please sign in to comment.