Skip to content

Commit

Permalink
Merge pull request HamaWhiteGG#7 from HamaWhiteGG/dev
Browse files Browse the repository at this point in the history
Support Memory: Add State to Chains and Agents
  • Loading branch information
HamaWhiteGG authored Jun 15, 2023
2 parents a0c63ee + a8cf70c commit 3dbc04d
Show file tree
Hide file tree
Showing 27 changed files with 835 additions and 94 deletions.
42 changes: 42 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,48 @@ Final Answer: 1.09874643447

1.09874643447
```

### 2.7 Memory: Add State to Chains and Agents
So far, all the chains and agents we’ve gone through have been stateless.
But often, you may want a chain or agent to have some concept of "memory" so that it may remember information about
its previous interactions. The clearest and simple example of this is when designing a chatBot -
you want it to remember previous messages so it can use context from that to have a better conversation.

```java
var llm = OpenAI.builder()
.temperature(0)
.build()
.init();

var conversation = new ConversationChain(llm);

var output = conversation.predict(Map.of("input", "Hi there!"));
System.out.println("Finished chain.\n'" + output + "'");

output = conversation.predict(Map.of("input", "I'm doing well! Just having a conversation with an AI."));
System.out.println("Finished chain.\n'" + output + "'");
```

```shell
The following is a friendly conversation between a human and an AI. The AI is talkative and provides lots of specific details from its context. If the AI does not know the answer to a question, it truthfully says it does not know.

Current conversation:

