forked from HamaWhiteGG/langchain-java
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Support Spark SQL Agent HamaWhiteGG#53
- Loading branch information
1 parent
631646f
commit d96a0c2
Showing
24 changed files
with
1,670 additions
and
47 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
20 changes: 20 additions & 0 deletions
20
...data/langchain-spark/src/main/java/com/hw/langchain/agents/toolkits/base/BaseToolkit.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
package com.hw.langchain.agents.toolkits.base; | ||
|
||
import com.hw.langchain.tools.base.BaseTool; | ||
|
||
import java.util.List; | ||
|
||
/** | ||
* Base Toolkit representing a collection of related tools. | ||
* | ||
* @author HamaWhite | ||
*/ | ||
public interface BaseToolkit { | ||
|
||
/** | ||
* Get the tools in the toolkit. | ||
* | ||
* @return a list of tools in the toolkit. | ||
*/ | ||
List<BaseTool> getTools(); | ||
} |
70 changes: 70 additions & 0 deletions
70
...in-spark/src/main/java/com/hw/langchain/agents/toolkits/spark/sql/base/SparkSqlAgent.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
package com.hw.langchain.agents.toolkits.spark.sql.base; | ||
|
||
import com.hw.langchain.agents.agent.Agent; | ||
import com.hw.langchain.agents.agent.AgentExecutor; | ||
import com.hw.langchain.agents.mrkl.base.ZeroShotAgent; | ||
import com.hw.langchain.agents.toolkits.spark.sql.toolkit.SparkSqlToolkit; | ||
import com.hw.langchain.base.language.BaseLanguageModel; | ||
import com.hw.langchain.chains.llm.LLMChain; | ||
import com.hw.langchain.prompts.prompt.PromptTemplate; | ||
import com.hw.langchain.tools.base.BaseTool; | ||
|
||
import java.util.List; | ||
import java.util.Map; | ||
|
||
import static com.hw.langchain.agents.mrkl.prompt.Prompt.FORMAT_INSTRUCTIONS; | ||
import static com.hw.langchain.agents.toolkits.spark.sql.prompt.Prompt.SQL_PREFIX; | ||
import static com.hw.langchain.agents.toolkits.spark.sql.prompt.Prompt.SQL_SUFFIX; | ||
import static com.hw.langchain.prompts.utils.FormatUtils.formatTemplate; | ||
|
||
/** | ||
* @author HamaWhite | ||
*/ | ||
public class SparkSqlAgent { | ||
|
||
private SparkSqlAgent() { | ||
// private constructor to hide the implicit public one | ||
throw new IllegalStateException("Utility class"); | ||
} | ||
|
||
/** | ||
* Construct a Spark SQL agent from an LLM and tools. | ||
*/ | ||
public static AgentExecutor createSparkSqlAgent(BaseLanguageModel llm, SparkSqlToolkit toolkit) { | ||
return createSparkSqlAgent(llm, toolkit, SQL_PREFIX, SQL_SUFFIX, FORMAT_INSTRUCTIONS, null, 10, 15, null, "force"); | ||
} | ||
|
||
/** | ||
* Construct a Spark SQL agent from an LLM and tools. | ||
*/ | ||
@SuppressWarnings("all") | ||
public static AgentExecutor createSparkSqlAgent( | ||
BaseLanguageModel llm, | ||
SparkSqlToolkit toolkit, | ||
String prefix, | ||
String suffix, | ||
String formatInstructions, | ||
List<String> inputVariables, | ||
int topK, | ||
Integer maxIterations, | ||
Float maxExecutionTime, | ||
String earlyStoppingMethod | ||
) { | ||
List<BaseTool> tools = toolkit.getTools(); | ||
prefix = formatTemplate(prefix, Map.of("top_k", topK)); | ||
|
||
PromptTemplate prompt = ZeroShotAgent.createPrompt(tools, prefix, suffix, formatInstructions, inputVariables); | ||
LLMChain llmChain = new LLMChain(llm, prompt); | ||
|
||
List<String> toolNames = tools.stream().map(BaseTool::getName).toList(); | ||
Agent agent = new ZeroShotAgent(llmChain, toolNames); | ||
|
||
return AgentExecutor.builder() | ||
.agent(agent) | ||
.tools(tools) | ||
.maxIterations(maxIterations) | ||
.maxExecutionTime(maxExecutionTime) | ||
.earlyStoppingMethod(earlyStoppingMethod) | ||
.build(); | ||
} | ||
} |
34 changes: 34 additions & 0 deletions
34
...ngchain-spark/src/main/java/com/hw/langchain/agents/toolkits/spark/sql/prompt/Prompt.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
package com.hw.langchain.agents.toolkits.spark.sql.prompt; | ||
|
||
/** | ||
* @author HamaWhite | ||
*/ | ||
public class Prompt { | ||
|
||
private Prompt() { | ||
// private constructor to hide the implicit public one | ||
throw new IllegalStateException("Utility class"); | ||
} | ||
|
||
public static final String SQL_PREFIX = """ | ||
You are an agent designed to interact with Spark SQL. | ||
Given an input question, create a syntactically correct Spark SQL query to run, then look at the results of the query and return the answer. | ||
Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most {top_k} results. | ||
You can order the results by a relevant column to return the most interesting examples in the database. | ||
Never query for all the columns from a specific table, only ask for the relevant columns given the question. | ||
You have access to tools for interacting with the database. | ||
Only use the below tools. Only use the information returned by the below tools to construct your final answer. | ||
You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again. | ||
DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database. | ||
If the question does not seem related to the database, just return "I don't know" as the answer. | ||
"""; | ||
|
||
public static final String SQL_SUFFIX = """ | ||
Begin! | ||
Question: {input} | ||
Thought: I should look at the tables in the database to see what I can query. | ||
{agent_scratchpad}"""; | ||
} |
39 changes: 39 additions & 0 deletions
39
...ark/src/main/java/com/hw/langchain/agents/toolkits/spark/sql/toolkit/SparkSqlToolkit.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
package com.hw.langchain.agents.toolkits.spark.sql.toolkit; | ||
|
||
import com.hw.langchain.agents.toolkits.base.BaseToolkit; | ||
import com.hw.langchain.base.language.BaseLanguageModel; | ||
import com.hw.langchain.tools.base.BaseTool; | ||
import com.hw.langchain.tools.spark.sql.tool.InfoSparkSQLTool; | ||
import com.hw.langchain.tools.spark.sql.tool.ListSparkSqlTool; | ||
import com.hw.langchain.tools.spark.sql.tool.QueryCheckerTool; | ||
import com.hw.langchain.tools.spark.sql.tool.QuerySparkSqlTool; | ||
import com.hw.langchain.utilities.spark.sql.SparkSql; | ||
|
||
import java.util.List; | ||
|
||
/** | ||
* Toolkit for interacting with Spark SQL. | ||
* | ||
* @author HamaWhite | ||
*/ | ||
public class SparkSqlToolkit implements BaseToolkit { | ||
|
||
private final SparkSql db; | ||
|
||
private final BaseLanguageModel llm; | ||
|
||
public SparkSqlToolkit(SparkSql db, BaseLanguageModel llm) { | ||
this.db = db; | ||
this.llm = llm; | ||
} | ||
|
||
@Override | ||
public List<BaseTool> getTools() { | ||
return List.of( | ||
new QuerySparkSqlTool(db), | ||
new InfoSparkSQLTool(db), | ||
new ListSparkSqlTool(db), | ||
new QueryCheckerTool(db, llm) | ||
); | ||
} | ||
} |
28 changes: 28 additions & 0 deletions
28
...bigdata/langchain-spark/src/main/java/com/hw/langchain/tools/spark/sql/prompt/Prompt.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
package com.hw.langchain.tools.spark.sql.prompt; | ||
|
||
|
||
/** | ||
* @author HamaWhite | ||
*/ | ||
public class Prompt { | ||
|
||
private Prompt() { | ||
// private constructor to hide the implicit public one | ||
throw new IllegalStateException("Utility class"); | ||
} | ||
|
||
public static final String QUERY_CHECKER = """ | ||
{query} | ||
Double check the Spark SQL query above for common mistakes, including: | ||
- Using NOT IN with NULL values | ||
- Using UNION when UNION ALL should have been used | ||
- Using BETWEEN for exclusive ranges | ||
- Data type mismatch in predicates | ||
- Properly quoting identifiers | ||
- Using the correct number of arguments for functions | ||
- Casting to the correct data type | ||
- Using the proper columns for joins | ||
If there are any of the above mistakes, rewrite the query. If there are no mistakes, just reproduce the original query."""; | ||
|
||
} |
21 changes: 21 additions & 0 deletions
21
...langchain-spark/src/main/java/com/hw/langchain/tools/spark/sql/tool/BaseSparkSqlTool.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
package com.hw.langchain.tools.spark.sql.tool; | ||
|
||
import com.hw.langchain.tools.base.BaseTool; | ||
import com.hw.langchain.utilities.spark.sql.SparkSql; | ||
import lombok.EqualsAndHashCode; | ||
|
||
/** | ||
* Base tool for interacting with Spark SQL. | ||
* | ||
* @author HamaWhite | ||
*/ | ||
@EqualsAndHashCode(callSuper = true) | ||
public abstract class BaseSparkSqlTool extends BaseTool { | ||
|
||
protected final SparkSql db; | ||
|
||
protected BaseSparkSqlTool(SparkSql db, String name, String description) { | ||
super(name, description); | ||
this.db = db; | ||
} | ||
} |
36 changes: 36 additions & 0 deletions
36
...langchain-spark/src/main/java/com/hw/langchain/tools/spark/sql/tool/InfoSparkSQLTool.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
package com.hw.langchain.tools.spark.sql.tool; | ||
|
||
import com.hw.langchain.utilities.spark.sql.SparkSql; | ||
import lombok.EqualsAndHashCode; | ||
|
||
import java.util.Map; | ||
import java.util.Set; | ||
|
||
/** | ||
* Tool for getting metadata about a Spark SQL. | ||
* | ||
* @author HamaWhite | ||
*/ | ||
@EqualsAndHashCode(callSuper = true) | ||
public class InfoSparkSQLTool extends BaseSparkSqlTool { | ||
|
||
private static final String NAME = "schema_sql_db"; | ||
private static final String DESCRIPTION = """ | ||
Input to this tool is a comma-separated list of tables, output is the schema and sample rows for those tables. | ||
Be sure that the tables actually exist by calling list_tables_sql_db first! | ||
Example Input: "table1, table2, table3" | ||
"""; | ||
|
||
public InfoSparkSQLTool(SparkSql db) { | ||
super(db, NAME, DESCRIPTION); | ||
} | ||
|
||
/** | ||
* Get the schema for tables in a comma-separated list. | ||
*/ | ||
@Override | ||
public Object innerRun(String query, Map<String, Object> kwargs) { | ||
return db.getTableInfoNoThrow(Set.of(query.split(", "))); | ||
} | ||
} |
30 changes: 30 additions & 0 deletions
30
...langchain-spark/src/main/java/com/hw/langchain/tools/spark/sql/tool/ListSparkSqlTool.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
package com.hw.langchain.tools.spark.sql.tool; | ||
|
||
import com.hw.langchain.utilities.spark.sql.SparkSql; | ||
import lombok.EqualsAndHashCode; | ||
|
||
import java.util.Map; | ||
|
||
/** | ||
* Tool for getting tables names. | ||
* | ||
* @author HamaWhite | ||
*/ | ||
@EqualsAndHashCode(callSuper = true) | ||
public class ListSparkSqlTool extends BaseSparkSqlTool { | ||
|
||
private static final String NAME = "list_tables_sql_db"; | ||
private static final String DESCRIPTION = "Input is an empty string, output is a comma separated list of tables in the Spark SQL."; | ||
|
||
public ListSparkSqlTool(SparkSql db) { | ||
super(db, NAME, DESCRIPTION); | ||
} | ||
|
||
/** | ||
* Get the schema for a specific table. | ||
*/ | ||
@Override | ||
public Object innerRun(String query, Map<String, Object> kwargs) { | ||
return String.join(", ",db.getUsableTableNames()); | ||
} | ||
} |
50 changes: 50 additions & 0 deletions
50
...langchain-spark/src/main/java/com/hw/langchain/tools/spark/sql/tool/QueryCheckerTool.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
package com.hw.langchain.tools.spark.sql.tool; | ||
|
||
import com.hw.langchain.base.language.BaseLanguageModel; | ||
import com.hw.langchain.chains.llm.LLMChain; | ||
import com.hw.langchain.prompts.prompt.PromptTemplate; | ||
import com.hw.langchain.utilities.spark.sql.SparkSql; | ||
import lombok.EqualsAndHashCode; | ||
|
||
import java.util.List; | ||
import java.util.Map; | ||
|
||
import static com.hw.langchain.tools.spark.sql.prompt.Prompt.QUERY_CHECKER; | ||
|
||
/** | ||
* Use an LLM to check if a query is correct. | ||
* | ||
* @author HamaWhite | ||
*/ | ||
@EqualsAndHashCode(callSuper = true) | ||
public class QueryCheckerTool extends BaseSparkSqlTool { | ||
|
||
private final BaseLanguageModel llm; | ||
|
||
private LLMChain llmChain; | ||
|
||
private static final String NAME = "query_checker_sql_db"; | ||
private static final String DESCRIPTION = """ | ||
Use this tool to double check if your query is correct before executing it. | ||
Always use this tool before executing a query with query_sql_db! | ||
"""; | ||
|
||
public QueryCheckerTool(SparkSql db, BaseLanguageModel llm) { | ||
super(db, NAME, DESCRIPTION); | ||
this.llm = llm; | ||
|
||
initializeLlmChain(); | ||
} | ||
|
||
private void initializeLlmChain() { | ||
llmChain = new LLMChain(llm, new PromptTemplate(List.of("query"), QUERY_CHECKER)); | ||
} | ||
|
||
/** | ||
* Use the LLM to check the query. | ||
*/ | ||
@Override | ||
public Object innerRun(String query, Map<String, Object> kwargs) { | ||
return llmChain.predict(Map.of("query", query)); | ||
} | ||
} |
Oops, something went wrong.