From 4bd00bd6046d4bd41dd29723e6ae27c757a26a5c Mon Sep 17 00:00:00 2001 From: HamaWhite Date: Mon, 19 Jun 2023 16:37:23 +0800 Subject: [PATCH] Support Chat Model and ChatOpenAITest --- .../chat/models/base/BaseChatModel.java | 21 +++++- .../chat/models/openai/ChatOpenAI.java | 16 +++++ .../langchain/chat/models/openai/OpenAI.java | 15 ++-- .../com/hw/langchain/schema/LLMResult.java | 9 ++- .../chat/models/openai/ChatOpenAITest.java | 72 +++++++++++++++++-- .../hw/openai/entity/completions/Usage.java | 4 ++ 6 files changed, 121 insertions(+), 16 deletions(-) diff --git a/langchain-core/src/main/java/com/hw/langchain/chat/models/base/BaseChatModel.java b/langchain-core/src/main/java/com/hw/langchain/chat/models/base/BaseChatModel.java index 31e473e1c..56bef5074 100644 --- a/langchain-core/src/main/java/com/hw/langchain/chat/models/base/BaseChatModel.java +++ b/langchain-core/src/main/java/com/hw/langchain/chat/models/base/BaseChatModel.java @@ -24,6 +24,7 @@ import lombok.experimental.SuperBuilder; import java.util.List; +import java.util.Map; /** * @author HamaWhite @@ -36,6 +37,14 @@ public abstract class BaseChatModel implements BaseLanguageModel { */ private List tags; + public Map combineLlmOutputs(List> llmOutputs) { + return Map.of(); + } + + public LLMResult generate(List> messages) { + return generate(messages, null); + } + /** * Top Level call */ @@ -44,8 +53,15 @@ public LLMResult generate(List> messages, List stop) { .map(message -> _generate(message, stop)) .toList(); - // TODO - return null; + List> llmOutputs = results.stream() + .map(ChatResult::getLlmOutput) + .toList(); + Map llmOutput = combineLlmOutputs(llmOutputs); + + List> generations = results.stream() + .map(ChatResult::getGenerations) + .toList(); + return new LLMResult(generations, llmOutput); } @Override @@ -54,7 +70,6 @@ public LLMResult generatePrompt(List prompts, List stop) { .map(PromptValue::toMessages) .toList(); return generate(promptMessages, stop); - } /** diff --git a/langchain-core/src/main/java/com/hw/langchain/chat/models/openai/ChatOpenAI.java b/langchain-core/src/main/java/com/hw/langchain/chat/models/openai/ChatOpenAI.java index 1aad95df6..b1349f21c 100644 --- a/langchain-core/src/main/java/com/hw/langchain/chat/models/openai/ChatOpenAI.java +++ b/langchain-core/src/main/java/com/hw/langchain/chat/models/openai/ChatOpenAI.java @@ -26,6 +26,7 @@ import com.hw.openai.entity.chat.ChatCompletion; import com.hw.openai.entity.chat.ChatCompletionResp; import com.hw.openai.entity.chat.Message; +import com.hw.openai.entity.completions.Usage; import lombok.Builder; import lombok.experimental.SuperBuilder; @@ -33,6 +34,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Objects; import static com.hw.langchain.chat.models.openai.OpenAI.convertOpenAiToLangChain; import static com.hw.langchain.utils.Utils.getOrEnvOrDefault; @@ -134,6 +136,20 @@ public ChatOpenAI init() { return this; } + @Override + public Map combineLlmOutputs(List> llmOutputs) { + Usage usage = llmOutputs.stream() + .filter(Objects::nonNull) + .map(e -> (Usage) e.get("token_usage")) + .reduce((a1, a2) -> new Usage( + a1.getPromptTokens() + a2.getPromptTokens(), + a1.getCompletionTokens() + a2.getCompletionTokens(), + a1.getTotalTokens() + a2.getTotalTokens())) + .orElse(new Usage()); + + return Map.of("token_usage", usage, "model_name", this.model); + } + @Override public ChatResult _generate(List messages, List stop) { var chatMessages = convertMessages(messages); diff --git a/langchain-core/src/main/java/com/hw/langchain/chat/models/openai/OpenAI.java b/langchain-core/src/main/java/com/hw/langchain/chat/models/openai/OpenAI.java index 4c620d7c2..41e10a5b9 100644 --- a/langchain-core/src/main/java/com/hw/langchain/chat/models/openai/OpenAI.java +++ b/langchain-core/src/main/java/com/hw/langchain/chat/models/openai/OpenAI.java @@ -27,6 +27,9 @@ */ public class OpenAI { + private OpenAI() { + } + public static Message convertLangChainToOpenAI(BaseMessage message) { if (message instanceof ChatMessage chatMessage) { return Message.of(chatMessage.getRole(), message.getContent()); @@ -47,15 +50,19 @@ public static BaseMessage convertOpenAiToLangChain(Message message) { Role role = message.getRole(); String content = message.getContent(); switch (role) { - case USER: + case USER -> { return new HumanMessage(content); - case ASSISTANT: + } + case ASSISTANT -> { content = content != null ? content : ""; return new AIMessage(content); - case SYSTEM: + } + case SYSTEM -> { return new SystemMessage(content); - default: + } + default -> { return new ChatMessage(role.getValue(), content); + } } } } diff --git a/langchain-core/src/main/java/com/hw/langchain/schema/LLMResult.java b/langchain-core/src/main/java/com/hw/langchain/schema/LLMResult.java index 5a89e22ed..714035128 100644 --- a/langchain-core/src/main/java/com/hw/langchain/schema/LLMResult.java +++ b/langchain-core/src/main/java/com/hw/langchain/schema/LLMResult.java @@ -18,7 +18,6 @@ package com.hw.langchain.schema; -import lombok.AllArgsConstructor; import lombok.Data; import java.util.List; @@ -29,16 +28,20 @@ * @author HamaWhite */ @Data -@AllArgsConstructor public class LLMResult { /** * List of the things generated. This is List> because each input could have multiple generations. */ - private List> generations; + private List> generations; /** * For arbitrary LLM provider specific output. */ private Map llmOutput; + + public LLMResult(List> generations, Map llmOutput) { + this.generations = generations; + this.llmOutput = llmOutput; + } } diff --git a/langchain-core/src/test/java/com/hw/langchain/chat/models/openai/ChatOpenAITest.java b/langchain-core/src/test/java/com/hw/langchain/chat/models/openai/ChatOpenAITest.java index b295f5828..88fd74a0a 100644 --- a/langchain-core/src/test/java/com/hw/langchain/chat/models/openai/ChatOpenAITest.java +++ b/langchain-core/src/test/java/com/hw/langchain/chat/models/openai/ChatOpenAITest.java @@ -18,26 +18,86 @@ package com.hw.langchain.chat.models.openai; +import com.hw.langchain.schema.AIMessage; import com.hw.langchain.schema.HumanMessage; +import com.hw.langchain.schema.SystemMessage; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.List; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; /** * @author HamaWhite */ +@Disabled("Test requires costly OpenAI calls, can be run manually.") class ChatOpenAITest { - @Test - void testChatOpenAI() { - var chat = ChatOpenAI.builder() + private static final Logger LOG = LoggerFactory.getLogger(ChatOpenAITest.class); + + private static ChatOpenAI chat; + + @BeforeAll + public static void setup() { + chat = ChatOpenAI.builder() .temperature(0) .build() .init(); + } + /** + * You can get completions by passing in a single message. + */ + @Test + void testChatWithSingleMessage() { var message = new HumanMessage("Translate this sentence from English to French. I love programming."); - // var actual = chat.call(List.of(message)); + var actual = chat.call(List.of(message)); + + var expected = new AIMessage("J'aime programmer."); + assertEquals(expected, actual); + } + + /** + * You can also pass in multiple messages for OpenAI’s gpt-3.5-turbo and gpt-4 models. + */ + @Test + void testChatWithMultiMessages() { + var messages = List.of( + new SystemMessage("You are a helpful assistant that translates English to French."), + new HumanMessage("I love programming.")); + var actual = chat.call(messages); + + var expected = new AIMessage("J'adore la programmation."); + assertEquals(expected, actual); + } + + /** + * You can go one step further and generate completions for multiple sets of messages using generate. + * This returns an LLMResult with an additional message parameter. + */ + @Test + void testGenerateWithMultiMessages() { + var batchMessages = List.of( + List.of( + new SystemMessage("You are a helpful assistant that translates English to French."), + new HumanMessage("I love programming.")), + List.of( + new SystemMessage("You are a helpful assistant that translates English to French."), + new HumanMessage("I love artificial intelligence.")) + + ); + var result = chat.generate(batchMessages); + assertNotNull(result, "result should not be null"); - // var expected = new AIMessage("xxx"); - // assertEquals(expected, actual); + LOG.info("result: {}", result); + LOG.info("token_usage: {}", result.getLlmOutput().get("token_usage")); + assertThat(result.getGenerations()).isNotNull().hasSize(2); } } \ No newline at end of file diff --git a/openai-client/src/main/java/com/hw/openai/entity/completions/Usage.java b/openai-client/src/main/java/com/hw/openai/entity/completions/Usage.java index e43f227b2..3af5c1309 100644 --- a/openai-client/src/main/java/com/hw/openai/entity/completions/Usage.java +++ b/openai-client/src/main/java/com/hw/openai/entity/completions/Usage.java @@ -20,13 +20,17 @@ import com.fasterxml.jackson.annotation.JsonProperty; +import lombok.AllArgsConstructor; import lombok.Data; +import lombok.NoArgsConstructor; /** * Usage * @author HamaWhite */ @Data +@NoArgsConstructor +@AllArgsConstructor public class Usage { @JsonProperty("prompt_tokens")