Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

EmbeddingStore (Metadata) Filter API #610

Merged
merged 37 commits into from
Mar 8, 2024
Merged
Show file tree
Hide file tree
Changes from 35 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
77ab190
Draft: metadata filtering
langchain4j Feb 7, 2024
b103441
removed unnecessary grouping
langchain4j Feb 7, 2024
7dd49b2
removed extension from Predicate<Metadata>
langchain4j Feb 7, 2024
65d6024
fix sql parser
langchain4j Feb 7, 2024
fd56a1b
Merge branch 'main' into metadata_filtering
langchain4j Feb 26, 2024
e7e9d3e
WIP: metadata filtering
langchain4j Feb 27, 2024
9a8bc95
WIP: metadata filtering
langchain4j Feb 27, 2024
54c2630
WIP: metadata filtering
langchain4j Feb 29, 2024
f146098
WIP: metadata filtering
langchain4j Feb 29, 2024
7bc4e01
WIP: metadata filtering
langchain4j Feb 29, 2024
eea5a4f
WIP: metadata filtering
langchain4j Feb 29, 2024
4960050
WIP: metadata filtering
langchain4j Feb 29, 2024
95fc6a3
WIP: metadata filtering
langchain4j Feb 29, 2024
39995d2
WIP: metadata filtering
langchain4j Feb 29, 2024
fe1265f
WIP: metadata filtering
langchain4j Feb 29, 2024
9eca8c5
WIP: metadata filtering
langchain4j Feb 29, 2024
7f0ad07
WIP: metadata filtering
langchain4j Feb 29, 2024
6efe55b
WIP: metadata filtering
langchain4j Feb 29, 2024
3343014
WIP: metadata filtering
langchain4j Feb 29, 2024
bc24317
WIP: metadata filtering
langchain4j Mar 5, 2024
6ce49e7
WIP: metadata filtering
langchain4j Mar 5, 2024
3562a63
Merge branch 'main' into metadata_filtering
langchain4j Mar 5, 2024
8627be0
WIP: metadata filtering
langchain4j Mar 5, 2024
d50ffa8
WIP: metadata filtering
langchain4j Mar 5, 2024
dcd90df
WIP: metadata filtering
langchain4j Mar 5, 2024
d27c615
WIP: metadata filtering
langchain4j Mar 5, 2024
466dab4
WIP: metadata filtering
langchain4j Mar 7, 2024
52c74e1
WIP: metadata filtering
langchain4j Mar 7, 2024
e80696a
WIP: metadata filtering
langchain4j Mar 7, 2024
5d262e2
WIP: metadata filtering
langchain4j Mar 7, 2024
c08e36f
WIP: metadata filtering
langchain4j Mar 7, 2024
266b67e
WIP: metadata filtering
langchain4j Mar 7, 2024
3d5049b
Merge branch 'main' into metadata_filtering
langchain4j Mar 7, 2024
ee139af
Upgrade langchain4j-opensearch to use latest version of opensearch-ja…
langchain4j Mar 7, 2024
245b60d
WIP: metadata filtering
langchain4j Mar 7, 2024
eeada28
Merge branch 'main' into metadata_filtering
langchain4j Mar 8, 2024
12a8872
WIP: metadata filtering
langchain4j Mar 8, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/main.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ jobs:
java_version: [8, 11, 17, 21]
include:
- java_version: '8'
included_modules: '-pl !code-execution-engines/langchain4j-code-execution-engine-graalvm-polyglot,!langchain4j-cassandra,!langchain4j-infinispan,!langchain4j-neo4j'
included_modules: '-pl !code-execution-engines/langchain4j-code-execution-engine-graalvm-polyglot,!langchain4j-cassandra,!langchain4j-infinispan,!langchain4j-neo4j,!langchain4j-opensearch'
- java_version: '11'
included_modules: '-pl !code-execution-engines/langchain4j-code-execution-engine-graalvm-polyglot,!langchain4j-infinispan,!langchain4j-neo4j'
- java_version: '17'
Expand Down
2 changes: 2 additions & 0 deletions .github/workflows/release.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ jobs:
AZURE_OPENAI_ENDPOINT: ${{ secrets.AZURE_OPENAI_ENDPOINT }}
AZURE_OPENAI_KEY: ${{ secrets.AZURE_OPENAI_KEY }}
COHERE_API_KEY: ${{ secrets.COHERE_API_KEY }}
ELASTICSEARCH_CLOUD_API_KEY: ${{ secrets.ELASTICSEARCH_CLOUD_API_KEY }}
ELASTICSEARCH_CLOUD_URL: ${{ secrets.ELASTICSEARCH_CLOUD_URL }}
GCP_CREDENTIALS_JSON: ${{ secrets.GCP_CREDENTIALS_JSON }}
GCP_LOCATION: ${{ secrets.GCP_LOCATION }}
GCP_PROJECT_ID: ${{ secrets.GCP_PROJECT_ID }}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>

