Skip to content

Commit

Permalink
Merge branch 'feature-bigdata' into dev
Browse files Browse the repository at this point in the history
  • Loading branch information
HamaWhiteGG committed Jul 22, 2023
2 parents eceb312 + d96a0c2 commit 4c00a45
Show file tree
Hide file tree
Showing 26 changed files with 1,748 additions and 47 deletions.
892 changes: 892 additions & 0 deletions docs/extras/modules/titanic.csv

Large diffs are not rendered by default.

21 changes: 21 additions & 0 deletions langchain-bigdata/langchain-flink/pom.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>io.github.hamawhitegg</groupId>
<artifactId>langchain-bigdata</artifactId>
<version>0.1.9</version>
</parent>

<artifactId>langchain-flink</artifactId>

<dependencies>
<dependency>
<groupId>io.github.hamawhitegg</groupId>
<artifactId>langchain-core</artifactId>
<version>${project.version}</version>
</dependency>
</dependencies>

</project>
46 changes: 46 additions & 0 deletions langchain-bigdata/langchain-spark/pom.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>io.github.hamawhitegg</groupId>
<artifactId>langchain-bigdata</artifactId>
<version>0.1.9</version>
</parent>

<artifactId>langchain-spark</artifactId>

<dependencies>
<dependency>
<groupId>io.github.hamawhitegg</groupId>
<artifactId>langchain-core</artifactId>
<version>${project.version}</version>
</dependency>

<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-core_${scala.binary.version}</artifactId>
</dependency>

<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-sql_${scala.binary.version}</artifactId>
</dependency>

<dependency>
<groupId>org.projectlombok</groupId>
<artifactId>lombok</artifactId>
</dependency>

<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter-api</artifactId>
</dependency>

<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter-engine</artifactId>
</dependency>
</dependencies>

</project>
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package com.hw.langchain.agents.toolkits.base;

import com.hw.langchain.tools.base.BaseTool;

import java.util.List;

