Skip to content

Commit

Permalink
Rename conversational retrieval chain
Browse files Browse the repository at this point in the history
  • Loading branch information
hupe1980 committed Jul 31, 2023
1 parent e1509c2 commit ed1daba
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ func main() {
log.Fatal(err)
}

conversationalRetrievalChain, err := rag.NewConversationalRetrieval(openai, &mockRetriever{}, func(o *rag.ConversationalRetrievalOptions) {
conversationalRetrievalQAChain, err := rag.NewConversationalRetrievalQA(openai, &mockRetriever{}, func(o *rag.ConversationalRetrievalQAOptions) {
o.ReturnGeneratedQuestion = true
})
if err != nil {
Expand All @@ -46,7 +46,7 @@ func main() {

question1 := "Why don't scientists trust atoms?"

result1, err := golc.Call(context.Background(), conversationalRetrievalChain, schema.ChainValues{
result1, err := golc.Call(context.Background(), conversationalRetrievalQAChain, schema.ChainValues{
"question": question1,
})
if err != nil {
Expand All @@ -59,7 +59,7 @@ func main() {

question2 := "Can you explain it better?"

result2, err := golc.Call(context.Background(), conversationalRetrievalChain, schema.ChainValues{
result2, err := golc.Call(context.Background(), conversationalRetrievalQAChain, schema.ChainValues{
"question": question2,
})
if err != nil {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@ Chat History:
Follow Up Input: {{.query}}
Standalone question:`

// Compile time check to ensure ConversationalRetrieval satisfies the Chain interface.
var _ schema.Chain = (*ConversationalRetrieval)(nil)
// Compile time check to ensure ConversationalRetrievalQA satisfies the Chain interface.
var _ schema.Chain = (*ConversationalRetrievalQA)(nil)

// ConversationalRetrievalOptions represents the options for the ConversationalRetrieval chain.
type ConversationalRetrievalOptions struct {
// ConversationalRetrievalQAOptions represents the options for the ConversationalRetrievalQA chain.
type ConversationalRetrievalQAOptions struct {
*schema.CallbackOptions

// Return the source documents
Expand All @@ -42,16 +42,16 @@ type ConversationalRetrievalOptions struct {
MaxTokenLimit uint
}

// ConversationalRetrieval is a chain implementation for conversational retrieval.
type ConversationalRetrieval struct {
// ConversationalRetrievalQA is a chain implementation for conversational retrieval.
type ConversationalRetrievalQA struct {
condenseQuestionChain *chain.LLM
retrievalQAChain *RetrievalQA
opts ConversationalRetrievalOptions
opts ConversationalRetrievalQAOptions
}

// NewConversationalRetrieval creates a new instance of the ConversationalRetrieval chain.
func NewConversationalRetrieval(llm schema.LLM, retriever schema.Retriever, optFns ...func(o *ConversationalRetrievalOptions)) (*ConversationalRetrieval, error) {
opts := ConversationalRetrievalOptions{
// NewConversationalRetrievalQA creates a new instance of the ConversationalRetrievalQA chain.
func NewConversationalRetrievalQA(llm schema.LLM, retriever schema.Retriever, optFns ...func(o *ConversationalRetrievalQAOptions)) (*ConversationalRetrievalQA, error) {
opts := ConversationalRetrievalQAOptions{
CallbackOptions: &schema.CallbackOptions{
Verbose: golc.Verbose,
},
Expand Down Expand Up @@ -90,16 +90,16 @@ func NewConversationalRetrieval(llm schema.LLM, retriever schema.Retriever, optF
return nil, err
}

return &ConversationalRetrieval{
return &ConversationalRetrievalQA{
condenseQuestionChain: condenseQuestionChain,
retrievalQAChain: retrievalQAChain,
opts: opts,
}, nil
}

// Call executes the ConversationalRetrieval chain with the given context and inputs.
// Call executes the ConversationalRetrievalQA chain with the given context and inputs.
// It returns the outputs of the chain or an error, if any.
func (c *ConversationalRetrieval) Call(ctx context.Context, inputs schema.ChainValues, optFns ...func(o *schema.CallOptions)) (schema.ChainValues, error) {
func (c *ConversationalRetrievalQA) Call(ctx context.Context, inputs schema.ChainValues, optFns ...func(o *schema.CallOptions)) (schema.ChainValues, error) {
opts := schema.CallOptions{
CallbackManger: &callback.NoopManager{},
}
Expand Down Expand Up @@ -143,7 +143,7 @@ func (c *ConversationalRetrieval) Call(ctx context.Context, inputs schema.ChainV
return returns, nil
}

func (c *ConversationalRetrieval) generateQuestion(ctx context.Context, inputs schema.ChainValues, opts schema.CallOptions) (string, error) {
func (c *ConversationalRetrievalQA) generateQuestion(ctx context.Context, inputs schema.ChainValues, opts schema.CallOptions) (string, error) {
if inputs["history"] == "" {
return inputs.GetString(c.opts.InputKey)
}
Expand All @@ -160,32 +160,32 @@ func (c *ConversationalRetrieval) generateQuestion(ctx context.Context, inputs s
}

// Memory returns the memory associated with the chain.
func (c *ConversationalRetrieval) Memory() schema.Memory {
func (c *ConversationalRetrievalQA) Memory() schema.Memory {
return c.opts.Memory
}

// Type returns the type of the chain.
func (c *ConversationalRetrieval) Type() string {
return "ConversationalRetrieval"
func (c *ConversationalRetrievalQA) Type() string {
return "ConversationalRetrievalQA"
}

// Verbose returns the verbosity setting of the chain.
func (c *ConversationalRetrieval) Verbose() bool {
func (c *ConversationalRetrievalQA) Verbose() bool {
return c.opts.CallbackOptions.Verbose
}

// Callbacks returns the callbacks associated with the chain.
func (c *ConversationalRetrieval) Callbacks() []schema.Callback {
func (c *ConversationalRetrievalQA) Callbacks() []schema.Callback {
return c.opts.CallbackOptions.Callbacks
}

// InputKeys returns the expected input keys.
func (c *ConversationalRetrieval) InputKeys() []string {
func (c *ConversationalRetrievalQA) InputKeys() []string {
return []string{c.opts.InputKey}
}

// OutputKeys returns the output keys the chain will return.
func (c *ConversationalRetrieval) OutputKeys() []string {
func (c *ConversationalRetrievalQA) OutputKeys() []string {
outputKeys := []string{c.opts.OutputKey}
if c.opts.ReturnSourceDocuments {
outputKeys = append(outputKeys, "sourceDocuments")
Expand Down

0 comments on commit ed1daba

Please sign in to comment.