Skip to content

Commit

Permalink
Add tagging chain
Browse files Browse the repository at this point in the history
  • Loading branch information
hupe1980 committed Aug 1, 2023
1 parent 2c6314c commit 9b06f3c
Show file tree
Hide file tree
Showing 4 changed files with 160 additions and 1 deletion.
49 changes: 49 additions & 0 deletions chain/tagging.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
package chain

import (
"github.com/hupe1980/golc/prompt"
"github.com/hupe1980/golc/schema"
)

const defaultTaggingTemplate = `Extract the desired information from the following passage.
Only extract the properties mentioned in the 'information_extraction' function.
Passage:
{input}`

// Compile time check to ensure Tagging satisfies the Chain interface.
var _ schema.Chain = (*Tagging)(nil)

// Tagging is a chain that uses structured output to perform tagging on a passage.
// It extracts the desired information from the given passage using a structured output model.
type Tagging struct {
// StructuredOutput is the underlying structured output chain used for tagging.
*StructuredOutput
}

// NewTagging creates a new Tagging chain with the provided chat model, structured output data, and optional options.
// It returns a Tagging chain or an error if the creation fails.
func NewTagging(chatModel schema.ChatModel, data any, optFns ...func(o *StructuredOutputOptions)) (*Tagging, error) {
pt := prompt.NewChatTemplate([]prompt.MessageTemplate{
prompt.NewHumanMessageTemplate(defaultTaggingTemplate),
})

so, err := NewStructuredOutput(chatModel, pt, []OutputCandidate{{
Name: "InformationExtraction",
Description: "Extracts the relevant information from the passage.",
Data: data,
}}, optFns...)
if err != nil {
return nil, err
}

return &Tagging{
StructuredOutput: so,
}, nil
}

// Type returns the type of the chain.
func (c *Tagging) Type() string {
return "Tagging"
}
1 change: 0 additions & 1 deletion docs/content/en/docs/chains/structured_output.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ func main() {
fmt.Println("Age:", p.Age)
}
```
```
Output:
```text
Name: Max
Expand Down
62 changes: 62 additions & 0 deletions docs/content/en/docs/chains/tagging.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
---
title: Tagging
description: All about tagging chains.
weight: 110
---
```go
package main

import (
"context"
"fmt"
"log"
"os"

"github.com/hupe1980/golc"
"github.com/hupe1980/golc/chain"
"github.com/hupe1980/golc/model/chatmodel"
"github.com/hupe1980/golc/schema"
)

type Tagging struct {
Sentiment string `json:"sentiment" enum:"'happy','neutral','sad'"`
Aggressiveness int `json:"aggressiveness" description:"describes how aggressive the statement is, the higher the number the more aggressive" enum:"1,2,3,4,5"`
Language string `json:"language" enum:"'spanish','english','french','german','italian'"`
}

func main() {
chatModel, err := chatmodel.NewOpenAI(os.Getenv("OPENAI_API_KEY"), func(o *chatmodel.OpenAIOptions) {
o.Temperature = 0
})
if err != nil {
log.Fatal(err)
}

taggingChain, err := chain.NewTagging(chatModel, &Tagging{})
if err != nil {
log.Fatal(err)
}

result, err := golc.Call(context.Background(), taggingChain, schema.ChainValues{
"input": "Weather is ok here, I can go outside without much more than a coat",
})
if err != nil {
log.Fatal(err)
}

t, ok := result["output"].(*Tagging)
if !ok {
log.Fatal("output is not a person")
}

fmt.Println("Sentiment:", t.Sentiment)
fmt.Println("Aggressiveness:", t.Aggressiveness)
fmt.Println("Language:", t.Language)
}
```
Output:
```text
Sentiment: neutral
Aggressiveness: 3
Language: english
```
49 changes: 49 additions & 0 deletions examples/tagging_chain/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
package main

import (
"context"
"fmt"
"log"
"os"

"github.com/hupe1980/golc"
"github.com/hupe1980/golc/chain"
"github.com/hupe1980/golc/model/chatmodel"
"github.com/hupe1980/golc/schema"
)

type Tagging struct {
Sentiment string `json:"sentiment" enum:"'happy','neutral','sad'"`
Aggressiveness int `json:"aggressiveness" description:"describes how aggressive the statement is, the higher the number the more aggressive" enum:"1,2,3,4,5"`
Language string `json:"language" enum:"'spanish','english','french','german','italian'"`
}

func main() {
chatModel, err := chatmodel.NewOpenAI(os.Getenv("OPENAI_API_KEY"), func(o *chatmodel.OpenAIOptions) {
o.Temperature = 0
})
if err != nil {
log.Fatal(err)
}

taggingChain, err := chain.NewTagging(chatModel, &Tagging{})
if err != nil {
log.Fatal(err)
}

result, err := golc.Call(context.Background(), taggingChain, schema.ChainValues{
"input": "Weather is ok here, I can go outside without much more than a coat",
})
if err != nil {
log.Fatal(err)
}

t, ok := result["output"].(*Tagging)
if !ok {
log.Fatal("output is not a person")
}

fmt.Println("Sentiment:", t.Sentiment)
fmt.Println("Aggressiveness:", t.Aggressiveness)
fmt.Println("Language:", t.Language)
}

0 comments on commit 9b06f3c

Please sign in to comment.