Skip to content

Commit

Permalink
EmbeddingStore (Metadata) Filter API (langchain4j#610)
Browse files Browse the repository at this point in the history
## New EmbeddingStore (metadata) `Filter` API
Many embedding stores, such as
[Pinecone](https://docs.pinecone.io/docs/metadata-filtering) and
[Milvus](https://milvus.io/docs/boolean.md) support strict filtering
(think of an SQL "WHERE" clause) during similarity search.
So, if one has an embedding store with movies, for example, one could
search not only for the most semantically similar movies to the given
user query but also apply strict filtering by metadata fields like year,
genre, rating, etc. In this case, the similarity search will be
performed only on those movies that match the filter expression.

Since LangChain4j supports (and abstracts away) many embedding stores,
there needs to be an embedding-store-agnostic way for users to define
the filter expression.

This PR introduces a `Filter` interface, which can represent both simple
(e.g., `type = "documentation"`) and composite (e.g., `type in
("documentation", "tutorial") AND year > 2020`) filter expressions in an
embedding-store-agnostic manner.

`Filter` currently supports the following operations:

- Comparison:
  - `IsEqualTo`
  - `IsNotEqualTo`
  - `IsGreaterThan`
  - `IsGreaterThanOrEqualTo`
  - `IsLessThan`
  - `IsLessThanOrEqualTo`
  - `IsIn`
  - `IsNotIn`

- Logical:
  - `And`
  - `Not`
  - `Or`

These operations are supported by most embedding stores and serve as a
good starting point. However, the list of operations will expand over
time to include other operations (e.g., `Contains`) supported by
embedding stores.

Currently, the DSL looks like this:
```java
Filter onlyDocs = metadataKey("type").isEqualTo("documentation");

Filter docsAndTutorialsAfter2020 = metadataKey("type").isIn("documentation", "tutorial").and(metadataKey("year").isGreaterThan(2020));
// or
Filter docsAndTutorialsAfter2020 = and(
    metadataKey("type").isIn("documentation", "tutorial"),
    metadataKey("year").isGreaterThan(2020)
);
```

## Filter expression as a `String`
Filter expression can also be specified as a `String`. This might be
necessary, for example, if the filter expression is generated
dynamically by the application or by the LLM (as in [self
querying](https://python.langchain.com/docs/modules/data_connection/retrievers/self_query/)).

This PR introduces a `FilterParser` interface with a simple `Filter
parse(String)` API, allowing for future support of multiple syntaxes (if
this will be required).

For the out-of-the-box filter syntax, ANSI SQL's `WHERE` clause is
proposed as a suitable candidate for several reasons:
- SQL is well-known among Java developers
- There is extensive tooling available for SQL (e.g., parsers)
- LLMs are pretty good at generating valid SQL, as there are tons of SQL
queries on the internet, which are included in the LLM training
datasets. There are also specialized LLMs that are trained for
text-to-SQL task, such as [SQLCoder](https://huggingface.co/defog).

The downside is that SQL's `WHERE` clause might not support all
operations and data types that could be supported in the future by
various embedding stores. In such case, we could extend it to a superset
of ANSI SQL `WHERE` syntax and/or provide an option to express filters
in the native syntax of the store.

An out-of-the-box implementation of the SQL `FilterParser` is provided
as a `SqlFilterParser` in a separate module
`langchain4j-embedding-store-filter-parser-sql`, using
[JSqlParser](https://github.com/JSQLParser/JSqlParser) under the hood.

`SqlFilterParser` can parse SQL "SELECT" (or just "WHERE" clause)
statement into a `Filter` object:
- `SELECT * FROM fake_table WHERE userId = '123-456'` ->
`metadataKey("userId").isEqualTo("123-456")`
- `userId = '123-456'`  ->  `metadataKey("userId").isEqualTo("123-456")`

It can also resolve `CURDATE()` and
`CURRENT_DATE`/`CURRENT_TIME`/`CURRENT_TIMESTAMP`:
`SELECT * FROM fake_table WHERE year = EXTRACT(YEAR FROM CURRENT_DATE`
-> `metadataKey("year").isEqualTo(LocalDate.now().getYear())`

## Changes in `Metadata` API
Until now, `Metadata` supported only `String` values. This PR expands
the list of supported value types to `Integer`, `Long`, `Float` and
`Double`. In the future, more types may be added (if needed).
The method `String get(String key)` will be deprecated later in favor
of:
- `String getString(String key)`
- `Integer getInteger(String key)`
- `Long getLong(String key)`
- etc

New overloaded `put(key, value)` methods are introduced to support more
value types:
- `put(String key, int value)`
- `put(String key, long value)`
- etc

## Changes in `EmbeddingStore` API
New method `search` is added that will become the main entry point for
search in the future. All `findRelevant` methods will be deprecated
later.
New `search` method accepts `EmbeddingSearchRequest` and returns
`EmbeddingSearchResult`.
`EmbeddingSearchRequest` contains all search criteria (e.g.
`maxResults`, `minScore`), including new `Filter`.
`EmbeddingSearchResult` contains a list of `EmbeddingMatch`.
```java
EmbeddingSearchResult search(EmbeddingSearchRequest request);
```

## Changes in `EmbeddingStoreContentRetriever` API
`EmbeddingStoreContentRetriever` can now be configured with a static
`filter` as well as dynamic `dynamicMaxResults`, `dynamicMinScore` and
`dynamicFilter` in the builder:
```java
ContentRetriever contentRetriever = EmbeddingStoreContentRetriever.builder()
                .embeddingStore(embeddingStore)
                .embeddingModel(embeddingModel)
                ...
                .maxResults(3)
                // or
                .dynamicMaxResults(query -> 3) // You can define maxResults dynamically. The value could, for example, depend on the query or the user associated with the query.
                ...
                .minScore(0.3)
                // or
                .dynamicMinScore(query -> 0.3)
                ...
                .filter(metadataKey("userId").isEqualTo("123-456")) // Assuming your TextSegments contain Metadata with key "userId"
                // or
                .dynamicFilter(query -> metadataKey("userId").isEqualTo(query.metadata().chatMemoryId().toString()))
                ...
                .build();
```
So now you can define `maxResults`, `minScore` and `filter` both
statically and dynamically (they can depend on the query, user, etc.).
These values will be propagated to the underlying `EmbeddingStore`.

##
["Self-querying"](https://python.langchain.com/docs/modules/data_connection/retrievers/self_query/)
This PR also introduces `LanguageModelSqlFilterBuilder` in
`langchain4j-embedding-store-filter-parser-sql` module which can be used
with `EmbeddingStoreContentRetriever`'s `dynamicFilter` to automatically
build a `Filter` object from the `Query` using language model and
`SqlFilterParser`.

For example:
```java
TextSegment groundhogDay = TextSegment.from("Groundhog Day", new Metadata().put("genre", "comedy").put("year", 1993));
TextSegment forrestGump = TextSegment.from("Forrest Gump", new Metadata().put("genre", "drama").put("year", 1994));
TextSegment dieHard = TextSegment.from("Die Hard", new Metadata().put("genre", "action").put("year", 1998));

// describe metadata keys as if they were columns in the SQL table
TableDefinition tableDefinition = TableDefinition.builder()
                .name("movies")
                .addColumn("genre", "VARCHAR", "one of [comedy, drama, action]")
                .addColumn("year", "INT")
                .build();

LanguageModelSqlFilterBuilder sqlFilterBuilder = new LanguageModelSqlFilterBuilder(model, tableDefinition);

ContentRetriever contentRetriever = EmbeddingStoreContentRetriever.builder()
                .embeddingStore(embeddingStore)
                .embeddingModel(embeddingModel)
                .dynamicFilter(sqlFilterBuilder::build)
                .build();

String answer = assistant.answer("Recommend me a good drama from 90s"); // Forrest Gump
```

## Which embedding store integrations will support `Filter`?
In the long run, all (provided the embedding store itself supports it).
In the first iteration, I aim to add support to just a few:
- `InMemoryEmbeddingStore`
- Elasticsearch
- Milvus

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

## Summary by CodeRabbit

- **New Features**
- Introduced filters for checking key's value existence in a collection
for improved data handling.
- **Enhancements**
- Updated `InMemoryEmbeddingStoreTest` to extend a different class for
improved testing coverage and added a new test method.
- **Refactor**
- Made minor formatting adjustments in the assertion block for better
readability.
- **Documentation**
  - Updated class hierarchy information for clarity.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
  • Loading branch information
dliubarskyi authored Mar 8, 2024
1 parent 1eba657 commit 1acb7a6
Show file tree
Hide file tree
Showing 91 changed files with 7,665 additions and 412 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/release.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ jobs:
AZURE_OPENAI_ENDPOINT: ${{ secrets.AZURE_OPENAI_ENDPOINT }}
AZURE_OPENAI_KEY: ${{ secrets.AZURE_OPENAI_KEY }}
COHERE_API_KEY: ${{ secrets.COHERE_API_KEY }}
ELASTICSEARCH_CLOUD_API_KEY: ${{ secrets.ELASTICSEARCH_CLOUD_API_KEY }}
ELASTICSEARCH_CLOUD_URL: ${{ secrets.ELASTICSEARCH_CLOUD_URL }}
GCP_CREDENTIALS_JSON: ${{ secrets.GCP_CREDENTIALS_JSON }}
GCP_LOCATION: ${{ secrets.GCP_LOCATION }}
GCP_PROJECT_ID: ${{ secrets.GCP_PROJECT_ID }}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>

<parent>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-parent</artifactId>
<version>0.28.0-SNAPSHOT</version>
<relativePath>../../langchain4j-parent/pom.xml</relativePath>
</parent>

<artifactId>langchain4j-embedding-store-filter-parser-sql</artifactId>
<name>LangChain4j :: Embedding Store Filter Parser :: SQL</name>

<dependencies>

<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-core</artifactId>
</dependency>

<dependency>
<groupId>com.github.jsqlparser</groupId>
<artifactId>jsqlparser</artifactId>
<version>4.8</version>
</dependency>

<dependency>
<groupId>org.projectlombok</groupId>
<artifactId>lombok</artifactId>
<scope>provided</scope>
</dependency>

<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter-engine</artifactId>
<scope>test</scope>
</dependency>

<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter-params</artifactId>
<scope>test</scope>
</dependency>

<dependency>
<groupId>org.assertj</groupId>
<artifactId>assertj-core</artifactId>
<scope>test</scope>
</dependency>

<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-open-ai</artifactId>
<scope>test</scope>
</dependency>

<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-ollama</artifactId>
<version>${project.version}</version>
<scope>test</scope>
</dependency>

<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-embeddings-all-minilm-l6-v2-q</artifactId>
<scope>test</scope>
</dependency>

<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-core</artifactId>
<version>${project.parent.version}</version>
<classifier>tests</classifier>
<type>test-jar</type>
<scope>test</scope>
</dependency>

<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-core</artifactId>
<scope>test</scope>
</dependency>

<dependency>
<groupId>org.tinylog</groupId>
<artifactId>tinylog-impl</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.tinylog</groupId>
<artifactId>slf4j-tinylog</artifactId>
<scope>test</scope>
</dependency>

</dependencies>

</project>
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package dev.langchain4j.store.embedding.filter.builder.sql;

import dev.langchain4j.Experimental;
import lombok.EqualsAndHashCode;
import lombok.ToString;

import static dev.langchain4j.internal.ValidationUtils.ensureNotBlank;

@Experimental
@ToString
@EqualsAndHashCode
public class ColumnDefinition {

private final String name;
private final String type;
private final String description;

public ColumnDefinition(String name, String type) {
this(name, type, null);
}

public ColumnDefinition(String name, String type, String description) {
this.name = ensureNotBlank(name, "name");
this.type = ensureNotBlank(type, "type");
this.description = description;
}

public String name() {
return name;
}

public String type() {
return type;
}

public String description() {
return description;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,232 @@
package dev.langchain4j.store.embedding.filter.builder.sql;

import dev.langchain4j.Experimental;
import dev.langchain4j.data.document.Metadata;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.chat.ChatLanguageModel;
import dev.langchain4j.model.input.Prompt;
import dev.langchain4j.model.input.PromptTemplate;
import dev.langchain4j.model.output.Response;
import dev.langchain4j.rag.content.retriever.EmbeddingStoreContentRetriever;
import dev.langchain4j.rag.query.Query;
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.filter.Filter;
import dev.langchain4j.store.embedding.filter.parser.sql.SqlFilterParser;
import lombok.Builder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.HashMap;
import java.util.Map;

import static dev.langchain4j.internal.Utils.getOrDefault;
import static dev.langchain4j.internal.Utils.isNullOrBlank;
import static dev.langchain4j.internal.ValidationUtils.ensureNotNull;

/**
* Given a natural language {@link Query}, this class creates a suitable {@link Filter} using a language model.
* <br>
* This approach is also known as
* <a href="https://python.langchain.com/docs/modules/data_connection/retrievers/self_query">self-querying</a>.
* <br>
* It is useful for improving retrieval from an {@link EmbeddingStore} by narrowing down the search space.
* <br>
* For instance, if you have internal company documentation for multiple products in the same {@link EmbeddingStore}
* and want to search the documentation of a specific product without forcing the user to specify the
* {@link Filter} manually, you could use {@code LanguageModelSqlFilterBuilder} to automatically create the filter
* using a language model.
* <br>
* <br>
* First, describe the {@link Metadata} of your {@link TextSegment}
* as if it were an SQL table using {@link TableDefinition}:
* <pre>
* TableDefinition tableDefinition = TableDefinition.builder()
* .name("documentation") // table name
* .addColumn("product", "VARCHAR", "one of [iPhone, iPad, MacBook]") // column name, column type, comment
* ... other relevant metadata keys (columns) ...
* .build();
* </pre>
* Then, create a {@code LanguageModelSqlFilterBuilder} by providing a language model and a {@link TableDefinition},
* and use it with {@link EmbeddingStoreContentRetriever}:
* <pre>
* LanguageModelSqlFilterBuilder sqlFilterBuilder = new LanguageModelSqlFilterBuilder(model, tableDefinition);
* ContentRetriever contentRetriever = EmbeddingStoreContentRetriever.builder()
* .embeddingStore(embeddingStore)
* .embeddingModel(embeddingModel)
* .dynamicFilter(sqlFilterBuilder::build)
* .build();
* </pre>
* When the user asks, for example, "How to make the screen of my phone brighter?", the language model will generate
* an SQL query like {@code SELECT * from documentation WHERE product = 'iPhone'}.
* <br>
* Then, {@link SqlFilterParser} will parse the generated SQL into the following {@link Filter} object:
* {@code metadataKey("product").isEqualTo("iPhone")}.
* <br>
* This filter will be applied during similarity search in the {@link EmbeddingStore}.
* This means that only those {@link TextSegment}s with a {@link Metadata} entry {@code product = "iPhone"}
* will be considered for the search.
* <br>
* <br>
* It is recommended to use a capable language model, such as gpt-3.5-turbo,
* or the smaller one but fine-tuned for the text-to-SQL task, such as <a href="https://huggingface.co/defog">SQLCoder</a>.
* SQLCoder is also available via <a href="https://ollama.com/library/sqlcoder">Ollama</a>.
* <br>
* The default {@link PromptTemplate} in this class is suited for SQLCoder, but should work fine with
* capable language models like gpt-3.5-turbo and better.
* <br>
* You can override the default {@link PromptTemplate} using builder.
* <br>
* <br>
* In case SQL parsing fails (e.g., the generated SQL is invalid or contains text in addition to the SQL statement),
* {@code LanguageModelSqlFilterBuilder} will first try to extract the valid SQL from the input string.
* If parsing fails again, it will return {@code null}, meaning no filtering will be applied during the search.
*/
@Experimental
public class LanguageModelSqlFilterBuilder {

private static final Logger log = LoggerFactory.getLogger(LanguageModelSqlFilterBuilder.class);

private static final PromptTemplate DEFAULT_PROMPT_TEMPLATE = PromptTemplate.from(
"### Instructions:\n" +
"Your task is to convert a question into a SQL query, given a Postgres database schema.\n" +
"Adhere to these rules:\n" +
"- **Deliberately go through the question and database schema word by word** to appropriately answer the question\n" +
"- **Use Table Aliases** to prevent ambiguity. For example, `SELECT table1.col1, table2.col1 FROM table1 JOIN table2 ON table1.id = table2.id`.\n" +
"- When creating a ratio, always cast the numerator as float\n" +
"\n" +
"### Input:\n" +
"Generate a SQL query that answers the question `{{query}}`.\n" +
"This query will run on a database whose schema is represented in this string:\n" +
"{{create_table_statement}}\n" +
"\n" +
"### Response:\n" +
"Based on your instructions, here is the SQL query I have generated to answer the question `{{query}}`:\n" +
"```sql"
);

protected final ChatLanguageModel chatLanguageModel;
protected final TableDefinition tableDefinition;
protected final String createTableStatement;
protected final PromptTemplate promptTemplate;
protected final SqlFilterParser sqlFilterParser;

public LanguageModelSqlFilterBuilder(ChatLanguageModel chatLanguageModel,
TableDefinition tableDefinition) {
this(chatLanguageModel, tableDefinition, DEFAULT_PROMPT_TEMPLATE, new SqlFilterParser());
}

@Builder
private LanguageModelSqlFilterBuilder(ChatLanguageModel chatLanguageModel,
TableDefinition tableDefinition,
PromptTemplate promptTemplate,
SqlFilterParser sqlFilterParser) {
this.chatLanguageModel = ensureNotNull(chatLanguageModel, "chatLanguageModel");
this.tableDefinition = ensureNotNull(tableDefinition, "tableDefinition");
this.createTableStatement = format(tableDefinition);
this.promptTemplate = getOrDefault(promptTemplate, DEFAULT_PROMPT_TEMPLATE);
this.sqlFilterParser = getOrDefault(sqlFilterParser, SqlFilterParser::new);
}

public Filter build(Query query) {

Prompt prompt = createPrompt(query);

Response<AiMessage> response = chatLanguageModel.generate(prompt.toUserMessage());

String generatedSql = response.content().text();

String cleanedSql = clean(generatedSql);
log.trace("Cleaned SQL: '{}'", cleanedSql);

try {
return sqlFilterParser.parse(cleanedSql);
} catch (Exception e) {
log.warn("Failed parsing the following SQL: '{}'", cleanedSql, e);
// TODO implement additional strategies (configurable):
// - feed the error to the LLM and retry
// - return predefined filter
// - return partial filter if the filter is composite and some parts were parsed successfully
// - etc
return fallback(query, generatedSql, cleanedSql, e);
}
}

protected Prompt createPrompt(Query query) {
Map<String, Object> variables = new HashMap<>();
variables.put("create_table_statement", createTableStatement);
variables.put("query", query.text());
return promptTemplate.apply(variables);
}

protected String clean(String sql) {
return sql.trim();
}

protected Filter fallback(Query query, String generatedSql, String cleanedSql, Exception e) {

String extractedSql = extractSelectStatement(generatedSql);
if (isNullOrBlank(extractedSql)) {
log.trace("Cannot extract SQL, giving up");
return null;
}

try {
log.trace("Extracted SQL: '{}'", extractedSql);
return sqlFilterParser.parse(extractedSql);
} catch (Exception e2) {
log.warn("Failed parsing the following SQL, giving up: '{}'", extractedSql, e2);
return null;
}
}

protected String extractSelectStatement(String dirtySql) {
// TODO improve
if (dirtySql.contains("```sql")) {
for (String part : dirtySql.split("```sql")) {
if (part.toUpperCase().contains("SELECT") && part.toUpperCase().contains("WHERE")) {
return part.split("```")[0].trim();
}
}
} else if (dirtySql.contains("```")) {
for (String part : dirtySql.split("```")) {
if (part.toUpperCase().contains("SELECT") && part.toUpperCase().contains("WHERE")) {
return part.split("```")[0].trim();
}
}
} else {
for (String part : dirtySql.split("SELECT")) {
if (part.toUpperCase().contains("WHERE")) {
if (part.contains("\n")) {
for (String part2 : part.split("\n")) {
if (part2.toUpperCase().contains("WHERE")) {
return "SELECT " + part2.trim();
}
}
} else {
return "SELECT " + part.trim();
}
}
}
}
return null;
}

protected String format(TableDefinition tableDefinition) {
StringBuilder createTableStatement = new StringBuilder();
createTableStatement.append(String.format("CREATE TABLE %s (\n", tableDefinition.name()));
for (ColumnDefinition columnDefinition : tableDefinition.columns()) {
createTableStatement.append(String.format(" %s %s,", columnDefinition.name(), columnDefinition.type()));
if (!isNullOrBlank(columnDefinition.description())) {
createTableStatement.append(String.format(" -- %s", columnDefinition.description()));
}
createTableStatement.append("\n");
}
createTableStatement.append(")");
if (!isNullOrBlank(tableDefinition.description())) {
createTableStatement.append(String.format(" COMMENT='%s'", tableDefinition.description()));
}
createTableStatement.append(";");
return createTableStatement.toString();
}
}
Loading

0 comments on commit 1acb7a6

Please sign in to comment.