Skip to content

Commit

Permalink
Add Agents of ChatModels (70%)
Browse files Browse the repository at this point in the history
  • Loading branch information
HamaWhiteGG committed Jun 20, 2023
1 parent e24fdf7 commit 74d340e
Show file tree
Hide file tree
Showing 11 changed files with 343 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

package com.hw.langchain.agents.agent;

import com.hw.langchain.base.language.BaseLanguageModel;
import com.hw.langchain.chains.llm.LLMChain;
import com.hw.langchain.schema.AgentAction;
import com.hw.langchain.schema.AgentFinish;
Expand Down Expand Up @@ -63,9 +64,9 @@ public List<String> stop() {
* Construct the scratchpad that lets the agent continue its thought process.
*
* @param intermediateSteps Steps the LLM has taken to date, along with observations
* @return str or List[BaseMessage]
* @return String or List[BaseMessage]
*/
public String constructScratchpad(List<Pair<AgentAction, String>> intermediateSteps) {
public Object constructScratchpad(List<Pair<AgentAction, String>> intermediateSteps) {
StringBuilder thoughts = new StringBuilder();
for (Pair<AgentAction, String> step : intermediateSteps) {
thoughts.append(step.getKey().getLog());
Expand Down Expand Up @@ -110,13 +111,20 @@ public AgentResult plan(List<Pair<AgentAction, String>> intermediateSteps, Map<S
*/
public Map<String, Object> getFullInputs(List<Pair<AgentAction, String>> intermediateSteps,
Map<String, Object> kwargs) {
String thoughts = constructScratchpad(intermediateSteps);
Object thoughts = constructScratchpad(intermediateSteps);
var newInputs = Map.of("agent_scratchpad", thoughts, "stop", stop());
Map<String, Object> fullInputs = new HashMap<>(kwargs);
fullInputs.putAll(newInputs);
return fullInputs;
}

public static BaseSingleActionAgent fromLLMAndTools(
BaseLanguageModel llm,
List<BaseTool> tools,
Map<String, Object> kwargs) {
throw new UnsupportedOperationException();
}

public AgentFinish returnStoppedResponse(String earlyStoppingMethod,
List<Pair<AgentAction, String>> intermediateSteps, Map<String, ?> kwargs) {
if (earlyStoppingMethod.equals("force")) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.hw.langchain.agents.chat.base;

import com.hw.langchain.agents.agent.Agent;
import com.hw.langchain.agents.agent.AgentOutputParser;
import com.hw.langchain.agents.chat.output.parser.ChatOutputParser;
import com.hw.langchain.agents.mrkl.base.ZeroShotAgent;
import com.hw.langchain.agents.mrkl.output.parser.MRKLOutputParser;
import com.hw.langchain.base.language.BaseLanguageModel;
import com.hw.langchain.chains.llm.LLMChain;
import com.hw.langchain.prompts.base.BasePromptTemplate;
import com.hw.langchain.prompts.chat.BaseMessagePromptTemplate;
import com.hw.langchain.prompts.chat.ChatPromptTemplate;
import com.hw.langchain.prompts.chat.HumanMessagePromptTemplate;
import com.hw.langchain.prompts.chat.SystemMessagePromptTemplate;
import com.hw.langchain.prompts.prompt.PromptTemplate;
import com.hw.langchain.schema.AgentAction;
import com.hw.langchain.tools.base.BaseTool;
import org.apache.commons.lang3.tuple.Pair;

import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

import static com.hw.langchain.agents.chat.prompt.Prompt.*;
import static com.hw.langchain.agents.utils.Utils.validateToolsSingleInput;

/**
* @author HamaWhite
*/
public class ChatAgent extends Agent {

public ChatAgent(LLMChain llmChain, List<String> allowedTools, AgentOutputParser outputParser) {
super(llmChain, allowedTools, outputParser);
}

@Override
public String observationPrefix() {
return "Observation: ";
}

@Override
public String llmPrefix() {
return "Thought:";
}

@Override
public String constructScratchpad(List<Pair<AgentAction, String>> intermediateSteps) {
Object agentScratchpad = super.constructScratchpad(intermediateSteps);
if (!(agentScratchpad instanceof String)) {
throw new IllegalArgumentException("agent_scratchpad should be of type String.");
}
String scratchpad = agentScratchpad.toString();
if (!scratchpad.isEmpty()) {
return "This was your previous work (but I haven't seen any of it! I only see what "
+ "you return as the final answer):\n" + scratchpad;
} else {
return scratchpad;
}
}


private static AgentOutputParser getDefaultOutputParser(Map<String, Object> kwargs) {
return new ChatOutputParser();
}


public static void validateTools(List<BaseTool> tools) {
validateToolsSingleInput(ChatAgent.class.getSimpleName(), tools);
}

@Override
public List<String> stop() {
return List.of("Observation:");
}

public static BasePromptTemplate createPrompt(List<BaseTool> tools, String systemMessagePrefix, String systemMessageSuffix, String humanMessage, String formatInstructions, List<String> inputVariables) {
String toolNames = String.join(", ", tools.stream().map(BaseTool::getName).toList());
String toolStrings = String.join("\n", tools.stream().map(tool -> tool.getName() + ": " + tool.getDescription()).toList());
String formattedInstructions = formatInstructions.replace("{tool_names}", toolNames);
String template = String.join("\n\n", systemMessagePrefix, toolStrings, formattedInstructions, systemMessageSuffix);

List<BaseMessagePromptTemplate> messages = List.of(
SystemMessagePromptTemplate.fromTemplate(template),
HumanMessagePromptTemplate.fromTemplate(humanMessage)
);
if (inputVariables == null) {
inputVariables = List.of("input", "agent_scratchpad");
}
return new ChatPromptTemplate(inputVariables, messages);
}

/**
* Construct an agent from an LLM and tools.
*/
public static Agent fromLLMAndTools(BaseLanguageModel llm, List<BaseTool> tools, Map<String, Object> kwargs) {
return fromLLMAndTools(llm, tools, null, SYSTEM_MESSAGE_PREFIX, SYSTEM_MESSAGE_SUFFIX, HUMAN_MESSAGE,
FORMAT_INSTRUCTIONS, null, kwargs);
}

/**
* Construct an agent from an LLM and tools.
*/
public static Agent fromLLMAndTools(BaseLanguageModel llm, List<BaseTool> tools, AgentOutputParser outputParser,
String systemMessagePrefix, String systemMessageSuffix, String humanMessage, String formatInstructions,
List<String> inputVariables, Map<String, Object> kwargs) {
validateTools(tools);

var prompt = createPrompt(tools, systemMessagePrefix, systemMessageSuffix, humanMessage, formatInstructions, inputVariables);
var llmChain = new LLMChain(llm, prompt);

var toolNames = tools.stream().map(BaseTool::getName).toList();
outputParser = (outputParser != null) ? outputParser : getDefaultOutputParser(kwargs);

return new ChatAgent(llmChain, toolNames, outputParser);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
package com.hw.langchain.agents.chat.output.parser;

import com.google.gson.Gson;
import com.google.gson.reflect.TypeToken;
import com.hw.langchain.agents.agent.AgentOutputParser;
import com.hw.langchain.schema.AgentAction;
import com.hw.langchain.schema.AgentFinish;
import com.hw.langchain.schema.AgentResult;
import com.hw.langchain.schema.OutputParserException;

import java.lang.reflect.Type;
import java.util.Map;

import static com.hw.langchain.agents.chat.prompt.Prompt.FORMAT_INSTRUCTIONS;

/**
* @author HamaWhite
*/
public class ChatOutputParser extends AgentOutputParser {

private static final String FINAL_ANSWER_ACTION = "Final Answer:";


@Override
public AgentResult parse(String text) {
boolean includesAnswer = text.contains(FINAL_ANSWER_ACTION);
try {
String action = text.split("```")[1];
Type mapType = new TypeToken<Map<String, Object>>() {}.getType();
Map<String, Object> response = new Gson().fromJson(action.strip(), mapType);

boolean includesAction = response.containsKey("action") && response.containsKey("action_input");
if (includesAnswer && includesAction) {
throw new OutputParserException("Parsing LLM output produced a final answer and a parse-able action: " + text);
}
return new AgentAction(response.get("action").toString(), response.get("action_input"), text);

} catch (Exception e) {
if (!includesAnswer) {
throw new OutputParserException("Could not parse LLM output: " + text);
}
String[] splitText = text.split(FINAL_ANSWER_ACTION);
String output = splitText[splitText.length - 1].strip();
return new AgentFinish(Map.ofEntries(Map.entry("output",output)), text);
}
}

@Override
public String getFormatInstructions() {
return FORMAT_INSTRUCTIONS;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.hw.langchain.agents.chat.prompt;

/**
* @author HamaWhite
*/
public class Prompt {

public static String SYSTEM_MESSAGE_PREFIX = """
Answer the following questions as best you can. You have access to the following tools:""";

public static String FORMAT_INSTRUCTIONS =
"""
The way you use the tools is by specifying a json blob.
Specifically, this json should have a `action` key (with the name of the tool to use) and a `action_input` key (with the input to the tool going here).
The only values that should be in the "action" field are: {tool_names}
The $JSON_BLOB should only contain a SINGLE action, do NOT return a list of multiple actions. Here is an example of a valid $JSON_BLOB:
```
{{{{
"action": $TOOL_NAME,
"action_input": $INPUT
}}}}
```
ALWAYS use the following format:
Question: the input question you must answer
Thought: you should always think about what to do
Action:
```
$JSON_BLOB
```
Observation: the result of the action
... (this Thought/Action/Observation can repeat N times)
Thought: I now know the final answer
Final Answer: the final answer to the original input question""";

public static String SYSTEM_MESSAGE_SUFFIX = """
Begin! Reminder to always use the exact characters `Final Answer` when responding.""";

public static String HUMAN_MESSAGE = "{input}\n\n{agent_scratchpad}";
}
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,7 @@ private static AgentOutputParser getDefaultOutputParser(Map<String, Object> kwar
*/
public static PromptTemplate createPrompt(List<BaseTool> tools, String prefix, String suffix,
String formatInstructions, List<String> inputVariables) {
String toolStrings = tools.stream()
.map(tool -> tool.getName() + ": " + tool.getDescription())
.collect(Collectors.joining("\n"));

String toolStrings = String.join("\n", tools.stream().map(tool -> tool.getName() + ": " + tool.getDescription()).toList());
String toolNames = String.join(", ", tools.stream().map(BaseTool::getName).toList());
String formattedInstructions = formatInstructions.replace("{tool_names}", toolNames);
String template = String.join("\n\n", prefix, toolStrings, formattedInstructions, suffix);
Expand All @@ -81,6 +78,9 @@ public static Agent fromLLMAndTools(BaseLanguageModel llm, List<BaseTool> tools,
return fromLLMAndTools(llm, tools, null, PREFIX, SUFFIX, FORMAT_INSTRUCTIONS, null, kwargs);
}

/**
* Construct an agent from an LLM and tools.
*/
public static Agent fromLLMAndTools(BaseLanguageModel llm, List<BaseTool> tools, AgentOutputParser outputParser,
String prefix, String suffix, String formatInstructions, List<String> inputVariables,
Map<String, Object> kwargs) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import com.hw.langchain.agents.agent.BaseSingleActionAgent;
import com.hw.langchain.agents.agent.types.AgentType;
import com.hw.langchain.agents.chat.base.ChatAgent;
import com.hw.langchain.agents.mrkl.base.ZeroShotAgent;

import java.util.Map;
Expand All @@ -30,5 +31,6 @@
public class Types {

public static final Map<AgentType, Class<? extends BaseSingleActionAgent>> AGENT_TO_CLASS = Map.of(
AgentType.ZERO_SHOT_REACT_DESCRIPTION, ZeroShotAgent.class);
AgentType.ZERO_SHOT_REACT_DESCRIPTION, ZeroShotAgent.class,
AgentType.CHAT_ZERO_SHOT_REACT_DESCRIPTION, ChatAgent.class);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package com.hw.langchain.agents.utils;

import com.hw.langchain.tools.base.BaseTool;

import java.util.List;

/**
* @author HamaWhite
*/
public class Utils {

public static void validateToolsSingleInput(String className, List<BaseTool> tools) {
for (BaseTool tool : tools) {
if (!tool.isSingleInput()) {
throw new IllegalArgumentException(className + " does not support multi-input tool " + tool.getName() + ".");
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@
import lombok.Data;

import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;

/**
* Interface LangChain tools must implement.
Expand Down Expand Up @@ -55,6 +58,21 @@ public BaseTool(String name, String description) {
this.description = description;
}

/**
* Whether the tool only accepts a single input.
*/
public boolean isSingleInput() {
var keys = args().keySet()
.stream()
.filter(k -> !k.equals("kwargs"))
.collect(Collectors.toSet());
return keys.size() == 1;
}

public Map<String, Object> args() {
return null;
}

/**
* Use the tool.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,7 @@ public Tool(String name, String description, Function<String, String> func) {
/**
* The tool's input arguments.
*/
public Map<String, Object> getArgs() {

public Map<String, Object> args() {
// For backwards compatibility, if the function signature is ambiguous,
// assume it takes a single string input.
return Map.of("tool_input", Map.of("type", "string"));
Expand Down
Loading

0 comments on commit 74d340e

Please sign in to comment.