Skip to content

Commit

Permalink
Add structured output chain
Browse files Browse the repository at this point in the history
  • Loading branch information
hupe1980 committed Jul 31, 2023
1 parent e5c28a9 commit e9245e4
Show file tree
Hide file tree
Showing 4 changed files with 343 additions and 0 deletions.
150 changes: 150 additions & 0 deletions chain/structured_output.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
package chain

import (
"context"
"encoding/json"
"errors"
"reflect"

"github.com/hupe1980/golc"
"github.com/hupe1980/golc/callback"
"github.com/hupe1980/golc/integration/jsonschema"
"github.com/hupe1980/golc/prompt"
"github.com/hupe1980/golc/schema"
)

// Compile time check to ensure StructuredOutput satisfies the Chain interface.
var _ schema.Chain = (*StructuredOutput)(nil)

// OutputCandidates represents a map of candidate names to their descriptions and values used in the structured output chain.
type OutputCandidates map[string]struct {
Description string
Candidate any
}

// StructuredOutputOptions contains options for configuring the StructuredOutput chain.
type StructuredOutputOptions struct {
*schema.CallbackOptions
OutputKey string
}

// StructuredOutput is a chain that generates structured output using a ChatModel chain and candidate values.
type StructuredOutput struct {
chatModelChain *ChatModel
candidates OutputCandidates
opts StructuredOutputOptions
}

// NewStructuredOutput creates a new StructuredOutput chain with the given ChatModel, prompt, and candidates.
func NewStructuredOutput(chatModel schema.ChatModel, prompt prompt.ChatTemplate, candidates OutputCandidates, optFns ...func(o *StructuredOutputOptions)) (*StructuredOutput, error) {
opts := StructuredOutputOptions{
CallbackOptions: &schema.CallbackOptions{
Verbose: golc.Verbose,
},
OutputKey: "output",
}

for _, fn := range optFns {
fn(&opts)
}

functions := make([]schema.FunctionDefinition, 0, len(candidates))

for name, v := range candidates {
jsonSchema, err := jsonschema.Generate(reflect.TypeOf(v.Candidate))
if err != nil {
return nil, err
}

functions = append(functions, schema.FunctionDefinition{
Name: name,
Description: v.Description,
Parameters: schema.FunctionDefinitionParameters{
Type: "object",
Properties: jsonSchema.Properties,
Required: jsonSchema.Required,
},
})
}

chatModelChain, err := NewChatModelWithFunctions(chatModel, prompt, functions)
if err != nil {
return nil, err
}

return &StructuredOutput{
chatModelChain: chatModelChain,
candidates: candidates,
opts: opts,
}, nil
}

// Call executes the StructuredOutput chain with the given context and inputs.
// It returns the outputs of the chain or an error, if any.
func (c *StructuredOutput) Call(ctx context.Context, inputs schema.ChainValues, optFns ...func(o *schema.CallOptions)) (schema.ChainValues, error) {
opts := schema.CallOptions{
CallbackManger: &callback.NoopManager{},
}

for _, fn := range optFns {
fn(&opts)
}

output, err := golc.Call(ctx, c.chatModelChain, inputs, func(sco *golc.CallOptions) {
sco.Callbacks = opts.CallbackManger.GetInheritableCallbacks()
sco.ParentRunID = opts.CallbackManger.RunID()
})
if err != nil {
return nil, err
}

aiMsg, ok := output["message"].(*schema.AIChatMessage)
if !ok {
return nil, errors.New("unexpected output: message is not a ai chat message")
}

ext := aiMsg.Extension()
if ext.FunctionCall == nil {
return nil, errors.New("unexpected output: message without function call extension")
}

out := c.candidates[ext.FunctionCall.Name]

if err := json.Unmarshal([]byte(ext.FunctionCall.Arguments), &out.Candidate); err != nil {
return nil, err
}

return schema.ChainValues{
c.opts.OutputKey: out.Candidate,
}, nil
}

// Memory returns the memory associated with the chain.
func (c *StructuredOutput) Memory() schema.Memory {
return nil
}

// Type returns the type of the chain.
func (c *StructuredOutput) Type() string {
return "StructuredOutput"
}

// Verbose returns the verbosity setting of the chain.
func (c *StructuredOutput) Verbose() bool {
return c.opts.Verbose
}

// Callbacks returns the callbacks associated with the chain.
func (c *StructuredOutput) Callbacks() []schema.Callback {
return c.opts.Callbacks
}

// InputKeys returns the expected input keys.
func (c *StructuredOutput) InputKeys() []string {
return c.chatModelChain.InputKeys()
}

// OutputKeys returns the output keys the chain will return.
func (c *StructuredOutput) OutputKeys() []string {
return []string{c.opts.OutputKey}
}
72 changes: 72 additions & 0 deletions chain/structured_output_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
package chain

import (
"context"
"testing"

"github.com/hupe1980/golc"
"github.com/hupe1980/golc/model/chatmodel"
"github.com/hupe1980/golc/prompt"
"github.com/hupe1980/golc/schema"
"github.com/stretchr/testify/require"
)

