Skip to content

Commit

Permalink
Add Agents(10%)
Browse files Browse the repository at this point in the history
  • Loading branch information
HamaWhiteGG committed Jun 9, 2023
1 parent 8512d58 commit fd2457e
Show file tree
Hide file tree
Showing 20 changed files with 909 additions and 22 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.hw.langchain.agents.agent;

/**
* Class responsible for calling the language model and deciding the action.
* <p>
* This is driven by an LLMChain. The prompt in the LLMChain MUST include
* a variable called "agent_scratchpad" where the agent can put its
* intermediary work.
*
* @author HamaWhite
*/
public class Agent extends BaseSingleActionAgent {
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.hw.langchain.agents.agent;

import com.hw.langchain.chains.base.Chain;
import com.hw.langchain.tools.base.BaseTool;

import java.util.List;
import java.util.Map;

/**
* Consists of an agent using tools.
*
* @author HamaWhite
*/
public class AgentExecutor extends Chain {

/**
* Create from agent and tools.
*/
public static AgentExecutor fromAgentAndTools(BaseSingleActionAgent agent, List<BaseTool> tools,
Map<String, Object> kwargs) {
return new AgentExecutor(agent, tools, kwargs);
}

private AgentExecutor(BaseSingleActionAgent agent, List<BaseTool> tools, Map<String, Object> kwargs) {
// Constructor implementation
}

@Override
public String chainType() {
return null;
}

@Override
public List<String> inputKeys() {
return null;
}

@Override
public List<String> outputKeys() {
return null;
}

@Override
public Map<String, String> _call(Map<String, Object> inputs) {
return null;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.hw.langchain.agents.agent;

import com.hw.langchain.base.language.BaseLanguageModel;
import com.hw.langchain.tools.base.BaseTool;

import java.util.List;
import java.util.Map;

/**
* Base Agent class.
*
* @author HamaWhite
*/
public class BaseSingleActionAgent {

public static BaseSingleActionAgent fromLLMAndTools(
BaseLanguageModel llm,
List<BaseTool> tools,
Map<String, Object> kwargs) {
throw new UnsupportedOperationException();
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.hw.langchain.agents.agent.types;

import java.util.Optional;

/**
* @author HamaWhite
*/

public enum AgentType {

ZERO_SHOT_REACT_DESCRIPTION("zero-shot-react-description"),
REACT_DOCSTORE("react-docstore"),
SELF_ASK_WITH_SEARCH("self-ask-with-search"),
CONVERSATIONAL_REACT_DESCRIPTION("conversational-react-description"),
CHAT_ZERO_SHOT_REACT_DESCRIPTION("chat-zero-shot-react-description"),
CHAT_CONVERSATIONAL_REACT_DESCRIPTION("chat-conversational-react-description"),
STRUCTURED_CHAT_ZERO_SHOT_REACT_DESCRIPTION("structured-chat-zero-shot-react-description");

private final String value;

private AgentType(String value) {
this.value = value;
}

public String getValue() {
return value;
}

public static Optional<AgentType> fromValue(String value) {
for (AgentType agentType : AgentType.values()) {
if (agentType.getValue().equals(value)) {
return Optional.of(agentType);
}
}
return Optional.empty();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.hw.langchain.agents.initialize;

import com.hw.langchain.agents.agent.AgentExecutor;
import com.hw.langchain.agents.agent.BaseSingleActionAgent;
import com.hw.langchain.agents.agent.types.AgentType;
import com.hw.langchain.base.language.BaseLanguageModel;
import com.hw.langchain.tools.base.BaseTool;

import org.apache.commons.lang3.reflect.MethodUtils;

import lombok.SneakyThrows;

import java.lang.reflect.InvocationTargetException;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import static com.hw.langchain.agents.loading.Loading.loadAgent;
import static com.hw.langchain.agents.types.Types.AGENT_TO_CLASS;

/**
* Load agent.
*
* @author HamaWhite
*/
public class Initialize {

/**
* Load an agent executor given tools and LLM.
*
* @param tools List of tools this agent has access to.
* @param llm Language model to use as the agent.
* @param agent Agent type to use. If None and agent_path is also None, will default to AgentType.ZERO_SHOT_REACT_DESCRIPTION.
* @param agentPath Path to serialized agent to use.
* @param agentKwargs Additional key word arguments to pass to the underlying agent
* @param kwargs Additional key word arguments passed to the agent executor
* @return An agent executor
*/
@SneakyThrows({InvocationTargetException.class, NoSuchMethodException.class, IllegalAccessException.class})
public static AgentExecutor initializeAgent(List<BaseTool> tools, BaseLanguageModel llm, AgentType agent,
String agentPath, Map<String, Object> agentKwargs, Map<String, Object> kwargs) {
BaseSingleActionAgent agentObj;
if (agent == null && agentPath == null) {
agent = AgentType.ZERO_SHOT_REACT_DESCRIPTION;
}
if (agent != null && agentPath != null) {
throw new IllegalArgumentException(
"Both `agent` and `agentPath` are specified, but at most only one should be.");
}
if (agent != null) {
if (!AGENT_TO_CLASS.containsKey(agent)) {
throw new IllegalArgumentException(
"Got unknown agent type: " + agent + ". Valid types are: " + AGENT_TO_CLASS.keySet() + ".");
}
Class<? extends BaseSingleActionAgent> clazz = AGENT_TO_CLASS.get(agent);
agentKwargs = agentKwargs != null ? agentKwargs : new HashMap<>();
agentObj = (BaseSingleActionAgent) MethodUtils.invokeStaticMethod(clazz, "fromLLMAndTools", llm, tools,
agentKwargs);
} else if (agentPath != null) {
agentObj = loadAgent(agentPath, Map.of("llm", llm, "tools", tools));
} else {
throw new IllegalArgumentException(
"Somehow both `agent` and `agentPath` are null, this should never happen.");
}
return AgentExecutor.fromAgentAndTools(agentObj, tools, kwargs);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.hw.langchain.agents.load.tools;

import com.hw.langchain.base.language.BaseLanguageModel;
import com.hw.langchain.chains.llm.math.base.LLMMathChain;
import com.hw.langchain.tools.base.BaseTool;
import com.hw.langchain.tools.base.Tool;
import com.hw.langchain.utilities.serpapi.SerpAPIWrapper;

import org.apache.commons.lang3.tuple.Pair;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.function.Function;
import java.util.stream.Collectors;

/**
* Load tools.
*
* @author HamaWhite
*/
public class LoadTools {

public static BaseTool getLLMMath(BaseLanguageModel llm) {
return new Tool(
"Calculator",
"Useful for when you need to answer questions about math.",
LLMMathChain.fromLLM(llm)::run);
}

public static BaseTool getSerpapi(Map<String, Object> kwargs) {
return new Tool("Search",
"A search engine. Useful for when you need to answer questions about current events. Input should be a search query.",
SerpAPIWrapper.of(kwargs)::run);
}

private static Map<String, Pair<Function<Map<String, Object>, BaseTool>, List<String>>> _EXTRA_OPTIONAL_TOOLS =
Map.of(
"serpapi", Pair.of(LoadTools::getSerpapi, List.of("serpapi_api_key", "aiosession")));

private static Map<String, Function<BaseLanguageModel, BaseTool>> _LLM_TOOLS = Map.of(
"llm-math", LoadTools::getLLMMath);

/**
* Load tools based on their name.
*
* @param toolNames name of tools to load.
* @param llm language model, may be needed to initialize certain tools.
* @param kwargs keyword arguments
* @return List of tools.
*/
public static List<BaseTool> loadTools(List<String> toolNames, BaseLanguageModel llm, Map<String, Object> kwargs) {
List<BaseTool> tools = new ArrayList<>();
for (String name : toolNames) {
if (_LLM_TOOLS.containsKey(name)) {
BaseTool tool = _LLM_TOOLS.get(name).apply(llm);
tools.add(tool);
} else if (_EXTRA_OPTIONAL_TOOLS.containsKey(name)) {
var pair = _EXTRA_OPTIONAL_TOOLS.get(name);
List<String> extraKeys = pair.getRight();
Map<String, Object> subKwargs = extraKeys.stream()
.filter(kwargs::containsKey)
.collect(Collectors.toMap(key -> key, kwargs::get));

BaseTool tool = pair.getLeft().apply(subKwargs);
tools.add(tool);
} else {
throw new IllegalArgumentException("Got unknown tool " + name);
}
}
return tools;
}
}
Loading

0 comments on commit fd2457e

Please sign in to comment.