Skip to content

Commit

Permalink
Refactor rag chains
Browse files Browse the repository at this point in the history
  • Loading branch information
hupe1980 committed Jul 31, 2023
1 parent 4bd201b commit 33efa9a
Show file tree
Hide file tree
Showing 11 changed files with 253 additions and 106 deletions.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,10 @@ func main() {
log.Fatal(err)
}

completion, err := golc.Call(ctx, llmSummarizationChain, map[string]any{"inputDocuments": docs})
completion, err := golc.SimpleCall(ctx, llmSummarizationChain, docs)
if err != nil {
log.Fatal(err)
}

fmt.Println(completion["text"])
fmt.Println(completion)
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,16 @@ import (
"strings"

"github.com/hupe1980/golc"
"github.com/hupe1980/golc/callback"
"github.com/hupe1980/golc/documentloader"
"github.com/hupe1980/golc/model/llm"
"github.com/hupe1980/golc/rag"
"github.com/hupe1980/golc/schema"
)

func main() {
ctx := context.Background()

golc.Verbose = true

ctx := context.Background()

openai, err := llm.NewOpenAI(os.Getenv("OPENAI_API_KEY"))
if err != nil {
log.Fatal(err)
Expand Down Expand Up @@ -47,16 +45,10 @@ assistance, among others, to enhance human-computer interactions and support lan
log.Fatal(err)
}

info := callback.NewOpenAIHandler()

completion, err := golc.SimpleCall(ctx, llmSummarizationChain, docs, func(sco *golc.SimpleCallOptions) {
sco.Callbacks = []schema.Callback{info}
})
completion, err := golc.SimpleCall(ctx, llmSummarizationChain, docs)
if err != nil {
log.Fatal(err)
}

fmt.Println(completion)
fmt.Println("\n\n---")
fmt.Println(info)
}
57 changes: 29 additions & 28 deletions rag/conversational_retrieval.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package rag