func TestStructuredOutput(t *testing.T) {
t.Run("TestStructuredOutput", func(t *testing.T) {
// Create a dummy chat model for testing
chatModel := chatmodel.NewFake(func(ctx context.Context, messages schema.ChatMessages) (*schema.ModelResult, error) {
return &schema.ModelResult{
Generations: []schema.Generation{{
Text: "",
Message: schema.NewAIChatMessage("", func(o *schema.ChatMessageExtension) {
o.FunctionCall = &schema.FunctionCall{
Name: "person",
Arguments: `{"name": "Max", "age": 21}`,
}
}),
}},
LLMOutput: map[string]any{},
}, nil
})

// Create a dummy prompt template for testing
promptTemplate := prompt.NewChatTemplate([]prompt.MessageTemplate{
prompt.NewHumanMessageTemplate("{{.input}}"),
})

// Create a dummy output candidate
type person struct {
Name string `json:"name" description:"The person's name"`
Age int `json:"age" description:"The person's age"`
FavFood string `json:"fav_food,omitempty" description:"The person's favorite food"`
}

// Create a new StructuredOutput chain
structuredOutputChain, err := NewStructuredOutput(chatModel, promptTemplate, OutputCandidates{
"person": {
Description: "Identifying information about a person",
Candidate: &person{},
},
})
require.NoError(t, err)

// Prepare the input values for the chain
inputs := schema.ChainValues{
"input": "Max is 21",
}

// Call the ChatModel chain with the inputs
outputs, err := golc.Call(context.Background(), structuredOutputChain, inputs)
require.NoError(t, err)

// Check the output key in the result
require.Contains(t, outputs, structuredOutputChain.OutputKeys()[0])

// Check the output value type
p, ok := outputs[structuredOutputChain.OutputKeys()[0]].(*person)
require.True(t, ok)
require.Equal(t, "Max", p.Name)
require.Equal(t, 21, p.Age)
require.Equal(t, "", p.FavFood)
})
}
64 changes: 64 additions & 0 deletions docs/content/en/docs/chains/structured_output.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
---
title: StructuredOutput
description: All about structured output chains.
weight: 100
---
```go
package main

import (
"context"
"fmt"
"log"
"os"

"github.com/hupe1980/golc"
"github.com/hupe1980/golc/chain"
"github.com/hupe1980/golc/model/chatmodel"
"github.com/hupe1980/golc/prompt"
"github.com/hupe1980/golc/schema"
)

type Person struct {
Name string `json:"name" description:"The person's name"`
Age int `json:"age" description:"The person's age"`
FavFood string `json:"fav_food,omitempty" description:"The person's favorite food"`
}

func main() {
chatModel, err := chatmodel.NewOpenAI(os.Getenv("OPENAI_API_KEY"))
if err != nil {
log.Fatal(err)
}

pt := prompt.NewChatTemplate([]prompt.MessageTemplate{
prompt.NewSystemMessageTemplate("You are a world class algorithm for extracting information in structured formats."),
prompt.NewHumanMessageTemplate("Use the given format to extract information from the following input:\n{{.input}}\nTips: Make sure to answer in the correct format"),
})

structuredOutputChain, err := chain.NewStructuredOutput(chatModel, pt, chain.OutputCandidates{
"person": {
Description: "Identifying information about a person",
Candidate: &Person{},
},
})
if err != nil {
log.Fatal(err)
}

result, err := golc.Call(context.Background(), structuredOutputChain, schema.ChainValues{
"input": "Max is 21",
})
if err != nil {
log.Fatal(err)
}

p, ok := result["output"].(*Person)
if !ok {
log.Fatal("output is not a person")
}

fmt.Println("Name:", p.Name)
fmt.Println("Age:", p.Age)
}
```
57 changes: 57 additions & 0 deletions examples/structured_output_chain/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
package main

import (
"context"
"fmt"
"log"
"os"

"github.com/hupe1980/golc"
"github.com/hupe1980/golc/chain"
"github.com/hupe1980/golc/model/chatmodel"
"github.com/hupe1980/golc/prompt"
"github.com/hupe1980/golc/schema"
)

type Person struct {
Name string `json:"name" description:"The person's name"`
Age int `json:"age" description:"The person's age"`
FavFood string `json:"fav_food,omitempty" description:"The person's favorite food"`
}

func main() {
chatModel, err := chatmodel.NewOpenAI(os.Getenv("OPENAI_API_KEY"))
if err != nil {
log.Fatal(err)
}

pt := prompt.NewChatTemplate([]prompt.MessageTemplate{
prompt.NewSystemMessageTemplate("You are a world class algorithm for extracting information in structured formats."),
prompt.NewHumanMessageTemplate("Use the given format to extract information from the following input:\n{{.input}}\nTips: Make sure to answer in the correct format"),
})

structuredOutputChain, err := chain.NewStructuredOutput(chatModel, pt, chain.OutputCandidates{
"person": {
Description: "Identifying information about a person",
Candidate: &Person{},
},
})
if err != nil {
log.Fatal(err)
}

result, err := golc.Call(context.Background(), structuredOutputChain, schema.ChainValues{
"input": "Max is 21",
})
if err != nil {
log.Fatal(err)
}

p, ok := result["output"].(*Person)
if !ok {
log.Fatal("output is not a person")
}

fmt.Println("Name:", p.Name)
fmt.Println("Age:", p.Age)
}

0 comments on commit e9245e4

Please sign in to comment.