Skip to content

Commit

Permalink
Refactor qa chains
Browse files Browse the repository at this point in the history
  • Loading branch information
hupe1980 committed Jul 31, 2023
1 parent ff4ceaa commit e1509c2
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ func main() {
question1 := "Why don't scientists trust atoms?"

result1, err := golc.Call(context.Background(), conversationalRetrievalChain, schema.ChainValues{
"query": question1,
"question": question1,
})
if err != nil {
log.Fatal(err)
Expand All @@ -60,7 +60,7 @@ func main() {
question2 := "Can you explain it better?"

result2, err := golc.Call(context.Background(), conversationalRetrievalChain, schema.ChainValues{
"query": question2,
"question": question2,
})
if err != nil {
log.Fatal(err)
Expand Down
File renamed without changes.
8 changes: 4 additions & 4 deletions rag/conversational_retrieval.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ type ConversationalRetrievalOptions struct {
ReturnGeneratedQuestion bool

CondenseQuestionPrompt *prompt.Template
StuffQAPrompt *prompt.Template
RetrievalQAPrompt *prompt.Template
Memory schema.Memory
InputKey string
OutputKey string
Expand All @@ -57,7 +57,7 @@ func NewConversationalRetrieval(llm schema.LLM, retriever schema.Retriever, optF
},
ReturnSourceDocuments: false,
ReturnGeneratedQuestion: false,
InputKey: "query",
InputKey: "question",
OutputKey: "answer",
}

Expand All @@ -81,7 +81,7 @@ func NewConversationalRetrieval(llm schema.LLM, retriever schema.Retriever, optF
}

retrievalQAChain, err := NewRetrievalQA(llm, retriever, func(o *RetrievalQAOptions) {
o.StuffQAPrompt = opts.StuffQAPrompt
o.RetrievalQAPrompt = opts.RetrievalQAPrompt
o.ReturnSourceDocuments = opts.ReturnSourceDocuments
o.MaxTokenLimit = opts.MaxTokenLimit
o.InputKey = opts.InputKey
Expand Down Expand Up @@ -114,7 +114,7 @@ func (c *ConversationalRetrieval) Call(ctx context.Context, inputs schema.ChainV
}

retrievalOutput, err := golc.Call(ctx, c.retrievalQAChain, schema.ChainValues{
"query": generatedQuestion,
c.retrievalQAChain.InputKeys()[0]: generatedQuestion,
}, func(co *golc.CallOptions) {
co.Callbacks = opts.CallbackManger.GetInheritableCallbacks()
co.ParentRunID = opts.CallbackManger.RunID()
Expand Down
18 changes: 0 additions & 18 deletions rag/qa.go

This file was deleted.

30 changes: 20 additions & 10 deletions rag/retrieval_qa.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,20 @@ import (
"github.com/hupe1980/golc/util"
)

const defaultRetrievalQAPromptTemplate = `Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.
{{.text}}
Question: {{.question}}
Helpful Answer:`

// Compile time check to ensure RetrievalQA satisfies the Chain interface.
var _ schema.Chain = (*RetrievalQA)(nil)

type RetrievalQAOptions struct {
*schema.CallbackOptions
StuffQAPrompt *prompt.Template
InputKey string
RetrievalQAPrompt *prompt.Template
InputKey string

// Return the source documents
ReturnSourceDocuments bool
Expand All @@ -33,7 +43,7 @@ type RetrievalQA struct {

func NewRetrievalQA(llm schema.LLM, retriever schema.Retriever, optFns ...func(o *RetrievalQAOptions)) (*RetrievalQA, error) {
opts := RetrievalQAOptions{
InputKey: "query",
InputKey: "question",
ReturnSourceDocuments: false,
CallbackOptions: &schema.CallbackOptions{
Verbose: golc.Verbose,
Expand All @@ -44,11 +54,11 @@ func NewRetrievalQA(llm schema.LLM, retriever schema.Retriever, optFns ...func(o
fn(&opts)
}

if opts.StuffQAPrompt == nil {
opts.StuffQAPrompt = prompt.NewTemplate(defaultStuffQAPromptTemplate)
if opts.RetrievalQAPrompt == nil {
opts.RetrievalQAPrompt = prompt.NewTemplate(defaultRetrievalQAPromptTemplate)
}

llmChain, err := chain.NewLLM(llm, opts.StuffQAPrompt)
llmChain, err := chain.NewLLM(llm, opts.RetrievalQAPrompt)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -76,18 +86,18 @@ func (c *RetrievalQA) Call(ctx context.Context, values schema.ChainValues, optFn
fn(&opts)
}

query, err := values.GetString(c.opts.InputKey)
question, err := values.GetString(c.opts.InputKey)
if err != nil {
return nil, err
}

docs, err := c.getDocuments(ctx, query, opts)
docs, err := c.getDocuments(ctx, question, opts)
if err != nil {
return nil, err
}

result, err := golc.Call(ctx, c.stuffDocumentsChain, map[string]any{
"question": query,
result, err := golc.Call(ctx, c.stuffDocumentsChain, schema.ChainValues{
"question": question,
c.stuffDocumentsChain.InputKeys()[0]: docs,
}, func(co *golc.CallOptions) {
co.Callbacks = opts.CallbackManger.GetInheritableCallbacks()
Expand Down

0 comments on commit e1509c2

Please sign in to comment.