forked from HamaWhiteGG/langchain-java
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
e24fdf7
commit 74d340e
Showing
11 changed files
with
343 additions
and
16 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
134 changes: 134 additions & 0 deletions
134
langchain-core/src/main/java/com/hw/langchain/agents/chat/base/ChatAgent.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,134 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package com.hw.langchain.agents.chat.base; | ||
|
||
import com.hw.langchain.agents.agent.Agent; | ||
import com.hw.langchain.agents.agent.AgentOutputParser; | ||
import com.hw.langchain.agents.chat.output.parser.ChatOutputParser; | ||
import com.hw.langchain.agents.mrkl.base.ZeroShotAgent; | ||
import com.hw.langchain.agents.mrkl.output.parser.MRKLOutputParser; | ||
import com.hw.langchain.base.language.BaseLanguageModel; | ||
import com.hw.langchain.chains.llm.LLMChain; | ||
import com.hw.langchain.prompts.base.BasePromptTemplate; | ||
import com.hw.langchain.prompts.chat.BaseMessagePromptTemplate; | ||
import com.hw.langchain.prompts.chat.ChatPromptTemplate; | ||
import com.hw.langchain.prompts.chat.HumanMessagePromptTemplate; | ||
import com.hw.langchain.prompts.chat.SystemMessagePromptTemplate; | ||
import com.hw.langchain.prompts.prompt.PromptTemplate; | ||
import com.hw.langchain.schema.AgentAction; | ||
import com.hw.langchain.tools.base.BaseTool; | ||
import org.apache.commons.lang3.tuple.Pair; | ||
|
||
import java.util.List; | ||
import java.util.Map; | ||
import java.util.stream.Collectors; | ||
|
||
import static com.hw.langchain.agents.chat.prompt.Prompt.*; | ||
import static com.hw.langchain.agents.utils.Utils.validateToolsSingleInput; | ||
|
||
/** | ||
* @author HamaWhite | ||
*/ | ||
public class ChatAgent extends Agent { | ||
|
||
public ChatAgent(LLMChain llmChain, List<String> allowedTools, AgentOutputParser outputParser) { | ||
super(llmChain, allowedTools, outputParser); | ||
} | ||
|
||
@Override | ||
public String observationPrefix() { | ||
return "Observation: "; | ||
} | ||
|
||
@Override | ||
public String llmPrefix() { | ||
return "Thought:"; | ||
} | ||
|
||
@Override | ||
public String constructScratchpad(List<Pair<AgentAction, String>> intermediateSteps) { | ||
Object agentScratchpad = super.constructScratchpad(intermediateSteps); | ||
if (!(agentScratchpad instanceof String)) { | ||
throw new IllegalArgumentException("agent_scratchpad should be of type String."); | ||
} | ||
String scratchpad = agentScratchpad.toString(); | ||
if (!scratchpad.isEmpty()) { | ||
return "This was your previous work (but I haven't seen any of it! I only see what " | ||
+ "you return as the final answer):\n" + scratchpad; | ||
} else { | ||
return scratchpad; | ||
} | ||
} | ||
|
||
|
||
private static AgentOutputParser getDefaultOutputParser(Map<String, Object> kwargs) { | ||
return new ChatOutputParser(); | ||
} | ||
|
||
|
||
public static void validateTools(List<BaseTool> tools) { | ||
validateToolsSingleInput(ChatAgent.class.getSimpleName(), tools); | ||
} | ||
|
||
@Override | ||
public List<String> stop() { | ||
return List.of("Observation:"); | ||
} | ||
|
||
public static BasePromptTemplate createPrompt(List<BaseTool> tools, String systemMessagePrefix, String systemMessageSuffix, String humanMessage, String formatInstructions, List<String> inputVariables) { | ||
String toolNames = String.join(", ", tools.stream().map(BaseTool::getName).toList()); | ||
String toolStrings = String.join("\n", tools.stream().map(tool -> tool.getName() + ": " + tool.getDescription()).toList()); | ||
String formattedInstructions = formatInstructions.replace("{tool_names}", toolNames); | ||
String template = String.join("\n\n", systemMessagePrefix, toolStrings, formattedInstructions, systemMessageSuffix); | ||
|
||
List<BaseMessagePromptTemplate> messages = List.of( | ||
SystemMessagePromptTemplate.fromTemplate(template), | ||
HumanMessagePromptTemplate.fromTemplate(humanMessage) | ||
); | ||
if (inputVariables == null) { | ||
inputVariables = List.of("input", "agent_scratchpad"); | ||
} | ||
return new ChatPromptTemplate(inputVariables, messages); | ||
} | ||
|
||
/** | ||
* Construct an agent from an LLM and tools. | ||
*/ | ||
public static Agent fromLLMAndTools(BaseLanguageModel llm, List<BaseTool> tools, Map<String, Object> kwargs) { | ||
return fromLLMAndTools(llm, tools, null, SYSTEM_MESSAGE_PREFIX, SYSTEM_MESSAGE_SUFFIX, HUMAN_MESSAGE, | ||
FORMAT_INSTRUCTIONS, null, kwargs); | ||
} | ||
|
||
/** | ||
* Construct an agent from an LLM and tools. | ||
*/ | ||
public static Agent fromLLMAndTools(BaseLanguageModel llm, List<BaseTool> tools, AgentOutputParser outputParser, | ||
String systemMessagePrefix, String systemMessageSuffix, String humanMessage, String formatInstructions, | ||
List<String> inputVariables, Map<String, Object> kwargs) { | ||
validateTools(tools); | ||
|
||
var prompt = createPrompt(tools, systemMessagePrefix, systemMessageSuffix, humanMessage, formatInstructions, inputVariables); | ||
var llmChain = new LLMChain(llm, prompt); | ||
|
||
var toolNames = tools.stream().map(BaseTool::getName).toList(); | ||
outputParser = (outputParser != null) ? outputParser : getDefaultOutputParser(kwargs); | ||
|
||
return new ChatAgent(llmChain, toolNames, outputParser); | ||
} | ||
} |
52 changes: 52 additions & 0 deletions
52
...chain-core/src/main/java/com/hw/langchain/agents/chat/output/parser/ChatOutputParser.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
package com.hw.langchain.agents.chat.output.parser; | ||
|
||
import com.google.gson.Gson; | ||
import com.google.gson.reflect.TypeToken; | ||
import com.hw.langchain.agents.agent.AgentOutputParser; | ||
import com.hw.langchain.schema.AgentAction; | ||
import com.hw.langchain.schema.AgentFinish; | ||
import com.hw.langchain.schema.AgentResult; | ||
import com.hw.langchain.schema.OutputParserException; | ||
|
||
import java.lang.reflect.Type; | ||
import java.util.Map; | ||
|
||
import static com.hw.langchain.agents.chat.prompt.Prompt.FORMAT_INSTRUCTIONS; | ||
|
||
/** | ||
* @author HamaWhite | ||
*/ | ||
public class ChatOutputParser extends AgentOutputParser { | ||
|
||
private static final String FINAL_ANSWER_ACTION = "Final Answer:"; | ||
|
||
|
||
@Override | ||
public AgentResult parse(String text) { | ||
boolean includesAnswer = text.contains(FINAL_ANSWER_ACTION); | ||
try { | ||
String action = text.split("```")[1]; | ||
Type mapType = new TypeToken<Map<String, Object>>() {}.getType(); | ||
Map<String, Object> response = new Gson().fromJson(action.strip(), mapType); | ||
|
||
boolean includesAction = response.containsKey("action") && response.containsKey("action_input"); | ||
if (includesAnswer && includesAction) { | ||
throw new OutputParserException("Parsing LLM output produced a final answer and a parse-able action: " + text); | ||
} | ||
return new AgentAction(response.get("action").toString(), response.get("action_input"), text); | ||
|
||
} catch (Exception e) { | ||
if (!includesAnswer) { | ||
throw new OutputParserException("Could not parse LLM output: " + text); | ||
} | ||
String[] splitText = text.split(FINAL_ANSWER_ACTION); | ||
String output = splitText[splitText.length - 1].strip(); | ||
return new AgentFinish(Map.ofEntries(Map.entry("output",output)), text); | ||
} | ||
} | ||
|
||
@Override | ||
public String getFormatInstructions() { | ||
return FORMAT_INSTRUCTIONS; | ||
} | ||
} |
62 changes: 62 additions & 0 deletions
62
langchain-core/src/main/java/com/hw/langchain/agents/chat/prompt/Prompt.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package com.hw.langchain.agents.chat.prompt; | ||
|
||
/** | ||
* @author HamaWhite | ||
*/ | ||
public class Prompt { | ||
|
||
public static String SYSTEM_MESSAGE_PREFIX = """ | ||
Answer the following questions as best you can. You have access to the following tools:"""; | ||
|
||
public static String FORMAT_INSTRUCTIONS = | ||
""" | ||
The way you use the tools is by specifying a json blob. | ||
Specifically, this json should have a `action` key (with the name of the tool to use) and a `action_input` key (with the input to the tool going here). | ||
The only values that should be in the "action" field are: {tool_names} | ||
The $JSON_BLOB should only contain a SINGLE action, do NOT return a list of multiple actions. Here is an example of a valid $JSON_BLOB: | ||
``` | ||
{{{{ | ||
"action": $TOOL_NAME, | ||
"action_input": $INPUT | ||
}}}} | ||
``` | ||
ALWAYS use the following format: | ||
Question: the input question you must answer | ||
Thought: you should always think about what to do | ||
Action: | ||
``` | ||
$JSON_BLOB | ||
``` | ||
Observation: the result of the action | ||
... (this Thought/Action/Observation can repeat N times) | ||
Thought: I now know the final answer | ||
Final Answer: the final answer to the original input question"""; | ||
|
||
public static String SYSTEM_MESSAGE_SUFFIX = """ | ||
Begin! Reminder to always use the exact characters `Final Answer` when responding."""; | ||
|
||
public static String HUMAN_MESSAGE = "{input}\n\n{agent_scratchpad}"; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
19 changes: 19 additions & 0 deletions
19
langchain-core/src/main/java/com/hw/langchain/agents/utils/Utils.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
package com.hw.langchain.agents.utils; | ||
|
||
import com.hw.langchain.tools.base.BaseTool; | ||
|
||
import java.util.List; | ||
|
||
/** | ||
* @author HamaWhite | ||
*/ | ||
public class Utils { | ||
|
||
public static void validateToolsSingleInput(String className, List<BaseTool> tools) { | ||
for (BaseTool tool : tools) { | ||
if (!tool.isSingleInput()) { | ||
throw new IllegalArgumentException(className + " does not support multi-input tool " + tool.getName() + "."); | ||
} | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.