-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
216 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
package chain | ||
|
||
import ( | ||
"github.com/hupe1980/golc/prompt" | ||
"github.com/hupe1980/golc/schema" | ||
) | ||
|
||
// PromptSelector is an interface for selecting prompts based on a model. | ||
type PromptSelector interface { | ||
GetPrompt(model schema.Model) *prompt.Template | ||
} | ||
|
||
// Compile time check to ensure ConditionalPromptSelector satisfies the PromptSelector interface. | ||
var _ PromptSelector = (*ConditionalPromptSelector)(nil) | ||
|
||
// ConditionalFunc represents a function that evaluates a condition based on a model. | ||
type ConditionalFunc func(model schema.Model) bool | ||
|
||
// Conditional represents a conditional prompt configuration. | ||
type Conditional struct { | ||
Condition ConditionalFunc | ||
Prompt *prompt.Template | ||
} | ||
|
||
// ConditionalPromptSelector is a prompt selector that selects prompts based on conditions. | ||
type ConditionalPromptSelector struct { | ||
DefaultPrompt *prompt.Template | ||
Conditionals []Conditional | ||
} | ||
|
||
// GetPrompt selects a prompt template based on the provided model. | ||
// It evaluates the conditions in order and returns the prompt associated with the first matching condition, | ||
// or returns the default prompt if no condition is met. | ||
func (cps *ConditionalPromptSelector) GetPrompt(model schema.Model) *prompt.Template { | ||
for _, conditional := range cps.Conditionals { | ||
if conditional.Condition(model) { | ||
return conditional.Prompt | ||
} | ||
} | ||
|
||
return cps.DefaultPrompt | ||
} | ||
|
||
// IsLLM checks if the given model is of type schema.LLM. | ||
func IsLLM(model schema.Model) bool { | ||
_, ok := model.(schema.LLM) | ||
return ok | ||
} | ||
|
||
// IsChatModel checks if the given model is of type schema.ChatModel. | ||
func IsChatModel(model schema.Model) bool { | ||
_, ok := model.(schema.ChatModel) | ||
return ok | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
package chain | ||
|
||
import ( | ||
"testing" | ||
|
||
"github.com/hupe1980/golc/model/chatmodel" | ||
"github.com/hupe1980/golc/model/llm" | ||
"github.com/hupe1980/golc/prompt" | ||
"github.com/hupe1980/golc/schema" | ||
"github.com/stretchr/testify/assert" | ||
) | ||
|
||
func TestConditionalPromptSelector(t *testing.T) { | ||
t.Run("GetPrompt", func(t *testing.T) { | ||
defaultPrompt, _ := prompt.NewTemplate("Default Prompt") | ||
llmPrompt, _ := prompt.NewTemplate("LLM Prompt") | ||
chatModelPrompt, _ := prompt.NewTemplate("ChatModel Prompt") | ||
|
||
conditional1 := Conditional{ | ||
Condition: func(model schema.Model) bool { | ||
return IsLLM(model) | ||
}, | ||
Prompt: llmPrompt, | ||
} | ||
|
||
conditional2 := Conditional{ | ||
Condition: func(model schema.Model) bool { | ||
return IsChatModel(model) | ||
}, | ||
Prompt: chatModelPrompt, | ||
} | ||
|
||
cps := ConditionalPromptSelector{ | ||
DefaultPrompt: defaultPrompt, | ||
Conditionals: []Conditional{conditional1, conditional2}, | ||
} | ||
|
||
t.Run("LLM model should return LLM Prompt", func(t *testing.T) { | ||
llmModel := llm.NewFake("dummy") | ||
prompt := cps.GetPrompt(llmModel) | ||
|
||
text, _ := prompt.Format(nil) | ||
assert.Equal(t, "LLM Prompt", text) | ||
}) | ||
|
||
t.Run("Chat model should return Chat Prompt", func(t *testing.T) { | ||
chatModel := chatmodel.NewFake("dummy") | ||
prompt := cps.GetPrompt(chatModel) | ||
|
||
text, _ := prompt.Format(nil) | ||
assert.Equal(t, "ChatModel Prompt", text) | ||
}) | ||
}) | ||
} | ||
|
||
func TestIsLLM(t *testing.T) { | ||
t.Run("LLM model should return true", func(t *testing.T) { | ||
llmModel := llm.NewFake("dummy") | ||
isLLM := IsLLM(llmModel) | ||
assert.True(t, isLLM) | ||
}) | ||
|
||
t.Run("ChatModel should return false", func(t *testing.T) { | ||
chatModel := chatmodel.NewFake("dummy") | ||
isLLM := IsLLM(chatModel) | ||
assert.False(t, isLLM) | ||
}) | ||
} | ||
|
||
func TestIsChatModel(t *testing.T) { | ||
t.Run("Chat model should return true", func(t *testing.T) { | ||
chatModel := chatmodel.NewFake("dummy") | ||
isChatModel := IsChatModel(chatModel) | ||
assert.True(t, isChatModel) | ||
}) | ||
|
||
t.Run("LLM should return false", func(t *testing.T) { | ||
otherModel := llm.NewFake("dummy") | ||
isChatModel := IsChatModel(otherModel) | ||
assert.False(t, isChatModel) | ||
}) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
package chatmodel | ||
|
||
import ( | ||
"context" | ||
|
||
"github.com/hupe1980/golc/schema" | ||
) | ||
|
||
// Compile time check to ensure Fake satisfies the ChatModel interface. | ||
var _ schema.ChatModel = (*Fake)(nil) | ||
|
||
type Fake struct { | ||
schema.Tokenizer | ||
response string | ||
} | ||
|
||
func NewFake(response string) *Fake { | ||
return &Fake{ | ||
response: response, | ||
} | ||
} | ||
|
||
func (cm *Fake) Generate(ctx context.Context, messages schema.ChatMessages) (*schema.LLMResult, error) { | ||
return &schema.LLMResult{ | ||
Generations: [][]*schema.Generation{{newChatGeneraton(cm.response)}}, | ||
LLMOutput: map[string]any{}, | ||
}, nil | ||
} | ||
|
||
func (cm *Fake) Type() string { | ||
return "Fake" | ||
} | ||
|
||
func (cm *Fake) Verbose() bool { | ||
return false | ||
} | ||
|
||
func (cm *Fake) Callbacks() []schema.Callback { | ||
return []schema.Callback{} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
package llm | ||
|
||
import ( | ||
"context" | ||
|
||
"github.com/hupe1980/golc/schema" | ||
) | ||
|
||
// Compile time check to ensure Fake satisfies the LLM interface. | ||
var _ schema.LLM = (*Fake)(nil) | ||
|
||
type Fake struct { | ||
schema.Tokenizer | ||
response string | ||
} | ||
|
||
func NewFake(response string) *Fake { | ||
return &Fake{ | ||
response: response, | ||
} | ||
} | ||
|
||
func (l *Fake) Generate(ctx context.Context, prompts []string, stop []string) (*schema.LLMResult, error) { | ||
return &schema.LLMResult{ | ||
Generations: [][]*schema.Generation{{&schema.Generation{Text: l.response}}}, | ||
LLMOutput: map[string]any{}, | ||
}, nil | ||
} | ||
|
||
func (l *Fake) Type() string { | ||
return "Fake" | ||
} | ||
|
||
func (l *Fake) Verbose() bool { | ||
return false | ||
} | ||
|
||
func (l *Fake) Callbacks() []schema.Callback { | ||
return []schema.Callback{} | ||
} |