Skip to content

Commit

Permalink
Add Chat Model(60%)
Browse files Browse the repository at this point in the history
  • Loading branch information
HamaWhiteGG committed Jun 18, 2023
1 parent b9dfb47 commit dcb4853
Show file tree
Hide file tree
Showing 16 changed files with 305 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,15 @@ public abstract class BaseChatModel implements BaseLanguageModel {
*/
private List<String> tags;

public LLMResult generate(List<List<BaseMessage>> messages, List<String> stop) {
return generate(messages, stop, null);
}

/**
* Top Level call
*/
public LLMResult generate(List<List<BaseMessage>> messages, List<String> stop, List<String> tags) {
public LLMResult generate(List<List<BaseMessage>> messages, List<String> stop) {
List<ChatResult> results = messages.stream()
.map(message -> _generate(message, stop))
.toList();

// TODO
return null;
}

Expand All @@ -57,6 +57,15 @@ public LLMResult generatePrompt(List<PromptValue> prompts, List<String> stop) {

}

/**
* Top Level call
*/
public abstract ChatResult _generate(List<BaseMessage> messages, List<String> stop);

public BaseMessage call(List<BaseMessage> messages) {
return call(messages, null);
}

public BaseMessage call(List<BaseMessage> messages, List<String> stop) {
var generation = generate(List.of(messages), stop).getGenerations().get(0).get(0);
if (generation instanceof ChatGeneration chatGeneration) {
Expand All @@ -80,4 +89,9 @@ public BaseMessage predictMessages(List<BaseMessage> messages, List<String> stop
List<String> copyStop = stop != null ? List.copyOf(stop) : null;
return call(messages, copyStop);
}

/**
* Return type of chat model.
*/
public abstract String llmType();
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,24 @@
package com.hw.langchain.chat.models.openai;

import com.hw.langchain.chat.models.base.BaseChatModel;
import com.hw.langchain.schema.BaseMessage;
import com.hw.langchain.schema.ChatGeneration;
import com.hw.langchain.schema.ChatResult;
import com.hw.openai.OpenAiClient;
import com.hw.openai.entity.chat.ChatCompletion;
import com.hw.openai.entity.chat.ChatCompletionResp;
import com.hw.openai.entity.chat.Message;

import lombok.Builder;
import lombok.experimental.SuperBuilder;

import java.util.HashMap;
import java.util.List;
import java.util.Map;

import static com.hw.langchain.chat.models.openai.OpenAI.convertOpenAiToLangChain;
import static com.hw.langchain.utils.Utils.getOrEnvOrDefault;

/**
* Wrapper around OpenAI Chat large language models.
*
Expand All @@ -34,13 +45,13 @@
@SuperBuilder
public class ChatOpenAI extends BaseChatModel {

protected Object client;
protected OpenAiClient client;

/**
* Model name to use.
*/
@Builder.Default
protected String modelName = "gpt-3.5-turbo";
protected String model = "gpt-3.5-turbo";

/**
* What sampling temperature to use.
Expand Down Expand Up @@ -83,7 +94,7 @@ public class ChatOpenAI extends BaseChatModel {
/**
* Whether to stream the results or not.
*/
protected boolean streaming;
protected boolean stream;

/**
* Number of chat completions to generate for each prompt.
Expand All @@ -96,4 +107,72 @@ public class ChatOpenAI extends BaseChatModel {
*/
protected Integer maxTokens;

/**
* Validate that api key exists in environment.
*/
public ChatOpenAI init() {
openaiApiKey = getOrEnvOrDefault(openaiApiKey, "OPENAI_API_KEY");
openaiOrganization = getOrEnvOrDefault(openaiOrganization, "OPENAI_ORGANIZATION", "");
openaiApiBase = getOrEnvOrDefault(openaiApiBase, "OPENAI_API_BASE", "");
openaiProxy = getOrEnvOrDefault(openaiProxy, "OPENAI_PROXY", "");

this.client = OpenAiClient.builder()
.openaiApiBase(openaiApiBase)
.openaiApiKey(openaiApiKey)
.openaiOrganization(openaiOrganization)
.openaiProxy(openaiProxy)
.requestTimeout(requestTimeout)
.build()
.init();

if (n < 1) {
throw new IllegalArgumentException("n must be at least 1.");
}
if (n > 1 && stream) {
throw new IllegalArgumentException("n must be 1 when streaming.");
}
return this;
}

@Override
public ChatResult _generate(List<BaseMessage> messages, List<String> stop) {
var chatMessages = convertMessages(messages);

ChatCompletion chatCompletion = ChatCompletion.builder()
.model(model)
.temperature(temperature)
.messages(chatMessages)
.maxTokens(maxTokens)
.stream(stream)
.n(n)
.stop(stop)
.build();

var response = client.create(chatCompletion);
return createChatResult(response);
}

public List<Message> convertMessages(List<BaseMessage> messages) {
return messages.stream()
.map(OpenAI::convertLangChainToOpenAI)
.toList();
}

public ChatResult createChatResult(ChatCompletionResp response) {
List<ChatGeneration> generations = response.getChoices()
.stream()
.map(choice -> convertOpenAiToLangChain(choice.getMessage()))
.map(ChatGeneration::new)
.toList();

Map<String, Object> llmOutput = Map.of(
"token_usage", response.getUsage(),
"model_name", response.getModel());
return new ChatResult(generations, llmOutput);
}

@Override
public String llmType() {
return "openai-chat";
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.hw.langchain.chat.models.openai;

import com.hw.langchain.schema.*;
import com.hw.openai.entity.chat.Message;
import com.hw.openai.entity.chat.Role;

/**
* @author HamaWhite
*/
public class OpenAI {

public static Message convertLangChainToOpenAI(BaseMessage message) {
if (message instanceof ChatMessage chatMessage) {
return Message.of(chatMessage.getRole(), message.getContent());
} else if (message instanceof HumanMessage) {
return Message.of(message.getContent());
} else if (message instanceof AIMessage) {
return Message.ofAssistant(message.getContent());
} else if (message instanceof SystemMessage) {
return Message.ofSystem(message.getContent());
} else if (message instanceof FunctionMessage functionMessage) {
return Message.ofFunction(message.getContent(), functionMessage.getName());
} else {
throw new IllegalArgumentException("Got unknown type " + message.getClass().getSimpleName());
}
}

public static BaseMessage convertOpenAiToLangChain(Message message) {
Role role = message.getRole();
String content = message.getContent();
switch (role) {
case USER:
return new HumanMessage(content);
case ASSISTANT:
content = content != null ? content : "";
return new AIMessage(content);
case SYSTEM:
return new SystemMessage(content);
default:
return new ChatMessage(role.getValue(), content);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,13 @@
@SuperBuilder
public class BaseOpenAI extends BaseLLM {

protected Object client;
protected OpenAiClient client;

/**
* Model name to use.
*/
@Builder.Default
protected String modelName = "text-davinci-003";
protected String model = "text-davinci-003";

/**
* What sampling temperature to use.
Expand Down Expand Up @@ -137,7 +137,7 @@ public class BaseOpenAI extends BaseLLM {
/**
* Whether to stream the results or not.
*/
protected boolean streaming;
protected boolean stream;

/**
* Set of special tokens that are allowed.
Expand Down Expand Up @@ -176,7 +176,7 @@ protected LLMResult _generate(List<String> prompts, List<String> stop) {
List<Choice> choices = new ArrayList<>();
List<List<String>> subPrompts = getSubPrompts(prompts);
Completion completion = Completion.builder()
.model(modelName)
.model(model)
.temperature(temperature)
.maxTokens(maxTokens)
.topP(topP)
Expand All @@ -189,7 +189,7 @@ protected LLMResult _generate(List<String> prompts, List<String> stop) {

for (var prompt : subPrompts) {
completion.setPrompt(prompt);
CompletionResp response = ((OpenAiClient) client).create(completion);
CompletionResp response = client.create(completion);
choices.addAll(response.getChoices());
}

Expand Down Expand Up @@ -219,8 +219,8 @@ private LLMResult createLLMResult(List<Choice> choices, List<String> prompts, Ma
}

Map<String, Object> llmOutput = new HashMap<>(2);
llmOutput.put("tokenUsage", tokenUsage);
llmOutput.put("modelName", modelName);
llmOutput.put("token_usage", tokenUsage);
llmOutput.put("model_name", model);

return new LLMResult(generations, llmOutput);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ public class OpenAI extends BaseOpenAI {
* Validate that api key exists in environment.
*/
public OpenAI init() {
checkArgument(!(streaming && n > 1), "Cannot stream results when n > 1.");
checkArgument(!(streaming && bestOf > 1), "Cannot stream results when bestOf > 1.");
checkArgument(!(stream && n > 1), "Cannot stream results when n > 1.");
checkArgument(!(stream && bestOf > 1), "Cannot stream results when bestOf > 1.");

openaiApiKey = Utils.getOrEnvOrDefault(openaiApiKey, "OPENAI_API_KEY");
openaiApiBase = Utils.getOrEnvOrDefault(openaiApiBase, "OPENAI_API_BASE", "");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,13 @@
@SuperBuilder
public class OpenAIChat extends BaseLLM {

protected Object client;
protected OpenAiClient client;

/**
* Model name to use.
*/
@Builder.Default
protected String modelName = "gpt-3.5-turbo";
protected String model = "gpt-3.5-turbo";

/**
* What sampling temperature to use.
Expand Down Expand Up @@ -131,7 +131,7 @@ public class OpenAIChat extends BaseLLM {
/**
* Whether to stream the results or not.
*/
protected boolean streaming;
protected boolean stream;

public OpenAIChat init() {
openaiApiBase = Utils.getOrEnvOrDefault(openaiApiBase, "OPENAI_API_BASE", "");
Expand Down Expand Up @@ -170,7 +170,7 @@ protected LLMResult _generate(List<String> prompts, List<String> stop) {
List<Message> messages = getChatMessages(prompts);

ChatCompletion chatCompletion = ChatCompletion.builder()
.model(modelName)
.model(model)
.temperature(temperature)
.messages(messages)
.maxTokens(maxTokens)
Expand All @@ -182,7 +182,7 @@ protected LLMResult _generate(List<String> prompts, List<String> stop) {
.stop(stop)
.build();

ChatCompletionResp response = ((OpenAiClient) client).create(chatCompletion);
ChatCompletionResp response = client.create(chatCompletion);

List<List<Generation>> generations = new ArrayList<>();
Generation generation = Generation.builder()
Expand All @@ -192,8 +192,8 @@ protected LLMResult _generate(List<String> prompts, List<String> stop) {
generations.add(List.of(generation));

Map<String, Object> llmOutput = new HashMap<>(2);
llmOutput.put("tokenUsage", response.getUsage());
llmOutput.put("modelName", modelName);
llmOutput.put("token_usage", response.getUsage());
llmOutput.put("model_name", response.getModel());

return new LLMResult(generations, llmOutput);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,18 @@
package com.hw.langchain.schema;

import lombok.Data;
import lombok.NoArgsConstructor;

import java.util.HashMap;
import java.util.Map;

/**
* Message object.
*
* @author HamaWhite
*/
@Data
@NoArgsConstructor
public abstract class BaseMessage {

protected String content;
Expand All @@ -35,6 +39,7 @@ public abstract class BaseMessage {

protected BaseMessage(String content) {
this.content = content;
this.additionalKwargs = new HashMap<>();
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@

import lombok.Getter;

import java.util.Map;

/**
* Output of a single generation.
*
Expand All @@ -33,7 +31,8 @@ public class ChatGeneration extends Generation {

protected BaseMessage message;

public ChatGeneration(String text, Map<String, Object> generationInfo) {
super(text, generationInfo);
public ChatGeneration(BaseMessage message) {
super(message.getContent());
this.message = message;
}
}
Loading

0 comments on commit dcb4853

Please sign in to comment.