<parent>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-parent</artifactId>
<version>0.28.0-SNAPSHOT</version>
<relativePath>../../langchain4j-parent/pom.xml</relativePath>
</parent>

<artifactId>langchain4j-embedding-store-filter-parser-sql</artifactId>
<name>LangChain4j :: Embedding Store Filter Parser :: SQL</name>

<dependencies>

<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-core</artifactId>
</dependency>

<dependency>
<groupId>com.github.jsqlparser</groupId>
<artifactId>jsqlparser</artifactId>
<version>4.8</version>
</dependency>

<dependency>
<groupId>org.projectlombok</groupId>
<artifactId>lombok</artifactId>
<scope>provided</scope>
</dependency>

<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter-engine</artifactId>
<scope>test</scope>
</dependency>

<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter-params</artifactId>
<scope>test</scope>
</dependency>

<dependency>
<groupId>org.assertj</groupId>
<artifactId>assertj-core</artifactId>
<scope>test</scope>
</dependency>

<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-open-ai</artifactId>
<scope>test</scope>
</dependency>

<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-ollama</artifactId>
<version>${project.version}</version>
<scope>test</scope>
</dependency>

<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-embeddings-all-minilm-l6-v2-q</artifactId>
<scope>test</scope>
</dependency>

<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-core</artifactId>
<version>${project.parent.version}</version>
<classifier>tests</classifier>
<type>test-jar</type>
<scope>test</scope>
</dependency>

<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-core</artifactId>
<scope>test</scope>
</dependency>

<dependency>
<groupId>org.tinylog</groupId>
<artifactId>tinylog-impl</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.tinylog</groupId>
<artifactId>slf4j-tinylog</artifactId>
<scope>test</scope>
</dependency>

</dependencies>

</project>
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package dev.langchain4j.store.embedding.filter.builder.sql;

import dev.langchain4j.Experimental;
import lombok.EqualsAndHashCode;
import lombok.ToString;

import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank;

