Skip to content

Commit

Permalink
Refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
hupe1980 committed Aug 1, 2023
1 parent e9245e4 commit b9497c7
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 21 deletions.
29 changes: 18 additions & 11 deletions chain/structured_output.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@ import (
// Compile time check to ensure StructuredOutput satisfies the Chain interface.
var _ schema.Chain = (*StructuredOutput)(nil)

// OutputCandidates represents a map of candidate names to their descriptions and values used in the structured output chain.
type OutputCandidates map[string]struct {
// OutputCandidate represents a candidate for structured output containing a name,
// description, and data of any struct type.
type OutputCandidate struct {
Name string
Description string
Candidate any
Data any
}

// StructuredOutputOptions contains options for configuring the StructuredOutput chain.
Expand All @@ -31,12 +33,12 @@ type StructuredOutputOptions struct {
// StructuredOutput is a chain that generates structured output using a ChatModel chain and candidate values.
type StructuredOutput struct {
chatModelChain *ChatModel
candidates OutputCandidates
candidatesMap map[string]OutputCandidate
opts StructuredOutputOptions
}

// NewStructuredOutput creates a new StructuredOutput chain with the given ChatModel, prompt, and candidates.
func NewStructuredOutput(chatModel schema.ChatModel, prompt prompt.ChatTemplate, candidates OutputCandidates, optFns ...func(o *StructuredOutputOptions)) (*StructuredOutput, error) {
func NewStructuredOutput(chatModel schema.ChatModel, prompt prompt.ChatTemplate, candidates []OutputCandidate, optFns ...func(o *StructuredOutputOptions)) (*StructuredOutput, error) {
opts := StructuredOutputOptions{
CallbackOptions: &schema.CallbackOptions{
Verbose: golc.Verbose,
Expand All @@ -48,10 +50,15 @@ func NewStructuredOutput(chatModel schema.ChatModel, prompt prompt.ChatTemplate,
fn(&opts)
}

candidatesMap := make(map[string]OutputCandidate, len(candidates))
for _, c := range candidates {
candidatesMap[c.Name] = c
}

functions := make([]schema.FunctionDefinition, 0, len(candidates))

for name, v := range candidates {
jsonSchema, err := jsonschema.Generate(reflect.TypeOf(v.Candidate))
for name, v := range candidatesMap {
jsonSchema, err := jsonschema.Generate(reflect.TypeOf(v.Data))
if err != nil {
return nil, err
}
Expand All @@ -74,7 +81,7 @@ func NewStructuredOutput(chatModel schema.ChatModel, prompt prompt.ChatTemplate,

return &StructuredOutput{
chatModelChain: chatModelChain,
candidates: candidates,
candidatesMap: candidatesMap,
opts: opts,
}, nil
}
Expand Down Expand Up @@ -108,14 +115,14 @@ func (c *StructuredOutput) Call(ctx context.Context, inputs schema.ChainValues,
return nil, errors.New("unexpected output: message without function call extension")
}

out := c.candidates[ext.FunctionCall.Name]
out := c.candidatesMap[ext.FunctionCall.Name]

if err := json.Unmarshal([]byte(ext.FunctionCall.Arguments), &out.Candidate); err != nil {
if err := json.Unmarshal([]byte(ext.FunctionCall.Arguments), &out.Data); err != nil {
return nil, err
}

return schema.ChainValues{
c.opts.OutputKey: out.Candidate,
c.opts.OutputKey: out.Data,
}, nil
}

Expand Down
9 changes: 5 additions & 4 deletions chain/structured_output_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ func TestStructuredOutput(t *testing.T) {
Text: "",
Message: schema.NewAIChatMessage("", func(o *schema.ChatMessageExtension) {
o.FunctionCall = &schema.FunctionCall{
Name: "person",
Name: "Person",
Arguments: `{"name": "Max", "age": 21}`,
}
}),
Expand All @@ -42,10 +42,11 @@ func TestStructuredOutput(t *testing.T) {
}

// Create a new StructuredOutput chain
structuredOutputChain, err := NewStructuredOutput(chatModel, promptTemplate, OutputCandidates{
"person": {
structuredOutputChain, err := NewStructuredOutput(chatModel, promptTemplate, []OutputCandidate{
{
Name: "Person",
Description: "Identifying information about a person",
Candidate: &person{},
Data: &person{},
},
})
require.NoError(t, err)
Expand Down
7 changes: 4 additions & 3 deletions docs/content/en/docs/chains/structured_output.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,11 @@ func main() {
prompt.NewHumanMessageTemplate("Use the given format to extract information from the following input:\n{{.input}}\nTips: Make sure to answer in the correct format"),
})

structuredOutputChain, err := chain.NewStructuredOutput(chatModel, pt, chain.OutputCandidates{
"person": {
structuredOutputChain, err := chain.NewStructuredOutput(chatModel, pt, []chain.OutputCandidate{
{
Name: "Person",
Description: "Identifying information about a person",
Candidate: &Person{},
Data: &Person{},
},
})
if err != nil {
Expand Down
7 changes: 4 additions & 3 deletions examples/structured_output_chain/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,11 @@ func main() {
prompt.NewHumanMessageTemplate("Use the given format to extract information from the following input:\n{{.input}}\nTips: Make sure to answer in the correct format"),
})

structuredOutputChain, err := chain.NewStructuredOutput(chatModel, pt, chain.OutputCandidates{
"person": {
structuredOutputChain, err := chain.NewStructuredOutput(chatModel, pt, []chain.OutputCandidate{
{
Name: "Person",
Description: "Identifying information about a person",
Candidate: &Person{},
Data: &Person{},
},
})
if err != nil {
Expand Down

0 comments on commit b9497c7

Please sign in to comment.