From ed1dabaeab2c325b9cb9116ef63ec132a62764e2 Mon Sep 17 00:00:00 2001 From: hupe1980 Date: Mon, 31 Jul 2023 16:33:47 +0200 Subject: [PATCH] Rename conversational retrieval chain --- .../main.go | 6 +-- ...eval.go => conversational_retrieval_qa.go} | 42 +++++++++---------- 2 files changed, 24 insertions(+), 24 deletions(-) rename examples/{conversational_retrieval_rag => conversational_retrieval_qa_rag}/main.go (87%) rename rag/{conversational_retrieval.go => conversational_retrieval_qa.go} (72%) diff --git a/examples/conversational_retrieval_rag/main.go b/examples/conversational_retrieval_qa_rag/main.go similarity index 87% rename from examples/conversational_retrieval_rag/main.go rename to examples/conversational_retrieval_qa_rag/main.go index 16cf735..d72f0d9 100644 --- a/examples/conversational_retrieval_rag/main.go +++ b/examples/conversational_retrieval_qa_rag/main.go @@ -37,7 +37,7 @@ func main() { log.Fatal(err) } - conversationalRetrievalChain, err := rag.NewConversationalRetrieval(openai, &mockRetriever{}, func(o *rag.ConversationalRetrievalOptions) { + conversationalRetrievalQAChain, err := rag.NewConversationalRetrievalQA(openai, &mockRetriever{}, func(o *rag.ConversationalRetrievalQAOptions) { o.ReturnGeneratedQuestion = true }) if err != nil { @@ -46,7 +46,7 @@ func main() { question1 := "Why don't scientists trust atoms?" - result1, err := golc.Call(context.Background(), conversationalRetrievalChain, schema.ChainValues{ + result1, err := golc.Call(context.Background(), conversationalRetrievalQAChain, schema.ChainValues{ "question": question1, }) if err != nil { @@ -59,7 +59,7 @@ func main() { question2 := "Can you explain it better?" - result2, err := golc.Call(context.Background(), conversationalRetrievalChain, schema.ChainValues{ + result2, err := golc.Call(context.Background(), conversationalRetrievalQAChain, schema.ChainValues{ "question": question2, }) if err != nil { diff --git a/rag/conversational_retrieval.go b/rag/conversational_retrieval_qa.go similarity index 72% rename from rag/conversational_retrieval.go rename to rag/conversational_retrieval_qa.go index 2616686..bf1b772 100644 --- a/rag/conversational_retrieval.go +++ b/rag/conversational_retrieval_qa.go @@ -18,11 +18,11 @@ Chat History: Follow Up Input: {{.query}} Standalone question:` -// Compile time check to ensure ConversationalRetrieval satisfies the Chain interface. -var _ schema.Chain = (*ConversationalRetrieval)(nil) +// Compile time check to ensure ConversationalRetrievalQA satisfies the Chain interface. +var _ schema.Chain = (*ConversationalRetrievalQA)(nil) -// ConversationalRetrievalOptions represents the options for the ConversationalRetrieval chain. -type ConversationalRetrievalOptions struct { +// ConversationalRetrievalQAOptions represents the options for the ConversationalRetrievalQA chain. +type ConversationalRetrievalQAOptions struct { *schema.CallbackOptions // Return the source documents @@ -42,16 +42,16 @@ type ConversationalRetrievalOptions struct { MaxTokenLimit uint } -// ConversationalRetrieval is a chain implementation for conversational retrieval. -type ConversationalRetrieval struct { +// ConversationalRetrievalQA is a chain implementation for conversational retrieval. +type ConversationalRetrievalQA struct { condenseQuestionChain *chain.LLM retrievalQAChain *RetrievalQA - opts ConversationalRetrievalOptions + opts ConversationalRetrievalQAOptions } -// NewConversationalRetrieval creates a new instance of the ConversationalRetrieval chain. -func NewConversationalRetrieval(llm schema.LLM, retriever schema.Retriever, optFns ...func(o *ConversationalRetrievalOptions)) (*ConversationalRetrieval, error) { - opts := ConversationalRetrievalOptions{ +// NewConversationalRetrievalQA creates a new instance of the ConversationalRetrievalQA chain. +func NewConversationalRetrievalQA(llm schema.LLM, retriever schema.Retriever, optFns ...func(o *ConversationalRetrievalQAOptions)) (*ConversationalRetrievalQA, error) { + opts := ConversationalRetrievalQAOptions{ CallbackOptions: &schema.CallbackOptions{ Verbose: golc.Verbose, }, @@ -90,16 +90,16 @@ func NewConversationalRetrieval(llm schema.LLM, retriever schema.Retriever, optF return nil, err } - return &ConversationalRetrieval{ + return &ConversationalRetrievalQA{ condenseQuestionChain: condenseQuestionChain, retrievalQAChain: retrievalQAChain, opts: opts, }, nil } -// Call executes the ConversationalRetrieval chain with the given context and inputs. +// Call executes the ConversationalRetrievalQA chain with the given context and inputs. // It returns the outputs of the chain or an error, if any. -func (c *ConversationalRetrieval) Call(ctx context.Context, inputs schema.ChainValues, optFns ...func(o *schema.CallOptions)) (schema.ChainValues, error) { +func (c *ConversationalRetrievalQA) Call(ctx context.Context, inputs schema.ChainValues, optFns ...func(o *schema.CallOptions)) (schema.ChainValues, error) { opts := schema.CallOptions{ CallbackManger: &callback.NoopManager{}, } @@ -143,7 +143,7 @@ func (c *ConversationalRetrieval) Call(ctx context.Context, inputs schema.ChainV return returns, nil } -func (c *ConversationalRetrieval) generateQuestion(ctx context.Context, inputs schema.ChainValues, opts schema.CallOptions) (string, error) { +func (c *ConversationalRetrievalQA) generateQuestion(ctx context.Context, inputs schema.ChainValues, opts schema.CallOptions) (string, error) { if inputs["history"] == "" { return inputs.GetString(c.opts.InputKey) } @@ -160,32 +160,32 @@ func (c *ConversationalRetrieval) generateQuestion(ctx context.Context, inputs s } // Memory returns the memory associated with the chain. -func (c *ConversationalRetrieval) Memory() schema.Memory { +func (c *ConversationalRetrievalQA) Memory() schema.Memory { return c.opts.Memory } // Type returns the type of the chain. -func (c *ConversationalRetrieval) Type() string { - return "ConversationalRetrieval" +func (c *ConversationalRetrievalQA) Type() string { + return "ConversationalRetrievalQA" } // Verbose returns the verbosity setting of the chain. -func (c *ConversationalRetrieval) Verbose() bool { +func (c *ConversationalRetrievalQA) Verbose() bool { return c.opts.CallbackOptions.Verbose } // Callbacks returns the callbacks associated with the chain. -func (c *ConversationalRetrieval) Callbacks() []schema.Callback { +func (c *ConversationalRetrievalQA) Callbacks() []schema.Callback { return c.opts.CallbackOptions.Callbacks } // InputKeys returns the expected input keys. -func (c *ConversationalRetrieval) InputKeys() []string { +func (c *ConversationalRetrievalQA) InputKeys() []string { return []string{c.opts.InputKey} } // OutputKeys returns the output keys the chain will return. -func (c *ConversationalRetrieval) OutputKeys() []string { +func (c *ConversationalRetrievalQA) OutputKeys() []string { outputKeys := []string{c.opts.OutputKey} if c.opts.ReturnSourceDocuments { outputKeys = append(outputKeys, "sourceDocuments")