Skip to content

Commit

Permalink
feat: 流模式接口兼容Azure
Browse files Browse the repository at this point in the history
  • Loading branch information
Leizhenpeng committed Jul 8, 2023
1 parent 379b1e8 commit 82b31e9
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 7 deletions.
2 changes: 2 additions & 0 deletions code/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,5 @@ require (
gopkg.in/ini.v1 v1.67.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

//replace github.com/sashabaranov/go-openai v1.13.0 => github.com/Leizhenpeng/go-openai v0.0.3
6 changes: 1 addition & 5 deletions code/services/openai/gpt3_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,15 @@ import (

func TestCompletions(t *testing.T) {
config := initialization.LoadConfig("../../config.yaml")

msgs := []Messages{
{Role: "system", Content: "你是一个专业的翻译官,负责中英文翻译。"},
{Role: "user", Content: "翻译这段话: The assistant messages help store prior responses. They can also be written by a developer to help give examples of desired behavior."},
}

gpt := NewChatGPT(*config)

resp, err := gpt.Completions(msgs, Balance)
if err != nil {
t.Errorf("TestCompletions failed with error: %v", err)
}

fmt.Println(resp.Content, resp.Role)
}

Expand Down Expand Up @@ -157,7 +153,7 @@ func TestChatGPT_streamChat(t *testing.T) {

// 启动一个协程来模拟流式聊天
go func() {
err := c.StreamChat(ctx, tc.msg, responseStream)
err := c.StreamChat(ctx, tc.msg, Balance, responseStream)
if err != nil {
t.Errorf("streamChat() error = %v, wantErr %v", err, tc.wantErr)
}
Expand Down
16 changes: 14 additions & 2 deletions code/services/openai/stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,19 @@ func (c *ChatGPT) StreamChatWithHistory(ctx context.Context,
aiMode AIMode,
responseStream chan string,
) error {

config := go_openai.DefaultConfig(c.ApiKey[0])
config.BaseURL = c.ApiUrl + "/v1"
if c.Platform != OpenAI {
baseUrl := fmt.Sprintf("https://%s.%s",
c.AzureConfig.ResourceName, "openai.azure.com")
config = go_openai.DefaultAzureConfig(c.AzureConfig.
ApiToken, baseUrl)
config.AzureModelMapperFunc = func(model string) string {
return c.AzureConfig.DeploymentName

}
}

proxyClient, parseProxyError := GetProxyClient(c.HttpProxy)
if parseProxyError != nil {
Expand All @@ -43,12 +54,12 @@ func (c *ChatGPT) StreamChatWithHistory(ctx context.Context,
var temperature float32
temperature = float32(aiMode)
req := go_openai.ChatCompletionRequest{
Model: c.Model,
Model: "gpt-35-turbo",
Messages: msg,
N: 1,
Temperature: temperature,
MaxTokens: maxTokens,
TopP: 1,
//TopP: 1,
//Moderation: true,
//ModerationStop: true,
}
Expand All @@ -60,6 +71,7 @@ func (c *ChatGPT) StreamChatWithHistory(ctx context.Context,
defer stream.Close()
for {
response, err := stream.Recv()
fmt.Println("response: ", response)
if errors.Is(err, io.EOF) {
//fmt.Println("Stream finished")
return nil
Expand Down

0 comments on commit 82b31e9

Please sign in to comment.