forked from HamaWhiteGG/langchain-java
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
8512d58
commit fd2457e
Showing
20 changed files
with
909 additions
and
22 deletions.
There are no files selected for viewing
31 changes: 31 additions & 0 deletions
31
langchain-core/src/main/java/com/hw/langchain/agents/agent/Agent.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package com.hw.langchain.agents.agent; | ||
|
||
/** | ||
* Class responsible for calling the language model and deciding the action. | ||
* <p> | ||
* This is driven by an LLMChain. The prompt in the LLMChain MUST include | ||
* a variable called "agent_scratchpad" where the agent can put its | ||
* intermediary work. | ||
* | ||
* @author HamaWhite | ||
*/ | ||
public class Agent extends BaseSingleActionAgent { | ||
} |
65 changes: 65 additions & 0 deletions
65
langchain-core/src/main/java/com/hw/langchain/agents/agent/AgentExecutor.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package com.hw.langchain.agents.agent; | ||
|
||
import com.hw.langchain.chains.base.Chain; | ||
import com.hw.langchain.tools.base.BaseTool; | ||
|
||
import java.util.List; | ||
import java.util.Map; | ||
|
||
/** | ||
* Consists of an agent using tools. | ||
* | ||
* @author HamaWhite | ||
*/ | ||
public class AgentExecutor extends Chain { | ||
|
||
/** | ||
* Create from agent and tools. | ||
*/ | ||
public static AgentExecutor fromAgentAndTools(BaseSingleActionAgent agent, List<BaseTool> tools, | ||
Map<String, Object> kwargs) { | ||
return new AgentExecutor(agent, tools, kwargs); | ||
} | ||
|
||
private AgentExecutor(BaseSingleActionAgent agent, List<BaseTool> tools, Map<String, Object> kwargs) { | ||
// Constructor implementation | ||
} | ||
|
||
@Override | ||
public String chainType() { | ||
return null; | ||
} | ||
|
||
@Override | ||
public List<String> inputKeys() { | ||
return null; | ||
} | ||
|
||
@Override | ||
public List<String> outputKeys() { | ||
return null; | ||
} | ||
|
||
@Override | ||
public Map<String, String> _call(Map<String, Object> inputs) { | ||
return null; | ||
} | ||
} |
41 changes: 41 additions & 0 deletions
41
langchain-core/src/main/java/com/hw/langchain/agents/agent/BaseSingleActionAgent.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package com.hw.langchain.agents.agent; | ||
|
||
import com.hw.langchain.base.language.BaseLanguageModel; | ||
import com.hw.langchain.tools.base.BaseTool; | ||
|
||
import java.util.List; | ||
import java.util.Map; | ||
|
||
/** | ||
* Base Agent class. | ||
* | ||
* @author HamaWhite | ||
*/ | ||
public class BaseSingleActionAgent { | ||
|
||
public static BaseSingleActionAgent fromLLMAndTools( | ||
BaseLanguageModel llm, | ||
List<BaseTool> tools, | ||
Map<String, Object> kwargs) { | ||
throw new UnsupportedOperationException(); | ||
} | ||
|
||
} |
55 changes: 55 additions & 0 deletions
55
langchain-core/src/main/java/com/hw/langchain/agents/agent/types/AgentType.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package com.hw.langchain.agents.agent.types; | ||
|
||
import java.util.Optional; | ||
|
||
/** | ||
* @author HamaWhite | ||
*/ | ||
|
||
public enum AgentType { | ||
|
||
ZERO_SHOT_REACT_DESCRIPTION("zero-shot-react-description"), | ||
REACT_DOCSTORE("react-docstore"), | ||
SELF_ASK_WITH_SEARCH("self-ask-with-search"), | ||
CONVERSATIONAL_REACT_DESCRIPTION("conversational-react-description"), | ||
CHAT_ZERO_SHOT_REACT_DESCRIPTION("chat-zero-shot-react-description"), | ||
CHAT_CONVERSATIONAL_REACT_DESCRIPTION("chat-conversational-react-description"), | ||
STRUCTURED_CHAT_ZERO_SHOT_REACT_DESCRIPTION("structured-chat-zero-shot-react-description"); | ||
|
||
private final String value; | ||
|
||
private AgentType(String value) { | ||
this.value = value; | ||
} | ||
|
||
public String getValue() { | ||
return value; | ||
} | ||
|
||
public static Optional<AgentType> fromValue(String value) { | ||
for (AgentType agentType : AgentType.values()) { | ||
if (agentType.getValue().equals(value)) { | ||
return Optional.of(agentType); | ||
} | ||
} | ||
return Optional.empty(); | ||
} | ||
} |
85 changes: 85 additions & 0 deletions
85
langchain-core/src/main/java/com/hw/langchain/agents/initialize/Initialize.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package com.hw.langchain.agents.initialize; | ||
|
||
import com.hw.langchain.agents.agent.AgentExecutor; | ||
import com.hw.langchain.agents.agent.BaseSingleActionAgent; | ||
import com.hw.langchain.agents.agent.types.AgentType; | ||
import com.hw.langchain.base.language.BaseLanguageModel; | ||
import com.hw.langchain.tools.base.BaseTool; | ||
|
||
import org.apache.commons.lang3.reflect.MethodUtils; | ||
|
||
import lombok.SneakyThrows; | ||
|
||
import java.lang.reflect.InvocationTargetException; | ||
import java.util.HashMap; | ||
import java.util.List; | ||
import java.util.Map; | ||
|
||
import static com.hw.langchain.agents.loading.Loading.loadAgent; | ||
import static com.hw.langchain.agents.types.Types.AGENT_TO_CLASS; | ||
|
||
/** | ||
* Load agent. | ||
* | ||
* @author HamaWhite | ||
*/ | ||
public class Initialize { | ||
|
||
/** | ||
* Load an agent executor given tools and LLM. | ||
* | ||
* @param tools List of tools this agent has access to. | ||
* @param llm Language model to use as the agent. | ||
* @param agent Agent type to use. If None and agent_path is also None, will default to AgentType.ZERO_SHOT_REACT_DESCRIPTION. | ||
* @param agentPath Path to serialized agent to use. | ||
* @param agentKwargs Additional key word arguments to pass to the underlying agent | ||
* @param kwargs Additional key word arguments passed to the agent executor | ||
* @return An agent executor | ||
*/ | ||
@SneakyThrows({InvocationTargetException.class, NoSuchMethodException.class, IllegalAccessException.class}) | ||
public static AgentExecutor initializeAgent(List<BaseTool> tools, BaseLanguageModel llm, AgentType agent, | ||
String agentPath, Map<String, Object> agentKwargs, Map<String, Object> kwargs) { | ||
BaseSingleActionAgent agentObj; | ||
if (agent == null && agentPath == null) { | ||
agent = AgentType.ZERO_SHOT_REACT_DESCRIPTION; | ||
} | ||
if (agent != null && agentPath != null) { | ||
throw new IllegalArgumentException( | ||
"Both `agent` and `agentPath` are specified, but at most only one should be."); | ||
} | ||
if (agent != null) { | ||
if (!AGENT_TO_CLASS.containsKey(agent)) { | ||
throw new IllegalArgumentException( | ||
"Got unknown agent type: " + agent + ". Valid types are: " + AGENT_TO_CLASS.keySet() + "."); | ||
} | ||
Class<? extends BaseSingleActionAgent> clazz = AGENT_TO_CLASS.get(agent); | ||
agentKwargs = agentKwargs != null ? agentKwargs : new HashMap<>(); | ||
agentObj = (BaseSingleActionAgent) MethodUtils.invokeStaticMethod(clazz, "fromLLMAndTools", llm, tools, | ||
agentKwargs); | ||
} else if (agentPath != null) { | ||
agentObj = loadAgent(agentPath, Map.of("llm", llm, "tools", tools)); | ||
} else { | ||
throw new IllegalArgumentException( | ||
"Somehow both `agent` and `agentPath` are null, this should never happen."); | ||
} | ||
return AgentExecutor.fromAgentAndTools(agentObj, tools, kwargs); | ||
} | ||
} |
91 changes: 91 additions & 0 deletions
91
langchain-core/src/main/java/com/hw/langchain/agents/load/tools/LoadTools.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package com.hw.langchain.agents.load.tools; | ||
|
||
import com.hw.langchain.base.language.BaseLanguageModel; | ||
import com.hw.langchain.chains.llm.math.base.LLMMathChain; | ||
import com.hw.langchain.tools.base.BaseTool; | ||
import com.hw.langchain.tools.base.Tool; | ||
import com.hw.langchain.utilities.serpapi.SerpAPIWrapper; | ||
|
||
import org.apache.commons.lang3.tuple.Pair; | ||
|
||
import java.util.ArrayList; | ||
import java.util.List; | ||
import java.util.Map; | ||
import java.util.function.Function; | ||
import java.util.stream.Collectors; | ||
|
||
/** | ||
* Load tools. | ||
* | ||
* @author HamaWhite | ||
*/ | ||
public class LoadTools { | ||
|
||
public static BaseTool getLLMMath(BaseLanguageModel llm) { | ||
return new Tool( | ||
"Calculator", | ||
"Useful for when you need to answer questions about math.", | ||
LLMMathChain.fromLLM(llm)::run); | ||
} | ||
|
||
public static BaseTool getSerpapi(Map<String, Object> kwargs) { | ||
return new Tool("Search", | ||
"A search engine. Useful for when you need to answer questions about current events. Input should be a search query.", | ||
SerpAPIWrapper.of(kwargs)::run); | ||
} | ||
|
||
private static Map<String, Pair<Function<Map<String, Object>, BaseTool>, List<String>>> _EXTRA_OPTIONAL_TOOLS = | ||
Map.of( | ||
"serpapi", Pair.of(LoadTools::getSerpapi, List.of("serpapi_api_key", "aiosession"))); | ||
|
||
private static Map<String, Function<BaseLanguageModel, BaseTool>> _LLM_TOOLS = Map.of( | ||
"llm-math", LoadTools::getLLMMath); | ||
|
||
/** | ||
* Load tools based on their name. | ||
* | ||
* @param toolNames name of tools to load. | ||
* @param llm language model, may be needed to initialize certain tools. | ||
* @param kwargs keyword arguments | ||
* @return List of tools. | ||
*/ | ||
public static List<BaseTool> loadTools(List<String> toolNames, BaseLanguageModel llm, Map<String, Object> kwargs) { | ||
List<BaseTool> tools = new ArrayList<>(); | ||
for (String name : toolNames) { | ||
if (_LLM_TOOLS.containsKey(name)) { | ||
BaseTool tool = _LLM_TOOLS.get(name).apply(llm); | ||
tools.add(tool); | ||
} else if (_EXTRA_OPTIONAL_TOOLS.containsKey(name)) { | ||
var pair = _EXTRA_OPTIONAL_TOOLS.get(name); | ||
List<String> extraKeys = pair.getRight(); | ||
Map<String, Object> subKwargs = extraKeys.stream() | ||
.filter(kwargs::containsKey) | ||
.collect(Collectors.toMap(key -> key, kwargs::get)); | ||
|
||
BaseTool tool = pair.getLeft().apply(subKwargs); | ||
tools.add(tool); | ||
} else { | ||
throw new IllegalArgumentException("Got unknown tool " + name); | ||
} | ||
} | ||
return tools; | ||
} | ||
} |
Oops, something went wrong.