Skip to content

Commit

Permalink
optimize code by spotless
Browse files Browse the repository at this point in the history
  • Loading branch information
HamaWhiteGG committed Jun 20, 2023
1 parent 74d340e commit 6d1d3fc
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,23 +21,20 @@
import com.hw.langchain.agents.agent.Agent;
import com.hw.langchain.agents.agent.AgentOutputParser;
import com.hw.langchain.agents.chat.output.parser.ChatOutputParser;
import com.hw.langchain.agents.mrkl.base.ZeroShotAgent;
import com.hw.langchain.agents.mrkl.output.parser.MRKLOutputParser;
import com.hw.langchain.base.language.BaseLanguageModel;
import com.hw.langchain.chains.llm.LLMChain;
import com.hw.langchain.prompts.base.BasePromptTemplate;
import com.hw.langchain.prompts.chat.BaseMessagePromptTemplate;
import com.hw.langchain.prompts.chat.ChatPromptTemplate;
import com.hw.langchain.prompts.chat.HumanMessagePromptTemplate;
import com.hw.langchain.prompts.chat.SystemMessagePromptTemplate;
import com.hw.langchain.prompts.prompt.PromptTemplate;
import com.hw.langchain.schema.AgentAction;
import com.hw.langchain.tools.base.BaseTool;

import org.apache.commons.lang3.tuple.Pair;

import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

import static com.hw.langchain.agents.chat.prompt.Prompt.*;
import static com.hw.langchain.agents.utils.Utils.validateToolsSingleInput;
Expand Down Expand Up @@ -76,12 +73,10 @@ public String constructScratchpad(List<Pair<AgentAction, String>> intermediateSt
}
}


private static AgentOutputParser getDefaultOutputParser(Map<String, Object> kwargs) {
return new ChatOutputParser();
}


public static void validateTools(List<BaseTool> tools) {
validateToolsSingleInput(ChatAgent.class.getSimpleName(), tools);
}
Expand All @@ -91,16 +86,18 @@ public List<String> stop() {
return List.of("Observation:");
}

public static BasePromptTemplate createPrompt(List<BaseTool> tools, String systemMessagePrefix, String systemMessageSuffix, String humanMessage, String formatInstructions, List<String> inputVariables) {
public static BasePromptTemplate createPrompt(List<BaseTool> tools, String systemMessagePrefix,
String systemMessageSuffix, String humanMessage, String formatInstructions, List<String> inputVariables) {
String toolNames = String.join(", ", tools.stream().map(BaseTool::getName).toList());
String toolStrings = String.join("\n", tools.stream().map(tool -> tool.getName() + ": " + tool.getDescription()).toList());
String toolStrings =
String.join("\n", tools.stream().map(tool -> tool.getName() + ": " + tool.getDescription()).toList());
String formattedInstructions = formatInstructions.replace("{tool_names}", toolNames);
String template = String.join("\n\n", systemMessagePrefix, toolStrings, formattedInstructions, systemMessageSuffix);
String template =
String.join("\n\n", systemMessagePrefix, toolStrings, formattedInstructions, systemMessageSuffix);

List<BaseMessagePromptTemplate> messages = List.of(
SystemMessagePromptTemplate.fromTemplate(template),
HumanMessagePromptTemplate.fromTemplate(humanMessage)
);
HumanMessagePromptTemplate.fromTemplate(humanMessage));
if (inputVariables == null) {
inputVariables = List.of("input", "agent_scratchpad");
}
Expand All @@ -119,11 +116,12 @@ public static Agent fromLLMAndTools(BaseLanguageModel llm, List<BaseTool> tools,
* Construct an agent from an LLM and tools.
*/
public static Agent fromLLMAndTools(BaseLanguageModel llm, List<BaseTool> tools, AgentOutputParser outputParser,
String systemMessagePrefix, String systemMessageSuffix, String humanMessage, String formatInstructions,
List<String> inputVariables, Map<String, Object> kwargs) {
String systemMessagePrefix, String systemMessageSuffix, String humanMessage, String formatInstructions,
List<String> inputVariables, Map<String, Object> kwargs) {
validateTools(tools);

var prompt = createPrompt(tools, systemMessagePrefix, systemMessageSuffix, humanMessage, formatInstructions, inputVariables);
var prompt = createPrompt(tools, systemMessagePrefix, systemMessageSuffix, humanMessage, formatInstructions,
inputVariables);
var llmChain = new LLMChain(llm, prompt);

var toolNames = tools.stream().map(BaseTool::getName).toList();
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,21 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.hw.langchain.agents.chat.output.parser;

import com.google.gson.Gson;
Expand All @@ -20,18 +38,19 @@ public class ChatOutputParser extends AgentOutputParser {

private static final String FINAL_ANSWER_ACTION = "Final Answer:";


@Override
public AgentResult parse(String text) {
boolean includesAnswer = text.contains(FINAL_ANSWER_ACTION);
try {
String action = text.split("```")[1];
Type mapType = new TypeToken<Map<String, Object>>() {}.getType();
Type mapType = new TypeToken<Map<String, Object>>() {
}.getType();
Map<String, Object> response = new Gson().fromJson(action.strip(), mapType);

boolean includesAction = response.containsKey("action") && response.containsKey("action_input");
if (includesAnswer && includesAction) {
throw new OutputParserException("Parsing LLM output produced a final answer and a parse-able action: " + text);
throw new OutputParserException(
"Parsing LLM output produced a final answer and a parse-able action: " + text);
}
return new AgentAction(response.get("action").toString(), response.get("action_input"), text);

Expand All @@ -41,7 +60,7 @@ public AgentResult parse(String text) {
}
String[] splitText = text.split(FINAL_ANSWER_ACTION);
String output = splitText[splitText.length - 1].strip();
return new AgentFinish(Map.ofEntries(Map.entry("output",output)), text);
return new AgentFinish(Map.ofEntries(Map.entry("output", output)), text);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@

import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

import static com.google.common.base.Preconditions.checkArgument;
import static com.hw.langchain.agents.mrkl.prompt.Prompt.*;
Expand Down Expand Up @@ -60,7 +59,8 @@ private static AgentOutputParser getDefaultOutputParser(Map<String, Object> kwar
*/
public static PromptTemplate createPrompt(List<BaseTool> tools, String prefix, String suffix,
String formatInstructions, List<String> inputVariables) {
String toolStrings = String.join("\n", tools.stream().map(tool -> tool.getName() + ": " + tool.getDescription()).toList());
String toolStrings =
String.join("\n", tools.stream().map(tool -> tool.getName() + ": " + tool.getDescription()).toList());
String toolNames = String.join(", ", tools.stream().map(BaseTool::getName).toList());
String formattedInstructions = formatInstructions.replace("{tool_names}", toolNames);
String template = String.join("\n\n", prefix, toolStrings, formattedInstructions, suffix);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,21 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.hw.langchain.agents.utils;

import com.hw.langchain.tools.base.BaseTool;
Expand All @@ -12,7 +30,8 @@ public class Utils {
public static void validateToolsSingleInput(String className, List<BaseTool> tools) {
for (BaseTool tool : tools) {
if (!tool.isSingleInput()) {
throw new IllegalArgumentException(className + " does not support multi-input tool " + tool.getName() + ".");
throw new IllegalArgumentException(
className + " does not support multi-input tool " + tool.getName() + ".");
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,7 @@
import lombok.Data;

import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;

/**
Expand Down

0 comments on commit 6d1d3fc

Please sign in to comment.