/**
* Base Toolkit representing a collection of related tools.
*
* @author HamaWhite
*/
public interface BaseToolkit {

/**
* Get the tools in the toolkit.
*
* @return a list of tools in the toolkit.
*/
List<BaseTool> getTools();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
package com.hw.langchain.agents.toolkits.spark.sql.base;

import com.hw.langchain.agents.agent.Agent;
import com.hw.langchain.agents.agent.AgentExecutor;
import com.hw.langchain.agents.mrkl.base.ZeroShotAgent;
import com.hw.langchain.agents.toolkits.spark.sql.toolkit.SparkSqlToolkit;
import com.hw.langchain.base.language.BaseLanguageModel;
import com.hw.langchain.chains.llm.LLMChain;
import com.hw.langchain.prompts.prompt.PromptTemplate;
import com.hw.langchain.tools.base.BaseTool;

import java.util.List;
import java.util.Map;

import static com.hw.langchain.agents.mrkl.prompt.Prompt.FORMAT_INSTRUCTIONS;
import static com.hw.langchain.agents.toolkits.spark.sql.prompt.Prompt.SQL_PREFIX;
import static com.hw.langchain.agents.toolkits.spark.sql.prompt.Prompt.SQL_SUFFIX;
import static com.hw.langchain.prompts.utils.FormatUtils.formatTemplate;

/**
* @author HamaWhite
*/
public class SparkSqlAgent {

private SparkSqlAgent() {
// private constructor to hide the implicit public one
throw new IllegalStateException("Utility class");
}

/**
* Construct a Spark SQL agent from an LLM and tools.
*/
public static AgentExecutor createSparkSqlAgent(BaseLanguageModel llm, SparkSqlToolkit toolkit) {
return createSparkSqlAgent(llm, toolkit, SQL_PREFIX, SQL_SUFFIX, FORMAT_INSTRUCTIONS, null, 10, 15, null, "force");
}

/**
* Construct a Spark SQL agent from an LLM and tools.
*/
@SuppressWarnings("all")
public static AgentExecutor createSparkSqlAgent(
BaseLanguageModel llm,
SparkSqlToolkit toolkit,
String prefix,
String suffix,
String formatInstructions,
List<String> inputVariables,
int topK,
Integer maxIterations,
Float maxExecutionTime,
String earlyStoppingMethod
) {
List<BaseTool> tools = toolkit.getTools();
prefix = formatTemplate(prefix, Map.of("top_k", topK));

PromptTemplate prompt = ZeroShotAgent.createPrompt(tools, prefix, suffix, formatInstructions, inputVariables);
LLMChain llmChain = new LLMChain(llm, prompt);

List<String> toolNames = tools.stream().map(BaseTool::getName).toList();
Agent agent = new ZeroShotAgent(llmChain, toolNames);

return AgentExecutor.builder()
.agent(agent)
.tools(tools)
.maxIterations(maxIterations)
.maxExecutionTime(maxExecutionTime)
.earlyStoppingMethod(earlyStoppingMethod)
.build();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
package com.hw.langchain.agents.toolkits.spark.sql.prompt;

/**
* @author HamaWhite
*/
public class Prompt {

private Prompt() {
// private constructor to hide the implicit public one
throw new IllegalStateException("Utility class");
}

public static final String SQL_PREFIX = """
You are an agent designed to interact with Spark SQL.
Given an input question, create a syntactically correct Spark SQL query to run, then look at the results of the query and return the answer.
Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most {top_k} results.
You can order the results by a relevant column to return the most interesting examples in the database.
Never query for all the columns from a specific table, only ask for the relevant columns given the question.
You have access to tools for interacting with the database.
Only use the below tools. Only use the information returned by the below tools to construct your final answer.
You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again.
DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.
If the question does not seem related to the database, just return "I don't know" as the answer.
""";

public static final String SQL_SUFFIX = """
Begin!
Question: {input}
Thought: I should look at the tables in the database to see what I can query.
{agent_scratchpad}""";
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package com.hw.langchain.agents.toolkits.spark.sql.toolkit;

import com.hw.langchain.agents.toolkits.base.BaseToolkit;
import com.hw.langchain.base.language.BaseLanguageModel;
import com.hw.langchain.tools.base.BaseTool;
import com.hw.langchain.tools.spark.sql.tool.InfoSparkSQLTool;
import com.hw.langchain.tools.spark.sql.tool.ListSparkSqlTool;
import com.hw.langchain.tools.spark.sql.tool.QueryCheckerTool;
import com.hw.langchain.tools.spark.sql.tool.QuerySparkSqlTool;
import com.hw.langchain.utilities.spark.sql.SparkSql;

import java.util.List;

/**
* Toolkit for interacting with Spark SQL.
*
* @author HamaWhite
*/
public class SparkSqlToolkit implements BaseToolkit {

private final SparkSql db;

private final BaseLanguageModel llm;

public SparkSqlToolkit(SparkSql db, BaseLanguageModel llm) {
this.db = db;
this.llm = llm;
}

@Override
public List<BaseTool> getTools() {
return List.of(
new QuerySparkSqlTool(db),
new InfoSparkSQLTool(db),
new ListSparkSqlTool(db),
new QueryCheckerTool(db, llm)
);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package com.hw.langchain.tools.spark.sql.prompt;


/**
* @author HamaWhite
*/
public class Prompt {

private Prompt() {
// private constructor to hide the implicit public one
throw new IllegalStateException("Utility class");
}

public static final String QUERY_CHECKER = """
{query}
Double check the Spark SQL query above for common mistakes, including:
- Using NOT IN with NULL values
- Using UNION when UNION ALL should have been used
- Using BETWEEN for exclusive ranges
- Data type mismatch in predicates
- Properly quoting identifiers
- Using the correct number of arguments for functions
- Casting to the correct data type
- Using the proper columns for joins
If there are any of the above mistakes, rewrite the query. If there are no mistakes, just reproduce the original query.""";

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package com.hw.langchain.tools.spark.sql.tool;

import com.hw.langchain.tools.base.BaseTool;
import com.hw.langchain.utilities.spark.sql.SparkSql;
import lombok.EqualsAndHashCode;

/**
* Base tool for interacting with Spark SQL.
*
* @author HamaWhite
*/
@EqualsAndHashCode(callSuper = true)
public abstract class BaseSparkSqlTool extends BaseTool {

protected final SparkSql db;

protected BaseSparkSqlTool(SparkSql db, String name, String description) {
super(name, description);
this.db = db;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package com.hw.langchain.tools.spark.sql.tool;

import com.hw.langchain.utilities.spark.sql.SparkSql;
import lombok.EqualsAndHashCode;

import java.util.Map;
import java.util.Set;

/**
* Tool for getting metadata about a Spark SQL.
*
* @author HamaWhite
*/
@EqualsAndHashCode(callSuper = true)
public class InfoSparkSQLTool extends BaseSparkSqlTool {

private static final String NAME = "schema_sql_db";
private static final String DESCRIPTION = """
Input to this tool is a comma-separated list of tables, output is the schema and sample rows for those tables.
Be sure that the tables actually exist by calling list_tables_sql_db first!
Example Input: "table1, table2, table3"
""";

public InfoSparkSQLTool(SparkSql db) {
super(db, NAME, DESCRIPTION);
}

/**
* Get the schema for tables in a comma-separated list.
*/
@Override
public Object innerRun(String query, Map<String, Object> kwargs) {
return db.getTableInfoNoThrow(Set.of(query.split(", ")));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
package com.hw.langchain.tools.spark.sql.tool;

import com.hw.langchain.utilities.spark.sql.SparkSql;
import lombok.EqualsAndHashCode;

import java.util.Map;

/**
* Tool for getting tables names.
*
* @author HamaWhite
*/
@EqualsAndHashCode(callSuper = true)
public class ListSparkSqlTool extends BaseSparkSqlTool {

private static final String NAME = "list_tables_sql_db";
private static final String DESCRIPTION = "Input is an empty string, output is a comma separated list of tables in the Spark SQL.";

public ListSparkSqlTool(SparkSql db) {
super(db, NAME, DESCRIPTION);
}

/**
* Get the schema for a specific table.
*/
@Override
public Object innerRun(String query, Map<String, Object> kwargs) {
return String.join(", ",db.getUsableTableNames());
}
}
Loading

0 comments on commit 4c00a45

Please sign in to comment.