Human: Hi there!
AI:
Finished chain.
' Hi there! It's nice to meet you. How can I help you today?'
The following is a friendly conversation between a human and an AI. The AI is talkative and provides lots of specific details from its context. If the AI does not know the answer to a question, it truthfully says it does not know.
Current conversation:
Human: Hi there!
AI: Hi there! It's nice to meet you. How can I help you today?
Human: I'm doing well! Just having a conversation with an AI.
AI:
Finished chain.
' That's great! It's always nice to have a conversation with someone new. What would you like to talk about?'
```
## 3. Run Test Cases from Source
```
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ public List<String> inputKeys() {
public abstract String llmPrefix();

@Override
public AgentResult plan(List<Pair<AgentAction, String>> intermediateSteps, Map<String, ?> kwargs) {
public AgentResult plan(List<Pair<AgentAction, String>> intermediateSteps, Map<String, Object> kwargs) {
var fullInputs = getFullInputs(intermediateSteps, kwargs);
String fullOutput = llmChain.predict(fullInputs);
return outputParser.parse(fullOutput);
Expand All @@ -108,7 +108,8 @@ public AgentResult plan(List<Pair<AgentAction, String>> intermediateSteps, Map<S
/**
* Create the full inputs for the LLMChain from intermediate steps.
*/
public Map<String, ?> getFullInputs(List<Pair<AgentAction, String>> intermediateSteps, Map<String, ?> kwargs) {
public Map<String, Object> getFullInputs(List<Pair<AgentAction, String>> intermediateSteps,
Map<String, Object> kwargs) {
String thoughts = constructScratchpad(intermediateSteps);
var newInputs = Map.of("agent_scratchpad", thoughts, "stop", stop());
Map<String, Object> fullInputs = new HashMap<>(kwargs);
Expand Down Expand Up @@ -146,9 +147,9 @@ public AgentFinish returnStoppedResponse(String earlyStoppingMethod,

// We try to extract a final answer
AgentResult agentResult = this.outputParser.parse(fullOutput);
if (agentResult instanceof AgentFinish) {
if (agentResult instanceof AgentFinish agentFinish) {
// If we can extract, we send the correct stuff
return (AgentFinish) agentResult;
return agentFinish;
} else {
// If we can extract, but the tool is not the final tool, we just return the full output
return new AgentFinish(Map.of("output", fullOutput), fullOutput);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import com.hw.langchain.schema.AgentAction;
import com.hw.langchain.schema.AgentFinish;
import com.hw.langchain.schema.AgentResult;
import com.hw.langchain.schema.OutputParserException;
import com.hw.langchain.tools.base.BaseTool;

import org.apache.commons.lang3.tuple.Pair;
Expand Down Expand Up @@ -107,52 +106,39 @@ public Map<String, String> _return(AgentFinish output, List<Pair<AgentAction, St
*
* @return AgentFinish or List<Pair<AgentAction, String>>
*/
public Object takeNextStep(Map<String, BaseTool> nameToToolMap, Map<String, ?> inputs,
public Object takeNextStep(Map<String, BaseTool> nameToToolMap, Map<String, Object> inputs,
List<Pair<AgentAction, String>> intermediateSteps) {
AgentResult output = null;
try {
// Call the LLM to see what to do.
output = agent.plan(intermediateSteps, inputs);
LOG.info("Plan output: {}", output);
} catch (OutputParserException e) {
LOG.error("Error parsing output", e);
}
// Call the LLM to see what to do.
AgentResult output = agent.plan(intermediateSteps, inputs);
LOG.info("Plan output: {}", output);
if (output instanceof AgentFinish) {
return output;
}
List<AgentAction> actions;
if (output instanceof AgentAction) {
actions = List.of((AgentAction) output);
} else {
actions = (List<AgentAction>) output;
}
List<Pair<AgentAction, String>> result = new ArrayList<>();
for (AgentAction agentAction : actions) {
} else if (output instanceof AgentAction agentAction) {
String observation;
if (nameToToolMap.containsKey(agentAction.getTool())) {
BaseTool tool = nameToToolMap.get(agentAction.getTool());
var tool = nameToToolMap.get(agentAction.getTool());
boolean returnDirect = tool.isReturnDirect();
Map<String, Object> toolRunKwargs = agent.toolRunLoggingKwargs();
var toolRunKwargs = agent.toolRunLoggingKwargs();
if (returnDirect) {
toolRunKwargs.put("llm_prefix", "");
}
// We then call the tool on the tool input to get an observation
observation = tool.run(agentAction.getToolInput(), toolRunKwargs).toString();
LOG.info("Observation: {}", observation);
} else {
Map<String, Object> toolRunKwargs = agent.toolRunLoggingKwargs();
var toolRunKwargs = agent.toolRunLoggingKwargs();
observation = new InvalidTool().run(agentAction.getTool(), toolRunKwargs).toString();
}
result.add(Pair.of(agentAction, observation));
return List.of(Pair.of(agentAction, observation));
}
return result;
return null;
}

/**
* Run text through and get agent response.
*/
@Override
public Map<String, String> _call(Map<String, ?> inputs) {
public Map<String, String> _call(Map<String, Object> inputs) {
// Construct a mapping of tool name to tool for easy lookup
Map<String, BaseTool> nameToToolMap = tools.stream().collect(Collectors.toMap(BaseTool::getName, tool -> tool));

Expand All @@ -166,9 +152,8 @@ public Map<String, String> _call(Map<String, ?> inputs) {
while (shouldContinue(iterations, timeElapsed)) {
Object nextStepOutput = takeNextStep(nameToToolMap, inputs, intermediateSteps);
LOG.info("NextStepOutput: {}", nextStepOutput);

if (nextStepOutput instanceof AgentFinish) {
return _return((AgentFinish) nextStepOutput, intermediateSteps);
if (nextStepOutput instanceof AgentFinish agentFinish) {
return _return(agentFinish, intermediateSteps);
}

var nextOutput = (List<Pair<AgentAction, String>>) nextStepOutput;
Expand All @@ -194,10 +179,7 @@ private boolean shouldContinue(int iterations, double timeElapsed) {
if (maxIterations != null && iterations >= maxIterations) {
return false;
}
if (maxExecutionTime != null && timeElapsed >= maxExecutionTime) {
return false;
}
return true;
return maxExecutionTime == null || timeElapsed < maxExecutionTime;
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ public List<String> returnValues() {
* @param kwargs User inputs.
* @return Action specifying what tool to use.
*/
public abstract AgentResult plan(List<Pair<AgentAction, String>> intermediateSteps, Map<String, ?> kwargs);
public abstract AgentResult plan(List<Pair<AgentAction, String>> intermediateSteps, Map<String, Object> kwargs);

public static BaseSingleActionAgent fromLLMAndTools(
BaseLanguageModel llm,
Expand Down
124 changes: 89 additions & 35 deletions langchain-core/src/main/java/com/hw/langchain/chains/base/Chain.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,19 @@

package com.hw.langchain.chains.base;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import com.hw.langchain.schema.BaseMemory;

import static com.google.common.base.Preconditions.checkArgument;
import java.util.*;

/**
* Base interface that all chains should implement.
*
* @author HamaWhite
*/
public abstract class Chain {

protected BaseMemory memory;

public abstract String chainType();

/**
Expand All @@ -42,10 +43,29 @@ public abstract class Chain {
*/
public abstract List<String> outputKeys();

/**
* Check that all inputs are present
*/
private void validateInputs(Map<String, Object> inputs) {
Set<String> missingKeys = new HashSet<>(inputKeys());
missingKeys.removeAll(inputs.keySet());
if (!missingKeys.isEmpty()) {
throw new IllegalArgumentException(String.format("Missing some input keys: %s", missingKeys));
}
}

private void validateOutputs(Map<String, String> outputs) {
Set<String> missingKeys = new HashSet<>(outputKeys());
missingKeys.removeAll(outputs.keySet());
if (!missingKeys.isEmpty()) {
throw new IllegalArgumentException(String.format("Missing some output keys: %s", missingKeys));
}
}

/**
* Run the logic of this chain and return the output.
*/
public abstract Map<String, String> _call(Map<String, ?> inputs);
public abstract Map<String, String> _call(Map<String, Object> inputs);

/**
* Run the logic of this chain and add to output if desired.
Expand All @@ -57,7 +77,7 @@ public abstract class Chain {
* Defaults to False.
*/
public Map<String, String> call(String input, boolean returnOnlyOutputs) {
Map<String, String> inputs = prepInputs(input);
Map<String, Object> inputs = prepInputs(input);
return call(inputs, returnOnlyOutputs);
}

Expand All @@ -70,55 +90,89 @@ public Map<String, String> call(String input, boolean returnOnlyOutputs) {
* If False, both input keys and new keys generated by this chain will be returned.
* Defaults to False.
*/
public Map<String, String> call(Map<String, ?> inputs, boolean returnOnlyOutputs) {
public Map<String, String> call(Map<String, Object> inputs, boolean returnOnlyOutputs) {
inputs = prepInputs(inputs);
Map<String, String> outputs = _call(inputs);
return prepOutputs(inputs, outputs, returnOnlyOutputs);
}

/**
* Run the chain as text in, text out
* Validate and prep outputs.
*/
public String run(String args) {
validateOutputKeys();
return call(args, false).get(outputKeys().get(0));
private Map<String, String> prepOutputs(Map<String, Object> inputs, Map<String, String> outputs,
boolean returnOnlyOutputs) {
validateOutputs(outputs);
if (memory != null) {
memory.saveContext(inputs, outputs);
}
if (returnOnlyOutputs) {
return outputs;
} else {
Map<String, String> result = new HashMap<>();
inputs.forEach((k, v) -> result.put(k, v.toString()));
result.putAll(outputs);
return result;
}
}

/**
* Run the chain as multiple variables, text out.
* Validate and prep inputs.
*/
public String run(Map<String, Object> args) {
validateOutputKeys();
return call(args, false).get(outputKeys().get(0));
private Map<String, Object> prepInputs(String input) {
Set<String> inputKeys = new HashSet<>(inputKeys());
if (memory != null) {
// If there are multiple input keys, but some get set by memory so that only one is not set,
// we can still figure out which key it is.
Set<String> memoryVariables = new HashSet<>(memory.memoryVariables());
inputKeys.removeAll(memoryVariables);
}
if (inputKeys.size() != 1) {
throw new IllegalArgumentException(
String.format(
"A single string input was passed in, but this chain expects multiple inputs (%s). " +
"When a chain expects multiple inputs, please call it by passing in a dictionary, "
+
"eg `chain(Map.of('foo', 1, 'bar', 2))`",
inputKeys));
}
return Map.of(new ArrayList<>(inputKeys).get(0), input);
}

private void validateOutputKeys() {
List<String> outputKeys = outputKeys();
checkArgument(outputKeys.size() == 1,
"run not supported when there is not exactly one output key. Got %s",
outputKeys);
/**
* Validate and prep inputs.
*/
private Map<String, Object> prepInputs(Map<String, Object> inputs) {
Map<String, Object> newInputs = new HashMap<>(inputs);
if (memory != null) {
Map<String, Object> externalContext = memory.loadMemoryVariables(inputs);
newInputs.putAll(externalContext);
}
validateInputs(newInputs);
return newInputs;
}

/**
* Validate and prep inputs.
* Run the chain as text in, text out
*/
private Map<String, String> prepInputs(String inputs) {
List<String> inputKeys = inputKeys();
checkArgument(inputKeys.size() == 1,
"A single string input was passed in, but this chain expects multiple inputs (%s)", inputKeys);
return Map.of(inputKeys.get(0), inputs);
public String run(String args) {
if (outputKeys().size() != 1) {
throw new IllegalArgumentException(
"The `run` method is not supported when there is not exactly one output key. Got " + outputKeys()
+ ".");
}
return call(args, false).get(outputKeys().get(0));
}

/**
* Validate and prep outputs.
* Run the chain as multiple variables, text out.
*/
private Map<String, String> prepOutputs(Map<String, ?> inputs, Map<String, String> outputs,
boolean returnOnlyOutputs) {
if (returnOnlyOutputs) {
return outputs;
public String run(Map<String, Object> args) {
if (outputKeys().size() != 1) {
throw new IllegalArgumentException(
"The `run` method is not supported when there is not exactly one output key. Got " + outputKeys()
+ ".");
}
Map<String, String> resultMap = new HashMap<>();
inputs.forEach((key, value) -> resultMap.put(key, String.valueOf(value)));
resultMap.putAll(outputs);
return resultMap;
return call(args, false).get(outputKeys().get(0));
}

}
Loading

0 comments on commit 3dbc04d

Please sign in to comment.