From cf2dd71c0e8c9ca09c6b48a5ddf3faa37da23ab6 Mon Sep 17 00:00:00 2001 From: HamaWhite Date: Tue, 20 Jun 2023 19:34:56 +0800 Subject: [PATCH] Support Prompt templates of ChatModels --- .../base/language/BaseLanguageModel.java | 13 ++- .../chat/models/base/BaseChatModel.java | 10 +++ .../langchain/chat/models/openai/OpenAI.java | 2 +- .../com/hw/langchain/llms/base/BaseLLM.java | 21 +++++ .../hw/langchain/llms/openai/BaseOpenAI.java | 11 --- .../hw/langchain/llms/openai/OpenAIChat.java | 11 --- .../prompts/base/BasePromptTemplate.java | 2 + .../prompts/chat/AIMessagePromptTemplate.java | 40 +++++++++ .../prompts/chat/BaseChatPromptTemplate.java | 52 +++++++++++ .../chat/BaseMessagePromptTemplate.java | 46 ++++++++++ .../chat/BaseStringMessagePromptTemplate.java | 65 ++++++++++++++ .../chat/ChatMessagePromptTemplate.java | 42 +++++++++ .../prompts/chat/ChatPromptTemplate.java | 89 +++++++++++++++++++ .../prompts/chat/ChatPromptValue.java | 51 +++++++++++ .../chat/HumanMessagePromptTemplate.java | 40 +++++++++ .../chat/SystemMessagePromptTemplate.java | 40 +++++++++ .../com/hw/langchain/schema/ChatMessage.java | 6 +- .../hw/langchain/chains/llm/LLMChainTest.java | 31 +++++++ .../chat/models/openai/ChatOpenAITest.java | 17 ++++ 19 files changed, 560 insertions(+), 29 deletions(-) create mode 100644 langchain-core/src/main/java/com/hw/langchain/prompts/chat/AIMessagePromptTemplate.java create mode 100644 langchain-core/src/main/java/com/hw/langchain/prompts/chat/BaseChatPromptTemplate.java create mode 100644 langchain-core/src/main/java/com/hw/langchain/prompts/chat/BaseMessagePromptTemplate.java create mode 100644 langchain-core/src/main/java/com/hw/langchain/prompts/chat/BaseStringMessagePromptTemplate.java create mode 100644 langchain-core/src/main/java/com/hw/langchain/prompts/chat/ChatMessagePromptTemplate.java create mode 100644 langchain-core/src/main/java/com/hw/langchain/prompts/chat/ChatPromptTemplate.java create mode 100644 langchain-core/src/main/java/com/hw/langchain/prompts/chat/ChatPromptValue.java create mode 100644 langchain-core/src/main/java/com/hw/langchain/prompts/chat/HumanMessagePromptTemplate.java create mode 100644 langchain-core/src/main/java/com/hw/langchain/prompts/chat/SystemMessagePromptTemplate.java diff --git a/langchain-core/src/main/java/com/hw/langchain/base/language/BaseLanguageModel.java b/langchain-core/src/main/java/com/hw/langchain/base/language/BaseLanguageModel.java index d221dfc3d..36e6f98d5 100644 --- a/langchain-core/src/main/java/com/hw/langchain/base/language/BaseLanguageModel.java +++ b/langchain-core/src/main/java/com/hw/langchain/base/language/BaseLanguageModel.java @@ -25,7 +25,8 @@ import java.util.List; /** - * BaseLanguageModel + * BaseLanguageModel is an interface for interacting with a language model. + * * @author HamaWhite */ public interface BaseLanguageModel { @@ -35,11 +36,21 @@ public interface BaseLanguageModel { */ LLMResult generatePrompt(List prompts, List stop); + /** + * Predict text from text. + */ + String predict(String text); + /** * Predict text from text. */ String predict(String text, List stop); + /** + * Predict message from messages. + */ + BaseMessage predictMessages(List messages); + /** * Predict message from messages. */ diff --git a/langchain-core/src/main/java/com/hw/langchain/chat/models/base/BaseChatModel.java b/langchain-core/src/main/java/com/hw/langchain/chat/models/base/BaseChatModel.java index 56bef5074..39007620b 100644 --- a/langchain-core/src/main/java/com/hw/langchain/chat/models/base/BaseChatModel.java +++ b/langchain-core/src/main/java/com/hw/langchain/chat/models/base/BaseChatModel.java @@ -90,6 +90,11 @@ public BaseMessage call(List messages, List stop) { } } + @Override + public String predict(String text) { + return predict(text, null); + } + @Override public String predict(String text, List stop) { List copyStop = stop != null ? List.copyOf(stop) : null; @@ -99,6 +104,11 @@ public String predict(String text, List stop) { return result.getContent(); } + @Override + public BaseMessage predictMessages(List messages) { + return predictMessages(messages, null); + } + @Override public BaseMessage predictMessages(List messages, List stop) { List copyStop = stop != null ? List.copyOf(stop) : null; diff --git a/langchain-core/src/main/java/com/hw/langchain/chat/models/openai/OpenAI.java b/langchain-core/src/main/java/com/hw/langchain/chat/models/openai/OpenAI.java index 41e10a5b9..19c9d2e98 100644 --- a/langchain-core/src/main/java/com/hw/langchain/chat/models/openai/OpenAI.java +++ b/langchain-core/src/main/java/com/hw/langchain/chat/models/openai/OpenAI.java @@ -61,7 +61,7 @@ public static BaseMessage convertOpenAiToLangChain(Message message) { return new SystemMessage(content); } default -> { - return new ChatMessage(role.getValue(), content); + return new ChatMessage(content, role.getValue()); } } } diff --git a/langchain-core/src/main/java/com/hw/langchain/llms/base/BaseLLM.java b/langchain-core/src/main/java/com/hw/langchain/llms/base/BaseLLM.java index 31a412eae..196120cd0 100644 --- a/langchain-core/src/main/java/com/hw/langchain/llms/base/BaseLLM.java +++ b/langchain-core/src/main/java/com/hw/langchain/llms/base/BaseLLM.java @@ -19,6 +19,7 @@ package com.hw.langchain.llms.base; import com.hw.langchain.base.language.BaseLanguageModel; +import com.hw.langchain.schema.BaseMessage; import com.hw.langchain.schema.LLMResult; import com.hw.langchain.schema.PromptValue; @@ -68,4 +69,24 @@ public LLMResult generatePrompt(List prompts, List stop) { .toList(); return generate(promptStrings, stop); } + + @Override + public String predict(String text) { + return predict(text, null); + } + + @Override + public String predict(String text, List stop) { + return null; + } + + @Override + public BaseMessage predictMessages(List messages) { + return predictMessages(messages, null); + } + + @Override + public BaseMessage predictMessages(List messages, List stop) { + return null; + } } diff --git a/langchain-core/src/main/java/com/hw/langchain/llms/openai/BaseOpenAI.java b/langchain-core/src/main/java/com/hw/langchain/llms/openai/BaseOpenAI.java index b5a3e88bb..09c7159f3 100644 --- a/langchain-core/src/main/java/com/hw/langchain/llms/openai/BaseOpenAI.java +++ b/langchain-core/src/main/java/com/hw/langchain/llms/openai/BaseOpenAI.java @@ -19,7 +19,6 @@ package com.hw.langchain.llms.openai; import com.hw.langchain.llms.base.BaseLLM; -import com.hw.langchain.schema.BaseMessage; import com.hw.langchain.schema.Generation; import com.hw.langchain.schema.LLMResult; import com.hw.openai.OpenAiClient; @@ -149,16 +148,6 @@ public class BaseOpenAI extends BaseLLM { */ protected Set disallowedSpecial; - @Override - public String predict(String text, List stop) { - return null; - } - - @Override - public BaseMessage predictMessages(List messages, List stop) { - return null; - } - @Override public String llmType() { return "openai"; diff --git a/langchain-core/src/main/java/com/hw/langchain/llms/openai/OpenAIChat.java b/langchain-core/src/main/java/com/hw/langchain/llms/openai/OpenAIChat.java index e9012cc05..e940c584b 100644 --- a/langchain-core/src/main/java/com/hw/langchain/llms/openai/OpenAIChat.java +++ b/langchain-core/src/main/java/com/hw/langchain/llms/openai/OpenAIChat.java @@ -19,7 +19,6 @@ package com.hw.langchain.llms.openai; import com.hw.langchain.llms.base.BaseLLM; -import com.hw.langchain.schema.BaseMessage; import com.hw.langchain.schema.Generation; import com.hw.langchain.schema.LLMResult; import com.hw.langchain.utils.Utils; @@ -150,16 +149,6 @@ public OpenAIChat init() { return this; } - @Override - public String predict(String text, List stop) { - return null; - } - - @Override - public BaseMessage predictMessages(List messages, List stop) { - return null; - } - @Override public String llmType() { return "openai-chat"; diff --git a/langchain-core/src/main/java/com/hw/langchain/prompts/base/BasePromptTemplate.java b/langchain-core/src/main/java/com/hw/langchain/prompts/base/BasePromptTemplate.java index 126609b70..1132f39e8 100644 --- a/langchain-core/src/main/java/com/hw/langchain/prompts/base/BasePromptTemplate.java +++ b/langchain-core/src/main/java/com/hw/langchain/prompts/base/BasePromptTemplate.java @@ -22,6 +22,7 @@ import com.hw.langchain.schema.PromptValue; import lombok.Data; +import lombok.NoArgsConstructor; import java.util.HashMap; import java.util.List; @@ -32,6 +33,7 @@ * @author HamaWhite */ @Data +@NoArgsConstructor public abstract class BasePromptTemplate { /** diff --git a/langchain-core/src/main/java/com/hw/langchain/prompts/chat/AIMessagePromptTemplate.java b/langchain-core/src/main/java/com/hw/langchain/prompts/chat/AIMessagePromptTemplate.java new file mode 100644 index 000000000..288845825 --- /dev/null +++ b/langchain-core/src/main/java/com/hw/langchain/prompts/chat/AIMessagePromptTemplate.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.hw.langchain.prompts.chat; + +import com.hw.langchain.schema.AIMessage; +import com.hw.langchain.schema.BaseMessage; + +import java.util.Map; + +/** + * @author HamaWhite + */ +public class AIMessagePromptTemplate extends BaseStringMessagePromptTemplate { + + public static AIMessagePromptTemplate fromTemplate(String template) { + return BaseStringMessagePromptTemplate.fromTemplate(AIMessagePromptTemplate.class, template); + } + + @Override + public BaseMessage format(Map kwargs) { + String text = prompt.format(kwargs); + return new AIMessage(text); + } +} \ No newline at end of file diff --git a/langchain-core/src/main/java/com/hw/langchain/prompts/chat/BaseChatPromptTemplate.java b/langchain-core/src/main/java/com/hw/langchain/prompts/chat/BaseChatPromptTemplate.java new file mode 100644 index 000000000..a4dd5cad9 --- /dev/null +++ b/langchain-core/src/main/java/com/hw/langchain/prompts/chat/BaseChatPromptTemplate.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.hw.langchain.prompts.chat; + +import com.hw.langchain.prompts.base.BasePromptTemplate; +import com.hw.langchain.schema.BaseMessage; +import com.hw.langchain.schema.PromptValue; + +import java.util.List; +import java.util.Map; + +/** + * @author HamaWhite + */ +public abstract class BaseChatPromptTemplate extends BasePromptTemplate { + + public BaseChatPromptTemplate(List inputVariables) { + super(inputVariables); + } + + @Override + public String format(Map kwargs) { + return formatPrompt(kwargs).toString(); + } + + public PromptValue formatPrompt(Map kwargs) { + List messages = formatMessages(kwargs); + return new ChatPromptValue(messages); + } + + /** + * Format kwargs into a list of messages. + */ + public abstract List formatMessages(Map kwargs); + +} diff --git a/langchain-core/src/main/java/com/hw/langchain/prompts/chat/BaseMessagePromptTemplate.java b/langchain-core/src/main/java/com/hw/langchain/prompts/chat/BaseMessagePromptTemplate.java new file mode 100644 index 000000000..fb0be7ee1 --- /dev/null +++ b/langchain-core/src/main/java/com/hw/langchain/prompts/chat/BaseMessagePromptTemplate.java @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.hw.langchain.prompts.chat; + +import com.hw.langchain.schema.BaseMessage; + +import java.io.Serializable; +import java.util.List; +import java.util.Map; + +/** + * @author HamaWhite + */ +public abstract class BaseMessagePromptTemplate implements Serializable { + + /** + * To messages. + * + * @param kwargs keyword arguments + * @return a list of BaseMessage + */ + public abstract List formatMessages(Map kwargs); + + /** + * Input variables for this prompt template. + * + * @return a list of input variables + */ + public abstract List inputVariables(); +} diff --git a/langchain-core/src/main/java/com/hw/langchain/prompts/chat/BaseStringMessagePromptTemplate.java b/langchain-core/src/main/java/com/hw/langchain/prompts/chat/BaseStringMessagePromptTemplate.java new file mode 100644 index 000000000..6475193bf --- /dev/null +++ b/langchain-core/src/main/java/com/hw/langchain/prompts/chat/BaseStringMessagePromptTemplate.java @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.hw.langchain.prompts.chat; + +import com.hw.langchain.exception.LangChainException; +import com.hw.langchain.prompts.base.StringPromptTemplate; +import com.hw.langchain.prompts.prompt.PromptTemplate; +import com.hw.langchain.schema.BaseMessage; + +import java.util.List; +import java.util.Map; + +/** + * @author HamaWhite + */ +public abstract class BaseStringMessagePromptTemplate extends BaseMessagePromptTemplate { + + protected StringPromptTemplate prompt; + + public static T fromTemplate(Class cls, String template) { + StringPromptTemplate prompt = PromptTemplate.fromTemplate(template); + try { + T instance = cls.getDeclaredConstructor().newInstance(); + instance.setPrompt(prompt); + return instance; + } catch (Exception e) { + throw new LangChainException("Failed to create instance of BaseStringMessagePromptTemplate", e); + } + } + + /** + * To a BaseMessage. + */ + public abstract BaseMessage format(Map kwargs); + + @Override + public List formatMessages(Map kwargs) { + return List.of(this.format(kwargs)); + } + + @Override + public List inputVariables() { + return prompt.getInputVariables(); + } + + public void setPrompt(StringPromptTemplate prompt) { + this.prompt = prompt; + } +} diff --git a/langchain-core/src/main/java/com/hw/langchain/prompts/chat/ChatMessagePromptTemplate.java b/langchain-core/src/main/java/com/hw/langchain/prompts/chat/ChatMessagePromptTemplate.java new file mode 100644 index 000000000..d4ec3b252 --- /dev/null +++ b/langchain-core/src/main/java/com/hw/langchain/prompts/chat/ChatMessagePromptTemplate.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.hw.langchain.prompts.chat; + +import com.hw.langchain.schema.BaseMessage; +import com.hw.langchain.schema.ChatMessage; + +import java.util.Map; + +/** + * @author HamaWhite + */ +public class ChatMessagePromptTemplate extends BaseStringMessagePromptTemplate { + + private String role; + + public static ChatMessagePromptTemplate fromTemplate(String template) { + return BaseStringMessagePromptTemplate.fromTemplate(ChatMessagePromptTemplate.class, template); + } + + @Override + public BaseMessage format(Map kwargs) { + String text = prompt.format(kwargs); + return new ChatMessage(text, role); + } +} diff --git a/langchain-core/src/main/java/com/hw/langchain/prompts/chat/ChatPromptTemplate.java b/langchain-core/src/main/java/com/hw/langchain/prompts/chat/ChatPromptTemplate.java new file mode 100644 index 000000000..a8cc7f9db --- /dev/null +++ b/langchain-core/src/main/java/com/hw/langchain/prompts/chat/ChatPromptTemplate.java @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.hw.langchain.prompts.chat; + +import com.hw.langchain.schema.BaseMessage; + +import java.util.*; + +/** + * @author HamaWhite + */ +public class ChatPromptTemplate extends BaseChatPromptTemplate { + + /** + * Union[BaseMessagePromptTemplate, BaseMessage] + */ + private final List messages; + + public ChatPromptTemplate(List inputVariables, List messages) { + super(inputVariables); + this.messages = messages; + + validateInputVariables(); + } + + private void validateInputVariables() { + Set inputVars = new HashSet<>(); + for (var message : messages) { + if (message instanceof BaseMessagePromptTemplate promptTemplate) { + inputVars.addAll(promptTemplate.inputVariables()); + } + } + if (inputVariables != null) { + if (!inputVars.equals(new HashSet<>(inputVariables))) { + throw new IllegalArgumentException(String + .format("Got mismatched input_variables. Expected: %s. Got: %s", inputVars, inputVariables)); + } + } else { + inputVariables = List.copyOf(inputVars); + } + } + + @Override + public List formatMessages(Map kwargs) { + List result = new ArrayList<>(); + for (var messageTemplate : messages) { + if (messageTemplate instanceof BaseMessage baseMessage) { + result.add(baseMessage); + } else if (messageTemplate instanceof BaseMessagePromptTemplate promptTemplate) { + var relParams = new HashMap(); + kwargs.forEach((key, value) -> { + if (promptTemplate.inputVariables().contains(key)) { + relParams.put(key, value); + } + }); + result.addAll(promptTemplate.formatMessages(relParams)); + } else { + throw new IllegalArgumentException("Unexpected input: " + messageTemplate); + } + } + return result; + } + + public static ChatPromptTemplate fromMessages(List messages) { + Set inputVars = new HashSet<>(); + for (var message : messages) { + if (message instanceof BaseMessagePromptTemplate template) { + inputVars.addAll(template.inputVariables()); + } + } + return new ChatPromptTemplate(new ArrayList<>(inputVars), messages); + } +} diff --git a/langchain-core/src/main/java/com/hw/langchain/prompts/chat/ChatPromptValue.java b/langchain-core/src/main/java/com/hw/langchain/prompts/chat/ChatPromptValue.java new file mode 100644 index 000000000..d475eedcf --- /dev/null +++ b/langchain-core/src/main/java/com/hw/langchain/prompts/chat/ChatPromptValue.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.hw.langchain.prompts.chat; + +import com.hw.langchain.schema.BaseMessage; +import com.hw.langchain.schema.PromptValue; + +import java.util.List; + +import static com.hw.langchain.schema.Schema.getBufferString; + +/** + * @author HamaWhite + */ +public class ChatPromptValue implements PromptValue { + + private List messages; + + public ChatPromptValue(List messages) { + this.messages = messages; + } + + @Override + public List toMessages() { + return this.messages; + } + + /** + * Return prompt as string. + */ + @Override + public String toString() { + return getBufferString(messages); + } +} diff --git a/langchain-core/src/main/java/com/hw/langchain/prompts/chat/HumanMessagePromptTemplate.java b/langchain-core/src/main/java/com/hw/langchain/prompts/chat/HumanMessagePromptTemplate.java new file mode 100644 index 000000000..5658cfa83 --- /dev/null +++ b/langchain-core/src/main/java/com/hw/langchain/prompts/chat/HumanMessagePromptTemplate.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.hw.langchain.prompts.chat; + +import com.hw.langchain.schema.BaseMessage; +import com.hw.langchain.schema.HumanMessage; + +import java.util.Map; + +/** + * @author HamaWhite + */ +public class HumanMessagePromptTemplate extends BaseStringMessagePromptTemplate { + + public static HumanMessagePromptTemplate fromTemplate(String template) { + return BaseStringMessagePromptTemplate.fromTemplate(HumanMessagePromptTemplate.class, template); + } + + @Override + public BaseMessage format(Map kwargs) { + String text = prompt.format(kwargs); + return new HumanMessage(text); + } +} \ No newline at end of file diff --git a/langchain-core/src/main/java/com/hw/langchain/prompts/chat/SystemMessagePromptTemplate.java b/langchain-core/src/main/java/com/hw/langchain/prompts/chat/SystemMessagePromptTemplate.java new file mode 100644 index 000000000..68ce5f455 --- /dev/null +++ b/langchain-core/src/main/java/com/hw/langchain/prompts/chat/SystemMessagePromptTemplate.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.hw.langchain.prompts.chat; + +import com.hw.langchain.schema.BaseMessage; +import com.hw.langchain.schema.SystemMessage; + +import java.util.Map; + +/** + * @author HamaWhite + */ +public class SystemMessagePromptTemplate extends BaseStringMessagePromptTemplate { + + public static SystemMessagePromptTemplate fromTemplate(String template) { + return BaseStringMessagePromptTemplate.fromTemplate(SystemMessagePromptTemplate.class, template); + } + + @Override + public BaseMessage format(Map kwargs) { + String text = prompt.format(kwargs); + return new SystemMessage(text); + } +} diff --git a/langchain-core/src/main/java/com/hw/langchain/schema/ChatMessage.java b/langchain-core/src/main/java/com/hw/langchain/schema/ChatMessage.java index fcd161bbf..c6f5c7151 100644 --- a/langchain-core/src/main/java/com/hw/langchain/schema/ChatMessage.java +++ b/langchain-core/src/main/java/com/hw/langchain/schema/ChatMessage.java @@ -30,11 +30,7 @@ public class ChatMessage extends BaseMessage { private String role; - public ChatMessage(String content) { - super(content); - } - - public ChatMessage(String role, String content) { + public ChatMessage(String content, String role) { super(content); this.role = role; } diff --git a/langchain-core/src/test/java/com/hw/langchain/chains/llm/LLMChainTest.java b/langchain-core/src/test/java/com/hw/langchain/chains/llm/LLMChainTest.java index 77f792325..cb52be025 100644 --- a/langchain-core/src/test/java/com/hw/langchain/chains/llm/LLMChainTest.java +++ b/langchain-core/src/test/java/com/hw/langchain/chains/llm/LLMChainTest.java @@ -20,7 +20,11 @@ import com.hw.langchain.base.language.BaseLanguageModel; import com.hw.langchain.chains.base.Chain; +import com.hw.langchain.chat.models.openai.ChatOpenAI; import com.hw.langchain.llms.openai.OpenAI; +import com.hw.langchain.prompts.chat.ChatPromptTemplate; +import com.hw.langchain.prompts.chat.HumanMessagePromptTemplate; +import com.hw.langchain.prompts.chat.SystemMessagePromptTemplate; import com.hw.langchain.prompts.prompt.PromptTemplate; import org.junit.jupiter.api.BeforeAll; @@ -35,6 +39,7 @@ /** * LLMChainTest + * * @author HamaWhite */ @Disabled("Test requires costly OpenAI calls, can be run manually.") @@ -42,12 +47,19 @@ class LLMChainTest { private static BaseLanguageModel llm; + private static BaseLanguageModel chat; + @BeforeAll public static void setup() throws SQLException { llm = OpenAI.builder() .temperature(0) .build() .init(); + + chat = ChatOpenAI.builder() + .temperature(0) + .build() + .init(); } @Test @@ -134,4 +146,23 @@ teacher_note CHARACTER VARYING(256) " SELECT `parent_name`, `parent_mobile` FROM `parents` WHERE `student_name` IN (SELECT `name` FROM `students` WHERE `score` = 0) LIMIT 5;"; assertEquals(expected, actual); } + + @Test + void testLLMChainWithChatModels() { + var template = "You are a helpful assistant that translates {input_language} to {output_language}."; + var systemMessagePrompt = SystemMessagePromptTemplate.fromTemplate(template); + + var humanTemplate = "{text}"; + var humanMessagePrompt = HumanMessagePromptTemplate.fromTemplate(humanTemplate); + + var chatPrompt = ChatPromptTemplate.fromMessages(List.of(systemMessagePrompt, humanMessagePrompt)); + + var chain = new LLMChain(chat, chatPrompt); + String actual = chain.run(Map.of("input_language", "English", + "output_language", "French", + "text", "I love programming.")); + + String expected = "J'adore la programmation."; + assertEquals(expected, actual); + } } \ No newline at end of file diff --git a/langchain-core/src/test/java/com/hw/langchain/chat/models/openai/ChatOpenAITest.java b/langchain-core/src/test/java/com/hw/langchain/chat/models/openai/ChatOpenAITest.java index 4b9efd12a..872846c1a 100644 --- a/langchain-core/src/test/java/com/hw/langchain/chat/models/openai/ChatOpenAITest.java +++ b/langchain-core/src/test/java/com/hw/langchain/chat/models/openai/ChatOpenAITest.java @@ -98,4 +98,21 @@ void testGenerateWithMultiMessages() { LOG.info("token_usage: {}", result.getLlmOutput().get("token_usage")); assertThat(result.getGenerations()).isNotNull().hasSize(2); } + + @Test + void testPredictMessages() { + var message = new HumanMessage("Translate this sentence from English to French. I love programming."); + var actual = chat.predictMessages(List.of(message)); + + var expected = new AIMessage("J'aime programmer."); + assertEquals(expected, actual); + } + + @Test + void testPredict() { + var text = "Translate this sentence from English to French. I love programming."; + var actual = chat.predict(text); + var expected = "J'aime programmer."; + assertEquals(expected, actual); + } } \ No newline at end of file