From b9497c7c4dc54e1cdb3c996e0f577fa7dc14d8a1 Mon Sep 17 00:00:00 2001 From: hupe1980 Date: Tue, 1 Aug 2023 16:54:31 +0200 Subject: [PATCH] Refactor --- chain/structured_output.go | 29 ++++++++++++------- chain/structured_output_test.go | 9 +++--- .../en/docs/chains/structured_output.md | 7 +++-- examples/structured_output_chain/main.go | 7 +++-- 4 files changed, 31 insertions(+), 21 deletions(-) diff --git a/chain/structured_output.go b/chain/structured_output.go index 74ced69..19e6cdf 100644 --- a/chain/structured_output.go +++ b/chain/structured_output.go @@ -16,10 +16,12 @@ import ( // Compile time check to ensure StructuredOutput satisfies the Chain interface. var _ schema.Chain = (*StructuredOutput)(nil) -// OutputCandidates represents a map of candidate names to their descriptions and values used in the structured output chain. -type OutputCandidates map[string]struct { +// OutputCandidate represents a candidate for structured output containing a name, +// description, and data of any struct type. +type OutputCandidate struct { + Name string Description string - Candidate any + Data any } // StructuredOutputOptions contains options for configuring the StructuredOutput chain. @@ -31,12 +33,12 @@ type StructuredOutputOptions struct { // StructuredOutput is a chain that generates structured output using a ChatModel chain and candidate values. type StructuredOutput struct { chatModelChain *ChatModel - candidates OutputCandidates + candidatesMap map[string]OutputCandidate opts StructuredOutputOptions } // NewStructuredOutput creates a new StructuredOutput chain with the given ChatModel, prompt, and candidates. -func NewStructuredOutput(chatModel schema.ChatModel, prompt prompt.ChatTemplate, candidates OutputCandidates, optFns ...func(o *StructuredOutputOptions)) (*StructuredOutput, error) { +func NewStructuredOutput(chatModel schema.ChatModel, prompt prompt.ChatTemplate, candidates []OutputCandidate, optFns ...func(o *StructuredOutputOptions)) (*StructuredOutput, error) { opts := StructuredOutputOptions{ CallbackOptions: &schema.CallbackOptions{ Verbose: golc.Verbose, @@ -48,10 +50,15 @@ func NewStructuredOutput(chatModel schema.ChatModel, prompt prompt.ChatTemplate, fn(&opts) } + candidatesMap := make(map[string]OutputCandidate, len(candidates)) + for _, c := range candidates { + candidatesMap[c.Name] = c + } + functions := make([]schema.FunctionDefinition, 0, len(candidates)) - for name, v := range candidates { - jsonSchema, err := jsonschema.Generate(reflect.TypeOf(v.Candidate)) + for name, v := range candidatesMap { + jsonSchema, err := jsonschema.Generate(reflect.TypeOf(v.Data)) if err != nil { return nil, err } @@ -74,7 +81,7 @@ func NewStructuredOutput(chatModel schema.ChatModel, prompt prompt.ChatTemplate, return &StructuredOutput{ chatModelChain: chatModelChain, - candidates: candidates, + candidatesMap: candidatesMap, opts: opts, }, nil } @@ -108,14 +115,14 @@ func (c *StructuredOutput) Call(ctx context.Context, inputs schema.ChainValues, return nil, errors.New("unexpected output: message without function call extension") } - out := c.candidates[ext.FunctionCall.Name] + out := c.candidatesMap[ext.FunctionCall.Name] - if err := json.Unmarshal([]byte(ext.FunctionCall.Arguments), &out.Candidate); err != nil { + if err := json.Unmarshal([]byte(ext.FunctionCall.Arguments), &out.Data); err != nil { return nil, err } return schema.ChainValues{ - c.opts.OutputKey: out.Candidate, + c.opts.OutputKey: out.Data, }, nil } diff --git a/chain/structured_output_test.go b/chain/structured_output_test.go index 513bdf3..0da2f14 100644 --- a/chain/structured_output_test.go +++ b/chain/structured_output_test.go @@ -20,7 +20,7 @@ func TestStructuredOutput(t *testing.T) { Text: "", Message: schema.NewAIChatMessage("", func(o *schema.ChatMessageExtension) { o.FunctionCall = &schema.FunctionCall{ - Name: "person", + Name: "Person", Arguments: `{"name": "Max", "age": 21}`, } }), @@ -42,10 +42,11 @@ func TestStructuredOutput(t *testing.T) { } // Create a new StructuredOutput chain - structuredOutputChain, err := NewStructuredOutput(chatModel, promptTemplate, OutputCandidates{ - "person": { + structuredOutputChain, err := NewStructuredOutput(chatModel, promptTemplate, []OutputCandidate{ + { + Name: "Person", Description: "Identifying information about a person", - Candidate: &person{}, + Data: &person{}, }, }) require.NoError(t, err) diff --git a/docs/content/en/docs/chains/structured_output.md b/docs/content/en/docs/chains/structured_output.md index e7adeb7..bac2517 100644 --- a/docs/content/en/docs/chains/structured_output.md +++ b/docs/content/en/docs/chains/structured_output.md @@ -36,10 +36,11 @@ func main() { prompt.NewHumanMessageTemplate("Use the given format to extract information from the following input:\n{{.input}}\nTips: Make sure to answer in the correct format"), }) - structuredOutputChain, err := chain.NewStructuredOutput(chatModel, pt, chain.OutputCandidates{ - "person": { + structuredOutputChain, err := chain.NewStructuredOutput(chatModel, pt, []chain.OutputCandidate{ + { + Name: "Person", Description: "Identifying information about a person", - Candidate: &Person{}, + Data: &Person{}, }, }) if err != nil { diff --git a/examples/structured_output_chain/main.go b/examples/structured_output_chain/main.go index 0d15a86..e4b645e 100644 --- a/examples/structured_output_chain/main.go +++ b/examples/structured_output_chain/main.go @@ -30,10 +30,11 @@ func main() { prompt.NewHumanMessageTemplate("Use the given format to extract information from the following input:\n{{.input}}\nTips: Make sure to answer in the correct format"), }) - structuredOutputChain, err := chain.NewStructuredOutput(chatModel, pt, chain.OutputCandidates{ - "person": { + structuredOutputChain, err := chain.NewStructuredOutput(chatModel, pt, []chain.OutputCandidate{ + { + Name: "Person", Description: "Identifying information about a person", - Candidate: &Person{}, + Data: &Person{}, }, }) if err != nil {