Skip to content

Commit

Permalink
Split chain package
Browse files Browse the repository at this point in the history
  • Loading branch information
hupe1980 committed Jul 19, 2023
1 parent 2227a67 commit fffd90e
Show file tree
Hide file tree
Showing 12 changed files with 45 additions and 26 deletions.
4 changes: 4 additions & 0 deletions chain/llm.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,10 @@ func (c *LLM) Call(ctx context.Context, inputs schema.ChainValues, optFns ...fun
return outputs[0], nil
}

func (c *LLM) GetNumTokens(text string) (uint, error) {
return c.llm.GetNumTokens(text)
}

func (c *LLM) Prompt() *prompt.Template {
return c.prompt
}
Expand Down
4 changes: 2 additions & 2 deletions examples/conversational_retrieval/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ import (
"os"

"github.com/hupe1980/golc"
"github.com/hupe1980/golc/chain"
"github.com/hupe1980/golc/model/llm"
"github.com/hupe1980/golc/rag"
"github.com/hupe1980/golc/schema"
)

Expand All @@ -29,7 +29,7 @@ func main() {
log.Fatal(err)
}

conversationalRetrievalChain, err := chain.NewConversationalRetrieval(openai, &mockRetriever{}, func(o *chain.ConversationalRetrievalOptions) {
conversationalRetrievalChain, err := rag.NewConversationalRetrieval(openai, &mockRetriever{}, func(o *rag.ConversationalRetrievalOptions) {
o.ReturnGeneratedQuestion = true
})
if err != nil {
Expand Down
4 changes: 2 additions & 2 deletions examples/llm_refine_summarization/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ import (
"strings"

"github.com/hupe1980/golc"
"github.com/hupe1980/golc/chain"
"github.com/hupe1980/golc/documentloader"
"github.com/hupe1980/golc/model/llm"
"github.com/hupe1980/golc/rag"
"github.com/hupe1980/golc/textsplitter"
)

Expand All @@ -22,7 +22,7 @@ func main() {
log.Fatal(err)
}

llmSummarizationChain, err := chain.NewRefineSummarization(openai)
llmSummarizationChain, err := rag.NewRefineSummarization(openai)
if err != nil {
log.Fatal(err)
}
Expand Down
4 changes: 2 additions & 2 deletions examples/llm_stuff_summarization/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ import (

"github.com/hupe1980/golc"
"github.com/hupe1980/golc/callback"
"github.com/hupe1980/golc/chain"
"github.com/hupe1980/golc/documentloader"
"github.com/hupe1980/golc/model/llm"
"github.com/hupe1980/golc/rag"
"github.com/hupe1980/golc/schema"
)

Expand All @@ -25,7 +25,7 @@ func main() {
log.Fatal(err)
}

llmSummarizationChain, err := chain.NewStuffSummarization(openai)
llmSummarizationChain, err := rag.NewStuffSummarization(openai)
if err != nil {
log.Fatal(err)
}
Expand Down
4 changes: 2 additions & 2 deletions examples/retrieval_qa/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ import (
"os"

"github.com/hupe1980/golc"
"github.com/hupe1980/golc/chain"
"github.com/hupe1980/golc/model/llm"
"github.com/hupe1980/golc/rag"
"github.com/hupe1980/golc/schema"
)

Expand All @@ -27,7 +27,7 @@ func main() {
log.Fatal(err)
}

retrievalQAChain, err := chain.NewRetrievalQA(openai, &mockRetriever{})
retrievalQAChain, err := rag.NewRetrievalQA(openai, &mockRetriever{})
if err != nil {
log.Fatal(err)
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
package chain
package rag

import (
"context"
"fmt"

"github.com/hupe1980/golc"
"github.com/hupe1980/golc/callback"
"github.com/hupe1980/golc/chain"
"github.com/hupe1980/golc/memory"
"github.com/hupe1980/golc/prompt"
"github.com/hupe1980/golc/schema"
Expand Down Expand Up @@ -44,7 +45,7 @@ type ConversationalRetrievalOptions struct {

// ConversationalRetrieval is a chain implementation for conversational retrieval.
type ConversationalRetrieval struct {
condenseQuestionChain *LLM
condenseQuestionChain *chain.LLM
retrievalQAChain *RetrievalQA
opts ConversationalRetrievalOptions
}
Expand Down Expand Up @@ -75,7 +76,7 @@ func NewConversationalRetrieval(llm schema.LLM, retriever schema.Retriever, optF
opts.CondenseQuestionPrompt = prompt.NewTemplate(defaultcondenseQuestionPromptTemplate)
}

condenseQuestionChain, err := NewLLM(llm, opts.CondenseQuestionPrompt)
condenseQuestionChain, err := chain.NewLLM(llm, opts.CondenseQuestionPrompt)
if err != nil {
return nil, err
}
Expand Down
10 changes: 10 additions & 0 deletions rag/errors.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package rag

import "errors"

var (
ErrNoInputValues = errors.New("no input values")
ErrInvalidInputValues = errors.New("invalid input values")
ErrInputValuesWrongType = errors.New("input key is of wrong type")
ErrNoOutputParser = errors.New("no output parser")
)
2 changes: 1 addition & 1 deletion chain/qa.go → rag/qa.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package chain
package rag

const defaultStuffQAPromptTemplate = `Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.
Expand Down
9 changes: 5 additions & 4 deletions chain/refine_documents.go → rag/refine_documents.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package chain
package rag

import (
"context"
Expand All @@ -7,6 +7,7 @@ import (

"github.com/hupe1980/golc"
"github.com/hupe1980/golc/callback"
"github.com/hupe1980/golc/chain"
"github.com/hupe1980/golc/prompt"
"github.com/hupe1980/golc/schema"
"github.com/hupe1980/golc/util"
Expand All @@ -25,12 +26,12 @@ type RefineDocumentsOptions struct {
}

type RefineDocuments struct {
llmChain *LLM
refineLLMChain *LLM
llmChain *chain.LLM
refineLLMChain *chain.LLM
opts RefineDocumentsOptions
}

func NewRefineDocuments(llmChain *LLM, refineLLMChain *LLM, optFns ...func(o *RefineDocumentsOptions)) (*RefineDocuments, error) {
func NewRefineDocuments(llmChain *chain.LLM, refineLLMChain *chain.LLM, optFns ...func(o *RefineDocumentsOptions)) (*RefineDocuments, error) {
opts := RefineDocumentsOptions{
InputKey: "inputDocuments",
DocumentVariableName: "context",
Expand Down
7 changes: 4 additions & 3 deletions chain/retrieval_qa.go → rag/retrieval_qa.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
package chain
package rag

import (
"context"
"fmt"

"github.com/hupe1980/golc"
"github.com/hupe1980/golc/callback"
"github.com/hupe1980/golc/chain"
"github.com/hupe1980/golc/prompt"
"github.com/hupe1980/golc/schema"
"github.com/hupe1980/golc/util"
Expand Down Expand Up @@ -47,7 +48,7 @@ func NewRetrievalQA(llm schema.LLM, retriever schema.Retriever, optFns ...func(o
opts.StuffQAPrompt = prompt.NewTemplate(defaultStuffQAPromptTemplate)
}

llmChain, err := NewLLM(llm, opts.StuffQAPrompt)
llmChain, err := chain.NewLLM(llm, opts.StuffQAPrompt)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -120,7 +121,7 @@ func (c *RetrievalQA) getDocuments(ctx context.Context, query string) ([]schema.
tokens := make([]uint, len(docs))

for i, doc := range docs {
t, err := c.stuffDocumentsChain.llmChain.llm.GetNumTokens(doc.PageContent)
t, err := c.stuffDocumentsChain.llmChain.GetNumTokens(doc.PageContent)
if err != nil {
return nil, err
}
Expand Down
7 changes: 4 additions & 3 deletions chain/stuff_documents.go → rag/stuff_documents.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package chain
package rag

import (
"context"
Expand All @@ -7,6 +7,7 @@ import (

"github.com/hupe1980/golc"
"github.com/hupe1980/golc/callback"
"github.com/hupe1980/golc/chain"
"github.com/hupe1980/golc/schema"
"github.com/hupe1980/golc/util"
)
Expand All @@ -22,11 +23,11 @@ type StuffDocumentsOptions struct {
}

type StuffDocuments struct {
llmChain *LLM
llmChain *chain.LLM
opts StuffDocumentsOptions
}

func NewStuffDocuments(llmChain *LLM, optFns ...func(o *StuffDocumentsOptions)) (*StuffDocuments, error) {
func NewStuffDocuments(llmChain *chain.LLM, optFns ...func(o *StuffDocumentsOptions)) (*StuffDocuments, error) {
opts := StuffDocumentsOptions{
InputKey: "inputDocuments",
DocumentVariableName: "context",
Expand Down
9 changes: 5 additions & 4 deletions chain/summarization.go → rag/summarization.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
package chain
package rag

import (
"github.com/hupe1980/golc"
"github.com/hupe1980/golc/chain"
"github.com/hupe1980/golc/prompt"
"github.com/hupe1980/golc/schema"
)
Expand Down Expand Up @@ -31,7 +32,7 @@ func NewStuffSummarization(llm schema.LLM, optFns ...func(o *StuffSummarizationO

stuffPrompt := prompt.NewTemplate(stuffSummarizationTemplate)

llmChain, err := NewLLM(llm, stuffPrompt, func(o *LLMOptions) {
llmChain, err := chain.NewLLM(llm, stuffPrompt, func(o *chain.LLMOptions) {
o.CallbackOptions = opts.CallbackOptions
})
if err != nil {
Expand Down Expand Up @@ -73,7 +74,7 @@ func NewRefineSummarization(llm schema.LLM, optFns ...func(o *RefineSummarizatio

stuffPrompt := prompt.NewTemplate(stuffSummarizationTemplate)

llmChain, err := NewLLM(llm, stuffPrompt, func(o *LLMOptions) {
llmChain, err := chain.NewLLM(llm, stuffPrompt, func(o *chain.LLMOptions) {
o.CallbackOptions = opts.CallbackOptions
})
if err != nil {
Expand All @@ -82,7 +83,7 @@ func NewRefineSummarization(llm schema.LLM, optFns ...func(o *RefineSummarizatio

refinePrompt := prompt.NewTemplate(refineSummarizationTemplate)

refineLLMChain, err := NewLLM(llm, refinePrompt, func(o *LLMOptions) {
refineLLMChain, err := chain.NewLLM(llm, refinePrompt, func(o *chain.LLMOptions) {
o.CallbackOptions = opts.CallbackOptions
})
if err != nil {
Expand Down

0 comments on commit fffd90e

Please sign in to comment.