forked from HamaWhiteGG/langchain-java
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request HamaWhiteGG#15 from HamaWhiteGG/dev
Support Agents of Chat models
- Loading branch information
Showing
14 changed files
with
458 additions
and
22 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
136 changes: 136 additions & 0 deletions
136
langchain-core/src/main/java/com/hw/langchain/agents/chat/base/ChatAgent.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,136 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package com.hw.langchain.agents.chat.base; | ||
|
||
import com.hw.langchain.agents.agent.Agent; | ||
import com.hw.langchain.agents.agent.AgentOutputParser; | ||
import com.hw.langchain.agents.chat.output.parser.ChatOutputParser; | ||
import com.hw.langchain.base.language.BaseLanguageModel; | ||
import com.hw.langchain.chains.llm.LLMChain; | ||
import com.hw.langchain.prompts.base.BasePromptTemplate; | ||
import com.hw.langchain.prompts.chat.BaseMessagePromptTemplate; | ||
import com.hw.langchain.prompts.chat.ChatPromptTemplate; | ||
import com.hw.langchain.prompts.chat.HumanMessagePromptTemplate; | ||
import com.hw.langchain.prompts.chat.SystemMessagePromptTemplate; | ||
import com.hw.langchain.schema.AgentAction; | ||
import com.hw.langchain.tools.base.BaseTool; | ||
|
||
import org.apache.commons.lang3.tuple.Pair; | ||
|
||
import java.util.List; | ||
import java.util.Map; | ||
|
||
import static com.hw.langchain.agents.chat.prompt.Prompt.*; | ||
import static com.hw.langchain.agents.utils.Utils.validateToolsSingleInput; | ||
|
||
/** | ||
* @author HamaWhite | ||
*/ | ||
public class ChatAgent extends Agent { | ||
|
||
public ChatAgent(LLMChain llmChain, List<String> allowedTools, AgentOutputParser outputParser) { | ||
super(llmChain, allowedTools, outputParser); | ||
} | ||
|
||
@Override | ||
public String observationPrefix() { | ||
return "Observation: "; | ||
} | ||
|
||
@Override | ||
public String llmPrefix() { | ||
return "Thought:"; | ||
} | ||
|
||
@Override | ||
public String constructScratchpad(List<Pair<AgentAction, String>> intermediateSteps) { | ||
var agentScratchpad = super.constructScratchpad(intermediateSteps); | ||
if (!(agentScratchpad instanceof String)) { | ||
throw new IllegalArgumentException("agent_scratchpad should be of type String."); | ||
} | ||
String scratchpad = agentScratchpad.toString(); | ||
if (!scratchpad.isEmpty()) { | ||
return "This was your previous work (but I haven't seen any of it! I only see what " | ||
+ "you return as the final answer):\n" + scratchpad; | ||
} else { | ||
return scratchpad; | ||
} | ||
} | ||
|
||
private static AgentOutputParser getDefaultOutputParser(Map<String, Object> kwargs) { | ||
return new ChatOutputParser(); | ||
} | ||
|
||
public static void validateTools(List<BaseTool> tools) { | ||
validateToolsSingleInput(ChatAgent.class.getSimpleName(), tools); | ||
} | ||
|
||
@Override | ||
public List<String> stop() { | ||
return List.of("Observation:"); | ||
} | ||
|
||
public static BasePromptTemplate createPrompt(List<BaseTool> tools, String systemMessagePrefix, | ||
String systemMessageSuffix, String humanMessage, String formatInstructions, List<String> inputVariables) { | ||
String toolNames = String.join(", ", tools.stream().map(BaseTool::getName).toList()); | ||
String toolStrings = | ||
String.join("\n", tools.stream().map(tool -> tool.getName() + ": " + tool.getDescription()).toList()); | ||
|
||
formatInstructions = formatInstructions.replace("{tool_names}", toolNames); | ||
// In Python format() method, the curly braces '{{}}' are used to represent the output '{}'. | ||
formatInstructions = formatInstructions.replace("{{{{", "{{").replace("}}}}", "}}"); | ||
|
||
String template = | ||
String.join("\n\n", systemMessagePrefix, toolStrings, formatInstructions, systemMessageSuffix); | ||
|
||
List<BaseMessagePromptTemplate> messages = List.of( | ||
SystemMessagePromptTemplate.fromTemplate(template), | ||
HumanMessagePromptTemplate.fromTemplate(humanMessage)); | ||
if (inputVariables == null) { | ||
inputVariables = List.of("input", "agent_scratchpad"); | ||
} | ||
return new ChatPromptTemplate(inputVariables, messages); | ||
} | ||
|
||
/** | ||
* Construct an agent from an LLM and tools. | ||
*/ | ||
public static Agent fromLLMAndTools(BaseLanguageModel llm, List<BaseTool> tools, Map<String, Object> kwargs) { | ||
return fromLLMAndTools(llm, tools, null, SYSTEM_MESSAGE_PREFIX, SYSTEM_MESSAGE_SUFFIX, HUMAN_MESSAGE, | ||
FORMAT_INSTRUCTIONS, null, kwargs); | ||
} | ||
|
||
/** | ||
* Construct an agent from an LLM and tools. | ||
*/ | ||
public static Agent fromLLMAndTools(BaseLanguageModel llm, List<BaseTool> tools, AgentOutputParser outputParser, | ||
String systemMessagePrefix, String systemMessageSuffix, String humanMessage, String formatInstructions, | ||
List<String> inputVariables, Map<String, Object> kwargs) { | ||
validateTools(tools); | ||
|
||
var prompt = createPrompt(tools, systemMessagePrefix, systemMessageSuffix, humanMessage, formatInstructions, | ||
inputVariables); | ||
var llmChain = new LLMChain(llm, prompt); | ||
|
||
var toolNames = tools.stream().map(BaseTool::getName).toList(); | ||
outputParser = (outputParser != null) ? outputParser : getDefaultOutputParser(kwargs); | ||
|
||
return new ChatAgent(llmChain, toolNames, outputParser); | ||
} | ||
} |
70 changes: 70 additions & 0 deletions
70
...chain-core/src/main/java/com/hw/langchain/agents/chat/output/parser/ChatOutputParser.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package com.hw.langchain.agents.chat.output.parser; | ||
|
||
import com.google.gson.Gson; | ||
import com.google.gson.reflect.TypeToken; | ||
import com.hw.langchain.agents.agent.AgentOutputParser; | ||
import com.hw.langchain.schema.AgentAction; | ||
import com.hw.langchain.schema.AgentFinish; | ||
import com.hw.langchain.schema.AgentResult; | ||
import com.hw.langchain.schema.OutputParserException; | ||
|
||
import java.lang.reflect.Type; | ||
import java.util.Map; | ||
|
||
import static com.hw.langchain.agents.chat.prompt.Prompt.FORMAT_INSTRUCTIONS; | ||
|
||
/** | ||
* @author HamaWhite | ||
*/ | ||
public class ChatOutputParser extends AgentOutputParser { | ||
|
||
private static final String FINAL_ANSWER_ACTION = "Final Answer:"; | ||
|
||
@Override | ||
public AgentResult parse(String text) { | ||
boolean includesAnswer = text.contains(FINAL_ANSWER_ACTION); | ||
try { | ||
String action = text.split("```")[1]; | ||
Type mapType = new TypeToken<Map<String, Object>>() { | ||
}.getType(); | ||
Map<String, Object> response = new Gson().fromJson(action.strip(), mapType); | ||
|
||
boolean includesAction = response.containsKey("action") && response.containsKey("action_input"); | ||
if (includesAnswer && includesAction) { | ||
throw new OutputParserException( | ||
"Parsing LLM output produced a final answer and a parse-able action: " + text); | ||
} | ||
return new AgentAction(response.get("action").toString(), response.get("action_input"), text); | ||
} catch (Exception e) { | ||
if (!includesAnswer) { | ||
throw new OutputParserException("Could not parse LLM output: " + text); | ||
} | ||
String[] splitText = text.split(FINAL_ANSWER_ACTION); | ||
String output = splitText[splitText.length - 1].strip(); | ||
return new AgentFinish(Map.ofEntries(Map.entry("output", output)), text); | ||
} | ||
} | ||
|
||
@Override | ||
public String getFormatInstructions() { | ||
return FORMAT_INSTRUCTIONS; | ||
} | ||
} |
62 changes: 62 additions & 0 deletions
62
langchain-core/src/main/java/com/hw/langchain/agents/chat/prompt/Prompt.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package com.hw.langchain.agents.chat.prompt; | ||
|
||
/** | ||
* @author HamaWhite | ||
*/ | ||
public class Prompt { | ||
|
||
public static String SYSTEM_MESSAGE_PREFIX = """ | ||
Answer the following questions as best you can. You have access to the following tools:"""; | ||
|
||
public static String FORMAT_INSTRUCTIONS = | ||
""" | ||
The way you use the tools is by specifying a json blob. | ||
Specifically, this json should have a `action` key (with the name of the tool to use) and a `action_input` key (with the input to the tool going here). | ||
The only values that should be in the "action" field are: {tool_names} | ||
The $JSON_BLOB should only contain a SINGLE action, do NOT return a list of multiple actions. Here is an example of a valid $JSON_BLOB: | ||
``` | ||
{{{{ | ||
"action": $TOOL_NAME, | ||
"action_input": $INPUT | ||
}}}} | ||
``` | ||
ALWAYS use the following format: | ||
Question: the input question you must answer | ||
Thought: you should always think about what to do | ||
Action: | ||
``` | ||
$JSON_BLOB | ||
``` | ||
Observation: the result of the action | ||
... (this Thought/Action/Observation can repeat N times) | ||
Thought: I now know the final answer | ||
Final Answer: the final answer to the original input question"""; | ||
|
||
public static String SYSTEM_MESSAGE_SUFFIX = """ | ||
Begin! Reminder to always use the exact characters `Final Answer` when responding."""; | ||
|
||
public static String HUMAN_MESSAGE = "{input}\n\n{agent_scratchpad}"; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.