Skip to content

Commit

Permalink
Merge pull request HamaWhiteGG#151 from HamaWhiteGG/dev
Browse files Browse the repository at this point in the history
Optimize class naming in the openai-client
  • Loading branch information
HamaWhiteGG authored Dec 21, 2023
2 parents cef7500 + ed69f54 commit 1c3978e
Show file tree
Hide file tree
Showing 15 changed files with 68 additions and 57 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import com.hw.openai.common.OpenaiApiType;
import com.hw.openai.entity.chat.ChatCompletion;
import com.hw.openai.entity.chat.ChatCompletionResp;
import com.hw.openai.entity.chat.Message;
import com.hw.openai.entity.chat.ChatMessage;
import com.hw.openai.entity.completions.Usage;

import lombok.Builder;
Expand Down Expand Up @@ -185,7 +185,7 @@ public ChatResult innerGenerate(List<BaseMessage> messages, List<String> stop) {
return createChatResult(response);
}

public List<Message> convertMessages(List<BaseMessage> messages) {
public List<ChatMessage> convertMessages(List<BaseMessage> messages) {
return messages.stream()
.map(OpenAI::convertLangChainToOpenAI)
.toList();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
package com.hw.langchain.chat.models.openai;

import com.hw.langchain.schema.*;
import com.hw.openai.entity.chat.Message;
import com.hw.openai.entity.chat.Role;
import com.hw.openai.entity.chat.ChatMessage;
import com.hw.openai.entity.chat.ChatMessageRole;

/**
* @author HamaWhite
Expand All @@ -30,24 +30,24 @@ public class OpenAI {
private OpenAI() {
}

public static Message convertLangChainToOpenAI(BaseMessage message) {
if (message instanceof ChatMessage chatMessage) {
return Message.of(chatMessage.getRole(), message.getContent());
public static ChatMessage convertLangChainToOpenAI(BaseMessage message) {
if (message instanceof com.hw.langchain.schema.ChatMessage chatMessage) {
return ChatMessage.of(chatMessage.getRole(), message.getContent());
} else if (message instanceof HumanMessage) {
return Message.of(message.getContent());
return ChatMessage.of(message.getContent());
} else if (message instanceof AIMessage) {
return Message.ofAssistant(message.getContent());
return ChatMessage.ofAssistant(message.getContent());
} else if (message instanceof SystemMessage) {
return Message.ofSystem(message.getContent());
return ChatMessage.ofSystem(message.getContent());
} else if (message instanceof FunctionMessage functionMessage) {
return Message.ofFunction(message.getContent(), functionMessage.getName());
return ChatMessage.ofFunction(message.getContent(), functionMessage.getName());
} else {
throw new IllegalArgumentException("Got unknown type " + message.getClass().getSimpleName());
}
}

public static BaseMessage convertOpenAiToLangChain(Message message) {
Role role = message.getRole();
public static BaseMessage convertOpenAiToLangChain(ChatMessage message) {
ChatMessageRole role = message.getRole();
String content = message.getContent();
switch (role) {
case USER -> {
Expand All @@ -61,7 +61,7 @@ public static BaseMessage convertOpenAiToLangChain(Message message) {
return new SystemMessage(content);
}
default -> {
return new ChatMessage(content, role.getValue());
return new com.hw.langchain.schema.ChatMessage(content, role.getValue());
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import com.hw.openai.common.OpenaiApiType;
import com.hw.openai.entity.chat.ChatCompletion;
import com.hw.openai.entity.chat.ChatCompletionResp;
import com.hw.openai.entity.chat.Message;
import com.hw.openai.entity.chat.ChatMessage;

import lombok.Builder;
import lombok.experimental.SuperBuilder;
Expand Down Expand Up @@ -128,7 +128,7 @@ public class OpenAIChat extends BaseLLM {
* Series of messages for Chat input.
*/
@Builder.Default
private List<Message> prefixMessages = new ArrayList<>();
private List<ChatMessage> prefixMessages = new ArrayList<>();

/**
* Timeout for requests to OpenAI completion API. Default is 16 seconds.
Expand Down Expand Up @@ -173,7 +173,7 @@ public String llmType() {

@Override
protected LLMResult innerGenerate(List<String> prompts, List<String> stop) {
List<Message> messages = getChatMessages(prompts);
List<ChatMessage> messages = getChatMessages(prompts);

ChatCompletion chatCompletion = ChatCompletion.builder()
.model(model)
Expand Down Expand Up @@ -209,10 +209,10 @@ protected Flux<AsyncLLMResult> asyncInnerGenerate(List<String> prompts, List<Str
throw new UnsupportedOperationException("not supported yet.");
}

private List<Message> getChatMessages(List<String> prompts) {
private List<ChatMessage> getChatMessages(List<String> prompts) {
checkArgument(prompts.size() == 1, "OpenAIChat currently only supports single prompt, got %s", prompts);
List<Message> messages = new ArrayList<>(prefixMessages);
messages.add(Message.of(prompts.get(0)));
List<ChatMessage> messages = new ArrayList<>(prefixMessages);
messages.add(ChatMessage.of(prompts.get(0)));
return messages;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import com.hw.openai.OpenAiClient;
import com.hw.openai.entity.chat.ChatCompletion;
import com.hw.openai.entity.chat.ChatCompletionResp;
import com.hw.openai.entity.chat.Message;
import com.hw.openai.entity.chat.ChatMessage;

import java.util.List;

Expand All @@ -40,7 +40,7 @@ public static void main(String[] args) {
.build()
.init();

Message message = Message.of("Introduce West Lake in Hangzhou, China.");
ChatMessage message = ChatMessage.of("Introduce West Lake in Hangzhou, China.");
ChatCompletion chatCompletion = ChatCompletion.builder()
.model("gpt-4")
.temperature(0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import com.hw.langchain.examples.runner.RunnableExample;
import com.hw.openai.OpenAiClient;
import com.hw.openai.entity.chat.ChatCompletion;
import com.hw.openai.entity.chat.Message;
import com.hw.openai.entity.chat.ChatMessage;

import java.util.List;

Expand All @@ -37,7 +37,7 @@ public static void main(String[] args) {
.build()
.init();

Message message = Message.of("Introduce West Lake in Hangzhou, China.");
ChatMessage message = ChatMessage.of("Introduce West Lake in Hangzhou, China.");
ChatCompletion chatCompletion = ChatCompletion.builder()
.model("gpt-4")
.messages(List.of(message))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ public class ChatChoice {
private Integer index;

@JsonAlias("delta")
private Message message;
private ChatMessage message;

/**
* The reason the model stopped generating tokens. This will be stopped if the model hit a natural stop point or a
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
import java.util.Map;

/**
* ChatCompletion
* Chat conversation.
* @author HamaWhite
*/
@Data
Expand All @@ -49,7 +49,7 @@ public class ChatCompletion implements Serializable {
* A list of messages describing the conversation so far.
*/
@NotEmpty
private List<Message> messages;
private List<ChatMessage> messages;

/**
* What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,13 @@
@NoArgsConstructor
@AllArgsConstructor
@JsonInclude(JsonInclude.Include.NON_NULL)
public class Message implements Serializable {
public class ChatMessage implements Serializable {

/**
* The role of the author of this message. One of system, user, or assistant.
*/
@NotNull
private Role role;
private ChatMessageRole role;

/**
* The contents of the message.
Expand All @@ -67,34 +67,45 @@ public class Message implements Serializable {
@JsonProperty("tool_calls")
private List<ToolCall> toolCalls;

public Message(Role role, String content) {
public ChatMessage(String content) {
this.content = content;
}

public ChatMessage(ChatMessageRole role, String content) {
this.role = role;
this.content = content;
}

public Message(Role role, String content, String name) {
public ChatMessage(ChatMessageRole role, String content, String name) {
this.role = role;
this.content = content;
this.name = name;
}

public static Message of(String role, String content) {
return new Message(Role.fromValue(role), content);
public ChatMessage(ChatMessage source) {
this.role = source.role;
this.content = source.content;
this.name = source.name;
this.toolCalls = source.toolCalls;
}

public static ChatMessage of(String role, String content) {
return new ChatMessage(ChatMessageRole.fromValue(role), content);
}

public static Message of(String content) {
return new Message(Role.USER, content);
public static ChatMessage of(String content) {
return new ChatMessage(ChatMessageRole.USER, content);
}

public static Message ofSystem(String content) {
return new Message(Role.SYSTEM, content);
public static ChatMessage ofSystem(String content) {
return new ChatMessage(ChatMessageRole.SYSTEM, content);
}

public static Message ofAssistant(String content) {
return new Message(Role.ASSISTANT, content);
public static ChatMessage ofAssistant(String content) {
return new ChatMessage(ChatMessageRole.ASSISTANT, content);
}

public static Message ofFunction(String content, String name) {
return new Message(Role.FUNCTION, content, name);
public static ChatMessage ofFunction(String content, String name) {
return new ChatMessage(ChatMessageRole.FUNCTION, content, name);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
* Role
* @author HamaWhite
*/
public enum Role {
public enum ChatMessageRole {

/**
* System role.
Expand All @@ -49,7 +49,7 @@ public enum Role {

private final String value;

Role(String value) {
ChatMessageRole(String value) {
this.value = value;
}

Expand All @@ -59,10 +59,10 @@ public String getValue() {
}

@JsonCreator
public static Role fromValue(String value) {
for (Role role : Role.values()) {
if (role.value.equalsIgnoreCase(value)) {
return role;
public static ChatMessageRole fromValue(String value) {
for (ChatMessageRole item : ChatMessageRole.values()) {
if (item.value.equalsIgnoreCase(value)) {
return item;
}
}
throw new IllegalArgumentException("Invalid Role value: " + value);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
* @author HamaWhite
*/
@Data
public class Function {
public class FunctionCall {

/**
* The name of the function to call.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,5 +38,5 @@ public class ToolCall {
*/
private String type;

private Function function;
private FunctionCall function;
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

import com.hw.openai.common.OpenaiApiType;
import com.hw.openai.entity.chat.ChatCompletion;
import com.hw.openai.entity.chat.Message;
import com.hw.openai.entity.chat.ChatMessage;
import com.hw.openai.entity.completions.Completion;
import com.hw.openai.entity.embeddings.Embedding;

Expand Down Expand Up @@ -71,7 +71,7 @@ void testCompletion() {

@Test
void testChatCompletion() {
Message message = Message.of("Hello!");
ChatMessage message = ChatMessage.of("Hello!");

ChatCompletion chatCompletion = ChatCompletion.builder()
.model("gpt-35-turbo")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
package com.hw.openai;

import com.hw.openai.entity.chat.ChatCompletion;
import com.hw.openai.entity.chat.Message;
import com.hw.openai.entity.chat.ChatMessage;
import com.hw.openai.entity.completions.Completion;
import com.hw.openai.entity.embeddings.Embedding;
import com.hw.openai.entity.models.Model;
Expand Down Expand Up @@ -115,7 +115,7 @@ void testStreamCompletion() {

@Test
void testChatCompletion() {
Message message = Message.of("Hello!");
ChatMessage message = ChatMessage.of("Hello!");

ChatCompletion chatCompletion = ChatCompletion.builder()
.model("gpt-3.5-turbo")
Expand All @@ -128,7 +128,7 @@ void testChatCompletion() {

@Test
void testStreamChatCompletion() {
Message message = Message.of("Hello!");
ChatMessage message = ChatMessage.of("Hello!");

ChatCompletion chatCompletion = ChatCompletion.builder()
.model("gpt-3.5-turbo")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ void testChatFunction() {
.parameters(ChatParameterUtils.generate(Weather.class))
.build();

Message message = Message.of("What is the weather like in Boston?");
ChatMessage message = ChatMessage.of("What is the weather like in Boston?");

ChatCompletion chatCompletion = ChatCompletion.builder()
.model("gpt-4")
Expand All @@ -86,7 +86,7 @@ void testChatFunction() {
assertThat(chatChoice).isNotNull();
assertEquals("tool_calls", chatChoice.getFinishReason());

Function function = chatChoice.getMessage().getToolCalls().get(0).getFunction();
FunctionCall function = chatChoice.getMessage().getToolCalls().get(0).getFunction();
// name=get_current_weather, arguments={ "location": "Boston" }
assertEquals(functionName, function.getName());

Expand Down
2 changes: 1 addition & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
<awaitility.version>4.2.0</awaitility.version>
<reflections.version>0.10.2</reflections.version>
<resilience4j.version>2.1.0</resilience4j.version>
<slf4j-log4j12.version>1.7.25</slf4j-log4j12.version>
<slf4j-log4j12.version>1.7.36</slf4j-log4j12.version>
<commons-text.version>1.10.0</commons-text.version>
<commons-lang3.version>3.12.0</commons-lang3.version>
<commons-collections4.version>4.4</commons-collections4.version>
Expand Down

0 comments on commit 1c3978e

Please sign in to comment.