Skip to content

Commit

Permalink
Support Spark SQL Agent HamaWhiteGG#53
Browse files Browse the repository at this point in the history
  • Loading branch information
HamaWhiteGG committed Jul 22, 2023
1 parent 631646f commit d96a0c2
Show file tree
Hide file tree
Showing 24 changed files with 1,670 additions and 47 deletions.
892 changes: 892 additions & 0 deletions docs/extras/modules/titanic.csv

Large diffs are not rendered by default.

25 changes: 25 additions & 0 deletions langchain-bigdata/langchain-spark/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,31 @@
<artifactId>langchain-core</artifactId>
<version>${project.version}</version>
</dependency>

<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-core_${scala.binary.version}</artifactId>
</dependency>

<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-sql_${scala.binary.version}</artifactId>
</dependency>

<dependency>
<groupId>org.projectlombok</groupId>
<artifactId>lombok</artifactId>
</dependency>

<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter-api</artifactId>
</dependency>

<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter-engine</artifactId>
</dependency>
</dependencies>

</project>
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package com.hw.langchain.agents.toolkits.base;

import com.hw.langchain.tools.base.BaseTool;

import java.util.List;

/**
* Base Toolkit representing a collection of related tools.
*
* @author HamaWhite
*/
public interface BaseToolkit {

/**
* Get the tools in the toolkit.
*
* @return a list of tools in the toolkit.
*/
List<BaseTool> getTools();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
package com.hw.langchain.agents.toolkits.spark.sql.base;

import com.hw.langchain.agents.agent.Agent;
import com.hw.langchain.agents.agent.AgentExecutor;
import com.hw.langchain.agents.mrkl.base.ZeroShotAgent;
import com.hw.langchain.agents.toolkits.spark.sql.toolkit.SparkSqlToolkit;
import com.hw.langchain.base.language.BaseLanguageModel;
import com.hw.langchain.chains.llm.LLMChain;
import com.hw.langchain.prompts.prompt.PromptTemplate;
import com.hw.langchain.tools.base.BaseTool;

import java.util.List;
import java.util.Map;

import static com.hw.langchain.agents.mrkl.prompt.Prompt.FORMAT_INSTRUCTIONS;
import static com.hw.langchain.agents.toolkits.spark.sql.prompt.Prompt.SQL_PREFIX;
import static com.hw.langchain.agents.toolkits.spark.sql.prompt.Prompt.SQL_SUFFIX;
import static com.hw.langchain.prompts.utils.FormatUtils.formatTemplate;

/**
* @author HamaWhite
*/
public class SparkSqlAgent {

private SparkSqlAgent() {
// private constructor to hide the implicit public one
throw new IllegalStateException("Utility class");
}

/**
* Construct a Spark SQL agent from an LLM and tools.
*/
public static AgentExecutor createSparkSqlAgent(BaseLanguageModel llm, SparkSqlToolkit toolkit) {
return createSparkSqlAgent(llm, toolkit, SQL_PREFIX, SQL_SUFFIX, FORMAT_INSTRUCTIONS, null, 10, 15, null, "force");
}

/**
* Construct a Spark SQL agent from an LLM and tools.
*/
@SuppressWarnings("all")
public static AgentExecutor createSparkSqlAgent(
BaseLanguageModel llm,
SparkSqlToolkit toolkit,
String prefix,
String suffix,
String formatInstructions,
List<String> inputVariables,
int topK,
Integer maxIterations,
Float maxExecutionTime,
String earlyStoppingMethod
) {
List<BaseTool> tools = toolkit.getTools();
prefix = formatTemplate(prefix, Map.of("top_k", topK));

PromptTemplate prompt = ZeroShotAgent.createPrompt(tools, prefix, suffix, formatInstructions, inputVariables);
LLMChain llmChain = new LLMChain(llm, prompt);

List<String> toolNames = tools.stream().map(BaseTool::getName).toList();
Agent agent = new ZeroShotAgent(llmChain, toolNames);

return AgentExecutor.builder()
.agent(agent)
.tools(tools)
.maxIterations(maxIterations)
.maxExecutionTime(maxExecutionTime)
.earlyStoppingMethod(earlyStoppingMethod)
.build();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
package com.hw.langchain.agents.toolkits.spark.sql.prompt;

/**
* @author HamaWhite
*/
public class Prompt {

private Prompt() {
// private constructor to hide the implicit public one
throw new IllegalStateException("Utility class");
}

public static final String SQL_PREFIX = """
You are an agent designed to interact with Spark SQL.
Given an input question, create a syntactically correct Spark SQL query to run, then look at the results of the query and return the answer.
Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most {top_k} results.
You can order the results by a relevant column to return the most interesting examples in the database.
Never query for all the columns from a specific table, only ask for the relevant columns given the question.
You have access to tools for interacting with the database.
Only use the below tools. Only use the information returned by the below tools to construct your final answer.
You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again.
DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.
If the question does not seem related to the database, just return "I don't know" as the answer.
""";

public static final String SQL_SUFFIX = """
Begin!
Question: {input}
Thought: I should look at the tables in the database to see what I can query.
{agent_scratchpad}""";
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package com.hw.langchain.agents.toolkits.spark.sql.toolkit;

import com.hw.langchain.agents.toolkits.base.BaseToolkit;
import com.hw.langchain.base.language.BaseLanguageModel;
import com.hw.langchain.tools.base.BaseTool;
import com.hw.langchain.tools.spark.sql.tool.InfoSparkSQLTool;
import com.hw.langchain.tools.spark.sql.tool.ListSparkSqlTool;
import com.hw.langchain.tools.spark.sql.tool.QueryCheckerTool;
import com.hw.langchain.tools.spark.sql.tool.QuerySparkSqlTool;
import com.hw.langchain.utilities.spark.sql.SparkSql;

import java.util.List;

/**
* Toolkit for interacting with Spark SQL.
*
* @author HamaWhite
*/
public class SparkSqlToolkit implements BaseToolkit {

private final SparkSql db;

private final BaseLanguageModel llm;

public SparkSqlToolkit(SparkSql db, BaseLanguageModel llm) {
this.db = db;
this.llm = llm;
}

@Override
public List<BaseTool> getTools() {
return List.of(
new QuerySparkSqlTool(db),
new InfoSparkSQLTool(db),
new ListSparkSqlTool(db),
new QueryCheckerTool(db, llm)
);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package com.hw.langchain.tools.spark.sql.prompt;


/**
* @author HamaWhite
*/
public class Prompt {

private Prompt() {
// private constructor to hide the implicit public one
throw new IllegalStateException("Utility class");
}

public static final String QUERY_CHECKER = """
{query}
Double check the Spark SQL query above for common mistakes, including:
- Using NOT IN with NULL values
- Using UNION when UNION ALL should have been used
- Using BETWEEN for exclusive ranges
- Data type mismatch in predicates
- Properly quoting identifiers
- Using the correct number of arguments for functions
- Casting to the correct data type
- Using the proper columns for joins
If there are any of the above mistakes, rewrite the query. If there are no mistakes, just reproduce the original query.""";

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package com.hw.langchain.tools.spark.sql.tool;

import com.hw.langchain.tools.base.BaseTool;
import com.hw.langchain.utilities.spark.sql.SparkSql;
import lombok.EqualsAndHashCode;

/**
* Base tool for interacting with Spark SQL.
*
* @author HamaWhite
*/
@EqualsAndHashCode(callSuper = true)
public abstract class BaseSparkSqlTool extends BaseTool {

protected final SparkSql db;

protected BaseSparkSqlTool(SparkSql db, String name, String description) {
super(name, description);
this.db = db;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package com.hw.langchain.tools.spark.sql.tool;

import com.hw.langchain.utilities.spark.sql.SparkSql;
import lombok.EqualsAndHashCode;

import java.util.Map;
import java.util.Set;

/**
* Tool for getting metadata about a Spark SQL.
*
* @author HamaWhite
*/
@EqualsAndHashCode(callSuper = true)
public class InfoSparkSQLTool extends BaseSparkSqlTool {

private static final String NAME = "schema_sql_db";
private static final String DESCRIPTION = """
Input to this tool is a comma-separated list of tables, output is the schema and sample rows for those tables.
Be sure that the tables actually exist by calling list_tables_sql_db first!
Example Input: "table1, table2, table3"
""";

public InfoSparkSQLTool(SparkSql db) {
super(db, NAME, DESCRIPTION);
}

/**
* Get the schema for tables in a comma-separated list.
*/
@Override
public Object innerRun(String query, Map<String, Object> kwargs) {
return db.getTableInfoNoThrow(Set.of(query.split(", ")));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
package com.hw.langchain.tools.spark.sql.tool;

import com.hw.langchain.utilities.spark.sql.SparkSql;
import lombok.EqualsAndHashCode;

import java.util.Map;

/**
* Tool for getting tables names.
*
* @author HamaWhite
*/
@EqualsAndHashCode(callSuper = true)
public class ListSparkSqlTool extends BaseSparkSqlTool {

private static final String NAME = "list_tables_sql_db";
private static final String DESCRIPTION = "Input is an empty string, output is a comma separated list of tables in the Spark SQL.";

public ListSparkSqlTool(SparkSql db) {
super(db, NAME, DESCRIPTION);
}

/**
* Get the schema for a specific table.
*/
@Override
public Object innerRun(String query, Map<String, Object> kwargs) {
return String.join(", ",db.getUsableTableNames());
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package com.hw.langchain.tools.spark.sql.tool;

import com.hw.langchain.base.language.BaseLanguageModel;
import com.hw.langchain.chains.llm.LLMChain;
import com.hw.langchain.prompts.prompt.PromptTemplate;
import com.hw.langchain.utilities.spark.sql.SparkSql;
import lombok.EqualsAndHashCode;

import java.util.List;
import java.util.Map;

import static com.hw.langchain.tools.spark.sql.prompt.Prompt.QUERY_CHECKER;

/**
* Use an LLM to check if a query is correct.
*
* @author HamaWhite
*/
@EqualsAndHashCode(callSuper = true)
public class QueryCheckerTool extends BaseSparkSqlTool {

private final BaseLanguageModel llm;

private LLMChain llmChain;

private static final String NAME = "query_checker_sql_db";
private static final String DESCRIPTION = """
Use this tool to double check if your query is correct before executing it.
Always use this tool before executing a query with query_sql_db!
""";

public QueryCheckerTool(SparkSql db, BaseLanguageModel llm) {
super(db, NAME, DESCRIPTION);
this.llm = llm;

initializeLlmChain();
}

private void initializeLlmChain() {
llmChain = new LLMChain(llm, new PromptTemplate(List.of("query"), QUERY_CHECKER));
}

/**
* Use the LLM to check the query.
*/
@Override
public Object innerRun(String query, Map<String, Object> kwargs) {
return llmChain.predict(Map.of("query", query));
}
}
Loading

0 comments on commit d96a0c2

Please sign in to comment.