import (
"context"
"fmt"

"github.com/hupe1980/golc"
"github.com/hupe1980/golc/callback"
Expand Down Expand Up @@ -100,7 +99,7 @@ func NewConversationalRetrieval(llm schema.LLM, retriever schema.Retriever, optF

// Call executes the ConversationalRetrieval chain with the given context and inputs.
// It returns the outputs of the chain or an error, if any.
func (c ConversationalRetrieval) Call(ctx context.Context, inputs schema.ChainValues, optFns ...func(o *schema.CallOptions)) (schema.ChainValues, error) {
func (c *ConversationalRetrieval) Call(ctx context.Context, inputs schema.ChainValues, optFns ...func(o *schema.CallOptions)) (schema.ChainValues, error) {
opts := schema.CallOptions{
CallbackManger: &callback.NoopManager{},
}
Expand All @@ -109,23 +108,9 @@ func (c ConversationalRetrieval) Call(ctx context.Context, inputs schema.ChainVa
fn(&opts)
}

generatedQuestion := inputs[c.opts.InputKey]

if inputs["history"] != "" {
output, err := golc.Call(ctx, c.condenseQuestionChain, inputs, func(co *golc.CallOptions) {
co.Callbacks = opts.CallbackManger.GetInheritableCallbacks()
co.ParentRunID = opts.CallbackManger.RunID()
})
if err != nil {
return nil, err
}

gq, ok := output[c.condenseQuestionChain.OutputKeys()[0]].(string)
if !ok {
return nil, fmt.Errorf("cannot convert generated question from output: %v", generatedQuestion)
}

generatedQuestion = gq
generatedQuestion, err := c.generateQuestion(ctx, inputs, opts)
if err != nil {
return nil, err
}

retrievalOutput, err := golc.Call(ctx, c.retrievalQAChain, schema.ChainValues{
Expand All @@ -138,9 +123,9 @@ func (c ConversationalRetrieval) Call(ctx context.Context, inputs schema.ChainVa
return nil, err
}

answer, ok := retrievalOutput[c.retrievalQAChain.OutputKeys()[0]].(string)
if !ok {
return nil, fmt.Errorf("cannot convert answer from output: %v", generatedQuestion)
answer, err := retrievalOutput.GetString(c.retrievalQAChain.OutputKeys()[0])
if err != nil {
return nil, err
}

returns := schema.ChainValues{
Expand All @@ -158,33 +143,49 @@ func (c ConversationalRetrieval) Call(ctx context.Context, inputs schema.ChainVa
return returns, nil
}

func (c *ConversationalRetrieval) generateQuestion(ctx context.Context, inputs schema.ChainValues, opts schema.CallOptions) (string, error) {
if inputs["history"] == "" {
return inputs.GetString(c.opts.InputKey)
}

output, err := golc.Call(ctx, c.condenseQuestionChain, inputs, func(co *golc.CallOptions) {
co.Callbacks = opts.CallbackManger.GetInheritableCallbacks()
co.ParentRunID = opts.CallbackManger.RunID()
})
if err != nil {
return "", err
}

return output.GetString(c.condenseQuestionChain.OutputKeys()[0])
}

// Memory returns the memory associated with the chain.
func (c ConversationalRetrieval) Memory() schema.Memory {
func (c *ConversationalRetrieval) Memory() schema.Memory {
return c.opts.Memory
}

// Type returns the type of the chain.
func (c ConversationalRetrieval) Type() string {
func (c *ConversationalRetrieval) Type() string {
return "ConversationalRetrieval"
}

// Verbose returns the verbosity setting of the chain.
func (c ConversationalRetrieval) Verbose() bool {
func (c *ConversationalRetrieval) Verbose() bool {
return c.opts.CallbackOptions.Verbose
}

// Callbacks returns the callbacks associated with the chain.
func (c ConversationalRetrieval) Callbacks() []schema.Callback {
func (c *ConversationalRetrieval) Callbacks() []schema.Callback {
return c.opts.CallbackOptions.Callbacks
}

// InputKeys returns the expected input keys.
func (c ConversationalRetrieval) InputKeys() []string {
func (c *ConversationalRetrieval) InputKeys() []string {
return []string{c.opts.InputKey}
}

// OutputKeys returns the output keys the chain will return.
func (c ConversationalRetrieval) OutputKeys() []string {
func (c *ConversationalRetrieval) OutputKeys() []string {
outputKeys := []string{c.opts.OutputKey}
if c.opts.ReturnSourceDocuments {
outputKeys = append(outputKeys, "sourceDocuments")
Expand Down
6 changes: 2 additions & 4 deletions rag/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@ package rag
import "errors"

var (
ErrNoInputValues = errors.New("no input values")
ErrInvalidInputValues = errors.New("invalid input values")
ErrInputValuesWrongType = errors.New("input key is of wrong type")
ErrNoOutputParser = errors.New("no output parser")
ErrNoInputValues = errors.New("no input values")
ErrNoOutputParser = errors.New("no output parser")
)
133 changes: 133 additions & 0 deletions rag/map_reduce_documents.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
package rag

import (
"context"

"github.com/hupe1980/golc"
"github.com/hupe1980/golc/callback"
"github.com/hupe1980/golc/chain"
"github.com/hupe1980/golc/schema"
"github.com/hupe1980/golc/util"
)

// Compile time check to ensure MapReduceDocuments satisfies the Chain interface.
var _ schema.Chain = (*MapReduceDocuments)(nil)

type MapReduceDocumentsOptions struct {
*schema.CallbackOptions
InputKey string
DocumentVariableName string
}

type MapReduceDocuments struct {
mapChain *chain.LLM
combineChain *StuffDocuments
opts MapReduceDocumentsOptions
}

func NewMapReduceDocuments(mapChain *chain.LLM, combineChain *StuffDocuments, optFns ...func(o *MapReduceDocumentsOptions)) (*MapReduceDocuments, error) {
opts := MapReduceDocumentsOptions{
CallbackOptions: &schema.CallbackOptions{
Verbose: golc.Verbose,
},
InputKey: "inputDocuments",
DocumentVariableName: "text",
}

for _, fn := range optFns {
fn(&opts)
}

return &MapReduceDocuments{
mapChain: mapChain,
combineChain: combineChain,
opts: opts,
}, nil
}

// Call executes the MapReduceDocuments chain with the given context and inputs.
// It returns the outputs of the chain or an error, if any.
func (c *MapReduceDocuments) Call(ctx context.Context, inputs schema.ChainValues, optFns ...func(o *schema.CallOptions)) (schema.ChainValues, error) {
opts := schema.CallOptions{
CallbackManger: &callback.NoopManager{},
}

for _, fn := range optFns {
fn(&opts)
}

docs, err := inputs.GetDocuments(c.opts.InputKey)
if err != nil {
return nil, err
}

rest := schema.ChainValues(util.OmitByKeys(inputs, []string{c.opts.InputKey}))

batchInputs := make([]schema.ChainValues, len(docs))

for i, d := range docs {
batchInput := rest.Clone()
batchInput[c.opts.DocumentVariableName] = d.PageContent
batchInputs[i] = batchInput
}

mapResults, err := golc.BatchCall(ctx, c.mapChain, batchInputs, func(co *golc.BatchCallOptions) {
co.Callbacks = opts.CallbackManger.GetInheritableCallbacks()
co.ParentRunID = opts.CallbackManger.RunID()
})
if err != nil {
return nil, err
}

combineDocs := make([]schema.Document, len(docs))

for i, d := range docs {
mapResult, err := mapResults[i].GetString(c.mapChain.OutputKeys()[0])
if err != nil {
return nil, err
}

combineDocs[i] = schema.Document{
PageContent: mapResult,
Metadata: d.Metadata,
}
}

combineInputs := rest.Clone()
combineInputs[c.combineChain.InputKeys()[0]] = combineDocs

return golc.Call(ctx, c.combineChain, combineInputs, func(co *golc.CallOptions) {
co.Callbacks = opts.CallbackManger.GetInheritableCallbacks()
co.ParentRunID = opts.CallbackManger.RunID()
})
}

// Memory returns the memory associated with the chain.
func (c *MapReduceDocuments) Memory() schema.Memory {
return nil
}

// Type returns the type of the chain.
func (c *MapReduceDocuments) Type() string {
return "MapReduceDocuments"
}

// Verbose returns the verbosity setting of the chain.
func (c *MapReduceDocuments) Verbose() bool {
return c.opts.Verbose
}

// Callbacks returns the callbacks associated with the chain.
func (c *MapReduceDocuments) Callbacks() []schema.Callback {
return c.opts.Callbacks
}

// InputKeys returns the expected input keys.
func (c *MapReduceDocuments) InputKeys() []string {
return []string{c.opts.InputKey}
}

// OutputKeys returns the output keys the chain will return.
func (c *MapReduceDocuments) OutputKeys() []string {
return c.combineChain.OutputKeys()
}
2 changes: 1 addition & 1 deletion rag/qa.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package rag

const defaultStuffQAPromptTemplate = `Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.
{{.context}}
{{.text}}
Question: {{.question}}
Helpful Answer:`
Expand Down
32 changes: 11 additions & 21 deletions rag/refine_documents.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package rag

import (
"context"
"fmt"
"strings"

"github.com/hupe1980/golc"
Expand Down Expand Up @@ -33,13 +32,13 @@ type RefineDocuments struct {

func NewRefineDocuments(llmChain *chain.LLM, refineLLMChain *chain.LLM, optFns ...func(o *RefineDocumentsOptions)) (*RefineDocuments, error) {
opts := RefineDocumentsOptions{
InputKey: "inputDocuments",
DocumentVariableName: "context",
InitialResponseName: "existingAnswer",
OutputKey: "text",
CallbackOptions: &schema.CallbackOptions{
Verbose: golc.Verbose,
},
InputKey: "inputDocuments",
OutputKey: "outputText",
DocumentVariableName: "text",
InitialResponseName: "existingAnswer",
}

for _, fn := range optFns {
Expand All @@ -57,7 +56,7 @@ func NewRefineDocuments(llmChain *chain.LLM, refineLLMChain *chain.LLM, optFns .
}, nil
}

// Call executes the ConversationalRetrieval chain with the given context and inputs.
// Call executes the RefineDocuments chain with the given context and inputs.
// It returns the outputs of the chain or an error, if any.
func (c *RefineDocuments) Call(ctx context.Context, values schema.ChainValues, optFns ...func(o *schema.CallOptions)) (schema.ChainValues, error) {
opts := schema.CallOptions{
Expand All @@ -68,18 +67,9 @@ func (c *RefineDocuments) Call(ctx context.Context, values schema.ChainValues, o
fn(&opts)
}

input, ok := values[c.opts.InputKey]
if !ok {
return nil, fmt.Errorf("%w: no value for inputKey %s", ErrInvalidInputValues, c.opts.InputKey)
}

docs, ok := input.([]schema.Document)
if !ok {
return nil, ErrInputValuesWrongType
}

if len(docs) == 0 {
return nil, fmt.Errorf("%w: documents slice has no elements", ErrInvalidInputValues)
docs, err := values.GetDocuments(c.opts.InputKey)
if err != nil {
return nil, err
}

rest := util.OmitByKeys(values, []string{c.opts.InputKey})
Expand Down Expand Up @@ -112,7 +102,7 @@ func (c *RefineDocuments) Call(ctx context.Context, values schema.ChainValues, o
}
}

return map[string]any{
return schema.ChainValues{
c.opts.OutputKey: strings.TrimSpace(res),
}, nil
}
Expand All @@ -129,12 +119,12 @@ func (c *RefineDocuments) Type() string {

// Verbose returns the verbosity setting of the chain.
func (c *RefineDocuments) Verbose() bool {
return c.opts.CallbackOptions.Verbose
return c.opts.Verbose
}

// Callbacks returns the callbacks associated with the chain.
func (c *RefineDocuments) Callbacks() []schema.Callback {
return c.opts.CallbackOptions.Callbacks
return c.opts.Callbacks
}

// InputKeys returns the expected input keys.
Expand Down
Loading

0 comments on commit 33efa9a

Please sign in to comment.