Skip to content

Commit

Permalink
Support Chat Model and ChatOpenAITest
Browse files Browse the repository at this point in the history
  • Loading branch information
HamaWhiteGG committed Jun 19, 2023
1 parent dcb4853 commit 4bd00bd
Show file tree
Hide file tree
Showing 6 changed files with 121 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import lombok.experimental.SuperBuilder;

import java.util.List;
import java.util.Map;

/**
* @author HamaWhite
Expand All @@ -36,6 +37,14 @@ public abstract class BaseChatModel implements BaseLanguageModel {
*/
private List<String> tags;

public Map<String, Object> combineLlmOutputs(List<Map<String, Object>> llmOutputs) {
return Map.of();
}

public LLMResult generate(List<List<BaseMessage>> messages) {
return generate(messages, null);
}

/**
* Top Level call
*/
Expand All @@ -44,8 +53,15 @@ public LLMResult generate(List<List<BaseMessage>> messages, List<String> stop) {
.map(message -> _generate(message, stop))
.toList();

// TODO
return null;
List<Map<String, Object>> llmOutputs = results.stream()
.map(ChatResult::getLlmOutput)
.toList();
Map<String, Object> llmOutput = combineLlmOutputs(llmOutputs);

List<List<ChatGeneration>> generations = results.stream()
.map(ChatResult::getGenerations)
.toList();
return new LLMResult(generations, llmOutput);
}

@Override
Expand All @@ -54,7 +70,6 @@ public LLMResult generatePrompt(List<PromptValue> prompts, List<String> stop) {
.map(PromptValue::toMessages)
.toList();
return generate(promptMessages, stop);

}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,15 @@
import com.hw.openai.entity.chat.ChatCompletion;
import com.hw.openai.entity.chat.ChatCompletionResp;
import com.hw.openai.entity.chat.Message;
import com.hw.openai.entity.completions.Usage;

import lombok.Builder;
import lombok.experimental.SuperBuilder;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;

import static com.hw.langchain.chat.models.openai.OpenAI.convertOpenAiToLangChain;
import static com.hw.langchain.utils.Utils.getOrEnvOrDefault;
Expand Down Expand Up @@ -134,6 +136,20 @@ public ChatOpenAI init() {
return this;
}

@Override
public Map<String, Object> combineLlmOutputs(List<Map<String, Object>> llmOutputs) {
Usage usage = llmOutputs.stream()
.filter(Objects::nonNull)
.map(e -> (Usage) e.get("token_usage"))
.reduce((a1, a2) -> new Usage(
a1.getPromptTokens() + a2.getPromptTokens(),
a1.getCompletionTokens() + a2.getCompletionTokens(),
a1.getTotalTokens() + a2.getTotalTokens()))
.orElse(new Usage());

return Map.of("token_usage", usage, "model_name", this.model);
}

@Override
public ChatResult _generate(List<BaseMessage> messages, List<String> stop) {
var chatMessages = convertMessages(messages);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@
*/
public class OpenAI {

private OpenAI() {
}

public static Message convertLangChainToOpenAI(BaseMessage message) {
if (message instanceof ChatMessage chatMessage) {
return Message.of(chatMessage.getRole(), message.getContent());
Expand All @@ -47,15 +50,19 @@ public static BaseMessage convertOpenAiToLangChain(Message message) {
Role role = message.getRole();
String content = message.getContent();
switch (role) {
case USER:
case USER -> {
return new HumanMessage(content);
case ASSISTANT:
}
case ASSISTANT -> {
content = content != null ? content : "";
return new AIMessage(content);
case SYSTEM:
}
case SYSTEM -> {
return new SystemMessage(content);
default:
}
default -> {
return new ChatMessage(role.getValue(), content);
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

package com.hw.langchain.schema;

import lombok.AllArgsConstructor;
import lombok.Data;

import java.util.List;
Expand All @@ -29,16 +28,20 @@
* @author HamaWhite
*/
@Data
@AllArgsConstructor
public class LLMResult {

/**
* List of the things generated. This is List<List<Generation>> because each input could have multiple generations.
*/
private List<List<Generation>> generations;
private List<? extends List<? extends Generation>> generations;

/**
* For arbitrary LLM provider specific output.
*/
private Map<String, Object> llmOutput;

public LLMResult(List<? extends List<? extends Generation>> generations, Map<String, Object> llmOutput) {
this.generations = generations;
this.llmOutput = llmOutput;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,26 +18,86 @@

package com.hw.langchain.chat.models.openai;

import com.hw.langchain.schema.AIMessage;
import com.hw.langchain.schema.HumanMessage;
import com.hw.langchain.schema.SystemMessage;

import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.List;

import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;

/**
* @author HamaWhite
*/
@Disabled("Test requires costly OpenAI calls, can be run manually.")
class ChatOpenAITest {

@Test
void testChatOpenAI() {
var chat = ChatOpenAI.builder()
private static final Logger LOG = LoggerFactory.getLogger(ChatOpenAITest.class);

private static ChatOpenAI chat;

@BeforeAll
public static void setup() {
chat = ChatOpenAI.builder()
.temperature(0)
.build()
.init();
}

/**
* You can get completions by passing in a single message.
*/
@Test
void testChatWithSingleMessage() {
var message = new HumanMessage("Translate this sentence from English to French. I love programming.");
// var actual = chat.call(List.of(message));
var actual = chat.call(List.of(message));

var expected = new AIMessage("J'aime programmer.");
assertEquals(expected, actual);
}

/**
* You can also pass in multiple messages for OpenAI’s gpt-3.5-turbo and gpt-4 models.
*/
@Test
void testChatWithMultiMessages() {
var messages = List.of(
new SystemMessage("You are a helpful assistant that translates English to French."),
new HumanMessage("I love programming."));
var actual = chat.call(messages);

var expected = new AIMessage("J'adore la programmation.");
assertEquals(expected, actual);
}

/**
* You can go one step further and generate completions for multiple sets of messages using generate.
* This returns an LLMResult with an additional message parameter.
*/
@Test
void testGenerateWithMultiMessages() {
var batchMessages = List.of(
List.of(
new SystemMessage("You are a helpful assistant that translates English to French."),
new HumanMessage("I love programming.")),
List.of(
new SystemMessage("You are a helpful assistant that translates English to French."),
new HumanMessage("I love artificial intelligence."))

);
var result = chat.generate(batchMessages);
assertNotNull(result, "result should not be null");

// var expected = new AIMessage("xxx");
// assertEquals(expected, actual);
LOG.info("result: {}", result);
LOG.info("token_usage: {}", result.getLlmOutput().get("token_usage"));
assertThat(result.getGenerations()).isNotNull().hasSize(2);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,17 @@

import com.fasterxml.jackson.annotation.JsonProperty;

import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;

/**
* Usage
* @author HamaWhite
*/
@Data
@NoArgsConstructor
@AllArgsConstructor
public class Usage {

@JsonProperty("prompt_tokens")
Expand Down

0 comments on commit 4bd00bd

Please sign in to comment.