-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
343 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,150 @@ | ||
package chain | ||
|
||
import ( | ||
"context" | ||
"encoding/json" | ||
"errors" | ||
"reflect" | ||
|
||
"github.com/hupe1980/golc" | ||
"github.com/hupe1980/golc/callback" | ||
"github.com/hupe1980/golc/integration/jsonschema" | ||
"github.com/hupe1980/golc/prompt" | ||
"github.com/hupe1980/golc/schema" | ||
) | ||
|
||
// Compile time check to ensure StructuredOutput satisfies the Chain interface. | ||
var _ schema.Chain = (*StructuredOutput)(nil) | ||
|
||
// OutputCandidates represents a map of candidate names to their descriptions and values used in the structured output chain. | ||
type OutputCandidates map[string]struct { | ||
Description string | ||
Candidate any | ||
} | ||
|
||
// StructuredOutputOptions contains options for configuring the StructuredOutput chain. | ||
type StructuredOutputOptions struct { | ||
*schema.CallbackOptions | ||
OutputKey string | ||
} | ||
|
||
// StructuredOutput is a chain that generates structured output using a ChatModel chain and candidate values. | ||
type StructuredOutput struct { | ||
chatModelChain *ChatModel | ||
candidates OutputCandidates | ||
opts StructuredOutputOptions | ||
} | ||
|
||
// NewStructuredOutput creates a new StructuredOutput chain with the given ChatModel, prompt, and candidates. | ||
func NewStructuredOutput(chatModel schema.ChatModel, prompt prompt.ChatTemplate, candidates OutputCandidates, optFns ...func(o *StructuredOutputOptions)) (*StructuredOutput, error) { | ||
opts := StructuredOutputOptions{ | ||
CallbackOptions: &schema.CallbackOptions{ | ||
Verbose: golc.Verbose, | ||
}, | ||
OutputKey: "output", | ||
} | ||
|
||
for _, fn := range optFns { | ||
fn(&opts) | ||
} | ||
|
||
functions := make([]schema.FunctionDefinition, 0, len(candidates)) | ||
|
||
for name, v := range candidates { | ||
jsonSchema, err := jsonschema.Generate(reflect.TypeOf(v.Candidate)) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
functions = append(functions, schema.FunctionDefinition{ | ||
Name: name, | ||
Description: v.Description, | ||
Parameters: schema.FunctionDefinitionParameters{ | ||
Type: "object", | ||
Properties: jsonSchema.Properties, | ||
Required: jsonSchema.Required, | ||
}, | ||
}) | ||
} | ||
|
||
chatModelChain, err := NewChatModelWithFunctions(chatModel, prompt, functions) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
return &StructuredOutput{ | ||
chatModelChain: chatModelChain, | ||
candidates: candidates, | ||
opts: opts, | ||
}, nil | ||
} | ||
|
||
// Call executes the StructuredOutput chain with the given context and inputs. | ||
// It returns the outputs of the chain or an error, if any. | ||
func (c *StructuredOutput) Call(ctx context.Context, inputs schema.ChainValues, optFns ...func(o *schema.CallOptions)) (schema.ChainValues, error) { | ||
opts := schema.CallOptions{ | ||
CallbackManger: &callback.NoopManager{}, | ||
} | ||
|
||
for _, fn := range optFns { | ||
fn(&opts) | ||
} | ||
|
||
output, err := golc.Call(ctx, c.chatModelChain, inputs, func(sco *golc.CallOptions) { | ||
sco.Callbacks = opts.CallbackManger.GetInheritableCallbacks() | ||
sco.ParentRunID = opts.CallbackManger.RunID() | ||
}) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
aiMsg, ok := output["message"].(*schema.AIChatMessage) | ||
if !ok { | ||
return nil, errors.New("unexpected output: message is not a ai chat message") | ||
} | ||
|
||
ext := aiMsg.Extension() | ||
if ext.FunctionCall == nil { | ||
return nil, errors.New("unexpected output: message without function call extension") | ||
} | ||
|
||
out := c.candidates[ext.FunctionCall.Name] | ||
|
||
if err := json.Unmarshal([]byte(ext.FunctionCall.Arguments), &out.Candidate); err != nil { | ||
return nil, err | ||
} | ||
|
||
return schema.ChainValues{ | ||
c.opts.OutputKey: out.Candidate, | ||
}, nil | ||
} | ||
|
||
// Memory returns the memory associated with the chain. | ||
func (c *StructuredOutput) Memory() schema.Memory { | ||
return nil | ||
} | ||
|
||
// Type returns the type of the chain. | ||
func (c *StructuredOutput) Type() string { | ||
return "StructuredOutput" | ||
} | ||
|
||
// Verbose returns the verbosity setting of the chain. | ||
func (c *StructuredOutput) Verbose() bool { | ||
return c.opts.Verbose | ||
} | ||
|
||
// Callbacks returns the callbacks associated with the chain. | ||
func (c *StructuredOutput) Callbacks() []schema.Callback { | ||
return c.opts.Callbacks | ||
} | ||
|
||
// InputKeys returns the expected input keys. | ||
func (c *StructuredOutput) InputKeys() []string { | ||
return c.chatModelChain.InputKeys() | ||
} | ||
|
||
// OutputKeys returns the output keys the chain will return. | ||
func (c *StructuredOutput) OutputKeys() []string { | ||
return []string{c.opts.OutputKey} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
package chain | ||
|
||
import ( | ||
"context" | ||
"testing" | ||
|
||
"github.com/hupe1980/golc" | ||
"github.com/hupe1980/golc/model/chatmodel" | ||
"github.com/hupe1980/golc/prompt" | ||
"github.com/hupe1980/golc/schema" | ||
"github.com/stretchr/testify/require" | ||
) | ||
|
||
func TestStructuredOutput(t *testing.T) { | ||
t.Run("TestStructuredOutput", func(t *testing.T) { | ||
// Create a dummy chat model for testing | ||
chatModel := chatmodel.NewFake(func(ctx context.Context, messages schema.ChatMessages) (*schema.ModelResult, error) { | ||
return &schema.ModelResult{ | ||
Generations: []schema.Generation{{ | ||
Text: "", | ||
Message: schema.NewAIChatMessage("", func(o *schema.ChatMessageExtension) { | ||
o.FunctionCall = &schema.FunctionCall{ | ||
Name: "person", | ||
Arguments: `{"name": "Max", "age": 21}`, | ||
} | ||
}), | ||
}}, | ||
LLMOutput: map[string]any{}, | ||
}, nil | ||
}) | ||
|
||
// Create a dummy prompt template for testing | ||
promptTemplate := prompt.NewChatTemplate([]prompt.MessageTemplate{ | ||
prompt.NewHumanMessageTemplate("{{.input}}"), | ||
}) | ||
|
||
// Create a dummy output candidate | ||
type person struct { | ||
Name string `json:"name" description:"The person's name"` | ||
Age int `json:"age" description:"The person's age"` | ||
FavFood string `json:"fav_food,omitempty" description:"The person's favorite food"` | ||
} | ||
|
||
// Create a new StructuredOutput chain | ||
structuredOutputChain, err := NewStructuredOutput(chatModel, promptTemplate, OutputCandidates{ | ||
"person": { | ||
Description: "Identifying information about a person", | ||
Candidate: &person{}, | ||
}, | ||
}) | ||
require.NoError(t, err) | ||
|
||
// Prepare the input values for the chain | ||
inputs := schema.ChainValues{ | ||
"input": "Max is 21", | ||
} | ||
|
||
// Call the ChatModel chain with the inputs | ||
outputs, err := golc.Call(context.Background(), structuredOutputChain, inputs) | ||
require.NoError(t, err) | ||
|
||
// Check the output key in the result | ||
require.Contains(t, outputs, structuredOutputChain.OutputKeys()[0]) | ||
|
||
// Check the output value type | ||
p, ok := outputs[structuredOutputChain.OutputKeys()[0]].(*person) | ||
require.True(t, ok) | ||
require.Equal(t, "Max", p.Name) | ||
require.Equal(t, 21, p.Age) | ||
require.Equal(t, "", p.FavFood) | ||
}) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
--- | ||
title: StructuredOutput | ||
description: All about structured output chains. | ||
weight: 100 | ||
--- | ||
```go | ||
package main | ||
|
||
import ( | ||
"context" | ||
"fmt" | ||
"log" | ||
"os" | ||
|
||
"github.com/hupe1980/golc" | ||
"github.com/hupe1980/golc/chain" | ||
"github.com/hupe1980/golc/model/chatmodel" | ||
"github.com/hupe1980/golc/prompt" | ||
"github.com/hupe1980/golc/schema" | ||
) | ||
|
||
type Person struct { | ||
Name string `json:"name" description:"The person's name"` | ||
Age int `json:"age" description:"The person's age"` | ||
FavFood string `json:"fav_food,omitempty" description:"The person's favorite food"` | ||
} | ||
|
||
func main() { | ||
chatModel, err := chatmodel.NewOpenAI(os.Getenv("OPENAI_API_KEY")) | ||
if err != nil { | ||
log.Fatal(err) | ||
} | ||
|
||
pt := prompt.NewChatTemplate([]prompt.MessageTemplate{ | ||
prompt.NewSystemMessageTemplate("You are a world class algorithm for extracting information in structured formats."), | ||
prompt.NewHumanMessageTemplate("Use the given format to extract information from the following input:\n{{.input}}\nTips: Make sure to answer in the correct format"), | ||
}) | ||
|
||
structuredOutputChain, err := chain.NewStructuredOutput(chatModel, pt, chain.OutputCandidates{ | ||
"person": { | ||
Description: "Identifying information about a person", | ||
Candidate: &Person{}, | ||
}, | ||
}) | ||
if err != nil { | ||
log.Fatal(err) | ||
} | ||
|
||
result, err := golc.Call(context.Background(), structuredOutputChain, schema.ChainValues{ | ||
"input": "Max is 21", | ||
}) | ||
if err != nil { | ||
log.Fatal(err) | ||
} | ||
|
||
p, ok := result["output"].(*Person) | ||
if !ok { | ||
log.Fatal("output is not a person") | ||
} | ||
|
||
fmt.Println("Name:", p.Name) | ||
fmt.Println("Age:", p.Age) | ||
} | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
package main | ||
|
||
import ( | ||
"context" | ||
"fmt" | ||
"log" | ||
"os" | ||
|
||
"github.com/hupe1980/golc" | ||
"github.com/hupe1980/golc/chain" | ||
"github.com/hupe1980/golc/model/chatmodel" | ||
"github.com/hupe1980/golc/prompt" | ||
"github.com/hupe1980/golc/schema" | ||
) | ||
|
||
type Person struct { | ||
Name string `json:"name" description:"The person's name"` | ||
Age int `json:"age" description:"The person's age"` | ||
FavFood string `json:"fav_food,omitempty" description:"The person's favorite food"` | ||
} | ||
|
||
func main() { | ||
chatModel, err := chatmodel.NewOpenAI(os.Getenv("OPENAI_API_KEY")) | ||
if err != nil { | ||
log.Fatal(err) | ||
} | ||
|
||
pt := prompt.NewChatTemplate([]prompt.MessageTemplate{ | ||
prompt.NewSystemMessageTemplate("You are a world class algorithm for extracting information in structured formats."), | ||
prompt.NewHumanMessageTemplate("Use the given format to extract information from the following input:\n{{.input}}\nTips: Make sure to answer in the correct format"), | ||
}) | ||
|
||
structuredOutputChain, err := chain.NewStructuredOutput(chatModel, pt, chain.OutputCandidates{ | ||
"person": { | ||
Description: "Identifying information about a person", | ||
Candidate: &Person{}, | ||
}, | ||
}) | ||
if err != nil { | ||
log.Fatal(err) | ||
} | ||
|
||
result, err := golc.Call(context.Background(), structuredOutputChain, schema.ChainValues{ | ||
"input": "Max is 21", | ||
}) | ||
if err != nil { | ||
log.Fatal(err) | ||
} | ||
|
||
p, ok := result["output"].(*Person) | ||
if !ok { | ||
log.Fatal("output is not a person") | ||
} | ||
|
||
fmt.Println("Name:", p.Name) | ||
fmt.Println("Age:", p.Age) | ||
} |