@Experimental
@ToString
@EqualsAndHashCode
public class ColumnDefinition {

private final String name;
private final String type;
langchain4j marked this conversation as resolved.
Show resolved Hide resolved
private final String description;

public ColumnDefinition(String name, String type) {
this(name, type, null);
}

public ColumnDefinition(String name, String type, String description) {
this.name = ensureNotBlank(name, "name");
this.type = ensureNotBlank(type, "type");
this.description = description;
}

public String name() {
return name;
}

public String type() {
return type;
}

public String description() {
return description;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,232 @@
package dev.langchain4j.store.embedding.filter.builder.sql;

import dev.langchain4j.Experimental;
import dev.langchain4j.data.document.Metadata;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.input.Prompt;
import dev.langchain4j.model.input.PromptTemplate;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.rag.content.retriever.EmbeddingStoreContentRetriever;
import dev.langchain4j.rag.query.Query;
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.filter.Filter;
import dev.langchain4j.store.embedding.filter.parser.sql.SqlFilterParser;
import lombok.Builder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.HashMap;
import java.util.Map;

import static dev.langchain4j.internal.Utils.getOrDefault;
import static dev.langchain4j.internal.Utils.isNullOrBlank;
import static dev.langchain4j.internal.ValidationUtils.ensureNotNull;

/**
* Given a natural language {@link Query}, this class creates a suitable {@link Filter} using a language model.
* <br>
* This approach is also known as
* <a href="https://python.langchain.com/docs/modules/data_connection/retrievers/self_query">self-querying</a>.
* <br>
* It is useful for improving retrieval from an {@link EmbeddingStore} by narrowing down the search space.
* <br>
* For instance, if you have internal company documentation for multiple products in the same {@link EmbeddingStore}
* and want to search the documentation of a specific product without forcing the user to specify the
* {@link Filter} manually, you could use {@code LanguageModelSqlFilterBuilder} to automatically create the filter
* using a language model.
* <br>
* <br>
* First, describe the {@link Metadata} of your {@link TextSegment}
langchain4j marked this conversation as resolved.
Show resolved Hide resolved
* as if it were an SQL table using {@link TableDefinition}:
* <pre>
* TableDefinition tableDefinition = TableDefinition.builder()
* .name("documentation") // table name
* .addColumn("product", "VARCHAR", "one of [iPhone, iPad, MacBook]") // column name, column type, comment
* ... other relevant metadata keys (columns) ...
* .build();
* </pre>
* Then, create a {@code LanguageModelSqlFilterBuilder} by providing a language model and a {@link TableDefinition},
* and use it with {@link EmbeddingStoreContentRetriever}:
* <pre>
* LanguageModelSqlFilterBuilder sqlFilterBuilder = new LanguageModelSqlFilterBuilder(model, tableDefinition);
* ContentRetriever contentRetriever = EmbeddingStoreContentRetriever.builder()
* .embeddingStore(embeddingStore)
* .embeddingModel(embeddingModel)
* .dynamicFilter(sqlFilterBuilder::build)
* .build();
* </pre>
* When the user asks, for example, "How to make the screen of my phone brighter?", the language model will generate
* an SQL query like {@code SELECT * from documentation WHERE product = 'iPhone'}.
* <br>
* Then, {@link SqlFilterParser} will parse the generated SQL into the following {@link Filter} object:
* {@code key("product").eq("iPhone")}.
* <br>
* This filter will be applied during similarity search in the {@link EmbeddingStore}.
* This means that only those {@link TextSegment}s with a {@link Metadata} entry {@code product = "iPhone"}
* will be considered for the search.
* <br>
* <br>
* It is recommended to use a capable language model, such as gpt-3.5-turbo,
* or the smaller one but fine-tuned for the text-to-SQL task, such as <a href="https://huggingface.co/defog">SQLCoder</a>.
* SQLCoder is also available via <a href="https://ollama.com/library/sqlcoder">Ollama</a>.
* <br>
* The default {@link PromptTemplate} in this class is suited for SQLCoder, but should work fine with
* capable language models like gpt-3.5-turbo and better.
langchain4j marked this conversation as resolved.
Show resolved Hide resolved
* <br>
* You can override the default {@link PromptTemplate} using builder.
langchain4j marked this conversation as resolved.
Show resolved Hide resolved
* <br>
* <br>
* In case SQL parsing fails (e.g., the generated SQL is invalid or contains text in addition to the SQL statement),
* {@code LanguageModelSqlFilterBuilder} will first try to extract the valid SQL from the input string.
* If parsing fails again, it will return {@code null}, meaning no filtering will be applied during the search.
*/
@Experimental
public class LanguageModelSqlFilterBuilder {

private static final Logger log = LoggerFactory.getLogger(LanguageModelSqlFilterBuilder.class);

private static final PromptTemplate DEFAULT_PROMPT_TEMPLATE = PromptTemplate.from(
"### Instructions:\n" +
"Your task is to convert a question into a SQL query, given a Postgres database schema.\n" +
"Adhere to these rules:\n" +
"- **Deliberately go through the question and database schema word by word** to appropriately answer the question\n" +
"- **Use Table Aliases** to prevent ambiguity. For example, `SELECT table1.col1, table2.col1 FROM table1 JOIN table2 ON table1.id = table2.id`.\n" +
"- When creating a ratio, always cast the numerator as float\n" +
"\n" +
"### Input:\n" +
"Generate a SQL query that answers the question `{{query}}`.\n" +
"This query will run on a database whose schema is represented in this string:\n" +
"{{create_table_statement}}\n" +
"\n" +
"### Response:\n" +
"Based on your instructions, here is the SQL query I have generated to answer the question `{{query}}`:\n" +
"```sql"
);

protected final ChatLanguageModel chatLanguageModel;
protected final TableDefinition tableDefinition;
protected final String createTableStatement;
protected final PromptTemplate promptTemplate;
protected final SqlFilterParser sqlFilterParser;

public LanguageModelSqlFilterBuilder(ChatLanguageModel chatLanguageModel,
TableDefinition tableDefinition) {
this(chatLanguageModel, tableDefinition, DEFAULT_PROMPT_TEMPLATE, new SqlFilterParser());
}

@Builder
private LanguageModelSqlFilterBuilder(ChatLanguageModel chatLanguageModel,
TableDefinition tableDefinition,
PromptTemplate promptTemplate,
SqlFilterParser sqlFilterParser) {
this.chatLanguageModel = ensureNotNull(chatLanguageModel, "chatLanguageModel");
this.tableDefinition = ensureNotNull(tableDefinition, "tableDefinition");
this.createTableStatement = format(tableDefinition);
this.promptTemplate = getOrDefault(promptTemplate, DEFAULT_PROMPT_TEMPLATE);
this.sqlFilterParser = getOrDefault(sqlFilterParser, SqlFilterParser::new);
}

public Filter build(Query query) {

Prompt prompt = createPrompt(query);

Response<AiMessage> response = chatLanguageModel.generate(prompt.toUserMessage());

String generatedSql = response.content().text();

String cleanedSql = clean(generatedSql);
log.trace("Cleaned SQL: '{}'", cleanedSql);

try {
return sqlFilterParser.parse(cleanedSql);
} catch (Exception e) {
log.warn("Failed parsing the following SQL: '{}'", cleanedSql, e);
// TODO implement additional strategies (configurable):
// - feed the error to the LLM and retry
// - return predefined filter
// - return partial filter if the filter is composite and some parts were parsed successfully
// - etc
return fallback(query, generatedSql, cleanedSql, e);
langchain4j marked this conversation as resolved.
Show resolved Hide resolved
}
}

protected Prompt createPrompt(Query query) {
Map<String, Object> variables = new HashMap<>();
variables.put("create_table_statement", createTableStatement);
variables.put("query", query.text());
return promptTemplate.apply(variables);
}

protected String clean(String sql) {
return sql.trim();
}

protected Filter fallback(Query query, String generatedSql, String cleanedSql, Exception e) {

String extractedSql = extractSelectStatement(generatedSql);
if (isNullOrBlank(extractedSql)) {
log.trace("Cannot extract SQL, giving up");
return null;
}

try {
log.trace("Extracted SQL: '{}'", extractedSql);
return sqlFilterParser.parse(extractedSql);
} catch (Exception e2) {
log.warn("Failed parsing the following SQL, giving up: '{}'", extractedSql, e2);
return null;
}
}

protected String extractSelectStatement(String dirtySql) {
// TODO improve
if (dirtySql.contains("```sql")) {
for (String part : dirtySql.split("```sql")) {
if (part.toUpperCase().contains("SELECT") && part.toUpperCase().contains("WHERE")) {
return part.split("```")[0].trim();
}
}
} else if (dirtySql.contains("```")) {
for (String part : dirtySql.split("```")) {
if (part.toUpperCase().contains("SELECT") && part.toUpperCase().contains("WHERE")) {
return part.split("```")[0].trim();
}
}
} else {
for (String part : dirtySql.split("SELECT")) {
if (part.toUpperCase().contains("WHERE")) {
if (part.contains("\n")) {
for (String part2 : part.split("\n")) {
if (part2.toUpperCase().contains("WHERE")) {
return "SELECT " + part2.trim();
}
}
} else {
return "SELECT " + part.trim();
}
}
}
}
return null;
langchain4j marked this conversation as resolved.
Show resolved Hide resolved
}

protected String format(TableDefinition tableDefinition) {
StringBuilder createTableStatement = new StringBuilder();
createTableStatement.append(String.format("CREATE TABLE %s (\n", tableDefinition.name()));
for (ColumnDefinition columnDefinition : tableDefinition.columns()) {
createTableStatement.append(String.format(" %s %s,", columnDefinition.name(), columnDefinition.type()));
if (!isNullOrBlank(columnDefinition.description())) {
createTableStatement.append(String.format(" -- %s", columnDefinition.description()));
}
createTableStatement.append("\n");
}
createTableStatement.append(")");
if (!isNullOrBlank(tableDefinition.description())) {
createTableStatement.append(String.format(" COMMENT='%s'", tableDefinition.description()));
}
createTableStatement.append(";");
return createTableStatement.toString();
}
}
Loading
Loading