Skip to content

Commit

Permalink
Support Agents of Chat models
Browse files Browse the repository at this point in the history
  • Loading branch information
HamaWhiteGG committed Jun 21, 2023
1 parent 6d1d3fc commit 432464b
Show file tree
Hide file tree
Showing 8 changed files with 51 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
import com.hw.langchain.tools.base.BaseTool;

import org.apache.commons.lang3.tuple.Pair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.HashMap;
import java.util.List;
Expand All @@ -42,6 +44,8 @@
*/
public abstract class Agent extends BaseSingleActionAgent {

private static final Logger LOG = LoggerFactory.getLogger(Agent.class);

private LLMChain llmChain;

private List<String> allowedTools;
Expand Down Expand Up @@ -103,6 +107,7 @@ public List<String> inputKeys() {
public AgentResult plan(List<Pair<AgentAction, String>> intermediateSteps, Map<String, Object> kwargs) {
var fullInputs = getFullInputs(intermediateSteps, kwargs);
String fullOutput = llmChain.predict(fullInputs);
LOG.info("fullOutput: \n{}", fullOutput);
return outputParser.parse(fullOutput);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ public Map<String, String> _call(Map<String, Object> inputs) {

// We now enter the agent loop (until it returns something).
while (shouldContinue(iterations, timeElapsed)) {
Object nextStepOutput = takeNextStep(nameToToolMap, inputs, intermediateSteps);
var nextStepOutput = takeNextStep(nameToToolMap, inputs, intermediateSteps);
LOG.info("NextStepOutput: {}", nextStepOutput);
if (nextStepOutput instanceof AgentFinish agentFinish) {
return _return(agentFinish, intermediateSteps);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ public String llmPrefix() {

@Override
public String constructScratchpad(List<Pair<AgentAction, String>> intermediateSteps) {
Object agentScratchpad = super.constructScratchpad(intermediateSteps);
var agentScratchpad = super.constructScratchpad(intermediateSteps);
if (!(agentScratchpad instanceof String)) {
throw new IllegalArgumentException("agent_scratchpad should be of type String.");
}
Expand Down Expand Up @@ -91,9 +91,13 @@ public static BasePromptTemplate createPrompt(List<BaseTool> tools, String syste
String toolNames = String.join(", ", tools.stream().map(BaseTool::getName).toList());
String toolStrings =
String.join("\n", tools.stream().map(tool -> tool.getName() + ": " + tool.getDescription()).toList());
String formattedInstructions = formatInstructions.replace("{tool_names}", toolNames);

formatInstructions = formatInstructions.replace("{tool_names}", toolNames);
// In Python format() method, the curly braces '{{}}' are used to represent the output '{}'.
formatInstructions = formatInstructions.replace("{{{{", "{{").replace("}}}}", "}}");

String template =
String.join("\n\n", systemMessagePrefix, toolStrings, formattedInstructions, systemMessageSuffix);
String.join("\n\n", systemMessagePrefix, toolStrings, formatInstructions, systemMessageSuffix);

List<BaseMessagePromptTemplate> messages = List.of(
SystemMessagePromptTemplate.fromTemplate(template),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ public AgentResult parse(String text) {
"Parsing LLM output produced a final answer and a parse-able action: " + text);
}
return new AgentAction(response.get("action").toString(), response.get("action_input"), text);

} catch (Exception e) {
if (!includesAnswer) {
throw new OutputParserException("Could not parse LLM output: " + text);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,12 @@ public static PromptTemplate createPrompt(List<BaseTool> tools, String prefix, S
String toolStrings =
String.join("\n", tools.stream().map(tool -> tool.getName() + ": " + tool.getDescription()).toList());
String toolNames = String.join(", ", tools.stream().map(BaseTool::getName).toList());
String formattedInstructions = formatInstructions.replace("{tool_names}", toolNames);
String template = String.join("\n\n", prefix, toolStrings, formattedInstructions, suffix);

formatInstructions = formatInstructions.replace("{tool_names}", toolNames);
// In Python format() method, the curly braces '{{}}' are used to represent the output '{}'.
formatInstructions = formatInstructions.replace("{{{{", "{{").replace("}}}}", "}}");

String template = String.join("\n\n", prefix, toolStrings, formatInstructions, suffix);

if (inputVariables == null) {
inputVariables = List.of("input", "agent_scratchpad");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import com.hw.langchain.schema.PromptValue;

import lombok.Data;
import lombok.NoArgsConstructor;

import java.util.List;
import java.util.Map;
Expand All @@ -31,6 +32,7 @@
* @author HamaWhite
*/
@Data
@NoArgsConstructor
public abstract class StringPromptTemplate extends BasePromptTemplate {

public StringPromptTemplate(List<String> inputVariables) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,17 @@ public PromptTemplate(List<String> inputVariables, String template, BaseOutputPa

@Override
public String format(Map<String, Object> kwargs) {
return StringSubstitutor.replace(template, kwargs, "{", "}");
String text = StringSubstitutor.replace(template, kwargs, "{", "}");
// In Python format() method, the curly braces '{{}}' are used to represent the output '{}'.
return text.replace("{{", "{").replace("}}", "}");
}

public static PromptTemplate fromTemplate(String template) {
List<String> variableNames = new ArrayList<>();
StringSubstitutor substitutor = new StringSubstitutor(variable -> {
variableNames.add(variable);
if (!variable.startsWith("{") && !variable.endsWith("}")) {
variableNames.add(variable);
}
return null;
});
substitutor.setVariablePrefix("{");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,14 @@

import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.List;

import static com.hw.langchain.agents.initialize.Initialize.initializeAgent;
import static com.hw.langchain.agents.load.tools.LoadTools.loadTools;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertTrue;

/**
Expand All @@ -37,6 +40,8 @@
@Disabled("Test requires costly OpenAI calls, can be run manually.")
class AgentExecutorTest {

private static final Logger LOG = LoggerFactory.getLogger(AgentExecutorTest.class);

@Test
void testAgentWithLLM() {
// First, let's load the language model we're going to use to control the agent.
Expand All @@ -51,7 +56,7 @@ void testAgentWithLLM() {
// Now let's test it out!
String actual = agent.run(
"What was the high temperature in SF yesterday in Fahrenheit? What is that number raised to the .023 power?");

LOG.info("actual: \n{}", actual);
assertTrue(actual.matches("^1\\.\\d+$"));
}

Expand All @@ -67,8 +72,23 @@ void testAgentWithChatModels() {
// Finally, let's initialize an agent with the tools, the language model, and the type of agent we want to use.
var agent = initializeAgent(tools, chat, AgentType.CHAT_ZERO_SHOT_REACT_DESCRIPTION);

// Now let's test it out!
String actual = agent.run("Who is Olivia Wilde's boyfriend? What is his current age raised to the 0.23 power?");
assertTrue(actual.matches("^1\\.\\d+$"));
/*
* Now let's test it out!,
*
* The correct return result for the final step is similar to the following: "The answer to the second question
* is 2.42427848557. Final Answer: Jason Sudeikis, and his age raised to the 0.23 power is 2.42427848557."
*
* However, sometimes OpenAI only returns "I now know the answer to the second part of the question." without
* including "Final Answer: xxx", which can cause parsing errors in the results.
*
* My temperature setting is 0, so this issue should not occur. If you know the answer, please feel free to let
* me know.
*/

String result = agent.run("Who is Olivia Wilde's boyfriend? What is his current age raised to the 0.23 power?");

// Jason Sudeikis, and his age raised to the 0.23 power is 2.42427848557.
LOG.info("result: \n{}", result);
assertNotNull(result, "result should not be null");
}
}

0 comments on commit 432464b

Please sign in to comment.