Skip to content

Commit

Permalink
Add few shot prompt template
Browse files Browse the repository at this point in the history
  • Loading branch information
hupe1980 committed Sep 17, 2023
1 parent 981e005 commit e1dc572
Show file tree
Hide file tree
Showing 20 changed files with 545 additions and 231 deletions.
2 changes: 1 addition & 1 deletion agent/conversational.go
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ func (a *ConversationalReactDescription) parseOutput(output string) ([]*schema.A

func createConversationalPrompt(tools []schema.Tool, prefix, instructions, suffix string) *prompt.Template {
return prompt.NewTemplate(strings.Join([]string{prefix, instructions, suffix}, "\n\n"), func(o *prompt.TemplateOptions) {
o.PartialValues = prompt.PartialValues{
o.PartialValues = map[string]any{
"toolNames": toolNames(tools),
"toolDescriptions": toolDescriptions(tools),
"chatHistory": "",
Expand Down
2 changes: 1 addition & 1 deletion agent/react.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ func (a *ReactDescription) parseOutput(output string) ([]*schema.AgentAction, *s

func createReactDescriptioPrompt(tools []schema.Tool, prefix, instructions, suffix string) *prompt.Template {
return prompt.NewTemplate(strings.Join([]string{prefix, instructions, suffix}, "\n\n"), func(o *prompt.TemplateOptions) {
o.PartialValues = prompt.PartialValues{
o.PartialValues = map[string]any{
"toolNames": toolNames(tools),
"toolDescriptions": toolDescriptions(tools),
}
Expand Down
4 changes: 2 additions & 2 deletions chain/conversation.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ var _ schema.Chain = (*Conversation)(nil)

type ConversationOptions struct {
*schema.CallbackOptions
Prompt *prompt.Template
Prompt schema.PromptTemplate
Memory schema.Memory
OutputKey string
OutputParser schema.OutputParser[any]
Expand Down Expand Up @@ -107,7 +107,7 @@ func (c *Conversation) Call(ctx context.Context, inputs schema.ChainValues, optF
return outputs[0], nil
}

func (c *Conversation) Prompt() *prompt.Template {
func (c *Conversation) Prompt() schema.PromptTemplate {
return c.opts.Prompt
}

Expand Down
7 changes: 3 additions & 4 deletions chain/llm.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"github.com/hupe1980/golc/callback"
"github.com/hupe1980/golc/model"
"github.com/hupe1980/golc/outputparser"
"github.com/hupe1980/golc/prompt"
"github.com/hupe1980/golc/schema"
)

Expand Down Expand Up @@ -38,12 +37,12 @@ type LLMOptions struct {
// LLM is a chain implementation that uses the Language Model (LLM) to generate text based on a given prompt.
type LLM struct {
llm schema.Model
prompt *prompt.Template
prompt schema.PromptTemplate
opts LLMOptions
}

// NewLLM creates a new instance of the LLM chain.
func NewLLM(llm schema.Model, prompt *prompt.Template, optFns ...func(o *LLMOptions)) (*LLM, error) {
func NewLLM(llm schema.Model, prompt schema.PromptTemplate, optFns ...func(o *LLMOptions)) (*LLM, error) {
opts := LLMOptions{
CallbackOptions: &schema.CallbackOptions{
Verbose: golc.Verbose,
Expand Down Expand Up @@ -112,7 +111,7 @@ func (c *LLM) GetNumTokens(text string) (uint, error) {
}

// Prompt returns the prompt.Template associated with the chain.
func (c *LLM) Prompt() *prompt.Template {
func (c *LLM) Prompt() schema.PromptTemplate {
return c.prompt
}

Expand Down
9 changes: 4 additions & 5 deletions chain/prompt_selector.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
package chain

import (
"github.com/hupe1980/golc/prompt"
"github.com/hupe1980/golc/schema"
)

// PromptSelector is an interface for selecting prompts based on a model.
type PromptSelector interface {
GetPrompt(model schema.Model) *prompt.Template
GetPrompt(model schema.Model) schema.PromptTemplate
}

// Compile time check to ensure ConditionalPromptSelector satisfies the PromptSelector interface.
Expand All @@ -19,19 +18,19 @@ type ConditionalFunc func(model schema.Model) bool
// Conditional represents a conditional prompt configuration.
type Conditional struct {
Condition ConditionalFunc
Prompt *prompt.Template
Prompt schema.PromptTemplate
}

// ConditionalPromptSelector is a prompt selector that selects prompts based on conditions.
type ConditionalPromptSelector struct {
DefaultPrompt *prompt.Template
DefaultPrompt schema.PromptTemplate
Conditionals []Conditional
}

// GetPrompt selects a prompt template based on the provided model.
// It evaluates the conditions in order and returns the prompt associated with the first matching condition,
// or returns the default prompt if no condition is met.
func (cps *ConditionalPromptSelector) GetPrompt(model schema.Model) *prompt.Template {
func (cps *ConditionalPromptSelector) GetPrompt(model schema.Model) schema.PromptTemplate {
for _, conditional := range cps.Conditionals {
if conditional.Condition(model) {
return conditional.Prompt
Expand Down
2 changes: 1 addition & 1 deletion evaluation/context_qa_eval_chain.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ STUDENT ANSWER: {{.result}}
GRADE:`

type ContextQAEvalChainOptions struct {
Prompt *prompt.Template
Prompt schema.PromptTemplate
QuestionKey string
ContextKey string
PredictionKey string
Expand Down
2 changes: 1 addition & 1 deletion evaluation/cot_qa_eval_chain.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ STUDENT ANSWER: {result}
EXPLANATION:`

type COTQAEvalChainOptions struct {
Prompt *prompt.Template
Prompt schema.PromptTemplate
QuestionKey string
ContextKey string
PredictionKey string
Expand Down
2 changes: 1 addition & 1 deletion evaluation/qa_eval_chain.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ TRUE ANSWER: {{.answer}}
GRADE:`

type QAEvalChainOptions struct {
Prompt *prompt.Template
Prompt schema.PromptTemplate
QuestionKey string
AnswerKey string
PredictionKey string
Expand Down
7 changes: 7 additions & 0 deletions prompt/errors.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package prompt

import "errors"

var (
ErrInvalidPartialVariableType = errors.New("invalid partial variable type")
)
153 changes: 153 additions & 0 deletions prompt/few_shot.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
package prompt

import (
"fmt"
"strings"
"text/template"

"github.com/hupe1980/golc/schema"
"github.com/hupe1980/golc/util"
)

// Compile time check to ensure FewShotTemplate satisfies the PromptTemplate interface.
var _ schema.PromptTemplate = (*FewShotTemplate)(nil)

// FewShotTemplateOptions represents options for configuring a FewShotTemplate.
type FewShotTemplateOptions struct {
// Prefix to be added before the template.
Prefix string
// Separator between examples and the template.
Separator string
// OutputParser to parse the response.
OutputParser schema.OutputParser[any]
// PartialValues to be used in the template.
PartialValues map[string]any
// IgnoreMissingKeys allows ignoring missing keys in the template.
IgnoreMissingKeys bool
}

// FewShotTemplate is a template that combines examples with a main template.
type FewShotTemplate struct {
template string
examples []map[string]any
exampleTemplate *Template
opts FewShotTemplateOptions
}

// NewFewShotTemplate creates a new FewShotTemplate with the provided template, examples, and options.
func NewFewShotTemplate(template string, examples []map[string]any, exampleTemplate *Template, optFns ...func(o *FewShotTemplateOptions)) *FewShotTemplate {
opts := FewShotTemplateOptions{
Separator: "\n\n",
IgnoreMissingKeys: false,
}

for _, fn := range optFns {
fn(&opts)
}

return &FewShotTemplate{
template: template,
examples: examples,
exampleTemplate: exampleTemplate,
opts: opts,
}
}

// Format applies values to the template and returns the formatted result.
func (p *FewShotTemplate) Format(values map[string]any) (string, error) {
pieces := []string{}

if p.opts.Prefix != "" {
pieces = append(pieces, p.opts.Prefix)
}

for _, example := range p.examples {
e, err := p.exampleTemplate.Format(example)
if err != nil {
return "", err
}

pieces = append(pieces, e)
}

pieces = append(pieces, p.template)

formatter := NewFormatter(strings.Join(pieces, p.opts.Separator), func(o *FormatterOptions) {
o.IgnoreMissingKeys = p.opts.IgnoreMissingKeys
})

resolvedValues, err := p.resolvePartialValues()
if err != nil {
return "", err
}

return formatter.Render(util.MergeMaps(resolvedValues, values))
}

// FormatPrompt applies values to the template and returns a PromptValue representation of the formatted result.
func (p *FewShotTemplate) FormatPrompt(values map[string]any) (schema.PromptValue, error) {
prompt, err := p.Format(values)
if err != nil {
return nil, err
}

return StringPromptValue(prompt), nil
}

// Partial creates a new FewShotTemplate with partial values.
func (p *FewShotTemplate) Partial(values map[string]any) schema.PromptTemplate {
return NewFewShotTemplate(p.template, p.examples, p.exampleTemplate, func(o *FewShotTemplateOptions) {
o.Prefix = p.opts.Prefix
o.Separator = p.opts.Separator
o.OutputParser = p.opts.OutputParser
o.PartialValues = util.MergeMaps(p.opts.PartialValues, values)
o.IgnoreMissingKeys = p.opts.IgnoreMissingKeys
})
}

// OutputParser returns the output parser function and a boolean indicating if an output parser is defined.
func (p *FewShotTemplate) OutputParser() (schema.OutputParser[any], bool) {
if p.opts.OutputParser != nil {
return p.opts.OutputParser, true
}

return nil, false
}

// InputVariables returns the input variables used in the template.
func (p *FewShotTemplate) InputVariables() []string {
vars := p.exampleTemplate.InputVariables()

t := template.Must(template.New("template").Parse(p.template))

for _, f := range ListTemplateFields(t) {
name := extractNameFromField(f)
if name != "" {
if _, ok := p.opts.PartialValues[name]; !ok {
if !util.Contains(vars, name) {
vars = append(vars, name)
}
}
}
}

return vars
}

// resolvePartialValues resolves partial values to be used in the template.
func (p *FewShotTemplate) resolvePartialValues() (map[string]any, error) {
resolvedValues := make(map[string]any)

for variable, value := range p.opts.PartialValues {
switch value := value.(type) {
case string:
resolvedValues[variable] = value
case func() string:
resolvedValues[variable] = value()
default:
return nil, fmt.Errorf("%w: %v", ErrInvalidPartialVariableType, variable)
}
}

return resolvedValues, nil
}
56 changes: 56 additions & 0 deletions prompt/few_shot_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
package prompt

import (
"testing"

"github.com/stretchr/testify/assert"
)

func TestFewShotTemplate(t *testing.T) {
// Define sample template, examples, and exampleTemplate
template := "{{.Greeting}}, {{.Name}}!"
examples := []map[string]interface{}{
{"Greeting": "Hello"},
{"Greeting": "Hi"},
}
exampleTemplate := NewTemplate("{{.Greeting}}")

// Create a FewShotTemplate
fsTemplate := NewFewShotTemplate(template, examples, exampleTemplate)

t.Run("Format", func(t *testing.T) {
values := map[string]interface{}{"Greeting": "Hey", "Name": "Charlie"}
formatted, err := fsTemplate.Format(values)
assert.NoError(t, err)
assert.Equal(t, "Hello\n\nHi\n\nHey, Charlie!", formatted)
})

t.Run("FormatPrompt", func(t *testing.T) {
values := map[string]interface{}{"Greeting": "Hey", "Name": "Charlie"}
promptValue, err := fsTemplate.FormatPrompt(values)
assert.NoError(t, err)
assert.IsType(t, StringPromptValue(""), promptValue)
assert.Equal(t, "Hello\n\nHi\n\nHey, Charlie!", promptValue.String())
})

t.Run("Partial", func(t *testing.T) {
values := map[string]interface{}{"Greeting": "Hey"}
partialValues := map[string]interface{}{"Name": "David"}
partialTemplate := fsTemplate.Partial(partialValues)
partialValues["Greeting"] = "Hi"
formattedPartial, err := partialTemplate.Format(values)
assert.NoError(t, err)
assert.Equal(t, "Hello\n\nHi\n\nHey, David!", formattedPartial)
})

t.Run("OutputParser", func(t *testing.T) {
outputParser, hasParser := fsTemplate.OutputParser()
assert.False(t, hasParser)
assert.Nil(t, outputParser)
})

t.Run("InputVariables", func(t *testing.T) {
inputVars := fsTemplate.InputVariables()
assert.ElementsMatch(t, inputVars, []string{"Greeting", "Name"})
})
}
12 changes: 12 additions & 0 deletions prompt/formatter.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package prompt

import (
"bytes"
"regexp"
"text/template"
"text/template/parse"

Expand Down Expand Up @@ -71,3 +72,14 @@ func listNodeFields(node parse.Node) []string {

return res
}

func extractNameFromField(input string) string {
re := regexp.MustCompile(`{{\.(.*?)}}`)
matches := re.FindStringSubmatch(input)

if len(matches) == 2 {
return matches[1]
}

return ""
}
Loading

0 comments on commit e1dc572

Please sign in to comment.