Skip to content

Commit

Permalink
Add retriever callback support
Browse files Browse the repository at this point in the history
  • Loading branch information
hupe1980 committed Jul 27, 2023
1 parent fd7c658 commit 5729574
Show file tree
Hide file tree
Showing 13 changed files with 166 additions and 10 deletions.
8 changes: 8 additions & 0 deletions examples/conversational_retrieval/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,14 @@ func (r *mockRetriever) GetRelevantDocuments(ctx context.Context, query string)
}, nil
}

func (r *mockRetriever) Verbose() bool {
return false
}

func (r *mockRetriever) Callbacks() []schema.Callback {
return nil
}

func main() {
golc.Verbose = true

Expand Down
8 changes: 8 additions & 0 deletions examples/retrieval_qa/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,14 @@ func (r *mockRetriever) GetRelevantDocuments(ctx context.Context, query string)
}, nil
}

func (r *mockRetriever) Verbose() bool {
return false
}

func (r *mockRetriever) Callbacks() []schema.Callback {
return nil
}

func main() {
openai, err := llm.NewOpenAI(os.Getenv("OPENAI_API_KEY"))
if err != nil {
Expand Down
10 changes: 7 additions & 3 deletions rag/retrieval_qa.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"github.com/hupe1980/golc/callback"
"github.com/hupe1980/golc/chain"
"github.com/hupe1980/golc/prompt"
"github.com/hupe1980/golc/retriever"
"github.com/hupe1980/golc/schema"
"github.com/hupe1980/golc/util"
)
Expand Down Expand Up @@ -86,7 +87,7 @@ func (c *RetrievalQA) Call(ctx context.Context, values schema.ChainValues, optFn
return nil, ErrInputValuesWrongType
}

docs, err := c.getDocuments(ctx, query)
docs, err := c.getDocuments(ctx, query, opts)
if err != nil {
return nil, err
}
Expand All @@ -109,8 +110,11 @@ func (c *RetrievalQA) Call(ctx context.Context, values schema.ChainValues, optFn
return result, nil
}

func (c *RetrievalQA) getDocuments(ctx context.Context, query string) ([]schema.Document, error) {
docs, err := c.retriever.GetRelevantDocuments(ctx, query)
func (c *RetrievalQA) getDocuments(ctx context.Context, query string, opts schema.CallOptions) ([]schema.Document, error) {
docs, err := retriever.Run(ctx, c.retriever, query, func(o *retriever.Options) {
o.Callbacks = opts.CallbackManger.GetInheritableCallbacks()
o.ParentRunID = opts.CallbackManger.RunID()
})
if err != nil {
return nil, err
}
Expand Down
15 changes: 15 additions & 0 deletions retriever/amazon_kendra.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/kendra"
"github.com/aws/aws-sdk-go-v2/service/kendra/types"
"github.com/hupe1980/golc"
"github.com/hupe1980/golc/schema"
)

Expand All @@ -27,6 +28,7 @@ type AmazonKendraClient interface {
}

type AmazonKendraOptions struct {
*schema.CallbackOptions
// Number of documents to query for
TopK int32

Expand All @@ -44,6 +46,9 @@ type AmazonKendra struct {
func NewAmazonKendra(client AmazonKendraClient, index string, optFns ...func(o *AmazonKendraOptions)) *AmazonKendra {
opts := AmazonKendraOptions{
TopK: 3,
CallbackOptions: &schema.CallbackOptions{
Verbose: golc.Verbose,
},
}

for _, fn := range optFns {
Expand All @@ -61,6 +66,16 @@ func (r *AmazonKendra) GetRelevantDocuments(ctx context.Context, query string) (
return r.kendraQuery(ctx, query)
}

// Verbose returns the verbosity setting of the retriever.
func (r *AmazonKendra) Verbose() bool {
return r.opts.CallbackOptions.Verbose
}

// Callbacks returns the registered callbacks of the retriever.
func (r *AmazonKendra) Callbacks() []schema.Callback {
return r.opts.CallbackOptions.Callbacks
}

func (r *AmazonKendra) kendraQuery(ctx context.Context, query string) ([]schema.Document, error) {
query = strings.TrimSpace(query)

Expand Down
15 changes: 15 additions & 0 deletions retriever/azure_cognitive_search.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"io"
"net/http"

"github.com/hupe1980/golc"
"github.com/hupe1980/golc/schema"
)

Expand All @@ -23,6 +24,7 @@ type AzureCognitiveSearchRequest struct {

// AzureCognitiveSearchOptions contains options for configuring the AzureCognitiveSearch retriever.
type AzureCognitiveSearchOptions struct {
*schema.CallbackOptions
// Number of documents to query for
TopK uint

Expand Down Expand Up @@ -51,6 +53,9 @@ func NewAzureCognitiveSearch(apiKey, serviceName, indexName string, optFns ...fu
APIVersion: "2020-06-30",
ContentKey: "content",
HTTPClient: http.DefaultClient,
CallbackOptions: &schema.CallbackOptions{
Verbose: golc.Verbose,
},
}

for _, fn := range optFns {
Expand Down Expand Up @@ -103,6 +108,16 @@ func (r *AzureCognitiveSearch) GetRelevantDocuments(ctx context.Context, query s
return docs, nil
}

// Verbose returns the verbosity setting of the retriever.
func (r *AzureCognitiveSearch) Verbose() bool {
return r.opts.CallbackOptions.Verbose
}

// Callbacks returns the registered callbacks of the retriever.
func (r *AzureCognitiveSearch) Callbacks() []schema.Callback {
return r.opts.CallbackOptions.Callbacks
}

// doRequest sends an HTTP request to the Azure Cognitive Search service.
func (r *AzureCognitiveSearch) doRequest(ctx context.Context, method string, url string, payload any) ([]byte, error) {
var body io.Reader
Expand Down
28 changes: 27 additions & 1 deletion retriever/merger.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,33 @@ package retriever
import (
"context"

"github.com/hupe1980/golc"
"github.com/hupe1980/golc/schema"
)

// Compile time check to ensure Merger satisfies the Retriever interface.
var _ schema.Retriever = (*Merger)(nil)

type MergerOptions struct {
*schema.CallbackOptions
}

type Merger struct {
retrievers []schema.Retriever
opts MergerOptions
}

func NewMerger(retrievers ...schema.Retriever) *Merger {
func NewMerger(retrievers []schema.Retriever, optFns ...func(o *MergerOptions)) *Merger {
opts := MergerOptions{
CallbackOptions: &schema.CallbackOptions{
Verbose: golc.Verbose,
},
}

for _, fn := range optFns {
fn(&opts)
}

return &Merger{
retrievers: retrievers,
}
Expand Down Expand Up @@ -53,3 +69,13 @@ func (r *Merger) GetRelevantDocuments(ctx context.Context, query string) ([]sche

return mergedDocuments, nil
}

// Verbose returns the verbosity setting of the retriever.
func (r *Merger) Verbose() bool {
return r.opts.CallbackOptions.Verbose
}

// Callbacks returns the registered callbacks of the retriever.
func (r *Merger) Callbacks() []schema.Callback {
return r.opts.CallbackOptions.Callbacks
}
6 changes: 3 additions & 3 deletions retriever/merger_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ func TestMergeDocuments(t *testing.T) {
}

t.Run("MergeDocuments returns merged documents from 2 retrievers", func(t *testing.T) {
merger := NewMerger(retriever1, retriever2)
merger := NewMerger([]schema.Retriever{retriever1, retriever2})

query := "test query"
expectedDocuments := []schema.Document{
Expand All @@ -50,7 +50,7 @@ func TestMergeDocuments(t *testing.T) {
})

t.Run("MergeDocuments returns merged documents from 3 retrievers", func(t *testing.T) {
merger := NewMerger(retriever1, retriever2, retriever3)
merger := NewMerger([]schema.Retriever{retriever1, retriever2, retriever3})

query := "test query"
expectedDocuments := []schema.Document{
Expand All @@ -67,7 +67,7 @@ func TestMergeDocuments(t *testing.T) {
})

t.Run("MergeDocuments handles empty retriever results", func(t *testing.T) {
merger := NewMerger(retriever1, retriever4)
merger := NewMerger([]schema.Retriever{retriever1, retriever4})

query := "test query"
expectedDocuments := []schema.Document{{PageContent: "Document 1"}}
Expand Down
51 changes: 50 additions & 1 deletion retriever/retriever.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,57 @@
// Package retriever provides functionality for retrieving relevant documents using various services.
package retriever

import "net/http"
import (
"context"
"net/http"

"github.com/hupe1980/golc/callback"
"github.com/hupe1980/golc/schema"
)

type HTTPClient interface {
Do(req *http.Request) (*http.Response, error)
}

type Options struct {
Callbacks []schema.Callback
ParentRunID string
}

func Run(ctx context.Context, retriever schema.Retriever, query string, optFns ...func(*Options)) ([]schema.Document, error) {
opts := Options{}

for _, fn := range optFns {
fn(&opts)
}

cm := callback.NewManager(opts.Callbacks, retriever.Callbacks(), retriever.Verbose(), func(mo *callback.ManagerOptions) {
mo.ParentRunID = opts.ParentRunID
})

rm, err := cm.OnRetrieverStart(ctx, &schema.RetrieverStartManagerInput{
Query: query,
})
if err != nil {
return nil, err
}

docs, err := retriever.GetRelevantDocuments(ctx, query)
if err != nil {
if cbErr := rm.OnRetrieverError(ctx, &schema.RetrieverErrorManagerInput{
Error: err,
}); cbErr != nil {
return nil, cbErr
}

return nil, err
}

if err := rm.OnRetrieverEnd(ctx, &schema.RetrieverEndManagerInput{
Docs: docs,
}); err != nil {
return nil, err
}

return docs, nil
}
11 changes: 11 additions & 0 deletions retriever/retriever_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ import (
"github.com/hupe1980/golc/schema"
)

// Compile time check to ensure retrieverMock satisfies the Retriever interface.
var _ schema.Retriever = (*retrieverMock)(nil)

type retrieverMock struct {
GetRelevantDocumentsFunc func(ctx context.Context, query string) ([]schema.Document, error)
}
Expand All @@ -20,6 +23,14 @@ func (m *retrieverMock) GetRelevantDocuments(ctx context.Context, query string)
return nil, nil
}

func (m *retrieverMock) Verbose() bool {
return false
}

func (m *retrieverMock) Callbacks() []schema.Callback {
return nil
}

type mockHTTPClient struct {
doFunc func(req *http.Request) (*http.Response, error)
}
Expand Down
15 changes: 15 additions & 0 deletions retriever/vector_store_retriever.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package retriever
import (
"context"

"github.com/hupe1980/golc"
"github.com/hupe1980/golc/schema"
)

Expand All @@ -16,6 +17,7 @@ const (
)

type VectorStoreOptions struct {
*schema.CallbackOptions
SearchType VectorStoreSearchType
}

Expand All @@ -27,6 +29,9 @@ type VectorStore struct {
func NewVectorStore(vectorStore schema.VectorStore, optFns ...func(o *VectorStoreOptions)) *VectorStore {
opts := VectorStoreOptions{
SearchType: VectorStoreSearchTypeSimilarity,
CallbackOptions: &schema.CallbackOptions{
Verbose: golc.Verbose,
},
}

for _, fn := range optFns {
Expand All @@ -43,3 +48,13 @@ func NewVectorStore(vectorStore schema.VectorStore, optFns ...func(o *VectorStor
func (r *VectorStore) GetRelevantDocuments(ctx context.Context, query string) ([]schema.Document, error) {
return r.v.SimilaritySearch(ctx, query)
}

// Verbose returns the verbosity setting of the retriever.
func (r *VectorStore) Verbose() bool {
return r.opts.CallbackOptions.Verbose
}

// Callbacks returns the registered callbacks of the retriever.
func (r *VectorStore) Callbacks() []schema.Callback {
return r.opts.CallbackOptions.Callbacks
}
1 change: 1 addition & 0 deletions schema/callback.go
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ type CallbackManager interface {
OnChatModelStart(ctx context.Context, input *ChatModelStartManagerInput) (CallbackManagerForModelRun, error)
OnChainStart(ctx context.Context, input *ChainStartManagerInput) (CallbackManagerForChainRun, error)
OnToolStart(ctx context.Context, input *ToolStartManagerInput) (CallbackManagerForToolRun, error)
OnRetrieverStart(ctx context.Context, input *RetrieverStartManagerInput) (CallbackManagerForRetrieverRun, error)
RunID() string
}

Expand Down
4 changes: 4 additions & 0 deletions schema/document.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ type DocumentLoader interface {

type Retriever interface {
GetRelevantDocuments(ctx context.Context, query string) ([]Document, error)
// Verbose returns the verbosity setting of the retriever.
Verbose() bool
// Callbacks returns the registered callbacks of the retriever.
Callbacks() []Callback
}

type TextSplitter interface {
Expand Down
4 changes: 2 additions & 2 deletions vectorstore/vectorstore.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ import (
)

// ToRetriever takes a vector store and returns a retriever
func ToRetriever(vectorStore schema.VectorStore) schema.Retriever {
return retriever.NewVectorStore(vectorStore)
func ToRetriever(vectorStore schema.VectorStore, optFns ...func(o *retriever.VectorStoreOptions)) schema.Retriever {
return retriever.NewVectorStore(vectorStore, optFns...)
}

func float64ToFloat32(v []float64) []float32 {
Expand Down

0 comments on commit 5729574

Please sign in to comment.