diff --git a/langchain-core/src/main/java/com/hw/langchain/agents/agent/Agent.java b/langchain-core/src/main/java/com/hw/langchain/agents/agent/Agent.java index 6dabfbaec..9d73a4849 100644 --- a/langchain-core/src/main/java/com/hw/langchain/agents/agent/Agent.java +++ b/langchain-core/src/main/java/com/hw/langchain/agents/agent/Agent.java @@ -18,6 +18,7 @@ package com.hw.langchain.agents.agent; +import com.hw.langchain.base.language.BaseLanguageModel; import com.hw.langchain.chains.llm.LLMChain; import com.hw.langchain.schema.AgentAction; import com.hw.langchain.schema.AgentFinish; @@ -25,6 +26,8 @@ import com.hw.langchain.tools.base.BaseTool; import org.apache.commons.lang3.tuple.Pair; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import java.util.HashMap; import java.util.List; @@ -41,6 +44,8 @@ */ public abstract class Agent extends BaseSingleActionAgent { + private static final Logger LOG = LoggerFactory.getLogger(Agent.class); + private LLMChain llmChain; private List allowedTools; @@ -63,9 +68,9 @@ public List stop() { * Construct the scratchpad that lets the agent continue its thought process. * * @param intermediateSteps Steps the LLM has taken to date, along with observations - * @return str or List[BaseMessage] + * @return String or List[BaseMessage] */ - public String constructScratchpad(List> intermediateSteps) { + public Object constructScratchpad(List> intermediateSteps) { StringBuilder thoughts = new StringBuilder(); for (Pair step : intermediateSteps) { thoughts.append(step.getKey().getLog()); @@ -102,6 +107,7 @@ public List inputKeys() { public AgentResult plan(List> intermediateSteps, Map kwargs) { var fullInputs = getFullInputs(intermediateSteps, kwargs); String fullOutput = llmChain.predict(fullInputs); + LOG.info("fullOutput: \n{}", fullOutput); return outputParser.parse(fullOutput); } @@ -110,13 +116,20 @@ public AgentResult plan(List> intermediateSteps, Map getFullInputs(List> intermediateSteps, Map kwargs) { - String thoughts = constructScratchpad(intermediateSteps); + Object thoughts = constructScratchpad(intermediateSteps); var newInputs = Map.of("agent_scratchpad", thoughts, "stop", stop()); Map fullInputs = new HashMap<>(kwargs); fullInputs.putAll(newInputs); return fullInputs; } + public static BaseSingleActionAgent fromLLMAndTools( + BaseLanguageModel llm, + List tools, + Map kwargs) { + throw new UnsupportedOperationException(); + } + public AgentFinish returnStoppedResponse(String earlyStoppingMethod, List> intermediateSteps, Map kwargs) { if (earlyStoppingMethod.equals("force")) { diff --git a/langchain-core/src/main/java/com/hw/langchain/agents/agent/AgentExecutor.java b/langchain-core/src/main/java/com/hw/langchain/agents/agent/AgentExecutor.java index ffba95a4d..455b1cdcf 100644 --- a/langchain-core/src/main/java/com/hw/langchain/agents/agent/AgentExecutor.java +++ b/langchain-core/src/main/java/com/hw/langchain/agents/agent/AgentExecutor.java @@ -150,7 +150,7 @@ public Map _call(Map inputs) { // We now enter the agent loop (until it returns something). while (shouldContinue(iterations, timeElapsed)) { - Object nextStepOutput = takeNextStep(nameToToolMap, inputs, intermediateSteps); + var nextStepOutput = takeNextStep(nameToToolMap, inputs, intermediateSteps); LOG.info("NextStepOutput: {}", nextStepOutput); if (nextStepOutput instanceof AgentFinish agentFinish) { return _return(agentFinish, intermediateSteps); diff --git a/langchain-core/src/main/java/com/hw/langchain/agents/chat/base/ChatAgent.java b/langchain-core/src/main/java/com/hw/langchain/agents/chat/base/ChatAgent.java new file mode 100644 index 000000000..810948ef7 --- /dev/null +++ b/langchain-core/src/main/java/com/hw/langchain/agents/chat/base/ChatAgent.java @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.hw.langchain.agents.chat.base; + +import com.hw.langchain.agents.agent.Agent; +import com.hw.langchain.agents.agent.AgentOutputParser; +import com.hw.langchain.agents.chat.output.parser.ChatOutputParser; +import com.hw.langchain.base.language.BaseLanguageModel; +import com.hw.langchain.chains.llm.LLMChain; +import com.hw.langchain.prompts.base.BasePromptTemplate; +import com.hw.langchain.prompts.chat.BaseMessagePromptTemplate; +import com.hw.langchain.prompts.chat.ChatPromptTemplate; +import com.hw.langchain.prompts.chat.HumanMessagePromptTemplate; +import com.hw.langchain.prompts.chat.SystemMessagePromptTemplate; +import com.hw.langchain.schema.AgentAction; +import com.hw.langchain.tools.base.BaseTool; + +import org.apache.commons.lang3.tuple.Pair; + +import java.util.List; +import java.util.Map; + +import static com.hw.langchain.agents.chat.prompt.Prompt.*; +import static com.hw.langchain.agents.utils.Utils.validateToolsSingleInput; + +/** + * @author HamaWhite + */ +public class ChatAgent extends Agent { + + public ChatAgent(LLMChain llmChain, List allowedTools, AgentOutputParser outputParser) { + super(llmChain, allowedTools, outputParser); + } + + @Override + public String observationPrefix() { + return "Observation: "; + } + + @Override + public String llmPrefix() { + return "Thought:"; + } + + @Override + public String constructScratchpad(List> intermediateSteps) { + var agentScratchpad = super.constructScratchpad(intermediateSteps); + if (!(agentScratchpad instanceof String)) { + throw new IllegalArgumentException("agent_scratchpad should be of type String."); + } + String scratchpad = agentScratchpad.toString(); + if (!scratchpad.isEmpty()) { + return "This was your previous work (but I haven't seen any of it! I only see what " + + "you return as the final answer):\n" + scratchpad; + } else { + return scratchpad; + } + } + + private static AgentOutputParser getDefaultOutputParser(Map kwargs) { + return new ChatOutputParser(); + } + + public static void validateTools(List tools) { + validateToolsSingleInput(ChatAgent.class.getSimpleName(), tools); + } + + @Override + public List stop() { + return List.of("Observation:"); + } + + public static BasePromptTemplate createPrompt(List tools, String systemMessagePrefix, + String systemMessageSuffix, String humanMessage, String formatInstructions, List inputVariables) { + String toolNames = String.join(", ", tools.stream().map(BaseTool::getName).toList()); + String toolStrings = + String.join("\n", tools.stream().map(tool -> tool.getName() + ": " + tool.getDescription()).toList()); + + formatInstructions = formatInstructions.replace("{tool_names}", toolNames); + // In Python format() method, the curly braces '{{}}' are used to represent the output '{}'. + formatInstructions = formatInstructions.replace("{{{{", "{{").replace("}}}}", "}}"); + + String template = + String.join("\n\n", systemMessagePrefix, toolStrings, formatInstructions, systemMessageSuffix); + + List messages = List.of( + SystemMessagePromptTemplate.fromTemplate(template), + HumanMessagePromptTemplate.fromTemplate(humanMessage)); + if (inputVariables == null) { + inputVariables = List.of("input", "agent_scratchpad"); + } + return new ChatPromptTemplate(inputVariables, messages); + } + + /** + * Construct an agent from an LLM and tools. + */ + public static Agent fromLLMAndTools(BaseLanguageModel llm, List tools, Map kwargs) { + return fromLLMAndTools(llm, tools, null, SYSTEM_MESSAGE_PREFIX, SYSTEM_MESSAGE_SUFFIX, HUMAN_MESSAGE, + FORMAT_INSTRUCTIONS, null, kwargs); + } + + /** + * Construct an agent from an LLM and tools. + */ + public static Agent fromLLMAndTools(BaseLanguageModel llm, List tools, AgentOutputParser outputParser, + String systemMessagePrefix, String systemMessageSuffix, String humanMessage, String formatInstructions, + List inputVariables, Map kwargs) { + validateTools(tools); + + var prompt = createPrompt(tools, systemMessagePrefix, systemMessageSuffix, humanMessage, formatInstructions, + inputVariables); + var llmChain = new LLMChain(llm, prompt); + + var toolNames = tools.stream().map(BaseTool::getName).toList(); + outputParser = (outputParser != null) ? outputParser : getDefaultOutputParser(kwargs); + + return new ChatAgent(llmChain, toolNames, outputParser); + } +} diff --git a/langchain-core/src/main/java/com/hw/langchain/agents/chat/output/parser/ChatOutputParser.java b/langchain-core/src/main/java/com/hw/langchain/agents/chat/output/parser/ChatOutputParser.java new file mode 100644 index 000000000..bd507b745 --- /dev/null +++ b/langchain-core/src/main/java/com/hw/langchain/agents/chat/output/parser/ChatOutputParser.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.hw.langchain.agents.chat.output.parser; + +import com.google.gson.Gson; +import com.google.gson.reflect.TypeToken; +import com.hw.langchain.agents.agent.AgentOutputParser; +import com.hw.langchain.schema.AgentAction; +import com.hw.langchain.schema.AgentFinish; +import com.hw.langchain.schema.AgentResult; +import com.hw.langchain.schema.OutputParserException; + +import java.lang.reflect.Type; +import java.util.Map; + +import static com.hw.langchain.agents.chat.prompt.Prompt.FORMAT_INSTRUCTIONS; + +/** + * @author HamaWhite + */ +public class ChatOutputParser extends AgentOutputParser { + + private static final String FINAL_ANSWER_ACTION = "Final Answer:"; + + @Override + public AgentResult parse(String text) { + boolean includesAnswer = text.contains(FINAL_ANSWER_ACTION); + try { + String action = text.split("```")[1]; + Type mapType = new TypeToken>() { + }.getType(); + Map response = new Gson().fromJson(action.strip(), mapType); + + boolean includesAction = response.containsKey("action") && response.containsKey("action_input"); + if (includesAnswer && includesAction) { + throw new OutputParserException( + "Parsing LLM output produced a final answer and a parse-able action: " + text); + } + return new AgentAction(response.get("action").toString(), response.get("action_input"), text); + } catch (Exception e) { + if (!includesAnswer) { + throw new OutputParserException("Could not parse LLM output: " + text); + } + String[] splitText = text.split(FINAL_ANSWER_ACTION); + String output = splitText[splitText.length - 1].strip(); + return new AgentFinish(Map.ofEntries(Map.entry("output", output)), text); + } + } + + @Override + public String getFormatInstructions() { + return FORMAT_INSTRUCTIONS; + } +} diff --git a/langchain-core/src/main/java/com/hw/langchain/agents/chat/prompt/Prompt.java b/langchain-core/src/main/java/com/hw/langchain/agents/chat/prompt/Prompt.java new file mode 100644 index 000000000..9501140eb --- /dev/null +++ b/langchain-core/src/main/java/com/hw/langchain/agents/chat/prompt/Prompt.java @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.hw.langchain.agents.chat.prompt; + +/** + * @author HamaWhite + */ +public class Prompt { + + public static String SYSTEM_MESSAGE_PREFIX = """ + Answer the following questions as best you can. You have access to the following tools:"""; + + public static String FORMAT_INSTRUCTIONS = + """ + The way you use the tools is by specifying a json blob. + Specifically, this json should have a `action` key (with the name of the tool to use) and a `action_input` key (with the input to the tool going here). + + The only values that should be in the "action" field are: {tool_names} + + The $JSON_BLOB should only contain a SINGLE action, do NOT return a list of multiple actions. Here is an example of a valid $JSON_BLOB: + + ``` + {{{{ + "action": $TOOL_NAME, + "action_input": $INPUT + }}}} + ``` + + ALWAYS use the following format: + + Question: the input question you must answer + Thought: you should always think about what to do + Action: + ``` + $JSON_BLOB + ``` + Observation: the result of the action + ... (this Thought/Action/Observation can repeat N times) + Thought: I now know the final answer + Final Answer: the final answer to the original input question"""; + + public static String SYSTEM_MESSAGE_SUFFIX = """ + Begin! Reminder to always use the exact characters `Final Answer` when responding."""; + + public static String HUMAN_MESSAGE = "{input}\n\n{agent_scratchpad}"; +} diff --git a/langchain-core/src/main/java/com/hw/langchain/agents/mrkl/base/ZeroShotAgent.java b/langchain-core/src/main/java/com/hw/langchain/agents/mrkl/base/ZeroShotAgent.java index 3383585c3..f1462f5ee 100644 --- a/langchain-core/src/main/java/com/hw/langchain/agents/mrkl/base/ZeroShotAgent.java +++ b/langchain-core/src/main/java/com/hw/langchain/agents/mrkl/base/ZeroShotAgent.java @@ -28,7 +28,6 @@ import java.util.List; import java.util.Map; -import java.util.stream.Collectors; import static com.google.common.base.Preconditions.checkArgument; import static com.hw.langchain.agents.mrkl.prompt.Prompt.*; @@ -60,13 +59,15 @@ private static AgentOutputParser getDefaultOutputParser(Map kwar */ public static PromptTemplate createPrompt(List tools, String prefix, String suffix, String formatInstructions, List inputVariables) { - String toolStrings = tools.stream() - .map(tool -> tool.getName() + ": " + tool.getDescription()) - .collect(Collectors.joining("\n")); - + String toolStrings = + String.join("\n", tools.stream().map(tool -> tool.getName() + ": " + tool.getDescription()).toList()); String toolNames = String.join(", ", tools.stream().map(BaseTool::getName).toList()); - String formattedInstructions = formatInstructions.replace("{tool_names}", toolNames); - String template = String.join("\n\n", prefix, toolStrings, formattedInstructions, suffix); + + formatInstructions = formatInstructions.replace("{tool_names}", toolNames); + // In Python format() method, the curly braces '{{}}' are used to represent the output '{}'. + formatInstructions = formatInstructions.replace("{{{{", "{{").replace("}}}}", "}}"); + + String template = String.join("\n\n", prefix, toolStrings, formatInstructions, suffix); if (inputVariables == null) { inputVariables = List.of("input", "agent_scratchpad"); @@ -81,6 +82,9 @@ public static Agent fromLLMAndTools(BaseLanguageModel llm, List tools, return fromLLMAndTools(llm, tools, null, PREFIX, SUFFIX, FORMAT_INSTRUCTIONS, null, kwargs); } + /** + * Construct an agent from an LLM and tools. + */ public static Agent fromLLMAndTools(BaseLanguageModel llm, List tools, AgentOutputParser outputParser, String prefix, String suffix, String formatInstructions, List inputVariables, Map kwargs) { diff --git a/langchain-core/src/main/java/com/hw/langchain/agents/types/Types.java b/langchain-core/src/main/java/com/hw/langchain/agents/types/Types.java index 90930994e..059ef21b6 100644 --- a/langchain-core/src/main/java/com/hw/langchain/agents/types/Types.java +++ b/langchain-core/src/main/java/com/hw/langchain/agents/types/Types.java @@ -20,6 +20,7 @@ import com.hw.langchain.agents.agent.BaseSingleActionAgent; import com.hw.langchain.agents.agent.types.AgentType; +import com.hw.langchain.agents.chat.base.ChatAgent; import com.hw.langchain.agents.mrkl.base.ZeroShotAgent; import java.util.Map; @@ -30,5 +31,6 @@ public class Types { public static final Map> AGENT_TO_CLASS = Map.of( - AgentType.ZERO_SHOT_REACT_DESCRIPTION, ZeroShotAgent.class); + AgentType.ZERO_SHOT_REACT_DESCRIPTION, ZeroShotAgent.class, + AgentType.CHAT_ZERO_SHOT_REACT_DESCRIPTION, ChatAgent.class); } diff --git a/langchain-core/src/main/java/com/hw/langchain/agents/utils/Utils.java b/langchain-core/src/main/java/com/hw/langchain/agents/utils/Utils.java new file mode 100644 index 000000000..5b7390068 --- /dev/null +++ b/langchain-core/src/main/java/com/hw/langchain/agents/utils/Utils.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.hw.langchain.agents.utils; + +import com.hw.langchain.tools.base.BaseTool; + +import java.util.List; + +/** + * @author HamaWhite + */ +public class Utils { + + public static void validateToolsSingleInput(String className, List tools) { + for (BaseTool tool : tools) { + if (!tool.isSingleInput()) { + throw new IllegalArgumentException( + className + " does not support multi-input tool " + tool.getName() + "."); + } + } + } +} diff --git a/langchain-core/src/main/java/com/hw/langchain/prompts/base/StringPromptTemplate.java b/langchain-core/src/main/java/com/hw/langchain/prompts/base/StringPromptTemplate.java index ba762a837..e6750c412 100644 --- a/langchain-core/src/main/java/com/hw/langchain/prompts/base/StringPromptTemplate.java +++ b/langchain-core/src/main/java/com/hw/langchain/prompts/base/StringPromptTemplate.java @@ -22,6 +22,7 @@ import com.hw.langchain.schema.PromptValue; import lombok.Data; +import lombok.NoArgsConstructor; import java.util.List; import java.util.Map; @@ -31,6 +32,7 @@ * @author HamaWhite */ @Data +@NoArgsConstructor public abstract class StringPromptTemplate extends BasePromptTemplate { public StringPromptTemplate(List inputVariables) { diff --git a/langchain-core/src/main/java/com/hw/langchain/prompts/prompt/PromptTemplate.java b/langchain-core/src/main/java/com/hw/langchain/prompts/prompt/PromptTemplate.java index a0c2a411d..16799ae72 100644 --- a/langchain-core/src/main/java/com/hw/langchain/prompts/prompt/PromptTemplate.java +++ b/langchain-core/src/main/java/com/hw/langchain/prompts/prompt/PromptTemplate.java @@ -58,13 +58,17 @@ public PromptTemplate(List inputVariables, String template, BaseOutputPa @Override public String format(Map kwargs) { - return StringSubstitutor.replace(template, kwargs, "{", "}"); + String text = StringSubstitutor.replace(template, kwargs, "{", "}"); + // In Python format() method, the curly braces '{{}}' are used to represent the output '{}'. + return text.replace("{{", "{").replace("}}", "}"); } public static PromptTemplate fromTemplate(String template) { List variableNames = new ArrayList<>(); StringSubstitutor substitutor = new StringSubstitutor(variable -> { - variableNames.add(variable); + if (!variable.startsWith("{") && !variable.endsWith("}")) { + variableNames.add(variable); + } return null; }); substitutor.setVariablePrefix("{"); diff --git a/langchain-core/src/main/java/com/hw/langchain/tools/base/BaseTool.java b/langchain-core/src/main/java/com/hw/langchain/tools/base/BaseTool.java index 76309715a..75e09b9bd 100644 --- a/langchain-core/src/main/java/com/hw/langchain/tools/base/BaseTool.java +++ b/langchain-core/src/main/java/com/hw/langchain/tools/base/BaseTool.java @@ -24,6 +24,7 @@ import java.util.HashMap; import java.util.Map; +import java.util.stream.Collectors; /** * Interface LangChain tools must implement. @@ -55,6 +56,21 @@ public BaseTool(String name, String description) { this.description = description; } + /** + * Whether the tool only accepts a single input. + */ + public boolean isSingleInput() { + var keys = args().keySet() + .stream() + .filter(k -> !k.equals("kwargs")) + .collect(Collectors.toSet()); + return keys.size() == 1; + } + + public Map args() { + return null; + } + /** * Use the tool. */ diff --git a/langchain-core/src/main/java/com/hw/langchain/tools/base/Tool.java b/langchain-core/src/main/java/com/hw/langchain/tools/base/Tool.java index f6f9a8784..68ec5ace6 100644 --- a/langchain-core/src/main/java/com/hw/langchain/tools/base/Tool.java +++ b/langchain-core/src/main/java/com/hw/langchain/tools/base/Tool.java @@ -43,8 +43,7 @@ public Tool(String name, String description, Function func) { /** * The tool's input arguments. */ - public Map getArgs() { - + public Map args() { // For backwards compatibility, if the function signature is ambiguous, // assume it takes a single string input. return Map.of("tool_input", Map.of("type", "string")); diff --git a/langchain-core/src/test/java/com/hw/langchain/agents/agent/AgentExecutorTest.java b/langchain-core/src/test/java/com/hw/langchain/agents/agent/AgentExecutorTest.java index f348edd6c..5f5122433 100644 --- a/langchain-core/src/test/java/com/hw/langchain/agents/agent/AgentExecutorTest.java +++ b/langchain-core/src/test/java/com/hw/langchain/agents/agent/AgentExecutorTest.java @@ -19,15 +19,19 @@ package com.hw.langchain.agents.agent; import com.hw.langchain.agents.agent.types.AgentType; +import com.hw.langchain.chat.models.openai.ChatOpenAI; import com.hw.langchain.llms.openai.OpenAI; import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import java.util.List; import static com.hw.langchain.agents.initialize.Initialize.initializeAgent; import static com.hw.langchain.agents.load.tools.LoadTools.loadTools; +import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertTrue; /** @@ -36,13 +40,12 @@ @Disabled("Test requires costly OpenAI calls, can be run manually.") class AgentExecutorTest { + private static final Logger LOG = LoggerFactory.getLogger(AgentExecutorTest.class); + @Test - void testAgent() { + void testAgentWithLLM() { // First, let's load the language model we're going to use to control the agent. - var llm = OpenAI.builder() - .temperature(0) - .build() - .init(); + var llm = OpenAI.builder().temperature(0).build().init(); // Next, let's load some tools to use. Note that the `llm-math` tool uses an LLM, so we need to pass that in. var tools = loadTools(List.of("serpapi", "llm-math"), llm); @@ -53,7 +56,39 @@ void testAgent() { // Now let's test it out! String actual = agent.run( "What was the high temperature in SF yesterday in Fahrenheit? What is that number raised to the .023 power?"); - + LOG.info("actual: \n{}", actual); assertTrue(actual.matches("^1\\.\\d+$")); } + + @Test + void testAgentWithChatModels() { + // First, let's load the language model we're going to use to control the agent. + var chat = ChatOpenAI.builder().temperature(0).build().init(); + + // Next, let's load some tools to use. Note that the `llm-math` tool uses an LLM, so we need to pass that in. + var llm = OpenAI.builder().temperature(0).build().init(); + var tools = loadTools(List.of("serpapi", "llm-math"), llm); + + // Finally, let's initialize an agent with the tools, the language model, and the type of agent we want to use. + var agent = initializeAgent(tools, chat, AgentType.CHAT_ZERO_SHOT_REACT_DESCRIPTION); + + /* + * Now let's test it out!, + * + * The correct return result for the final step is similar to the following: "The answer to the second question + * is 2.42427848557. Final Answer: Jason Sudeikis, and his age raised to the 0.23 power is 2.42427848557." + * + * However, sometimes OpenAI only returns "I now know the answer to the second part of the question." without + * including "Final Answer: xxx", which can cause parsing errors in the results. + * + * My temperature setting is 0, so this issue should not occur. If you know the answer, please feel free to let + * me know. + */ + + String result = agent.run("Who is Olivia Wilde's boyfriend? What is his current age raised to the 0.23 power?"); + + // Jason Sudeikis, and his age raised to the 0.23 power is 2.42427848557. + LOG.info("result: \n{}", result); + assertNotNull(result, "result should not be null"); + } } diff --git a/langchain-core/src/test/java/com/hw/langchain/prompts/chat/ChatPromptTemplateTest.java b/langchain-core/src/test/java/com/hw/langchain/prompts/chat/ChatPromptTemplateTest.java new file mode 100644 index 000000000..cafdbf650 --- /dev/null +++ b/langchain-core/src/test/java/com/hw/langchain/prompts/chat/ChatPromptTemplateTest.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.hw.langchain.prompts.chat; + +import com.hw.langchain.schema.BaseMessage; +import com.hw.langchain.schema.HumanMessage; +import com.hw.langchain.schema.SystemMessage; + +import org.junit.jupiter.api.Test; + +import java.util.List; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +/** + * @author HamaWhite + */ +class ChatPromptTemplateTest { + + @Test + void testFormatMessages() { + var template = "You are a helpful assistant that translates {input_language} to {output_language}."; + var systemMessagePrompt = SystemMessagePromptTemplate.fromTemplate(template); + + var humanTemplate = "{text}"; + var humanMessagePrompt = HumanMessagePromptTemplate.fromTemplate(humanTemplate); + + var chatPrompt = ChatPromptTemplate.fromMessages(List.of(systemMessagePrompt, humanMessagePrompt)); + List actual = chatPrompt.formatMessages(Map.of("input_language", "English", + "output_language", "French", + "text", "I love programming.")); + + List expected = List.of( + new SystemMessage("You are a helpful assistant that translates English to French."), + new HumanMessage("I love programming.")); + assertEquals(expected, actual); + } +} \ No newline at end of file