Skip to content

Commit

Permalink
Misc
Browse files Browse the repository at this point in the history
  • Loading branch information
hupe1980 committed Jun 28, 2023
1 parent d4e500f commit c12b8a0
Show file tree
Hide file tree
Showing 35 changed files with 105 additions and 166 deletions.
7 changes: 2 additions & 5 deletions agent/conversational.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,7 @@ func NewConversationalReactDescription(llm schema.LLM, tools []schema.Tool) (*Co
AIPrefix: "AI",
}

prompt, err := createConversationalPrompt(tools, opts.Prefix, opts.Instructions, opts.Suffix)
if err != nil {
return nil, err
}
prompt := createConversationalPrompt(tools, opts.Prefix, opts.Instructions, opts.Suffix)

llmChain, err := chain.NewLLM(llm, prompt, func(o *chain.LLMOptions) {
o.Memory = memory.NewConversationBuffer()
Expand Down Expand Up @@ -166,7 +163,7 @@ func (a *ConversationalReactDescription) parseOutput(output string) ([]schema.Ag
}, nil, nil
}

func createConversationalPrompt(tools []schema.Tool, prefix, instructions, suffix string) (*prompt.Template, error) {
func createConversationalPrompt(tools []schema.Tool, prefix, instructions, suffix string) *prompt.Template {
return prompt.NewTemplate(strings.Join([]string{prefix, instructions, suffix}, "\n\n"), func(o *prompt.TemplateOptions) {
o.PartialValues = prompt.PartialValues{
"toolNames": toolNames(tools),
Expand Down
7 changes: 2 additions & 5 deletions agent/react.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,7 @@ func NewReactDescription(llm schema.LLM, tools []schema.Tool) (*ReactDescription
OutputKey: "output",
}

prompt, err := createReactDescriptioPrompt(tools, opts.Prefix, opts.Instructions, opts.Suffix)
if err != nil {
return nil, err
}
prompt := createReactDescriptioPrompt(tools, opts.Prefix, opts.Instructions, opts.Suffix)

llmChain, err := chain.NewLLM(llm, prompt)
if err != nil {
Expand Down Expand Up @@ -153,7 +150,7 @@ func (a *ReactDescription) parseOutput(output string) ([]schema.AgentAction, *sc
}, nil, nil
}

func createReactDescriptioPrompt(tools []schema.Tool, prefix, instructions, suffix string) (*prompt.Template, error) {
func createReactDescriptioPrompt(tools []schema.Tool, prefix, instructions, suffix string) *prompt.Template {
return prompt.NewTemplate(strings.Join([]string{prefix, instructions, suffix}, "\n\n"), func(o *prompt.TemplateOptions) {
o.PartialValues = prompt.PartialValues{
"toolNames": toolNames(tools),
Expand Down
2 changes: 1 addition & 1 deletion callback/callback.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ func (h *handler) OnModelNewToken(token string) error {
return nil
}

func (h *handler) OnModelEnd(result schema.LLMResult) error {
func (h *handler) OnModelEnd(result schema.ModelResult) error {
return nil
}

Expand Down
2 changes: 1 addition & 1 deletion callback/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ func (m *manager) OnModelNewToken(token string) error {
return nil
}

func (m *manager) OnModelEnd(result schema.LLMResult) error {
func (m *manager) OnModelEnd(result schema.ModelResult) error {
for _, c := range m.callbacks {
if m.verbose || c.AlwaysVerbose() {
if err := c.OnModelEnd(result); err != nil {
Expand Down
2 changes: 1 addition & 1 deletion callback/openai_info.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ func (o *OpenAIHandler) AlwaysVerbose() bool {
return true
}

func (o *OpenAIHandler) OnModelEnd(result schema.LLMResult) error {
func (o *OpenAIHandler) OnModelEnd(result schema.ModelResult) error {
if result.LLMOutput == nil {
return nil
}
Expand Down
9 changes: 2 additions & 7 deletions chain/conversation.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,7 @@ func NewConversation(llm schema.LLM, optFns ...func(o *ConversationOptions)) (*C
}

if opts.Prompt == nil {
var pErr error

opts.Prompt, pErr = prompt.NewTemplate(conversationTemplate)
if pErr != nil {
return nil, pErr
}
opts.Prompt = prompt.NewTemplate(conversationTemplate)
}

return &Conversation{
Expand Down Expand Up @@ -147,7 +142,7 @@ func (c *Conversation) OutputKeys() []string {
return []string{c.opts.OutputKey}
}

func (c *Conversation) createOutputs(llmResult *schema.LLMResult) ([]map[string]any, error) {
func (c *Conversation) createOutputs(llmResult *schema.ModelResult) ([]map[string]any, error) {
result := make([]map[string]any, len(llmResult.Generations)-1)

for _, generation := range llmResult.Generations {
Expand Down
7 changes: 1 addition & 6 deletions chain/conversational_retrieval.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,7 @@ func NewConversationalRetrieval(llm schema.LLM, retriever schema.Retriever, optF
}

if opts.CondenseQuestionPrompt == nil {
p, err := prompt.NewTemplate(defaultcondenseQuestionPromptTemplate)
if err != nil {
return nil, err
}

opts.CondenseQuestionPrompt = p
opts.CondenseQuestionPrompt = prompt.NewTemplate(defaultcondenseQuestionPromptTemplate)
}

condenseQuestionChain, err := NewLLM(llm, opts.CondenseQuestionPrompt)
Expand Down
6 changes: 3 additions & 3 deletions chain/llm.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,10 +130,10 @@ func (c *LLM) OutputKeys() []string {
return []string{c.opts.OutputKey}
}

func (c *LLM) createOutputs(llmResult *schema.LLMResult) ([]map[string]any, error) {
result := make([]map[string]any, len(llmResult.Generations)-1)
func (c *LLM) createOutputs(modelResult *schema.ModelResult) ([]map[string]any, error) {
result := make([]map[string]any, len(modelResult.Generations)-1)

for _, generation := range llmResult.Generations {
for _, generation := range modelResult.Generations {
parsed, err := c.opts.OutputParser.ParseResult(generation)
if err != nil {
return nil, err
Expand Down
5 changes: 1 addition & 4 deletions chain/llm_bash.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,9 @@ func NewLLMBash(llm schema.LLM, optFns ...func(o *LLMBashOptions)) (*LLMBash, er
fn(&opts)
}

prompt, err := prompt.NewTemplate(llmBashTemplate, func(o *prompt.TemplateOptions) {
prompt := prompt.NewTemplate(llmBashTemplate, func(o *prompt.TemplateOptions) {
o.OutputParser = outputparser.NewFencedCodeBlock("```bash")
})
if err != nil {
return nil, err
}

llmChain, err := NewLLM(llm, prompt)
if err != nil {
Expand Down
5 changes: 1 addition & 4 deletions chain/llm_math.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,9 @@ func NewLLMMath(llm schema.LLM, optFns ...func(o *LLMMathOptions)) (*LLMMath, er
fn(&opts)
}

prompt, err := prompt.NewTemplate(llmMathTemplate, func(o *prompt.TemplateOptions) {
prompt := prompt.NewTemplate(llmMathTemplate, func(o *prompt.TemplateOptions) {
o.OutputParser = outputparser.NewFencedCodeBlock("```text")
})
if err != nil {
return nil, err
}

llmChain, err := NewLLM(llm, prompt)
if err != nil {
Expand Down
6 changes: 3 additions & 3 deletions chain/prompt_selector_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ import (

func TestConditionalPromptSelector(t *testing.T) {
t.Run("GetPrompt", func(t *testing.T) {
defaultPrompt, _ := prompt.NewTemplate("Default Prompt")
llmPrompt, _ := prompt.NewTemplate("LLM Prompt")
chatModelPrompt, _ := prompt.NewTemplate("ChatModel Prompt")
defaultPrompt := prompt.NewTemplate("Default Prompt")
llmPrompt := prompt.NewTemplate("LLM Prompt")
chatModelPrompt := prompt.NewTemplate("ChatModel Prompt")

conditional1 := Conditional{
Condition: func(model schema.Model) bool {
Expand Down
7 changes: 1 addition & 6 deletions chain/refine_documents.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,7 @@ func NewRefineDocuments(llmChain *LLM, refineLLMChain *LLM, optFns ...func(o *Re
}

if opts.DocumentPrompt == nil {
p, err := prompt.NewTemplate("{{.pageContent}}")
if err != nil {
return nil, err
}

opts.DocumentPrompt = p
opts.DocumentPrompt = prompt.NewTemplate("{{.pageContent}}")
}

return &RefineDocuments{
Expand Down
7 changes: 1 addition & 6 deletions chain/retrieval_qa.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,7 @@ func NewRetrievalQA(llm schema.LLM, retriever schema.Retriever, optFns ...func(o
}

if opts.StuffQAPrompt == nil {
p, err := prompt.NewTemplate(defaultStuffQAPromptTemplate)
if err != nil {
return nil, err
}

opts.StuffQAPrompt = p
opts.StuffQAPrompt = prompt.NewTemplate(defaultStuffQAPromptTemplate)
}

llmChain, err := NewLLM(llm, opts.StuffQAPrompt)
Expand Down
15 changes: 3 additions & 12 deletions chain/summarization.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,7 @@ func NewStuffSummarization(llm schema.LLM, optFns ...func(o *StuffSummarizationO
fn(&opts)
}

stuffPrompt, err := prompt.NewTemplate(stuffSummarizationTemplate)
if err != nil {
return nil, err
}
stuffPrompt := prompt.NewTemplate(stuffSummarizationTemplate)

llmChain, err := NewLLM(llm, stuffPrompt, func(o *LLMOptions) {
o.CallbackOptions = opts.CallbackOptions
Expand Down Expand Up @@ -74,10 +71,7 @@ func NewRefineSummarization(llm schema.LLM, optFns ...func(o *RefineSummarizatio
fn(&opts)
}

stuffPrompt, err := prompt.NewTemplate(stuffSummarizationTemplate)
if err != nil {
return nil, err
}
stuffPrompt := prompt.NewTemplate(stuffSummarizationTemplate)

llmChain, err := NewLLM(llm, stuffPrompt, func(o *LLMOptions) {
o.CallbackOptions = opts.CallbackOptions
Expand All @@ -86,10 +80,7 @@ func NewRefineSummarization(llm schema.LLM, optFns ...func(o *RefineSummarizatio
return nil, err
}

refinePrompt, err := prompt.NewTemplate(refineSummarizationTemplate)
if err != nil {
return nil, err
}
refinePrompt := prompt.NewTemplate(refineSummarizationTemplate)

refineLLMChain, err := NewLLM(llm, refinePrompt, func(o *LLMOptions) {
o.CallbackOptions = opts.CallbackOptions
Expand Down
7 changes: 1 addition & 6 deletions evaluation/context_qa_eval_chain.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,7 @@ func NewContextQAEvalChain(llm schema.LLM, optFns ...func(o *ContextQAEvalChainO
}

if opts.Prompt == nil {
contextQAEvalPrompt, err := prompt.NewTemplate(contextQAEvalTemplate)
if err != nil {
return nil, err
}

opts.Prompt = contextQAEvalPrompt
opts.Prompt = prompt.NewTemplate(contextQAEvalTemplate)
}

llmChain, err := chain.NewLLM(llm, opts.Prompt)
Expand Down
7 changes: 1 addition & 6 deletions evaluation/cot_qa_eval_chain.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,7 @@ func NewCOTQAEvalChain(llm schema.LLM, optFns ...func(o *COTQAEvalChainOptions))
}

if opts.Prompt == nil {
cotQAEvalPrompt, err := prompt.NewTemplate(cotQAEvalTemplate)
if err != nil {
return nil, err
}

opts.Prompt = cotQAEvalPrompt
opts.Prompt = prompt.NewTemplate(cotQAEvalTemplate)
}

contextQAEvalChain, err := NewContextQAEvalChain(llm, func(o *ContextQAEvalChainOptions) {
Expand Down
7 changes: 1 addition & 6 deletions evaluation/qa_eval_chain.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,8 @@ type QAEvalChain struct {
}

func NewQAEvalChain(llm schema.LLM, optFns ...func(o *QAEvalChainOptions)) (*QAEvalChain, error) {
qaEvalPrompt, err := prompt.NewTemplate(qaEvalTemplate)
if err != nil {
return nil, err
}

opts := QAEvalChainOptions{
Prompt: qaEvalPrompt,
Prompt: prompt.NewTemplate(qaEvalTemplate),
QuestionKey: "query",
AnswerKey: "answer",
PredictionKey: "result",
Expand Down
4 changes: 2 additions & 2 deletions model/chatmodel/anthropic.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ func NewAnthropic(apiKey string, optFns ...func(o *AnthropicOptions)) (*Anthropi
}, nil
}

func (cm *Anthropic) Generate(ctx context.Context, messages schema.ChatMessages, optFns ...func(o *schema.GenerateOptions)) (*schema.LLMResult, error) {
func (cm *Anthropic) Generate(ctx context.Context, messages schema.ChatMessages, optFns ...func(o *schema.GenerateOptions)) (*schema.ModelResult, error) {
res, err := cm.client.Complete(ctx, &anthropic.CompletionRequest{
Model: cm.opts.ModelName,
MaxTokens: cm.opts.MaxTokens,
Expand All @@ -65,7 +65,7 @@ func (cm *Anthropic) Generate(ctx context.Context, messages schema.ChatMessages,
return nil, err
}

return &schema.LLMResult{
return &schema.ModelResult{
Generations: [][]schema.Generation{{newChatGeneraton(res.Completion)}},
LLMOutput: map[string]any{},
}, nil
Expand Down
4 changes: 2 additions & 2 deletions model/chatmodel/fake.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ func NewFake(response string) *Fake {
}
}

func (cm *Fake) Generate(ctx context.Context, messages schema.ChatMessages, optFns ...func(o *schema.GenerateOptions)) (*schema.LLMResult, error) {
return &schema.LLMResult{
func (cm *Fake) Generate(ctx context.Context, messages schema.ChatMessages, optFns ...func(o *schema.GenerateOptions)) (*schema.ModelResult, error) {
return &schema.ModelResult{
Generations: [][]schema.Generation{{newChatGeneraton(cm.response)}},
LLMOutput: map[string]any{},
}, nil
Expand Down
4 changes: 2 additions & 2 deletions model/chatmodel/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ func newOpenAI(client *openai.Client, opts OpenAIOptions) (*OpenAI, error) {
}, nil
}

func (cm *OpenAI) Generate(ctx context.Context, messages schema.ChatMessages, optFns ...func(o *schema.GenerateOptions)) (*schema.LLMResult, error) {
func (cm *OpenAI) Generate(ctx context.Context, messages schema.ChatMessages, optFns ...func(o *schema.GenerateOptions)) (*schema.ModelResult, error) {
opts := schema.GenerateOptions{}

openAIMessages := []openai.ChatCompletionMessage{}
Expand Down Expand Up @@ -114,7 +114,7 @@ func (cm *OpenAI) Generate(ctx context.Context, messages schema.ChatMessages, op
text := res.Choices[0].Message.Content
role := res.Choices[0].Message.Role

return &schema.LLMResult{
return &schema.ModelResult{
Generations: [][]schema.Generation{{schema.Generation{
Text: text,
Message: openAIResponseToChatMessage(role, text),
Expand Down
4 changes: 2 additions & 2 deletions model/llm/cohere.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ func NewCohere(apiKey string, optFns ...func(o *CohereOptions)) (*Cohere, error)
}, nil
}

func (l *Cohere) Generate(ctx context.Context, prompts []string, optFns ...func(o *schema.GenerateOptions)) (*schema.LLMResult, error) {
func (l *Cohere) Generate(ctx context.Context, prompts []string, optFns ...func(o *schema.GenerateOptions)) (*schema.ModelResult, error) {
opts := schema.GenerateOptions{}

res, err := l.client.Generate(cohere.GenerateOptions{
Expand All @@ -70,7 +70,7 @@ func (l *Cohere) Generate(ctx context.Context, prompts []string, optFns ...func(
return nil, err
}

return &schema.LLMResult{
return &schema.ModelResult{
Generations: [][]schema.Generation{{schema.Generation{Text: res.Generations[0].Text}}},
LLMOutput: map[string]any{},
}, nil
Expand Down
4 changes: 2 additions & 2 deletions model/llm/fake.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ func NewFake(response string) *Fake {
}
}

func (l *Fake) Generate(ctx context.Context, prompts []string, optFns ...func(o *schema.GenerateOptions)) (*schema.LLMResult, error) {
return &schema.LLMResult{
func (l *Fake) Generate(ctx context.Context, prompts []string, optFns ...func(o *schema.GenerateOptions)) (*schema.ModelResult, error) {
return &schema.ModelResult{
Generations: [][]schema.Generation{{schema.Generation{Text: l.response}}},
LLMOutput: map[string]any{},
}, nil
Expand Down
4 changes: 2 additions & 2 deletions model/llm/hugging_face_hub.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ func NewHuggingFaceHub(apiToken string, optFns ...func(o *HuggingFaceHubOptions)
}, nil
}

func (l *HuggingFaceHub) Generate(ctx context.Context, prompts []string, optFns ...func(o *schema.GenerateOptions)) (*schema.LLMResult, error) {
func (l *HuggingFaceHub) Generate(ctx context.Context, prompts []string, optFns ...func(o *schema.GenerateOptions)) (*schema.ModelResult, error) {
var (
text string
err error
Expand All @@ -79,7 +79,7 @@ func (l *HuggingFaceHub) Generate(ctx context.Context, prompts []string, optFns
return nil, err
}

return &schema.LLMResult{
return &schema.ModelResult{
Generations: [][]schema.Generation{{schema.Generation{Text: text}}},
LLMOutput: map[string]any{},
}, nil
Expand Down
4 changes: 2 additions & 2 deletions model/llm/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ func NewOpenAI(apiKey string, optFns ...func(o *OpenAIOptions)) (*OpenAI, error)
}, nil
}

func (l *OpenAI) Generate(ctx context.Context, prompts []string, optFns ...func(o *schema.GenerateOptions)) (*schema.LLMResult, error) {
func (l *OpenAI) Generate(ctx context.Context, prompts []string, optFns ...func(o *schema.GenerateOptions)) (*schema.ModelResult, error) {
opts := schema.GenerateOptions{}

subPromps := util.ChunkBy(prompts, l.opts.BatchSize)
Expand Down Expand Up @@ -159,7 +159,7 @@ func (l *OpenAI) Generate(ctx context.Context, prompts []string, optFns ...func(
})
})

return &schema.LLMResult{
return &schema.ModelResult{
Generations: generations,
LLMOutput: map[string]any{
"ModelName": l.opts.ModelName,
Expand Down
4 changes: 2 additions & 2 deletions model/llm/sagemaker_endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ func NewSagemakerEndpoint(client *sagemakerruntime.Client, endpointName string,
}, nil
}

func (l *SagemakerEndpoint) Generate(ctx context.Context, prompts []string, optFns ...func(o *schema.GenerateOptions)) (*schema.LLMResult, error) {
func (l *SagemakerEndpoint) Generate(ctx context.Context, prompts []string, optFns ...func(o *schema.GenerateOptions)) (*schema.ModelResult, error) {
generations := [][]schema.Generation{}

for _, prompt := range prompts {
Expand Down Expand Up @@ -130,7 +130,7 @@ func (l *SagemakerEndpoint) Generate(ctx context.Context, prompts []string, optF
}})
}

return &schema.LLMResult{
return &schema.ModelResult{
Generations: generations,
LLMOutput: map[string]any{},
}, nil
Expand Down
Loading

0 comments on commit c12b8a0

Please sign in to comment.