forked from HamaWhiteGG/langchain-java
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request HamaWhiteGG#54 from HamaWhiteGG/dev
Support Spark SQL Agent
- Loading branch information
Showing
37 changed files
with
2,337 additions
and
49 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
<?xml version="1.0" encoding="UTF-8"?> | ||
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" | ||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> | ||
<modelVersion>4.0.0</modelVersion> | ||
<parent> | ||
<groupId>io.github.hamawhitegg</groupId> | ||
<artifactId>langchain-bigdata</artifactId> | ||
<version>0.1.9</version> | ||
</parent> | ||
|
||
<artifactId>langchain-flink</artifactId> | ||
|
||
<dependencies> | ||
<dependency> | ||
<groupId>io.github.hamawhitegg</groupId> | ||
<artifactId>langchain-core</artifactId> | ||
<version>${project.version}</version> | ||
</dependency> | ||
</dependencies> | ||
|
||
</project> |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
<?xml version="1.0" encoding="UTF-8"?> | ||
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" | ||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> | ||
<modelVersion>4.0.0</modelVersion> | ||
<parent> | ||
<groupId>io.github.hamawhitegg</groupId> | ||
<artifactId>langchain-bigdata</artifactId> | ||
<version>0.1.9</version> | ||
</parent> | ||
|
||
<artifactId>langchain-spark</artifactId> | ||
|
||
<dependencies> | ||
<dependency> | ||
<groupId>io.github.hamawhitegg</groupId> | ||
<artifactId>langchain-core</artifactId> | ||
<version>${project.version}</version> | ||
</dependency> | ||
|
||
<dependency> | ||
<groupId>org.apache.spark</groupId> | ||
<artifactId>spark-core_${scala.binary.version}</artifactId> | ||
</dependency> | ||
|
||
<dependency> | ||
<groupId>org.apache.spark</groupId> | ||
<artifactId>spark-sql_${scala.binary.version}</artifactId> | ||
</dependency> | ||
|
||
<dependency> | ||
<groupId>org.projectlombok</groupId> | ||
<artifactId>lombok</artifactId> | ||
</dependency> | ||
|
||
<dependency> | ||
<groupId>org.junit.jupiter</groupId> | ||
<artifactId>junit-jupiter-api</artifactId> | ||
</dependency> | ||
|
||
<dependency> | ||
<groupId>org.junit.jupiter</groupId> | ||
<artifactId>junit-jupiter-engine</artifactId> | ||
</dependency> | ||
</dependencies> | ||
|
||
</project> |
38 changes: 38 additions & 0 deletions
38
...data/langchain-spark/src/main/java/com/hw/langchain/agents/toolkits/base/BaseToolkit.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package com.hw.langchain.agents.toolkits.base; | ||
|
||
import com.hw.langchain.tools.base.BaseTool; | ||
|
||
import java.util.List; | ||
|
||
/** | ||
* Base Toolkit representing a collection of related tools. | ||
* | ||
* @author HamaWhite | ||
*/ | ||
public interface BaseToolkit { | ||
|
||
/** | ||
* Get the tools in the toolkit. | ||
* | ||
* @return a list of tools in the toolkit. | ||
*/ | ||
List<BaseTool> getTools(); | ||
} |
88 changes: 88 additions & 0 deletions
88
...in-spark/src/main/java/com/hw/langchain/agents/toolkits/spark/sql/base/SparkSqlAgent.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package com.hw.langchain.agents.toolkits.spark.sql.base; | ||
|
||
import com.hw.langchain.agents.agent.Agent; | ||
import com.hw.langchain.agents.agent.AgentExecutor; | ||
import com.hw.langchain.agents.mrkl.base.ZeroShotAgent; | ||
import com.hw.langchain.agents.toolkits.spark.sql.toolkit.SparkSqlToolkit; | ||
import com.hw.langchain.base.language.BaseLanguageModel; | ||
import com.hw.langchain.chains.llm.LLMChain; | ||
import com.hw.langchain.prompts.prompt.PromptTemplate; | ||
import com.hw.langchain.tools.base.BaseTool; | ||
|
||
import java.util.List; | ||
import java.util.Map; | ||
|
||
import static com.hw.langchain.agents.mrkl.prompt.Prompt.FORMAT_INSTRUCTIONS; | ||
import static com.hw.langchain.agents.toolkits.spark.sql.prompt.Prompt.SQL_PREFIX; | ||
import static com.hw.langchain.agents.toolkits.spark.sql.prompt.Prompt.SQL_SUFFIX; | ||
import static com.hw.langchain.prompts.utils.FormatUtils.formatTemplate; | ||
|
||
/** | ||
* @author HamaWhite | ||
*/ | ||
public class SparkSqlAgent { | ||
|
||
private SparkSqlAgent() { | ||
// private constructor to hide the implicit public one | ||
throw new IllegalStateException("Utility class"); | ||
} | ||
|
||
/** | ||
* Construct a Spark SQL agent from an LLM and tools. | ||
*/ | ||
public static AgentExecutor createSparkSqlAgent(BaseLanguageModel llm, SparkSqlToolkit toolkit) { | ||
return createSparkSqlAgent(llm, toolkit, SQL_PREFIX, SQL_SUFFIX, FORMAT_INSTRUCTIONS, null, 10, 15, null, | ||
"force"); | ||
} | ||
|
||
/** | ||
* Construct a Spark SQL agent from an LLM and tools. | ||
*/ | ||
@SuppressWarnings("all") | ||
public static AgentExecutor createSparkSqlAgent( | ||
BaseLanguageModel llm, | ||
SparkSqlToolkit toolkit, | ||
String prefix, | ||
String suffix, | ||
String formatInstructions, | ||
List<String> inputVariables, | ||
int topK, | ||
Integer maxIterations, | ||
Float maxExecutionTime, | ||
String earlyStoppingMethod) { | ||
List<BaseTool> tools = toolkit.getTools(); | ||
prefix = formatTemplate(prefix, Map.of("top_k", topK)); | ||
|
||
PromptTemplate prompt = ZeroShotAgent.createPrompt(tools, prefix, suffix, formatInstructions, inputVariables); | ||
LLMChain llmChain = new LLMChain(llm, prompt); | ||
|
||
List<String> toolNames = tools.stream().map(BaseTool::getName).toList(); | ||
Agent agent = new ZeroShotAgent(llmChain, toolNames); | ||
|
||
return AgentExecutor.builder() | ||
.agent(agent) | ||
.tools(tools) | ||
.maxIterations(maxIterations) | ||
.maxExecutionTime(maxExecutionTime) | ||
.earlyStoppingMethod(earlyStoppingMethod) | ||
.build(); | ||
} | ||
} |
53 changes: 53 additions & 0 deletions
53
...ngchain-spark/src/main/java/com/hw/langchain/agents/toolkits/spark/sql/prompt/Prompt.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package com.hw.langchain.agents.toolkits.spark.sql.prompt; | ||
|
||
/** | ||
* @author HamaWhite | ||
*/ | ||
public class Prompt { | ||
|
||
private Prompt() { | ||
// private constructor to hide the implicit public one | ||
throw new IllegalStateException("Utility class"); | ||
} | ||
|
||
public static final String SQL_PREFIX = | ||
""" | ||
You are an agent designed to interact with Spark SQL. | ||
Given an input question, create a syntactically correct Spark SQL query to run, then look at the results of the query and return the answer. | ||
Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most {top_k} results. | ||
You can order the results by a relevant column to return the most interesting examples in the database. | ||
Never query for all the columns from a specific table, only ask for the relevant columns given the question. | ||
You have access to tools for interacting with the database. | ||
Only use the below tools. Only use the information returned by the below tools to construct your final answer. | ||
You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again. | ||
DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database. | ||
If the question does not seem related to the database, just return "I don't know" as the answer. | ||
"""; | ||
|
||
public static final String SQL_SUFFIX = """ | ||
Begin! | ||
Question: {input} | ||
Thought: I should look at the tables in the database to see what I can query. | ||
{agent_scratchpad}"""; | ||
} |
56 changes: 56 additions & 0 deletions
56
...ark/src/main/java/com/hw/langchain/agents/toolkits/spark/sql/toolkit/SparkSqlToolkit.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package com.hw.langchain.agents.toolkits.spark.sql.toolkit; | ||
|
||
import com.hw.langchain.agents.toolkits.base.BaseToolkit; | ||
import com.hw.langchain.base.language.BaseLanguageModel; | ||
import com.hw.langchain.tools.base.BaseTool; | ||
import com.hw.langchain.tools.spark.sql.tool.InfoSparkSQLTool; | ||
import com.hw.langchain.tools.spark.sql.tool.ListSparkSqlTool; | ||
import com.hw.langchain.tools.spark.sql.tool.QueryCheckerTool; | ||
import com.hw.langchain.tools.spark.sql.tool.QuerySparkSqlTool; | ||
import com.hw.langchain.utilities.spark.sql.SparkSql; | ||
|
||
import java.util.List; | ||
|
||
/** | ||
* Toolkit for interacting with Spark SQL. | ||
* | ||
* @author HamaWhite | ||
*/ | ||
public class SparkSqlToolkit implements BaseToolkit { | ||
|
||
private final SparkSql db; | ||
|
||
private final BaseLanguageModel llm; | ||
|
||
public SparkSqlToolkit(SparkSql db, BaseLanguageModel llm) { | ||
this.db = db; | ||
this.llm = llm; | ||
} | ||
|
||
@Override | ||
public List<BaseTool> getTools() { | ||
return List.of( | ||
new QuerySparkSqlTool(db), | ||
new InfoSparkSQLTool(db), | ||
new ListSparkSqlTool(db), | ||
new QueryCheckerTool(db, llm)); | ||
} | ||
} |
46 changes: 46 additions & 0 deletions
46
...bigdata/langchain-spark/src/main/java/com/hw/langchain/tools/spark/sql/prompt/Prompt.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package com.hw.langchain.tools.spark.sql.prompt; | ||
|
||
/** | ||
* @author HamaWhite | ||
*/ | ||
public class Prompt { | ||
|
||
private Prompt() { | ||
// private constructor to hide the implicit public one | ||
throw new IllegalStateException("Utility class"); | ||
} | ||
|
||
public static final String QUERY_CHECKER = | ||
""" | ||
{query} | ||
Double check the Spark SQL query above for common mistakes, including: | ||
- Using NOT IN with NULL values | ||
- Using UNION when UNION ALL should have been used | ||
- Using BETWEEN for exclusive ranges | ||
- Data type mismatch in predicates | ||
- Properly quoting identifiers | ||
- Using the correct number of arguments for functions | ||
- Casting to the correct data type | ||
- Using the proper columns for joins | ||
If there are any of the above mistakes, rewrite the query. If there are no mistakes, just reproduce the original query."""; | ||
|
||
} |
Oops, something went wrong.