diff --git a/chain/prompt_selector.go b/chain/prompt_selector.go new file mode 100644 index 0000000..455f128 --- /dev/null +++ b/chain/prompt_selector.go @@ -0,0 +1,54 @@ +package chain + +import ( + "github.com/hupe1980/golc/prompt" + "github.com/hupe1980/golc/schema" +) + +// PromptSelector is an interface for selecting prompts based on a model. +type PromptSelector interface { + GetPrompt(model schema.Model) *prompt.Template +} + +// Compile time check to ensure ConditionalPromptSelector satisfies the PromptSelector interface. +var _ PromptSelector = (*ConditionalPromptSelector)(nil) + +// ConditionalFunc represents a function that evaluates a condition based on a model. +type ConditionalFunc func(model schema.Model) bool + +// Conditional represents a conditional prompt configuration. +type Conditional struct { + Condition ConditionalFunc + Prompt *prompt.Template +} + +// ConditionalPromptSelector is a prompt selector that selects prompts based on conditions. +type ConditionalPromptSelector struct { + DefaultPrompt *prompt.Template + Conditionals []Conditional +} + +// GetPrompt selects a prompt template based on the provided model. +// It evaluates the conditions in order and returns the prompt associated with the first matching condition, +// or returns the default prompt if no condition is met. +func (cps *ConditionalPromptSelector) GetPrompt(model schema.Model) *prompt.Template { + for _, conditional := range cps.Conditionals { + if conditional.Condition(model) { + return conditional.Prompt + } + } + + return cps.DefaultPrompt +} + +// IsLLM checks if the given model is of type schema.LLM. +func IsLLM(model schema.Model) bool { + _, ok := model.(schema.LLM) + return ok +} + +// IsChatModel checks if the given model is of type schema.ChatModel. +func IsChatModel(model schema.Model) bool { + _, ok := model.(schema.ChatModel) + return ok +} diff --git a/chain/prompt_selector_test.go b/chain/prompt_selector_test.go new file mode 100644 index 0000000..a87ee2e --- /dev/null +++ b/chain/prompt_selector_test.go @@ -0,0 +1,82 @@ +package chain + +import ( + "testing" + + "github.com/hupe1980/golc/model/chatmodel" + "github.com/hupe1980/golc/model/llm" + "github.com/hupe1980/golc/prompt" + "github.com/hupe1980/golc/schema" + "github.com/stretchr/testify/assert" +) + +func TestConditionalPromptSelector(t *testing.T) { + t.Run("GetPrompt", func(t *testing.T) { + defaultPrompt, _ := prompt.NewTemplate("Default Prompt") + llmPrompt, _ := prompt.NewTemplate("LLM Prompt") + chatModelPrompt, _ := prompt.NewTemplate("ChatModel Prompt") + + conditional1 := Conditional{ + Condition: func(model schema.Model) bool { + return IsLLM(model) + }, + Prompt: llmPrompt, + } + + conditional2 := Conditional{ + Condition: func(model schema.Model) bool { + return IsChatModel(model) + }, + Prompt: chatModelPrompt, + } + + cps := ConditionalPromptSelector{ + DefaultPrompt: defaultPrompt, + Conditionals: []Conditional{conditional1, conditional2}, + } + + t.Run("LLM model should return LLM Prompt", func(t *testing.T) { + llmModel := llm.NewFake("dummy") + prompt := cps.GetPrompt(llmModel) + + text, _ := prompt.Format(nil) + assert.Equal(t, "LLM Prompt", text) + }) + + t.Run("Chat model should return Chat Prompt", func(t *testing.T) { + chatModel := chatmodel.NewFake("dummy") + prompt := cps.GetPrompt(chatModel) + + text, _ := prompt.Format(nil) + assert.Equal(t, "ChatModel Prompt", text) + }) + }) +} + +func TestIsLLM(t *testing.T) { + t.Run("LLM model should return true", func(t *testing.T) { + llmModel := llm.NewFake("dummy") + isLLM := IsLLM(llmModel) + assert.True(t, isLLM) + }) + + t.Run("ChatModel should return false", func(t *testing.T) { + chatModel := chatmodel.NewFake("dummy") + isLLM := IsLLM(chatModel) + assert.False(t, isLLM) + }) +} + +func TestIsChatModel(t *testing.T) { + t.Run("Chat model should return true", func(t *testing.T) { + chatModel := chatmodel.NewFake("dummy") + isChatModel := IsChatModel(chatModel) + assert.True(t, isChatModel) + }) + + t.Run("LLM should return false", func(t *testing.T) { + otherModel := llm.NewFake("dummy") + isChatModel := IsChatModel(otherModel) + assert.False(t, isChatModel) + }) +} diff --git a/model/chatmodel/fake.go b/model/chatmodel/fake.go new file mode 100644 index 0000000..f9f94ac --- /dev/null +++ b/model/chatmodel/fake.go @@ -0,0 +1,40 @@ +package chatmodel + +import ( + "context" + + "github.com/hupe1980/golc/schema" +) + +// Compile time check to ensure Fake satisfies the ChatModel interface. +var _ schema.ChatModel = (*Fake)(nil) + +type Fake struct { + schema.Tokenizer + response string +} + +func NewFake(response string) *Fake { + return &Fake{ + response: response, + } +} + +func (cm *Fake) Generate(ctx context.Context, messages schema.ChatMessages) (*schema.LLMResult, error) { + return &schema.LLMResult{ + Generations: [][]*schema.Generation{{newChatGeneraton(cm.response)}}, + LLMOutput: map[string]any{}, + }, nil +} + +func (cm *Fake) Type() string { + return "Fake" +} + +func (cm *Fake) Verbose() bool { + return false +} + +func (cm *Fake) Callbacks() []schema.Callback { + return []schema.Callback{} +} diff --git a/model/llm/fake.go b/model/llm/fake.go new file mode 100644 index 0000000..ec6e963 --- /dev/null +++ b/model/llm/fake.go @@ -0,0 +1,40 @@ +package llm + +import ( + "context" + + "github.com/hupe1980/golc/schema" +) + +// Compile time check to ensure Fake satisfies the LLM interface. +var _ schema.LLM = (*Fake)(nil) + +type Fake struct { + schema.Tokenizer + response string +} + +func NewFake(response string) *Fake { + return &Fake{ + response: response, + } +} + +func (l *Fake) Generate(ctx context.Context, prompts []string, stop []string) (*schema.LLMResult, error) { + return &schema.LLMResult{ + Generations: [][]*schema.Generation{{&schema.Generation{Text: l.response}}}, + LLMOutput: map[string]any{}, + }, nil +} + +func (l *Fake) Type() string { + return "Fake" +} + +func (l *Fake) Verbose() bool { + return false +} + +func (l *Fake) Callbacks() []schema.Callback { + return []schema.Callback{} +}