Skip to content

Commit

Permalink
Merge pull request HamaWhiteGG#15 from HamaWhiteGG/dev
Browse files Browse the repository at this point in the history
Support Agents of Chat models
  • Loading branch information
HamaWhiteGG authored Jun 21, 2023
2 parents a5af5fb + 432464b commit 144ce3e
Show file tree
Hide file tree
Showing 14 changed files with 458 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,16 @@

package com.hw.langchain.agents.agent;

import com.hw.langchain.base.language.BaseLanguageModel;
import com.hw.langchain.chains.llm.LLMChain;
import com.hw.langchain.schema.AgentAction;
import com.hw.langchain.schema.AgentFinish;
import com.hw.langchain.schema.AgentResult;
import com.hw.langchain.tools.base.BaseTool;

import org.apache.commons.lang3.tuple.Pair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.HashMap;
import java.util.List;
Expand All @@ -41,6 +44,8 @@
*/
public abstract class Agent extends BaseSingleActionAgent {

private static final Logger LOG = LoggerFactory.getLogger(Agent.class);

private LLMChain llmChain;

private List<String> allowedTools;
Expand All @@ -63,9 +68,9 @@ public List<String> stop() {
* Construct the scratchpad that lets the agent continue its thought process.
*
* @param intermediateSteps Steps the LLM has taken to date, along with observations
* @return str or List[BaseMessage]
* @return String or List[BaseMessage]
*/
public String constructScratchpad(List<Pair<AgentAction, String>> intermediateSteps) {
public Object constructScratchpad(List<Pair<AgentAction, String>> intermediateSteps) {
StringBuilder thoughts = new StringBuilder();
for (Pair<AgentAction, String> step : intermediateSteps) {
thoughts.append(step.getKey().getLog());
Expand Down Expand Up @@ -102,6 +107,7 @@ public List<String> inputKeys() {
public AgentResult plan(List<Pair<AgentAction, String>> intermediateSteps, Map<String, Object> kwargs) {
var fullInputs = getFullInputs(intermediateSteps, kwargs);
String fullOutput = llmChain.predict(fullInputs);
LOG.info("fullOutput: \n{}", fullOutput);
return outputParser.parse(fullOutput);
}

Expand All @@ -110,13 +116,20 @@ public AgentResult plan(List<Pair<AgentAction, String>> intermediateSteps, Map<S
*/
public Map<String, Object> getFullInputs(List<Pair<AgentAction, String>> intermediateSteps,
Map<String, Object> kwargs) {
String thoughts = constructScratchpad(intermediateSteps);
Object thoughts = constructScratchpad(intermediateSteps);
var newInputs = Map.of("agent_scratchpad", thoughts, "stop", stop());
Map<String, Object> fullInputs = new HashMap<>(kwargs);
fullInputs.putAll(newInputs);
return fullInputs;
}

public static BaseSingleActionAgent fromLLMAndTools(
BaseLanguageModel llm,
List<BaseTool> tools,
Map<String, Object> kwargs) {
throw new UnsupportedOperationException();
}

public AgentFinish returnStoppedResponse(String earlyStoppingMethod,
List<Pair<AgentAction, String>> intermediateSteps, Map<String, ?> kwargs) {
if (earlyStoppingMethod.equals("force")) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ public Map<String, String> _call(Map<String, Object> inputs) {

// We now enter the agent loop (until it returns something).
while (shouldContinue(iterations, timeElapsed)) {
Object nextStepOutput = takeNextStep(nameToToolMap, inputs, intermediateSteps);
var nextStepOutput = takeNextStep(nameToToolMap, inputs, intermediateSteps);
LOG.info("NextStepOutput: {}", nextStepOutput);
if (nextStepOutput instanceof AgentFinish agentFinish) {
return _return(agentFinish, intermediateSteps);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.hw.langchain.agents.chat.base;

import com.hw.langchain.agents.agent.Agent;
import com.hw.langchain.agents.agent.AgentOutputParser;
import com.hw.langchain.agents.chat.output.parser.ChatOutputParser;
import com.hw.langchain.base.language.BaseLanguageModel;
import com.hw.langchain.chains.llm.LLMChain;
import com.hw.langchain.prompts.base.BasePromptTemplate;
import com.hw.langchain.prompts.chat.BaseMessagePromptTemplate;
import com.hw.langchain.prompts.chat.ChatPromptTemplate;
import com.hw.langchain.prompts.chat.HumanMessagePromptTemplate;
import com.hw.langchain.prompts.chat.SystemMessagePromptTemplate;
import com.hw.langchain.schema.AgentAction;
import com.hw.langchain.tools.base.BaseTool;

import org.apache.commons.lang3.tuple.Pair;

import java.util.List;
import java.util.Map;

import static com.hw.langchain.agents.chat.prompt.Prompt.*;
import static com.hw.langchain.agents.utils.Utils.validateToolsSingleInput;

/**
* @author HamaWhite
*/
public class ChatAgent extends Agent {

public ChatAgent(LLMChain llmChain, List<String> allowedTools, AgentOutputParser outputParser) {
super(llmChain, allowedTools, outputParser);
}

@Override
public String observationPrefix() {
return "Observation: ";
}

@Override
public String llmPrefix() {
return "Thought:";
}

@Override
public String constructScratchpad(List<Pair<AgentAction, String>> intermediateSteps) {
var agentScratchpad = super.constructScratchpad(intermediateSteps);
if (!(agentScratchpad instanceof String)) {
throw new IllegalArgumentException("agent_scratchpad should be of type String.");
}
String scratchpad = agentScratchpad.toString();
if (!scratchpad.isEmpty()) {
return "This was your previous work (but I haven't seen any of it! I only see what "
+ "you return as the final answer):\n" + scratchpad;
} else {
return scratchpad;
}
}

private static AgentOutputParser getDefaultOutputParser(Map<String, Object> kwargs) {
return new ChatOutputParser();
}

public static void validateTools(List<BaseTool> tools) {
validateToolsSingleInput(ChatAgent.class.getSimpleName(), tools);
}

@Override
public List<String> stop() {
return List.of("Observation:");
}

public static BasePromptTemplate createPrompt(List<BaseTool> tools, String systemMessagePrefix,
String systemMessageSuffix, String humanMessage, String formatInstructions, List<String> inputVariables) {
String toolNames = String.join(", ", tools.stream().map(BaseTool::getName).toList());
String toolStrings =
String.join("\n", tools.stream().map(tool -> tool.getName() + ": " + tool.getDescription()).toList());

formatInstructions = formatInstructions.replace("{tool_names}", toolNames);
// In Python format() method, the curly braces '{{}}' are used to represent the output '{}'.
formatInstructions = formatInstructions.replace("{{{{", "{{").replace("}}}}", "}}");

String template =
String.join("\n\n", systemMessagePrefix, toolStrings, formatInstructions, systemMessageSuffix);

List<BaseMessagePromptTemplate> messages = List.of(
SystemMessagePromptTemplate.fromTemplate(template),
HumanMessagePromptTemplate.fromTemplate(humanMessage));
if (inputVariables == null) {
inputVariables = List.of("input", "agent_scratchpad");
}
return new ChatPromptTemplate(inputVariables, messages);
}

/**
* Construct an agent from an LLM and tools.
*/
public static Agent fromLLMAndTools(BaseLanguageModel llm, List<BaseTool> tools, Map<String, Object> kwargs) {
return fromLLMAndTools(llm, tools, null, SYSTEM_MESSAGE_PREFIX, SYSTEM_MESSAGE_SUFFIX, HUMAN_MESSAGE,
FORMAT_INSTRUCTIONS, null, kwargs);
}

/**
* Construct an agent from an LLM and tools.
*/
public static Agent fromLLMAndTools(BaseLanguageModel llm, List<BaseTool> tools, AgentOutputParser outputParser,
String systemMessagePrefix, String systemMessageSuffix, String humanMessage, String formatInstructions,
List<String> inputVariables, Map<String, Object> kwargs) {
validateTools(tools);

var prompt = createPrompt(tools, systemMessagePrefix, systemMessageSuffix, humanMessage, formatInstructions,
inputVariables);
var llmChain = new LLMChain(llm, prompt);

var toolNames = tools.stream().map(BaseTool::getName).toList();
outputParser = (outputParser != null) ? outputParser : getDefaultOutputParser(kwargs);

return new ChatAgent(llmChain, toolNames, outputParser);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.hw.langchain.agents.chat.output.parser;

import com.google.gson.Gson;
import com.google.gson.reflect.TypeToken;
import com.hw.langchain.agents.agent.AgentOutputParser;
import com.hw.langchain.schema.AgentAction;
import com.hw.langchain.schema.AgentFinish;
import com.hw.langchain.schema.AgentResult;
import com.hw.langchain.schema.OutputParserException;

import java.lang.reflect.Type;
import java.util.Map;

import static com.hw.langchain.agents.chat.prompt.Prompt.FORMAT_INSTRUCTIONS;

/**
* @author HamaWhite
*/
public class ChatOutputParser extends AgentOutputParser {

private static final String FINAL_ANSWER_ACTION = "Final Answer:";

@Override
public AgentResult parse(String text) {
boolean includesAnswer = text.contains(FINAL_ANSWER_ACTION);
try {
String action = text.split("```")[1];
Type mapType = new TypeToken<Map<String, Object>>() {
}.getType();
Map<String, Object> response = new Gson().fromJson(action.strip(), mapType);

boolean includesAction = response.containsKey("action") && response.containsKey("action_input");
if (includesAnswer && includesAction) {
throw new OutputParserException(
"Parsing LLM output produced a final answer and a parse-able action: " + text);
}
return new AgentAction(response.get("action").toString(), response.get("action_input"), text);
} catch (Exception e) {
if (!includesAnswer) {
throw new OutputParserException("Could not parse LLM output: " + text);
}
String[] splitText = text.split(FINAL_ANSWER_ACTION);
String output = splitText[splitText.length - 1].strip();
return new AgentFinish(Map.ofEntries(Map.entry("output", output)), text);
}
}

@Override
public String getFormatInstructions() {
return FORMAT_INSTRUCTIONS;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.hw.langchain.agents.chat.prompt;

/**
* @author HamaWhite
*/
public class Prompt {

public static String SYSTEM_MESSAGE_PREFIX = """
Answer the following questions as best you can. You have access to the following tools:""";

public static String FORMAT_INSTRUCTIONS =
"""
The way you use the tools is by specifying a json blob.
Specifically, this json should have a `action` key (with the name of the tool to use) and a `action_input` key (with the input to the tool going here).
The only values that should be in the "action" field are: {tool_names}
The $JSON_BLOB should only contain a SINGLE action, do NOT return a list of multiple actions. Here is an example of a valid $JSON_BLOB:
```
{{{{
"action": $TOOL_NAME,
"action_input": $INPUT
}}}}
```
ALWAYS use the following format:
Question: the input question you must answer
Thought: you should always think about what to do
Action:
```
$JSON_BLOB
```
Observation: the result of the action
... (this Thought/Action/Observation can repeat N times)
Thought: I now know the final answer
Final Answer: the final answer to the original input question""";

public static String SYSTEM_MESSAGE_SUFFIX = """
Begin! Reminder to always use the exact characters `Final Answer` when responding.""";

public static String HUMAN_MESSAGE = "{input}\n\n{agent_scratchpad}";
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@

import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

import static com.google.common.base.Preconditions.checkArgument;
import static com.hw.langchain.agents.mrkl.prompt.Prompt.*;
Expand Down Expand Up @@ -60,13 +59,15 @@ private static AgentOutputParser getDefaultOutputParser(Map<String, Object> kwar
*/
public static PromptTemplate createPrompt(List<BaseTool> tools, String prefix, String suffix,
String formatInstructions, List<String> inputVariables) {
String toolStrings = tools.stream()
.map(tool -> tool.getName() + ": " + tool.getDescription())
.collect(Collectors.joining("\n"));

String toolStrings =
String.join("\n", tools.stream().map(tool -> tool.getName() + ": " + tool.getDescription()).toList());
String toolNames = String.join(", ", tools.stream().map(BaseTool::getName).toList());
String formattedInstructions = formatInstructions.replace("{tool_names}", toolNames);
String template = String.join("\n\n", prefix, toolStrings, formattedInstructions, suffix);

formatInstructions = formatInstructions.replace("{tool_names}", toolNames);
// In Python format() method, the curly braces '{{}}' are used to represent the output '{}'.
formatInstructions = formatInstructions.replace("{{{{", "{{").replace("}}}}", "}}");

String template = String.join("\n\n", prefix, toolStrings, formatInstructions, suffix);

if (inputVariables == null) {
inputVariables = List.of("input", "agent_scratchpad");
Expand All @@ -81,6 +82,9 @@ public static Agent fromLLMAndTools(BaseLanguageModel llm, List<BaseTool> tools,
return fromLLMAndTools(llm, tools, null, PREFIX, SUFFIX, FORMAT_INSTRUCTIONS, null, kwargs);
}

/**
* Construct an agent from an LLM and tools.
*/
public static Agent fromLLMAndTools(BaseLanguageModel llm, List<BaseTool> tools, AgentOutputParser outputParser,
String prefix, String suffix, String formatInstructions, List<String> inputVariables,
Map<String, Object> kwargs) {
Expand Down
Loading

0 comments on commit 144ce3e

Please sign in to comment.