diff --git a/README.md b/README.md index 726d3c8af..f859f2667 100644 --- a/README.md +++ b/README.md @@ -17,7 +17,9 @@ The following example can view in the [langchain-example](langchain-examples/sr ## 2. User cases - [SQL Chains](langchain-examples/src/main/java/com/hw/langchain/examples/chains/SqlChainExample.java) -- [Retrieval QA + Pinecone](langchain-examples/src/main/java/com/hw/langchain/examples/chains/RetrievalQaExample.java) +- [Agent with Google Search](langchain-examples/src/main/java/com/hw/langchain/examples/agents/LlmAgentExample.java) +- [Question answering over documents](langchain-examples/src/main/java/com/hw/langchain/examples/chains/RetrievalQaExample.java) +- [Context aware text splitting and QA / Chat](langchain-examples/src/main/java/com/hw/langchain/examples/chains/RetrievalMarkdownExample.java) ## 3. Quickstart Guide This tutorial gives you a quick walkthrough about building an end-to-end language model application with LangChain. diff --git a/docs/extras/use_cases/question_answering/notion_db/Auto-Evaluation of Metadata Filtering.md b/docs/extras/use_cases/question_answering/notion_db/Auto-Evaluation of Metadata Filtering.md new file mode 100755 index 000000000..655c6cffb --- /dev/null +++ b/docs/extras/use_cases/question_answering/notion_db/Auto-Evaluation of Metadata Filtering.md @@ -0,0 +1,60 @@ +# Auto-Evaluation of Metadata Filtering + +[Lance Martin](https://twitter.com/RLanceMartin) + +### Introduction + +Q+A systems often use a two-step approach: retrieve relevant text chunks and then synthesize them into an answer. There many ways to approach this. For example, we recently [discussed](https://blog.langchain.dev/auto-evaluation-of-anthropic-100k-context-window/) the Retriever-Less option (at bottom in the below diagram), highlighting the Anthropic 100k context window model. Metadata filtering is an alternative approach that pre-filters chunks based on a user-defined criteria in a VectorDB using metadata tags prior to semantic search. + +![Untitled](images/Untitled.png) + +### Motivation + +I [previously built](https://twitter.com/RLanceMartin/status/1637852936238956546?s=20) a [QA app](https://lex-gpt.vercel.app/) based on the Lex Fridman podcast. This uses semantic search on Pinecone. However, it [failed](https://twitter.com/RLanceMartin/status/1639286900270964737?s=20) in cases where a user wanted to retrieve information about a specific episode (e.g., `summarize episode 53`) or in cases where a guest had been in multiple times and a user wanted information for a particular episode (e.g., `what did Elon say in episode 252`). + +In these cases, semantic search will look for the concept `episode 53` in the chunks, but instead we simply want to filter the chunks for `episode 53` and then perform semantic search to extract those that best summarize the episode. Metadata filtering does this, so long as we 1) we have a metadata filter for episode number and 2) we can extract the value from the query (e.g., `54` or `252`) that we want to extract. The LangChain `SelfQueryRetriever` does the latter (see [docs](https://www.notion.so/Auto-Evaluation-of-Metadata-Filtering-5e3aff0fd1a04b8286b63dc64c8392c5?pvs=21)), [splitting the user input](https://twitter.com/hwchase17/status/1656791490922967041?s=20) into a semantic query and a metadata filter (for [Pinecone](https://python.langchain.com/en/latest/modules/indexes/retrievers/examples/self_query.html) or [Chroma](https://python.langchain.com/en/latest/modules/indexes/retrievers/examples/chroma_self_query.html)). + +### Evaluation + +We previously introduced [auto-evaluator](https://blog.langchain.dev/auto-evaluator-opportunities/), an open-source tool for grading LLM question-answer chains. Here, we extend auto-evaluator with a [lightweight Streamlit app](https://github.com/langchain-ai/auto-evaluator/tree/main/streamlit) that can connect to any existing Pinecone index. We add the ability to test metadata filtering using `SelfQueryRetriever` as well as some other approaches that we’ve found to be useful, as discussed below. + +[ret_trim.mov](images/ret_trim.mov) + +### Testing + +`SelfQueryRetriever` works well in [many cases](https://twitter.com/hwchase17/status/1656791488569954304/photo/1). For example, given [this test case](https://twitter.com/hwchase17/status/1656791488569954304?s=20): + +![Untitled](images/Untitled%201.png) + +The query can be nicely broken up into semantic query and metadata filter: + +```python +semantic query: "prompt injection" +metadata filter: "webinar_name=agents in production" +``` + +But, sometimes the metadata filter is not obvious based on the natural language in the question. For example, my [Lex-GPT](https://lex-gpt.vercel.app/) app used an episode ID tag derived from my initial scrape of the [Karpathy transcriptions](https://karpathy.ai/lexicap/index.html), e.g., I have `“0252”` for episode `252`. This means that the retriever will need to perform this translation step, as shown in the diagram below. + +![Untitled](images/Untitled%202.png) + +`SelfQueryRetriever` will infer metadata filters from the query using `metadata_field_info`, which you can supply to the auto-evaluator [here](https://github.com/langchain-ai/auto-evaluator/blob/main/streamlit/self_query_retriever_lex.py). To address the above translation, we also test [kor](https://eyurtsev.github.io/kor/tutorial.html), a library for structured data extraction using an LLM. Using Kor, we can explicitly specify a transformation (e.g., `252` to `“0252”`) that the auto-evaluator will use in retrieval [here](https://github.com/langchain-ai/auto-evaluator/blob/main/streamlit/kor_retriever_lex.py). + +Below, we can see detailed results from the app: + +- Kor extraction is above to perform the transformation between query and metadata format ✅ +- Self-querying attempts to filter using the episode ID (`252`) in the query and fails 🚫 +- Baseline returns docs from 3 different episodes (one from `252`), confusing the answer 🚫 + +### Conclusion + +Metadata filtering is a useful and powerful augmentation for semantic search. Extracting the metadata filter from the natural language query is a challenge. The `SelfQueryRetriever` does this well in many cases, but other LangChain integrations, such as Kor, can he very useful in cases where the filter definition is not obvious from the natural language query alone. + +### Appendix + +Detailed results from the [Lex-GPT](https://lex-gpt.vercel.app/) app. + +| | Question | Correct Answer | Generated Answer | Answer Grading | Retrieval Grading | Latency | +| --- | --- | --- | --- | --- | --- | --- | +| Pinecone | What does Elon Musk say about the self driving problem in episode 252? | Elon mentions that the self-driving problem is harder than he thought because you need to build a silicon equivalent of vision that maps from camera to vector space. But, he also mentions that the disengagements has been dropping rapidly and the probability of an accident on FSD will be less than that of the average human within a year. | In episode 252 of the Lex Fridman podcast, Elon Musk discusses the self-driving problem and is impressed by what the Tesla Autopilot team is doing. They are using multitask learning to study the problem of driving and are formulating over a hundred tasks to perceive. They are deploying it and getting data back from people that run into trouble and are trying to figure out if they should add tasks or focus on each individual task separately. Musk acknowledges that there may be a long-term solution involving fancy self-supervised learning, but there is pressure to take shortcuts due to the need to solve engineering problems in a relatively short time. | "1. Relevance: The student's answer does refer to Elon Musk discussing the self-driving problem in episode 252 of the Lex Fridman podcast. However, the details provided in the student's answer do not match the true answer's content.2. Conciseness: The student's answer is not concise, as it provides a lot of information that is not directly related to the true answer.3. Correct: The student's answer is not correct, as it does not mention the key points from the true answer, such as the difficulty of building a silicon equivalent of vision and the rapid decrease in disengagements. Incorrect" | "GRADE: Correct JUSTIFICATION: All three documents contain information about Elon Musk, Tesla Autopilot, and self-driving technology. They also include discussions from the Lex Fridman Podcast, where the specific episode (252) is mentioned. These documents provide context and information that will help the student arrive at the correct answer to the question." | 12.05227709 | +| Self-querying | What does Elon Musk say about the self driving problem in episode 252? | Elon mentions that the self-driving problem is harder than he thought because you need to build a silicon equivalent of vision that maps from camera to vector space. But, he also mentions that the disengagements has been dropping rapidly and the probability of an accident on FSD will be less than that of the average human within a year. | I'm sorry, I cannot provide an answer as there is no information about the specific episode mentioned. | "1. The first criterion is relevance, which asks if the submission refers to a real quote from the text. The student answer does not provide any information about Elon Musk's statement in episode 252, so it does not meet this criterion.2. The second criterion is conciseness, which asks if the answer is concise and to the point. The student answer is concise, but it does not provide any relevant information about the self-driving problem mentioned by Elon Musk.3. The third criterion is correctness, which asks if the answer is correct. The student answer does not provide any information about Elon Musk's statement, so it is not correct.Based on the reasoning above, the submission does not meet the criteria. Incorrect" | "GRADE: Incorrect JUSTIFICATION: There is only one document retrieved, and it does not contain any information related to the question. Therefore, the student cannot arrive at the correct answer using this document." | 7.095044136 | +| Kor Filtering | What does Elon Musk say about the self driving problem in episode 252? | Elon mentions that the self-driving problem is harder than he thought because you need to build a silicon equivalent of vision that maps from camera to vector space. But, he also mentions that the disengagements has been dropping rapidly and the probability of an accident on FSD will be less than that of the average human within a year. | Elon Musk says that the self-driving problem is harder than he initially thought and that to solve it, they need to recreate what humans do to drive, which is to drive with optical sensors and biological neural nets. They need to recreate that in digital form, which means cameras with advanced neural nets in silicon form. The only way to solve for full self-driving is to do this. He also mentions that the rate of disengagements has been dropping rapidly, and it looks like the probability of an accident on FSD is less than that of the average human and significantly less than before. It is looking quite likely that they will solve level four FSD next year. | "1. The student answer refers to Elon Musk's statement about the self-driving problem being harder than he initially thought, which is mentioned in the true answer.2. The student answer also discusses the need to recreate human driving using optical sensors and biological neural nets in digital form, which is related to building a silicon equivalent of vision mentioned in the true answer.3. The student answer mentions the rate of disengagements dropping rapidly and the probability of an accident on FSD being less than that of the average human, which is also mentioned in the true answer.4. The student answer adds information about solving level four FSD next year, which is not mentioned in the true answer but is still relevant to the self-driving problem.5. The student answer is not as concise as the true answer, but it covers the main points and provides additional relevant information. Correct" | "GRADE: Correct JUSTIFICATION: All three documents contain information about Elon Musk's thoughts on self-driving technology and its progress. Specifically, Doc 2 provides a direct answer to the question, mentioning the difficulty of the problem and the improvements in disengagements." | 14.81295681 | \ No newline at end of file diff --git a/docs/extras/use_cases/question_answering/notion_db/images/Untitled 1.png b/docs/extras/use_cases/question_answering/notion_db/images/Untitled 1.png new file mode 100755 index 000000000..876bffacf Binary files /dev/null and b/docs/extras/use_cases/question_answering/notion_db/images/Untitled 1.png differ diff --git a/docs/extras/use_cases/question_answering/notion_db/images/Untitled 2.png b/docs/extras/use_cases/question_answering/notion_db/images/Untitled 2.png new file mode 100755 index 000000000..b883afb59 Binary files /dev/null and b/docs/extras/use_cases/question_answering/notion_db/images/Untitled 2.png differ diff --git a/docs/extras/use_cases/question_answering/notion_db/images/Untitled.png b/docs/extras/use_cases/question_answering/notion_db/images/Untitled.png new file mode 100755 index 000000000..b3dc00b98 Binary files /dev/null and b/docs/extras/use_cases/question_answering/notion_db/images/Untitled.png differ diff --git a/docs/extras/use_cases/question_answering/notion_db/images/ret_trim.mov b/docs/extras/use_cases/question_answering/notion_db/images/ret_trim.mov new file mode 100755 index 000000000..faa64ba1e Binary files /dev/null and b/docs/extras/use_cases/question_answering/notion_db/images/ret_trim.mov differ diff --git a/langchain-core/src/main/java/com/hw/langchain/chains/base/Chain.java b/langchain-core/src/main/java/com/hw/langchain/chains/base/Chain.java index 714cb9e16..13733e954 100644 --- a/langchain-core/src/main/java/com/hw/langchain/chains/base/Chain.java +++ b/langchain-core/src/main/java/com/hw/langchain/chains/base/Chain.java @@ -18,6 +18,7 @@ package com.hw.langchain.chains.base; +import com.google.common.collect.Maps; import com.hw.langchain.schema.BaseMemory; import java.util.*; @@ -34,12 +35,16 @@ public abstract class Chain { public abstract String chainType(); /** - * Input keys this chain expects. + * Retrieves the list of input keys that this chain expects. + * + * @return the list of input keys */ public abstract List inputKeys(); /** - * Output keys this chain expects. + * Retrieves the list of output keys that this chain expects. + * + * @return the list of output keys */ public abstract List outputKeys(); @@ -63,7 +68,10 @@ private void validateOutputs(Map outputs) { } /** - * Run the logic of this chain and return the output. + * Runs the logic of this chain and returns the output. + * + * @param inputs the inputs to be processed by the chain + * @return a map containing the output generated by the chain */ public abstract Map innerCall(Map inputs); @@ -108,7 +116,7 @@ private Map prepOutputs(Map inputs, Map result = new HashMap<>(); + Map result = Maps.newHashMap(); inputs.forEach((k, v) -> result.put(k, v.toString())); result.putAll(outputs); return result; @@ -141,7 +149,7 @@ private Map prepInputs(String input) { /** * Validate and prep inputs. */ - private Map prepInputs(Map inputs) { + public Map prepInputs(Map inputs) { Map newInputs = new HashMap<>(inputs); if (memory != null) { Map externalContext = memory.loadMemoryVariables(inputs); diff --git a/langchain-core/src/main/java/com/hw/langchain/chains/query/constructor/JsonUtils.java b/langchain-core/src/main/java/com/hw/langchain/chains/query/constructor/JsonUtils.java new file mode 100644 index 000000000..519c3a7be --- /dev/null +++ b/langchain-core/src/main/java/com/hw/langchain/chains/query/constructor/JsonUtils.java @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.hw.langchain.chains.query.constructor; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.core.util.DefaultIndenter; +import com.fasterxml.jackson.core.util.DefaultPrettyPrinter; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.ObjectWriter; +import com.fasterxml.jackson.databind.SerializationFeature; +import com.hw.langchain.exception.LangChainException; + +/** + * @author HamaWhite + */ +public class JsonUtils { + + private JsonUtils() { + } + + private static final ObjectMapper OBJECT_MAPPER = createObjectMapper(); + + private static ObjectMapper createObjectMapper() { + ObjectMapper objectMapper = new ObjectMapper(); + objectMapper.enable(SerializationFeature.INDENT_OUTPUT); + return objectMapper; + } + + public static String toJsonStringWithIndent(Object object, int indent) { + try { + ObjectWriter writer = OBJECT_MAPPER.writer(getPrettyPrinter(indent)); + return writer.writeValueAsString(object); + } catch (JsonProcessingException e) { + throw new LangChainException("Failed to format attribute info.", e); + } + } + + private static DefaultPrettyPrinter getPrettyPrinter(int indent) { + DefaultPrettyPrinter printer = new DefaultPrettyPrinter(); + printer.indentArraysWith(DefaultPrettyPrinter.FixedSpaceIndenter.instance); + printer.indentObjectsWith(DefaultIndenter.SYSTEM_LINEFEED_INSTANCE.withIndent(" ".repeat(indent))); + printer.indentArraysWith(DefaultIndenter.SYSTEM_LINEFEED_INSTANCE.withIndent(" ".repeat(indent))); + return printer; + } +} diff --git a/langchain-core/src/main/java/com/hw/langchain/chains/query/constructor/base/BaseUtils.java b/langchain-core/src/main/java/com/hw/langchain/chains/query/constructor/base/BaseUtils.java new file mode 100644 index 000000000..3e28ada4f --- /dev/null +++ b/langchain-core/src/main/java/com/hw/langchain/chains/query/constructor/base/BaseUtils.java @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.hw.langchain.chains.query.constructor.base; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.collect.Maps; +import com.hw.langchain.base.language.BaseLanguageModel; +import com.hw.langchain.chains.llm.LLMChain; +import com.hw.langchain.chains.query.constructor.ir.Comparator; +import com.hw.langchain.chains.query.constructor.ir.Operator; +import com.hw.langchain.chains.query.constructor.schema.AttributeInfo; +import com.hw.langchain.prompts.base.BasePromptTemplate; +import com.hw.langchain.prompts.few.shot.FewShotPromptTemplate; + +import java.util.List; +import java.util.Map; + +import static com.hw.langchain.chains.query.constructor.JsonUtils.toJsonStringWithIndent; +import static com.hw.langchain.chains.query.constructor.prompt.Prompt.*; +import static com.hw.langchain.prompts.utils.FormatUtils.formatTemplate; + +/** + * @author HamaWhite + */ +public class BaseUtils { + + private BaseUtils() { + } + + public static String formatAttributeInfo(List infos) { + Map> infoMaps = Maps.newHashMap(); + for (AttributeInfo info : infos) { + Map map = new ObjectMapper().convertValue(info, new TypeReference<>() { + }); + infoMaps.put((String) map.remove("name"), map); + } + return toJsonStringWithIndent(infoMaps, 4) + .replace("{", "{{") + .replace("}", "}}"); + } + + private static BasePromptTemplate getPrompt(String documentContents, List attributeInfo, + List> examples, List allowedComparators, List allowedOperators, + boolean enableLimit) { + String attributeStr = formatAttributeInfo(attributeInfo); + allowedComparators = allowedComparators != null ? allowedComparators : List.of(Comparator.values()); + allowedOperators = allowedOperators != null ? allowedOperators : List.of(Operator.values()); + + String schema; + if (enableLimit) { + schema = formatTemplate(SCHEMA_WITH_LIMIT, createTemplateArguments(allowedComparators, allowedOperators)); + examples = examples != null ? examples : EXAMPLES_WITH_LIMIT; + } else { + schema = formatTemplate(DEFAULT_SCHEMA, createTemplateArguments(allowedComparators, allowedOperators)); + examples = examples != null ? examples : DEFAULT_EXAMPLES; + } + + String prefix = formatTemplate(DEFAULT_PREFIX, Map.of("schema", schema)); + String suffix = formatTemplate(DEFAULT_SUFFIX, + Map.of("i", examples.size() + 1, "content", documentContents, "attributes", attributeStr)); + + var outputParser = StructuredQueryOutputParser.fromComponents(allowedComparators, allowedOperators); + return new FewShotPromptTemplate(examples, EXAMPLE_PROMPT, prefix, suffix, List.of("query"), outputParser); + } + + private static Map createTemplateArguments(List allowedComparators, + List allowedOperators) { + return Map.of( + "allowed_comparators", + String.join(" | ", allowedComparators.stream().map(Comparator::value).toArray(String[]::new)), + "allowed_operators", + String.join(" | ", allowedOperators.stream().map(Operator::value).toArray(String[]::new))); + } + + /** + * Load a query constructor chain. + * + * @param llm BaseLanguageModel to use for the chain. + * @param documentContents The contents of the document to be queried. + * @param attributeInfo A list of AttributeInfo objects describing the attributes of the document. + * @param examples Optional list of examples to use for the chain. + * @param allowedComparators A list of allowed comparators. + * @param allowedOperators A list of allowed operators. + * @param enableLimit Whether to enable the limit operator. Defaults to False. + * @return A LLMChain that can be used to construct queries. + */ + public static LLMChain loadQueryConstructorChain(BaseLanguageModel llm, String documentContents, + List attributeInfo, List> examples, List allowedComparators, + List allowedOperators, boolean enableLimit) { + var prompt = + getPrompt(documentContents, attributeInfo, examples, allowedComparators, allowedOperators, enableLimit); + return new LLMChain(llm, prompt); + } +} diff --git a/langchain-core/src/main/java/com/hw/langchain/chains/query/constructor/base/StructuredQueryOutputParser.java b/langchain-core/src/main/java/com/hw/langchain/chains/query/constructor/base/StructuredQueryOutputParser.java new file mode 100644 index 000000000..12c1b59aa --- /dev/null +++ b/langchain-core/src/main/java/com/hw/langchain/chains/query/constructor/base/StructuredQueryOutputParser.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.hw.langchain.chains.query.constructor.base; + +import com.hw.langchain.chains.query.constructor.ir.Comparator; +import com.hw.langchain.chains.query.constructor.ir.Comparison; +import com.hw.langchain.chains.query.constructor.ir.Operator; +import com.hw.langchain.chains.query.constructor.ir.StructuredQuery; +import com.hw.langchain.schema.BaseOutputParser; +import com.hw.langchain.schema.OutputParserException; + +import java.util.List; +import java.util.Map; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +import static com.hw.langchain.output.parsers.json.Json.parseAndCheckJsonMarkdown; + +/** + * @author HamaWhite + */ +public class StructuredQueryOutputParser extends BaseOutputParser { + + @Override + public StructuredQuery parse(String text) throws OutputParserException { + try { + List expectedKeys = List.of("query", "filter"); + Map parsed = parseAndCheckJsonMarkdown(text, expectedKeys); + Comparison filter = astParse(parsed.get("filter").toString()); + return new StructuredQuery(parsed.get("query").toString(), filter, 0); + } catch (Exception e) { + throw new OutputParserException("Parsing text\n" + text + "\nraised following error:\n" + e); + } + } + + /** + * I couldn't find a direct Java equivalent of Lark, which is available in Python. + * For the time being, let's manually parse the simpler ones. + */ + public Comparison astParse(String filter) { + // Remove leading and trailing spaces and split the string + filter = filter.trim(); + String regex = "(\\w+)\\(\"([^\"]+)\", \"([^\"]+)\"\\)"; + Pattern pattern = Pattern.compile(regex); + Matcher matcher = pattern.matcher(filter); + + if (matcher.matches()) { + // Parse the comparator + String comparatorStr = matcher.group(1).toUpperCase(); + Comparator comparator = Comparator.valueOf(comparatorStr); + // Parse the attribute and value + String attribute = matcher.group(2); + String value = matcher.group(3); + return new Comparison(comparator, attribute, value); + } + throw new IllegalArgumentException("Invalid comparison string: " + filter); + } + + public static StructuredQueryOutputParser fromComponents(List allowedComparators, + List allowedOperators) { + // TODO: Finding the Java equivalent of Lark in Python, maybe ANTLR. + return new StructuredQueryOutputParser(); + } +} diff --git a/langchain-core/src/main/java/com/hw/langchain/chains/query/constructor/ir/Comparator.java b/langchain-core/src/main/java/com/hw/langchain/chains/query/constructor/ir/Comparator.java new file mode 100644 index 000000000..5c425d985 --- /dev/null +++ b/langchain-core/src/main/java/com/hw/langchain/chains/query/constructor/ir/Comparator.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.hw.langchain.chains.query.constructor.ir; + +/** + * @author HamaWhite + */ +public enum Comparator implements StringEnum { + + /** + * Enumerator of the comparison operators. + */ + EQ("eq"), + GT("gt"), + GTE("gte"), + LT("lt"), + LTE("lte"), + CONTAIN("contain"), + LIKE("like"); + + private final String value; + + Comparator(String value) { + this.value = value; + } + + @Override + public String value() { + return value; + } +} diff --git a/langchain-core/src/main/java/com/hw/langchain/chains/query/constructor/ir/Comparison.java b/langchain-core/src/main/java/com/hw/langchain/chains/query/constructor/ir/Comparison.java new file mode 100644 index 000000000..19543210e --- /dev/null +++ b/langchain-core/src/main/java/com/hw/langchain/chains/query/constructor/ir/Comparison.java @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.hw.langchain.chains.query.constructor.ir; + +import lombok.AllArgsConstructor; +import lombok.Data; +import lombok.EqualsAndHashCode; +import lombok.NoArgsConstructor; + +import java.util.Map; + +/** + * A comparison to a value. + * + * @author HamaWhite + */ +@Data +@NoArgsConstructor +@AllArgsConstructor +@EqualsAndHashCode(callSuper = false) +public class Comparison extends FilterDirective { + + private Comparator comparator; + + private String attribute; + + private Object value; + + @Override + public Map accept(Visitor visitor) { + return visitor.visitComparison(this); + } +} diff --git a/langchain-core/src/main/java/com/hw/langchain/chains/query/constructor/ir/Expr.java b/langchain-core/src/main/java/com/hw/langchain/chains/query/constructor/ir/Expr.java new file mode 100644 index 000000000..c62092e2f --- /dev/null +++ b/langchain-core/src/main/java/com/hw/langchain/chains/query/constructor/ir/Expr.java @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.hw.langchain.chains.query.constructor.ir; + +import java.util.Map; + +/** + * @author HamaWhite + */ +public interface Expr { + + /** + * Accepts a visitor and performs operations on this expression. + * + * @param visitor the visitor object to operate on this expression + * @return the result of the visitor's operation + */ + Map accept(Visitor visitor); + +} diff --git a/langchain-core/src/main/java/com/hw/langchain/chains/query/constructor/ir/FilterDirective.java b/langchain-core/src/main/java/com/hw/langchain/chains/query/constructor/ir/FilterDirective.java new file mode 100644 index 000000000..6c944ba3f --- /dev/null +++ b/langchain-core/src/main/java/com/hw/langchain/chains/query/constructor/ir/FilterDirective.java @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.hw.langchain.chains.query.constructor.ir; + +/** + * A filtering expression. + * + * @author HamaWhite + */ +public abstract class FilterDirective implements Expr { +} diff --git a/langchain-core/src/main/java/com/hw/langchain/chains/query/constructor/ir/Operation.java b/langchain-core/src/main/java/com/hw/langchain/chains/query/constructor/ir/Operation.java new file mode 100644 index 000000000..de7b77512 --- /dev/null +++ b/langchain-core/src/main/java/com/hw/langchain/chains/query/constructor/ir/Operation.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.hw.langchain.chains.query.constructor.ir; + +import java.util.List; +import java.util.Map; + +/** + * A logical operation over other directives. + * + * @author HamaWhite + */ +public class Operation extends FilterDirective { + + private Operator operator; + + private List arguments; + + @Override + public Map accept(Visitor visitor) { + return visitor.visitOperation(this); + } +} diff --git a/langchain-core/src/main/java/com/hw/langchain/chains/query/constructor/ir/Operator.java b/langchain-core/src/main/java/com/hw/langchain/chains/query/constructor/ir/Operator.java new file mode 100644 index 000000000..0ee5d57d7 --- /dev/null +++ b/langchain-core/src/main/java/com/hw/langchain/chains/query/constructor/ir/Operator.java @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.hw.langchain.chains.query.constructor.ir; + +/** + * Enumerator of the operations. + * + * @author HamaWhite + */ +public enum Operator implements StringEnum { + + /** + * Enumerator of the operations. + */ + AND("and"), + OR("or"), + NOT("not"); + + private final String value; + + Operator(String value) { + this.value = value; + } + + @Override + public String value() { + return value; + } +} diff --git a/langchain-core/src/main/java/com/hw/langchain/chains/query/constructor/ir/StringEnum.java b/langchain-core/src/main/java/com/hw/langchain/chains/query/constructor/ir/StringEnum.java new file mode 100644 index 000000000..2d5ebbd18 --- /dev/null +++ b/langchain-core/src/main/java/com/hw/langchain/chains/query/constructor/ir/StringEnum.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.hw.langchain.chains.query.constructor.ir; + +/** + * @author HamaWhite + */ +public interface StringEnum> { + + /** + * Retrieves the string representation of the enum value. + * @return the string representation + */ + String value(); +} diff --git a/langchain-core/src/main/java/com/hw/langchain/chains/query/constructor/ir/StructuredQuery.java b/langchain-core/src/main/java/com/hw/langchain/chains/query/constructor/ir/StructuredQuery.java new file mode 100644 index 000000000..f9dc042dc --- /dev/null +++ b/langchain-core/src/main/java/com/hw/langchain/chains/query/constructor/ir/StructuredQuery.java @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.hw.langchain.chains.query.constructor.ir; + +import lombok.AllArgsConstructor; +import lombok.Data; +import lombok.NoArgsConstructor; + +import java.util.Map; + +/** + * @author HamaWhite + */ +@Data +@NoArgsConstructor +@AllArgsConstructor +public class StructuredQuery implements Expr { + + private String query; + + private FilterDirective filter; + + private Integer limit; + + @Override + public Map accept(Visitor visitor) { + return visitor.visitStructuredQuery(this); + } +} diff --git a/langchain-core/src/main/java/com/hw/langchain/chains/query/constructor/ir/Visitor.java b/langchain-core/src/main/java/com/hw/langchain/chains/query/constructor/ir/Visitor.java new file mode 100644 index 000000000..36aff9662 --- /dev/null +++ b/langchain-core/src/main/java/com/hw/langchain/chains/query/constructor/ir/Visitor.java @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.hw.langchain.chains.query.constructor.ir; + +import java.util.List; +import java.util.Map; + +/** + * Defines interface for IR translation using visitor pattern. + * + * @author HamaWhite + */ + +public abstract class Visitor { + + protected List allowedComparators; + + protected List allowedOperators; + + protected Visitor(List allowedComparators, List allowedOperators) { + this.allowedComparators = allowedComparators; + this.allowedOperators = allowedOperators; + } + + public List getAllowedComparators() { + return allowedComparators; + } + + public List getAllowedOperators() { + return allowedOperators; + } + + /** + * Translates an Operation expression. + * + * @param operation the Operation expression to translate + * @return a map containing the translated result + */ + public abstract Map visitOperation(Operation operation); + + /** + * Translates a Comparison expression. + * + * @param comparison the Comparison expression to translate + * @return a map containing the translated result + */ + public abstract Map visitComparison(Comparison comparison); + + /** + * Translates a StructuredQuery expression. + * + * @param structuredQuery the StructuredQuery expression to translate + * @return a map containing the translated result + */ + public abstract Map visitStructuredQuery(StructuredQuery structuredQuery); +} diff --git a/langchain-core/src/main/java/com/hw/langchain/chains/query/constructor/prompt/Prompt.java b/langchain-core/src/main/java/com/hw/langchain/chains/query/constructor/prompt/Prompt.java new file mode 100644 index 000000000..58cde46d8 --- /dev/null +++ b/langchain-core/src/main/java/com/hw/langchain/chains/query/constructor/prompt/Prompt.java @@ -0,0 +1,237 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.hw.langchain.chains.query.constructor.prompt; + +import com.hw.langchain.prompts.prompt.PromptTemplate; + +import java.util.List; +import java.util.Map; + +/** + * @author HamaWhite + */ +public class Prompt { + + private Prompt() { + } + + private static final String SONG_DATA_SOURCE = """ + \ + ```json + { + "content": "Lyrics of a song", + "attributes": { + "artist": { + "type": "string", + "description": "Name of the song artist" + }, + "length": { + "type": "integer", + "description": "Length of the song in seconds" + }, + "genre": { + "type": "string", + "description": "The song genre, one of \"pop\", \"rock\" or \"rap\"" + } + } + } + ```\ + """.replace( + "{", "{{").replace( + "}", "}}"); + + private static final String FULL_ANSWER = """ + \ + ```json + {{ + "query": "teenager love", + "filter": "and(or(eq(\\"artist\\", \\"Taylor Swift\\"), eq(\\"artist\\", \\"Katy Perry\\")), \ + lt(\\"length\\", 180), eq(\\"genre\\", \\"pop\\"))" + }} + ```\ + """; + + private static final String NO_FILTER_ANSWER = """ + \ + ```json + {{ + "query": "", + "filter": "NO_FILTER" + }} + ```\ + """; + + private static final String WITH_LIMIT_ANSWER = """ + \ + ```json + {{ + "query": "love", + "filter": "NO_FILTER", + "limit": 2 + }} + ```\ + """; + + public static final List> DEFAULT_EXAMPLES = List.of( + createExample(1, SONG_DATA_SOURCE, + "What are songs by Taylor Swift or Katy Perry about teenage romance under 3 minutes long in the dance pop genre", + FULL_ANSWER), + createExample(2, SONG_DATA_SOURCE, "What are songs that were not published on Spotify", NO_FILTER_ANSWER)); + + public static final List> EXAMPLES_WITH_LIMIT = List.of( + createExample(1, SONG_DATA_SOURCE, + "What are songs by Taylor Swift or Katy Perry about teenage romance under 3 minutes long in the dance pop genre", + FULL_ANSWER), + createExample(2, SONG_DATA_SOURCE, "What are songs that were not published on Spotify", NO_FILTER_ANSWER), + createExample(3, SONG_DATA_SOURCE, "What are three songs about love", WITH_LIMIT_ANSWER)); + + private static Map createExample(int i, String dataSource, String userQuery, + String structuredRequest) { + return Map.of( + "i", i, + "data_source", dataSource, + "user_query", userQuery, + "structured_request", structuredRequest); + } + + private static final String EXAMPLE_PROMPT_TEMPLATE = """ + \ + << Example {i}. >> + Data Source: + {data_source} + + User Query: + {user_query} + + Structured Request: + {structured_request} + """; + + public static final PromptTemplate EXAMPLE_PROMPT = new PromptTemplate( + List.of("i", "data_source", "user_query", "structured_request"), EXAMPLE_PROMPT_TEMPLATE); + + public static final String DEFAULT_SCHEMA = + """ + \ + << Structured Request Schema >> + When responding use a markdown code snippet with a JSON object formatted in the \ + following schema: + + ```json + {{{{ + "query": string \\ text string to compare to document contents + "filter": string \\ logical condition statement for filtering documents + }}}} + ``` + + The query string should contain only text that is expected to match the contents of \ + documents. Any conditions in the filter should not be mentioned in the query as well. + + A logical condition statement is composed of one or more comparison and logical \ + operation statements. + + A comparison statement takes the form: `comp(attr, val)`: + - `comp` ({allowed_comparators}): comparator + - `attr` (string): name of attribute to apply the comparison to + - `val` (string): is the comparison value + + A logical operation statement takes the form `op(statement1, statement2, ...)`: + - `op` ({allowed_operators}): logical operator + - `statement1`, `statement2`, ... (comparison statements or logical operation \ + statements): one or more statements to apply the operation to + + Make sure that you only use the comparators and logical operators listed above and \ + no others. + Make sure that filters only refer to attributes that exist in the data source. + Make sure that filters only use the attributed names with its function names if there are functions applied on them. + Make sure that filters only use format `YYYY-MM-DD` when handling timestamp data typed values. + Make sure that filters take into account the descriptions of attributes and only make \ + comparisons that are feasible given the type of data being stored. + Make sure that filters are only used as needed. If there are no filters that should be \ + applied return "NO_FILTER" for the filter value.\ + """; + + public static final String SCHEMA_WITH_LIMIT = + """ + \ + << Structured Request Schema >> + When responding use a markdown code snippet with a JSON object formatted in the \ + following schema: + + ```json + {{{{ + "query": string \\ text string to compare to document contents + "filter": string \\ logical condition statement for filtering documents + "limit": int \\ the number of documents to retrieve + }}}} + ``` + + The query string should contain only text that is expected to match the contents of \ + documents. Any conditions in the filter should not be mentioned in the query as well. + + A logical condition statement is composed of one or more comparison and logical \ + operation statements. + + A comparison statement takes the form: `comp(attr, val)`: + - `comp` ({allowed_comparators}): comparator + - `attr` (string): name of attribute to apply the comparison to + - `val` (string): is the comparison value + + A logical operation statement takes the form `op(statement1, statement2, ...)`: + - `op` ({allowed_operators}): logical operator + - `statement1`, `statement2`, ... (comparison statements or logical operation \ + statements): one or more statements to apply the operation to + + Make sure that you only use the comparators and logical operators listed above and \ + no others. + Make sure that filters only refer to attributes that exist in the data source. + Make sure that filters only use the attributed names with its function names if there are functions applied on them. + Make sure that filters only use format `YYYY-MM-DD` when handling timestamp data typed values. + Make sure that filters take into account the descriptions of attributes and only make \ + comparisons that are feasible given the type of data being stored. + Make sure that filters are only used as needed. If there are no filters that should be \ + applied return "NO_FILTER" for the filter value. + Make sure the `limit` is always an int value. It is an optional parameter so leave it blank if it is does not make sense. + """; + + public static final String DEFAULT_PREFIX = """ + \ + Your goal is to structure the user's query to match the request schema provided below. + + {schema}\ + """; + + public static final String DEFAULT_SUFFIX = """ + \ + << Example {i}. >> + Data Source: + ```json + {{{{ + "content": "{content}", + "attributes": {attributes} + }}}} + ``` + + User Query: + {{query}} + + Structured Request: + """; + +} diff --git a/langchain-core/src/main/java/com/hw/langchain/chains/query/constructor/schema/AttributeInfo.java b/langchain-core/src/main/java/com/hw/langchain/chains/query/constructor/schema/AttributeInfo.java new file mode 100644 index 000000000..08b596bb5 --- /dev/null +++ b/langchain-core/src/main/java/com/hw/langchain/chains/query/constructor/schema/AttributeInfo.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.hw.langchain.chains.query.constructor.schema; + +import lombok.AllArgsConstructor; +import lombok.Data; +import lombok.NoArgsConstructor; + +/** + * Information about a data source attribute. + * + * @author HamaWhite + */ +@Data +@NoArgsConstructor +@AllArgsConstructor +public class AttributeInfo { + + private String name; + + private String description; + + private String type; +} diff --git a/langchain-core/src/main/java/com/hw/langchain/chains/retrieval/qa/base/RetrievalQA.java b/langchain-core/src/main/java/com/hw/langchain/chains/retrieval/qa/base/RetrievalQa.java similarity index 81% rename from langchain-core/src/main/java/com/hw/langchain/chains/retrieval/qa/base/RetrievalQA.java rename to langchain-core/src/main/java/com/hw/langchain/chains/retrieval/qa/base/RetrievalQa.java index bdd665f79..600c8f970 100644 --- a/langchain-core/src/main/java/com/hw/langchain/chains/retrieval/qa/base/RetrievalQA.java +++ b/langchain-core/src/main/java/com/hw/langchain/chains/retrieval/qa/base/RetrievalQa.java @@ -26,6 +26,7 @@ import java.util.List; +import static com.hw.langchain.chains.question.answering.ChainType.STUFF; import static com.hw.langchain.chains.question.answering.init.Init.loadQaChain; /** @@ -33,21 +34,28 @@ * * @author HamaWhite */ -public class RetrievalQA extends BaseRetrievalQA { +public class RetrievalQa extends BaseRetrievalQA { private final BaseRetriever retriever; - public RetrievalQA(BaseCombineDocumentsChain combineDocumentsChain, BaseRetriever retriever) { + public RetrievalQa(BaseCombineDocumentsChain combineDocumentsChain, BaseRetriever retriever) { super(combineDocumentsChain); this.retriever = retriever; } + /** + * Load chain from chain type. + */ + public static BaseRetrievalQA fromChainType(BaseLanguageModel llm, BaseRetriever retriever) { + return fromChainType(llm, STUFF, retriever); + } + /** * Load chain from chain type. */ public static BaseRetrievalQA fromChainType(BaseLanguageModel llm, ChainType chainType, BaseRetriever retriever) { BaseCombineDocumentsChain combineDocumentsChain = loadQaChain(llm, chainType); - return new RetrievalQA(combineDocumentsChain, retriever); + return new RetrievalQa(combineDocumentsChain, retriever); } @Override diff --git a/langchain-core/src/main/java/com/hw/langchain/document/loaders/base/BaseLoader.java b/langchain-core/src/main/java/com/hw/langchain/document/loaders/base/BaseLoader.java index b12696d32..6123b487f 100644 --- a/langchain-core/src/main/java/com/hw/langchain/document/loaders/base/BaseLoader.java +++ b/langchain-core/src/main/java/com/hw/langchain/document/loaders/base/BaseLoader.java @@ -30,12 +30,17 @@ * * @author HamaWhite */ -public interface BaseLoader { +public abstract class BaseLoader { /** * Load data into document objects. * * @return a List which is materialized in memory. */ - List load(); + public abstract List load(); + + protected String errorMessage(String filePath) { + return "Error loading " + filePath; + } + } diff --git a/langchain-core/src/main/java/com/hw/langchain/document/loaders/notion/NotionDirectoryLoader.java b/langchain-core/src/main/java/com/hw/langchain/document/loaders/notion/NotionDirectoryLoader.java new file mode 100644 index 000000000..8013b82bd --- /dev/null +++ b/langchain-core/src/main/java/com/hw/langchain/document/loaders/notion/NotionDirectoryLoader.java @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.hw.langchain.document.loaders.notion; + +import com.hw.langchain.document.loaders.base.BaseLoader; +import com.hw.langchain.exception.LangChainException; +import com.hw.langchain.schema.Document; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.List; +import java.util.Map; +import java.util.stream.Stream; + +/** + * Loader that loads Notion directory dump. + * + * @author HamaWhite + */ +public class NotionDirectoryLoader extends BaseLoader { + + private final String filePath; + + public NotionDirectoryLoader(String path) { + this.filePath = path; + } + + @Override + public List load() { + try (Stream pathStream = Files.walk(Path.of(filePath))) { + return pathStream + .filter(p -> p.toString().endsWith(".md")) + .flatMap(this::processFile) + .toList(); + } catch (IOException e) { + throw new LangChainException(errorMessage(filePath), e); + } + } + + private Stream processFile(Path path) { + try { + String text = Files.readString(path); + Map metadata = Map.of("source", path.toString()); + return Stream.of(new Document(text, metadata)); + } catch (IOException e) { + throw new LangChainException(errorMessage(path.toString()), e); + } + } +} diff --git a/langchain-core/src/main/java/com/hw/langchain/document/loaders/text/TextLoader.java b/langchain-core/src/main/java/com/hw/langchain/document/loaders/text/TextLoader.java index 805fa1ddb..d0aeea481 100644 --- a/langchain-core/src/main/java/com/hw/langchain/document/loaders/text/TextLoader.java +++ b/langchain-core/src/main/java/com/hw/langchain/document/loaders/text/TextLoader.java @@ -40,7 +40,7 @@ * * @author HamaWhite */ -public class TextLoader implements BaseLoader { +public class TextLoader extends BaseLoader { private static final Logger LOG = LoggerFactory.getLogger(TextLoader.class); @@ -97,8 +97,4 @@ private String loadWithDetectedEncoding(String filePath) { throw new LangChainException(errorMessage(filePath), e); } } - - private String errorMessage(String filePath) { - return "Error loading " + filePath; - } } diff --git a/langchain-core/src/main/java/com/hw/langchain/output/parsers/json/Json.java b/langchain-core/src/main/java/com/hw/langchain/output/parsers/json/Json.java index 6fc98a9c2..4361054d0 100644 --- a/langchain-core/src/main/java/com/hw/langchain/output/parsers/json/Json.java +++ b/langchain-core/src/main/java/com/hw/langchain/output/parsers/json/Json.java @@ -19,11 +19,13 @@ package com.hw.langchain.output.parsers.json; import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.ObjectMapper; import com.hw.langchain.schema.OutputParserException; import java.util.List; +import java.util.Map; import java.util.regex.Matcher; import java.util.regex.Pattern; @@ -68,7 +70,7 @@ public static JsonNode parseJsonMarkdown(String jsonString) { * @param expectedKeys The expected keys in the JSON string. * @return The parsed JSON object as a JsonNode. */ - public static JsonNode parseAndCheckJsonMarkdown(String markdown, List expectedKeys) { + public static Map parseAndCheckJsonMarkdown(String markdown, List expectedKeys) { JsonNode jsonNode = parseJsonMarkdown(markdown); for (String key : expectedKeys) { if (!jsonNode.has(key)) { @@ -76,6 +78,7 @@ public static JsonNode parseAndCheckJsonMarkdown(String markdown, List e "Got invalid return object. Expected key `%s` to be present, but got %s", key, jsonNode)); } } - return jsonNode; + return new ObjectMapper().convertValue(jsonNode, new TypeReference<>() { + }); } } diff --git a/langchain-core/src/main/java/com/hw/langchain/output/parsers/structured/ResponseSchema.java b/langchain-core/src/main/java/com/hw/langchain/output/parsers/structured/ResponseSchema.java index 0f5383ad8..7df4c30ba 100644 --- a/langchain-core/src/main/java/com/hw/langchain/output/parsers/structured/ResponseSchema.java +++ b/langchain-core/src/main/java/com/hw/langchain/output/parsers/structured/ResponseSchema.java @@ -22,13 +22,15 @@ import lombok.Data; import lombok.NoArgsConstructor; +import java.io.Serializable; + /** * @author HamaWhite */ @Data @NoArgsConstructor @AllArgsConstructor -public class ResponseSchema { +public class ResponseSchema implements Serializable { private String name; diff --git a/langchain-core/src/main/java/com/hw/langchain/output/parsers/structured/StructuredOutputParser.java b/langchain-core/src/main/java/com/hw/langchain/output/parsers/structured/StructuredOutputParser.java index 24e147ffe..9b7eb0440 100644 --- a/langchain-core/src/main/java/com/hw/langchain/output/parsers/structured/StructuredOutputParser.java +++ b/langchain-core/src/main/java/com/hw/langchain/output/parsers/structured/StructuredOutputParser.java @@ -18,7 +18,6 @@ package com.hw.langchain.output.parsers.structured; -import com.fasterxml.jackson.databind.JsonNode; import com.hw.langchain.schema.BaseOutputParser; import java.util.List; @@ -31,11 +30,11 @@ /** * @author HamaWhite */ -public class StructuredOutputParser extends BaseOutputParser { +public class StructuredOutputParser extends BaseOutputParser> { private static final String LINE_TEMPLATE = "\t\"{name}\": {type} // {description}"; - private final List responseSchemas; + private final transient List responseSchemas; public StructuredOutputParser(List responseSchemas) { this.responseSchemas = responseSchemas; @@ -53,7 +52,7 @@ private String getSubString(ResponseSchema schema) { } @Override - public JsonNode parse(String text) { + public Map parse(String text) { var expectedKeys = responseSchemas.stream() .map(ResponseSchema::getName) .toList(); diff --git a/langchain-core/src/main/java/com/hw/langchain/prompts/base/StringPromptTemplate.java b/langchain-core/src/main/java/com/hw/langchain/prompts/base/StringPromptTemplate.java index e7c31e1b6..6894e3958 100644 --- a/langchain-core/src/main/java/com/hw/langchain/prompts/base/StringPromptTemplate.java +++ b/langchain-core/src/main/java/com/hw/langchain/prompts/base/StringPromptTemplate.java @@ -22,6 +22,7 @@ import com.hw.langchain.schema.PromptValue; import lombok.Data; +import lombok.EqualsAndHashCode; import lombok.NoArgsConstructor; import java.util.List; @@ -34,6 +35,7 @@ */ @Data @NoArgsConstructor +@EqualsAndHashCode(callSuper = true) public abstract class StringPromptTemplate extends BasePromptTemplate { protected StringPromptTemplate(List inputVariables) { @@ -44,7 +46,7 @@ protected StringPromptTemplate(List inputVariables, Map super(inputVariables, partialVariables); } - protected StringPromptTemplate(List inputVariables, BaseOutputParser outputParser) { + protected StringPromptTemplate(List inputVariables, BaseOutputParser outputParser) { super(inputVariables, outputParser); } @@ -52,5 +54,4 @@ protected StringPromptTemplate(List inputVariables, BaseOutputParser out public PromptValue formatPrompt(Map kwargs) { return new StringPromptValue(format(kwargs)); } - } diff --git a/langchain-core/src/main/java/com/hw/langchain/prompts/few/shot/FewShotPromptTemplate.java b/langchain-core/src/main/java/com/hw/langchain/prompts/few/shot/FewShotPromptTemplate.java new file mode 100644 index 000000000..82438fac8 --- /dev/null +++ b/langchain-core/src/main/java/com/hw/langchain/prompts/few/shot/FewShotPromptTemplate.java @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.hw.langchain.prompts.few.shot; + +import com.google.common.collect.Lists; +import com.google.common.collect.Maps; +import com.hw.langchain.prompts.base.StringPromptTemplate; +import com.hw.langchain.prompts.prompt.PromptTemplate; +import com.hw.langchain.schema.BaseOutputParser; + +import lombok.EqualsAndHashCode; + +import java.util.List; +import java.util.Map; + +import static com.hw.langchain.prompts.utils.FormatUtils.formatTemplate; + +/** + * Prompt template that contains few shot examples. + * + * @author HamaWhite + */ +@EqualsAndHashCode(callSuper = true) +public class FewShotPromptTemplate extends StringPromptTemplate { + + /** + * Examples to format into the prompt. + * Either this or exampleSelector should be provided. + */ + private final List> examples; + + /** + * PromptTemplate used to format an individual example. + */ + private final PromptTemplate examplePrompt; + + /** + * A prompt template string to put before the examples. + */ + private final String prefix; + + /** + * A prompt template string to put after the examples. + */ + private final String suffix; + + /** + * String separator used to join the prefix, the examples, and suffix. + */ + private String exampleSeparator = "\n\n"; + + public FewShotPromptTemplate(List> examples, PromptTemplate examplePrompt, String prefix, + String suffix, List inputVariables, BaseOutputParser outputParser) { + super(inputVariables, outputParser); + this.examples = examples; + this.examplePrompt = examplePrompt; + this.prefix = prefix; + this.suffix = suffix; + } + + public List> getExamples() { + return this.examples; + } + + @Override + public String format(Map kwargs) { + kwargs = mergePartialAndUserVariables(kwargs); + // Get the examples to use. + List> exampleList = getExamples(); + exampleList = exampleList.stream() + .map(example -> Maps.filterKeys(example, examplePrompt.getInputVariables()::contains)) + .toList(); + // Format the examples. + List exampleStrings = exampleList.stream() + .map(examplePrompt::format) + .toList(); + // Create the overall template. + List pieces = Lists.newArrayList(prefix); + pieces.addAll(exampleStrings); + pieces.add(suffix); + + String template = String.join(exampleSeparator, pieces); + return formatTemplate(template, kwargs); + } +} diff --git a/langchain-core/src/main/java/com/hw/langchain/prompts/utils/FormatUtils.java b/langchain-core/src/main/java/com/hw/langchain/prompts/utils/FormatUtils.java index 7dc2a9e95..c00760267 100644 --- a/langchain-core/src/main/java/com/hw/langchain/prompts/utils/FormatUtils.java +++ b/langchain-core/src/main/java/com/hw/langchain/prompts/utils/FormatUtils.java @@ -42,15 +42,14 @@ private FormatUtils() { * @return the formatted string */ public static String formatTemplate(String template, Map kwargs) { - String result = template; + // In Python format() method, the curly braces '{{}}' are used to represent the output '{}'. + String result = template.replace("{{", "{").replace("}}", "}"); for (Map.Entry entry : kwargs.entrySet()) { String placeholder = "{" + entry.getKey() + "}"; String value = entry.getValue().toString(); result = result.replace(placeholder, value); } - - // In Python format() method, the curly braces '{{}}' are used to represent the output '{}'. - return result.replace("{{", "{").replace("}}", "}"); + return result; } /** diff --git a/langchain-core/src/main/java/com/hw/langchain/retrievers/self/query/base/BaseUtils.java b/langchain-core/src/main/java/com/hw/langchain/retrievers/self/query/base/BaseUtils.java new file mode 100644 index 000000000..5d3d67440 --- /dev/null +++ b/langchain-core/src/main/java/com/hw/langchain/retrievers/self/query/base/BaseUtils.java @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.hw.langchain.retrievers.self.query.base; + +import com.hw.langchain.chains.query.constructor.ir.Visitor; +import com.hw.langchain.exception.LangChainException; +import com.hw.langchain.retrievers.self.query.pinecone.PineconeTranslator; +import com.hw.langchain.vectorstores.base.VectorStore; +import com.hw.langchain.vectorstores.pinecone.Pinecone; + +import java.util.Map; + +/** + * @author HamaWhite + */ +public class BaseUtils { + + private BaseUtils() { + } + + private static final Map, Class> BUILTIN_TRANSLATORS = Map.of( + Pinecone.class, PineconeTranslator.class); + + /** + * Get the translator class corresponding to the vector store class. + * + * @param vectorStore the VectorStore instance + * @return the Visitor instance + * @throws IllegalArgumentException if the vector store type is not supported + */ + public static Visitor getBuiltinTranslator(VectorStore vectorStore) { + Class vectorStoreCls = vectorStore.getClass(); + if (!BUILTIN_TRANSLATORS.containsKey(vectorStoreCls)) { + throw new IllegalArgumentException("Self query retriever with Vector Store type " + + vectorStoreCls.getName() + " not supported."); + } + try { + return BUILTIN_TRANSLATORS.get(vectorStoreCls).getConstructor().newInstance(); + } catch (Exception e) { + throw new LangChainException(e); + } + } +} diff --git a/langchain-core/src/main/java/com/hw/langchain/retrievers/self/query/base/SelfQueryRetriever.java b/langchain-core/src/main/java/com/hw/langchain/retrievers/self/query/base/SelfQueryRetriever.java new file mode 100644 index 000000000..41a1e4bd7 --- /dev/null +++ b/langchain-core/src/main/java/com/hw/langchain/retrievers/self/query/base/SelfQueryRetriever.java @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.hw.langchain.retrievers.self.query.base; + +import com.google.common.collect.Maps; +import com.hw.langchain.base.language.BaseLanguageModel; +import com.hw.langchain.chains.llm.LLMChain; +import com.hw.langchain.chains.query.constructor.ir.StructuredQuery; +import com.hw.langchain.chains.query.constructor.ir.Visitor; +import com.hw.langchain.chains.query.constructor.schema.AttributeInfo; +import com.hw.langchain.schema.BaseRetriever; +import com.hw.langchain.schema.Document; +import com.hw.langchain.vectorstores.base.SearchType; +import com.hw.langchain.vectorstores.base.VectorStore; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.List; +import java.util.Map; + +import static com.hw.langchain.chains.query.constructor.base.BaseUtils.loadQueryConstructorChain; +import static com.hw.langchain.retrievers.self.query.base.BaseUtils.getBuiltinTranslator; +import static com.hw.langchain.vectorstores.base.SearchType.SIMILARITY; + +/** + * Retriever that wraps around a vector store and uses an LLM to generate the vector store queries. + * + * @author HamaWhite + */ +public class SelfQueryRetriever implements BaseRetriever { + + private static final Logger LOG = LoggerFactory.getLogger(SelfQueryRetriever.class); + + /** + * The underlying vector store from which documents will be retrieved. + */ + private final VectorStore vectorStore; + + /** + * The LLMChain for generating the vector store queries. + */ + private final LLMChain llmChain; + + /** + * The search type to perform on the vector store. + * Default value: "similarity" + */ + private final SearchType searchType; + + /** + * Keyword arguments to pass in to the vector store search. + */ + private final Map searchKwargs; + + /** + * Translator for turning internal query language into vectorstore search params. + */ + private final Visitor structuredQueryTranslator; + + /** + * Use original query instead of the revised new query from LLM. + */ + private final boolean useOriginalQuery; + + public SelfQueryRetriever(VectorStore vectorStore, LLMChain llmChain, Visitor structuredQueryTranslator, + boolean useOriginalQuery) { + this(vectorStore, llmChain, SIMILARITY, Maps.newHashMap(), structuredQueryTranslator, useOriginalQuery); + } + + public SelfQueryRetriever(VectorStore vectorStore, LLMChain llmChain, SearchType searchType, + Map searchKwargs, Visitor structuredQueryTranslator, boolean useOriginalQuery) { + this.vectorStore = vectorStore; + this.llmChain = llmChain; + this.searchType = searchType; + this.searchKwargs = searchKwargs; + this.structuredQueryTranslator = structuredQueryTranslator; + this.useOriginalQuery = useOriginalQuery; + } + + @Override + public List getRelevantDocuments(String query) { + Map inputs = llmChain.prepInputs(Map.of("query", query)); + StructuredQuery structuredQuery = llmChain.predictAndParse(inputs); + LOG.info("Structured Query: {}", structuredQuery); + + Map filter = structuredQueryTranslator.visitStructuredQuery(structuredQuery); + return vectorStore.search(structuredQuery.getQuery(), searchType, filter); + } + + public static SelfQueryRetriever fromLLM(BaseLanguageModel llm, VectorStore vectorStore, String documentContents, + List metadataFieldInfo) { + Visitor structuredQueryTranslator = getBuiltinTranslator(vectorStore); + return fromLLM(llm, vectorStore, documentContents, metadataFieldInfo, structuredQueryTranslator, false, false); + } + + public static SelfQueryRetriever fromLLM(BaseLanguageModel llm, VectorStore vectorStore, String documentContents, + List metadataFieldInfo, Visitor structuredQueryTranslator, + boolean enableLimit, boolean useOriginalQuery) { + LLMChain llmChain = loadQueryConstructorChain(llm, documentContents, metadataFieldInfo, null, + structuredQueryTranslator.getAllowedComparators(), structuredQueryTranslator.getAllowedOperators(), + enableLimit); + return new SelfQueryRetriever(vectorStore, llmChain, structuredQueryTranslator, useOriginalQuery); + } +} diff --git a/langchain-core/src/main/java/com/hw/langchain/retrievers/self/query/pinecone/PineconeTranslator.java b/langchain-core/src/main/java/com/hw/langchain/retrievers/self/query/pinecone/PineconeTranslator.java new file mode 100644 index 000000000..cb80e5228 --- /dev/null +++ b/langchain-core/src/main/java/com/hw/langchain/retrievers/self/query/pinecone/PineconeTranslator.java @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.hw.langchain.retrievers.self.query.pinecone; + +import com.google.common.collect.Maps; +import com.hw.langchain.chains.query.constructor.ir.*; + +import java.util.List; +import java.util.Map; + +/** + * Logic for converting internal query language elements to valid filters. + * + * @author HamaWhite + */ +public class PineconeTranslator extends Visitor { + + public PineconeTranslator() { + super(null, List.of(Operator.AND, Operator.OR)); + } + + private String formatFunc(StringEnum func) { + return "$" + func.value(); + } + + @Override + public Map visitOperation(Operation operation) { + return Map.of(); + } + + @Override + public Map visitComparison(Comparison comparison) { + return Map.of( + comparison.getAttribute(), + Map.of(formatFunc(comparison.getComparator()), comparison.getValue())); + } + + @Override + public Map visitStructuredQuery(StructuredQuery structuredQuery) { + if (structuredQuery.getFilter() != null) { + return structuredQuery.getFilter().accept(this); + } + return Maps.newHashMap(); + } +} diff --git a/langchain-core/src/main/java/com/hw/langchain/text/splitter/RecursiveCharacterTextSplitter.java b/langchain-core/src/main/java/com/hw/langchain/text/splitter/RecursiveCharacterTextSplitter.java new file mode 100644 index 000000000..7717176fe --- /dev/null +++ b/langchain-core/src/main/java/com/hw/langchain/text/splitter/RecursiveCharacterTextSplitter.java @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.hw.langchain.text.splitter; + +import lombok.Builder; +import lombok.experimental.SuperBuilder; + +import java.util.ArrayList; +import java.util.List; +import java.util.regex.Pattern; + +import static com.hw.langchain.text.splitter.TextSplitterUtils.splitTextWithRegex; + +/** + * Implementation of splitting text that looks at characters. + * Recursively tries to split by different characters to find one that works. + * + * @author HamaWhite + */ +@SuperBuilder +public class RecursiveCharacterTextSplitter extends TextSplitter { + + @Builder.Default + protected List separators = List.of("\n\n", "\n", " ", ""); + + /** + * Split incoming text and return chunks. + */ + public List innerSplitText(String text, List separators) { + // Get appropriate separator to use + String separator = separators.get(separators.size() - 1); + List newSeparators = new ArrayList<>(); + boolean foundSeparator = false; + for (int i = 0; i < separators.size() && !foundSeparator; i++) { + String temp = separators.get(i); + if (temp.isEmpty()) { + separator = temp; + foundSeparator = true; + } else if (Pattern.compile(temp).matcher(text).find()) { + separator = temp; + newSeparators = separators.subList(i + 1, separators.size()); + foundSeparator = true; + } + } + return recursivelySplitLongerTexts(text, separator, newSeparators); + } + + private List recursivelySplitLongerTexts(String text, String separator, List newSeparators) { + List finalChunks = new ArrayList<>(); + List splits = splitTextWithRegex(text, separator, this.keepSeparator); + + // Now go merging things, recursively splitting longer texts. + List tempGoodSplits = new ArrayList<>(); + String tempSeparator = this.keepSeparator ? "" : separator; + for (String s : splits) { + if (lengthFunction.apply(s) < this.chunkSize) { + tempGoodSplits.add(s); + } else { + if (!tempGoodSplits.isEmpty()) { + List mergedText = mergeSplits(tempGoodSplits, tempSeparator); + finalChunks.addAll(mergedText); + tempGoodSplits.clear(); + } + if (newSeparators.isEmpty()) { + finalChunks.add(s); + } else { + List otherInfo = innerSplitText(s, newSeparators); + finalChunks.addAll(otherInfo); + } + } + } + if (!tempGoodSplits.isEmpty()) { + List mergedText = mergeSplits(tempGoodSplits, tempSeparator); + finalChunks.addAll(mergedText); + } + return finalChunks; + } + + @Override + public List splitText(String text) { + return innerSplitText(text, this.separators); + } +} diff --git a/langchain-core/src/main/java/com/hw/langchain/text/splitter/TextSplitterUtils.java b/langchain-core/src/main/java/com/hw/langchain/text/splitter/TextSplitterUtils.java index c90165223..45ded33fe 100644 --- a/langchain-core/src/main/java/com/hw/langchain/text/splitter/TextSplitterUtils.java +++ b/langchain-core/src/main/java/com/hw/langchain/text/splitter/TextSplitterUtils.java @@ -22,6 +22,7 @@ import java.util.ArrayList; import java.util.List; +import java.util.regex.Matcher; import java.util.regex.Pattern; /** @@ -38,9 +39,8 @@ public static List splitTextWithRegex(String text, String separator, boo if (StringUtils.isNotEmpty(separator)) { if (keepSeparator) { // The parentheses in the pattern keep the delimiters in the result. - Pattern pattern = Pattern.compile(String.format("(%s)", Pattern.quote(separator))); - String[] parts = pattern.split(text, -1); - for (int i = 1; i < parts.length; i += 2) { + String[] parts = splitWithSeparator(text, separator); + for (int i = 1; i < parts.length - 1; i += 2) { splits.add(parts[i] + parts[i + 1]); } if (parts.length % 2 == 0) { @@ -55,4 +55,27 @@ public static List splitTextWithRegex(String text, String separator, boo } return splits.stream().filter(StringUtils::isNotEmpty).toList(); } + + /** + * Split the text by the occurrences of the separator, + * returning a list containing the resulting substrings and the separators. + */ + public static String[] splitWithSeparator(String text, String separator) { + List splits = new ArrayList<>(); + Pattern pattern = Pattern.compile("(" + Pattern.quote(separator) + ")"); + Matcher matcher = pattern.matcher(text); + int prevEnd = 0; + + while (matcher.find()) { + int start = matcher.start(); + int end = matcher.end(); + splits.add(text.substring(prevEnd, start)); + splits.add(text.substring(start, end)); + prevEnd = end; + } + if (prevEnd < text.length()) { + splits.add(text.substring(prevEnd)); + } + return splits.toArray(new String[0]); + } } diff --git a/langchain-core/src/main/java/com/hw/langchain/vectorstores/base/VectorStore.java b/langchain-core/src/main/java/com/hw/langchain/vectorstores/base/VectorStore.java index 17a9471e4..07b2a3fc1 100644 --- a/langchain-core/src/main/java/com/hw/langchain/vectorstores/base/VectorStore.java +++ b/langchain-core/src/main/java/com/hw/langchain/vectorstores/base/VectorStore.java @@ -69,9 +69,9 @@ public List addDocuments(List documents, Map k return addTexts(texts, metadatas, kwargs); } - public List search(String query, SearchType searchType) { + public List search(String query, SearchType searchType, Map filter) { return switch (searchType) { - case SIMILARITY -> similaritySearch(query); + case SIMILARITY -> similaritySearch(query, filter); case MMR -> maxMarginalRelevanceSearch(query); default -> throw new IllegalArgumentException( "searchType of " + searchType + " not allowed. Expected searchType to be 'similarity' or 'mmr'."); @@ -79,19 +79,39 @@ public List search(String query, SearchType searchType) { } /** - * Return docs most similar to query. + * Returns the documents most similar to the given query. + * + * @param query the input text + * @return a list of tuples containing the documents and their similarity scores */ public List similaritySearch(String query) { - return similaritySearch(query, 4); + return similaritySearch(query, null); + } + + /** + * Returns the documents most similar to the given query. + * + * @param query the input text + * @param filter a filter to apply to the search + * @return a list of tuples containing the documents and their similarity scores + */ + public List similaritySearch(String query, Map filter) { + return similaritySearch(query, 4, filter); } /** - * Return docs most similar to query. + * Returns the documents most similar to the given query. + * + * @param query the input text + * @param k the number of documents to return + * @param filter a filter to apply to the search + * @return a list of tuples containing the documents and their similarity scores */ - public abstract List similaritySearch(String query, int k); + public abstract List similaritySearch(String query, int k, Map filter); /** * Return docs and relevance scores in the range [0, 1]. 0 is dissimilar, 1 is most similar. + * @param query input text */ public List> similaritySearchWithRelevanceScores(String query) { return similaritySearchWithRelevanceScores(query, 4); @@ -102,7 +122,7 @@ public List> similaritySearchWithRelevanceScores(String qu * * @param query input text * @param k Number of Documents to return. - * @return List of Tuples of (doc, similarity_score) + * @return List of Tuples of (doc, similarityScore) */ public List> similaritySearchWithRelevanceScores(String query, int k) { List> docsAndSimilarities = _similaritySearchWithRelevanceScores(query, k); @@ -116,6 +136,10 @@ public List> similaritySearchWithRelevanceScores(String qu /** * Return docs and relevance scores, normalized on a scale from 0 to 1. 0 is dissimilar, 1 is most similar. + * + * @param query input text + * @param k Number of Documents to return. + * @return List of Tuples of (doc, similarityScore) */ protected abstract List> _similaritySearchWithRelevanceScores(String query, int k); @@ -174,7 +198,12 @@ public int fromDocuments(List documents, Embeddings embedding) { } /** - * Return VectorStore initialized from texts and embeddings. + * Initializes and returns a VectorStore from the given texts, embeddings, and metadata. + * + * @param texts the list of texts + * @param embedding the embeddings for the texts + * @param metadatas the list of metadata associated with the texts + * @return the initialized VectorStore */ public abstract int fromTexts(List texts, Embeddings embedding, List> metadatas); diff --git a/langchain-core/src/main/java/com/hw/langchain/vectorstores/base/VectorStoreRetriever.java b/langchain-core/src/main/java/com/hw/langchain/vectorstores/base/VectorStoreRetriever.java index 32f7bfe2c..36856de99 100644 --- a/langchain-core/src/main/java/com/hw/langchain/vectorstores/base/VectorStoreRetriever.java +++ b/langchain-core/src/main/java/com/hw/langchain/vectorstores/base/VectorStoreRetriever.java @@ -64,7 +64,7 @@ private void validateSearchType() { @Override public List getRelevantDocuments(String query) { return switch (searchType) { - case SIMILARITY -> vectorstore.similaritySearch(query); + case SIMILARITY -> vectorstore.similaritySearch(query, Map.of()); case SIMILARITY_SCORE_THRESHOLD -> vectorstore.similaritySearchWithRelevanceScores(query) .stream() .map(Pair::getLeft) diff --git a/langchain-core/src/main/java/com/hw/langchain/vectorstores/pinecone/Pinecone.java b/langchain-core/src/main/java/com/hw/langchain/vectorstores/pinecone/Pinecone.java index 2f5eccd75..28b375179 100644 --- a/langchain-core/src/main/java/com/hw/langchain/vectorstores/pinecone/Pinecone.java +++ b/langchain-core/src/main/java/com/hw/langchain/vectorstores/pinecone/Pinecone.java @@ -100,15 +100,17 @@ public boolean delete(List ids) { /** * Return pinecone documents most similar to query, along with scores. * - * @param query Text to look up documents similar to. - * @param k Number of Documents to return. Defaults to 4. + * @param query Text to look up documents similar to. + * @param k Number of Documents to return. Defaults to 4. + * @param filter Dictionary of argument(s) to filter on metadata * @return List of Documents most similar to the query and score for each */ - private List> similaritySearchWithScore(String query, int k) { + private List> similaritySearchWithScore(String query, int k, Map filter) { List queryObj = embeddingFunction.apply(query); QueryRequest queryRequest = QueryRequest.builder() .vector(queryObj) .topK(k) + .filter(filter) .namespace(namespace) .includeMetadata(true) .build(); @@ -131,13 +133,14 @@ private List> similaritySearchWithScore(String query, int /** * Return pinecone documents most similar to query. * - * @param query Text to look up documents similar to. - * @param k Number of Documents to return. Defaults to 4. + * @param query Text to look up documents similar to. + * @param k Number of Documents to return. Defaults to 4. + * @param filter Dictionary of argument(s) to filter on metadata * @return List of Documents most similar to the query and score for each */ @Override - public List similaritySearch(String query, int k) { - List> docsAndScores = similaritySearchWithScore(query, k); + public List similaritySearch(String query, int k, Map filter) { + List> docsAndScores = similaritySearchWithScore(query, k, filter); return docsAndScores.stream().map(Pair::getLeft).toList(); } @@ -236,4 +239,8 @@ private List createVectors(List idsBatch, List> embe .mapToObj(k -> new Vector(idsBatch.get(k), embeds.get(k), metadata.get(k))) .toList(); } + + public IndexClient getIndex() { + return index; + } } diff --git a/langchain-core/src/test/java/com/hw/langchain/chains/query/constructor/base/StructuredQueryOutputParserTest.java b/langchain-core/src/test/java/com/hw/langchain/chains/query/constructor/base/StructuredQueryOutputParserTest.java new file mode 100644 index 000000000..159375534 --- /dev/null +++ b/langchain-core/src/test/java/com/hw/langchain/chains/query/constructor/base/StructuredQueryOutputParserTest.java @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.hw.langchain.chains.query.constructor.base; + +import com.hw.langchain.chains.query.constructor.ir.Comparator; +import com.hw.langchain.chains.query.constructor.ir.Comparison; + +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.*; + +/** + * @author HamaWhite + */ +class StructuredQueryOutputParserTest { + + @Test + void testAstParse() { + StructuredQueryOutputParser parser = new StructuredQueryOutputParser(); + + String filter = "eq(\"Section\", \"Testing\")"; + Comparison comparison = parser.astParse(filter); + + assertEquals(Comparator.EQ, comparison.getComparator()); + assertEquals("Section", comparison.getAttribute()); + assertEquals("Testing", comparison.getValue()); + } +} \ No newline at end of file diff --git a/langchain-core/src/test/java/com/hw/langchain/chains/retrieval/qa/base/RetrievalQATest.java b/langchain-core/src/test/java/com/hw/langchain/chains/retrieval/qa/base/RetrievalQATest.java index 9b7641d16..7d63cc3f6 100644 --- a/langchain-core/src/test/java/com/hw/langchain/chains/retrieval/qa/base/RetrievalQATest.java +++ b/langchain-core/src/test/java/com/hw/langchain/chains/retrieval/qa/base/RetrievalQATest.java @@ -75,7 +75,7 @@ void testRetrievalQAFromPinecone() { var pinecone = createPinecone(); var llm = OpenAI.builder().temperature(0).requestTimeout(30).build().init(); - var qa = RetrievalQA.fromChainType(llm, STUFF, pinecone.asRetriever()); + var qa = RetrievalQa.fromChainType(llm, STUFF, pinecone.asRetriever()); String query = "What did the president say about Ketanji Brown Jackson"; var actual = qa.run(query); diff --git a/langchain-core/src/test/java/com/hw/langchain/document/loaders/notion/NotionDirectoryLoaderTest.java b/langchain-core/src/test/java/com/hw/langchain/document/loaders/notion/NotionDirectoryLoaderTest.java new file mode 100644 index 000000000..feb9884b1 --- /dev/null +++ b/langchain-core/src/test/java/com/hw/langchain/document/loaders/notion/NotionDirectoryLoaderTest.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.hw.langchain.document.loaders.notion; + +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; + +/** + * @author HamaWhite + */ +class NotionDirectoryLoaderTest { + + @Test + void testLoad() { + String path = "../docs/extras/use_cases/question_answering/notion_db/"; + + var loader = new NotionDirectoryLoader(path); + var docs = loader.load(); + assertEquals(1, docs.size()); + var mdFile = docs.get(0).getPageContent(); + assertNotNull(mdFile); + } +} \ No newline at end of file diff --git a/langchain-core/src/test/java/com/hw/langchain/output/parsers/structured/StructuredOutputParserTest.java b/langchain-core/src/test/java/com/hw/langchain/output/parsers/structured/StructuredOutputParserTest.java index 35dd88480..f5534acf6 100644 --- a/langchain-core/src/test/java/com/hw/langchain/output/parsers/structured/StructuredOutputParserTest.java +++ b/langchain-core/src/test/java/com/hw/langchain/output/parsers/structured/StructuredOutputParserTest.java @@ -18,7 +18,6 @@ package com.hw.langchain.output.parsers.structured; -import com.fasterxml.jackson.databind.ObjectMapper; import com.hw.langchain.chat.models.openai.ChatOpenAI; import com.hw.langchain.llms.openai.OpenAI; import com.hw.langchain.prompts.chat.ChatPromptTemplate; @@ -52,10 +51,7 @@ void testParse() { var text = "```json\n{\"name\": \"John\", \"age\": 30}\n```"; var result = parser.parse(text); - var expectedResult = new ObjectMapper().createObjectNode() - .put("name", "John") - .put("age", 30); - + var expectedResult = Map.of("name", "John", "age", 30); assertEquals(expectedResult, result); } @@ -91,9 +87,8 @@ void testStructuredOutputParserWithLLM() { var output = llm.call(input.toString()); var actual = outputParser.parse(output); - var expected = new ObjectMapper().createObjectNode() - .put("answer", "Paris") - .put("source", "https://www.worldatlas.com/articles/what-is-the-capital-of-france.html"); + var expected = Map.of("answer", "Paris", + "source", "https://www.worldatlas.com/articles/what-is-the-capital-of-france.html"); assertEquals(expected, actual); } @@ -114,9 +109,8 @@ void testStructuredOutputParserWithChatModel() { var output = chatModel.call(input.toMessages()); var actual = outputParser.parse(output.getContent()); - var expected = new ObjectMapper().createObjectNode() - .put("answer", "The capital of France is Paris.") - .put("source", "https://en.wikipedia.org/wiki/Paris"); + var expected = Map.of("answer", "The capital of France is Paris.", + "source", "https://en.wikipedia.org/wiki/Paris"); assertEquals(expected, actual); } } \ No newline at end of file diff --git a/langchain-core/src/test/java/com/hw/langchain/prompts/utils/FormatUtilsTest.java b/langchain-core/src/test/java/com/hw/langchain/prompts/utils/FormatUtilsTest.java index cdb62d167..8323a5a82 100644 --- a/langchain-core/src/test/java/com/hw/langchain/prompts/utils/FormatUtilsTest.java +++ b/langchain-core/src/test/java/com/hw/langchain/prompts/utils/FormatUtilsTest.java @@ -84,7 +84,7 @@ void testFormatTemplateWithQuadrupleCurlyBraces() { Map kwargs = Map.of("name", "John"); // python is 'Hello, {{name}}!' - String expected = "Hello, {{John}}!"; + String expected = "Hello, {John}!"; String actual = formatTemplate(template, kwargs); assertEquals(expected, actual); } diff --git a/langchain-core/src/test/java/com/hw/langchain/text/splitter/TextSplitterUtilsTest.java b/langchain-core/src/test/java/com/hw/langchain/text/splitter/TextSplitterUtilsTest.java new file mode 100644 index 000000000..f1dc35d0e --- /dev/null +++ b/langchain-core/src/test/java/com/hw/langchain/text/splitter/TextSplitterUtilsTest.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.hw.langchain.text.splitter; + +import org.junit.jupiter.api.Test; + +import static com.hw.langchain.text.splitter.TextSplitterUtils.splitWithSeparator; +import static org.junit.jupiter.api.Assertions.*; + +/** + * @author HamaWhite + */ +class TextSplitterUtilsTest { + + @Test + void testSplitWithSeparator() { + String text = "A\nB"; + String separator = "\n"; + String[] actual = splitWithSeparator(text, separator); + + String[] expected = {"A", "\n", "B"}; + assertArrayEquals(expected, actual); + } +} \ No newline at end of file diff --git a/langchain-examples/src/main/java/com/hw/langchain/examples/chains/RetrievalMarkdownExample.java b/langchain-examples/src/main/java/com/hw/langchain/examples/chains/RetrievalMarkdownExample.java new file mode 100644 index 000000000..48852b267 --- /dev/null +++ b/langchain-examples/src/main/java/com/hw/langchain/examples/chains/RetrievalMarkdownExample.java @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.hw.langchain.examples.chains; + +import com.hw.langchain.chains.query.constructor.schema.AttributeInfo; +import com.hw.langchain.chains.retrieval.qa.base.RetrievalQa; +import com.hw.langchain.chat.models.openai.ChatOpenAI; +import com.hw.langchain.document.loaders.notion.NotionDirectoryLoader; +import com.hw.langchain.llms.openai.OpenAI; +import com.hw.langchain.retrievers.self.query.base.SelfQueryRetriever; +import com.hw.langchain.schema.Document; +import com.hw.langchain.text.splitter.MarkdownHeaderTextSplitter; +import com.hw.langchain.text.splitter.RecursiveCharacterTextSplitter; + +import org.apache.commons.lang3.tuple.Pair; + +import java.util.List; + +import static com.hw.langchain.examples.utils.PrintUtils.println; +import static com.hw.langchain.examples.vectorstores.PineconeExample.*; + +/** + * Context aware text splitting and QA/Chat + *

+ * export PINECONE_API_KEY=xxx + * export PINECONE_ENV=xxx + * + * @author HamaWhite + */ +public class RetrievalMarkdownExample { + + public static final String NAMESPACE = "markdown"; + + public static void main(String[] args) { + // Load Notion page as a markdown file + String path = "docs/extras/use_cases/question_answering/notion_db/"; + var loader = new NotionDirectoryLoader(path); + var docs = loader.load(); + var mdFile = docs.get(0).getPageContent(); + + // Let's create groups based on the section headers in our page + List> headersToSplitOn = List.of(Pair.of("###", "Section")); + MarkdownHeaderTextSplitter markdownSplitter = new MarkdownHeaderTextSplitter(headersToSplitOn); + List mdHeaderSplits = markdownSplitter.splitText(mdFile); + + // Define our text splitter + var textSplitter = RecursiveCharacterTextSplitter.builder() + .chunkSize(500) + .chunkOverlap(0) + .keepSeparator(true) + .build(); + var allSplits = textSplitter.splitDocuments(mdHeaderSplits); + + // Build pinecone and keep the metadata + var vectorStore = initializePineconeIndex(NAMESPACE, allSplits); + + // Define our metadata + var metadataFieldInfo = List.of( + new AttributeInfo("Section", "Part of the document that the text comes from", + "string or list[string]")); + var documentContentDescription = "Major sections of the document"; + + // Define self query retriever + var llm = OpenAI.builder().temperature(0).requestTimeout(30).build().init(); + var retriever = SelfQueryRetriever.fromLLM(llm, vectorStore, documentContentDescription, metadataFieldInfo); + + // create chat or Q+A apps that are aware of the explicit document structure. + var chat = ChatOpenAI.builder().temperature(0).build().init(); + var qaChain = RetrievalQa.fromChainType(chat, retriever); + var result = qaChain.run("Summarize the Testing section of the document"); + println(result); + } +} diff --git a/langchain-examples/src/main/java/com/hw/langchain/examples/chains/RetrievalQaExample.java b/langchain-examples/src/main/java/com/hw/langchain/examples/chains/RetrievalQaExample.java index 5afaf8c96..cb0745488 100644 --- a/langchain-examples/src/main/java/com/hw/langchain/examples/chains/RetrievalQaExample.java +++ b/langchain-examples/src/main/java/com/hw/langchain/examples/chains/RetrievalQaExample.java @@ -18,22 +18,14 @@ package com.hw.langchain.examples.chains; -import com.hw.langchain.chains.retrieval.qa.base.RetrievalQA; +import com.hw.langchain.chains.retrieval.qa.base.RetrievalQa; import com.hw.langchain.document.loaders.text.TextLoader; -import com.hw.langchain.embeddings.openai.OpenAIEmbeddings; import com.hw.langchain.llms.openai.OpenAI; import com.hw.langchain.text.splitter.CharacterTextSplitter; -import com.hw.langchain.vectorstores.pinecone.Pinecone; -import com.hw.pinecone.PineconeClient; -import com.hw.pinecone.entity.index.CreateIndexRequest; -import com.hw.pinecone.entity.index.IndexDescription; - -import org.awaitility.Awaitility; - -import java.time.Duration; import static com.hw.langchain.chains.question.answering.ChainType.STUFF; import static com.hw.langchain.examples.utils.PrintUtils.println; +import static com.hw.langchain.examples.vectorstores.PineconeExample.*; /** * Retrieval QA @@ -45,7 +37,7 @@ */ public class RetrievalQaExample { - public static final String INDEX_NAME = "langchain-demo"; + public static final String NAMESPACE = "txt"; public static void main(String[] args) { var filePath = "docs/extras/modules/state_of_the_union.txt"; @@ -54,47 +46,12 @@ public static void main(String[] args) { var textSplitter = CharacterTextSplitter.builder().chunkSize(1000).chunkOverlap(0).build(); var docs = textSplitter.splitDocuments(documents); - var client = PineconeClient.builder().requestTimeout(30).build().init(); - - createPineconeIndex(client); - - var embeddings = OpenAIEmbeddings.builder().requestTimeout(60).build().init(); - var pinecone = Pinecone.builder().client(client).indexName(INDEX_NAME) - .embeddingFunction(embeddings::embedQuery).build().init(); - pinecone.fromDocuments(docs, embeddings); - + var pinecone = initializePineconeIndex(NAMESPACE, docs); var llm = OpenAI.builder().temperature(0).requestTimeout(30).build().init(); - var qa = RetrievalQA.fromChainType(llm, STUFF, pinecone.asRetriever()); + var qa = RetrievalQa.fromChainType(llm, STUFF, pinecone.asRetriever()); var query = "What did the president say about Ketanji Brown Jackson"; var result = qa.run(query); println(result); } - - /** - * If the index does not exist, it creates a new index with the specified name and dimension. - * It also waits until the index is ready before returning. - */ - private static void createPineconeIndex(PineconeClient client) { - if (!client.listIndexes().contains(INDEX_NAME)) { - // the text-embedding-ada-002 model has an output dimension of 1536. - var request = CreateIndexRequest.builder() - .name(INDEX_NAME) - .dimension(1536) - .build(); - client.createIndex(request); - - awaitIndexReady(client); - } - } - - private static void awaitIndexReady(PineconeClient client) { - Awaitility.await() - .atMost(Duration.ofSeconds(120)) - .pollInterval(Duration.ofSeconds(5)) - .until(() -> { - IndexDescription indexDescription = client.describeIndex(INDEX_NAME); - return indexDescription != null && indexDescription.getStatus().isReady(); - }); - } } diff --git a/langchain-examples/src/main/java/com/hw/langchain/examples/vectorstores/PineconeExample.java b/langchain-examples/src/main/java/com/hw/langchain/examples/vectorstores/PineconeExample.java new file mode 100644 index 000000000..0b21817f7 --- /dev/null +++ b/langchain-examples/src/main/java/com/hw/langchain/examples/vectorstores/PineconeExample.java @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.hw.langchain.examples.vectorstores; + +import com.hw.langchain.embeddings.openai.OpenAIEmbeddings; +import com.hw.langchain.schema.Document; +import com.hw.langchain.vectorstores.pinecone.Pinecone; +import com.hw.pinecone.PineconeClient; +import com.hw.pinecone.entity.index.CreateIndexRequest; +import com.hw.pinecone.entity.index.IndexDescription; +import com.hw.pinecone.entity.vector.DescribeIndexStatsRequest; + +import org.awaitility.Awaitility; + +import java.time.Duration; +import java.util.List; + +/** + * export PINECONE_API_KEY=xxx + * export PINECONE_ENV=xxx + * + * @author HamaWhite + */ +public class PineconeExample { + + private PineconeExample() { + } + + public static final String INDEX_NAME = "langchain-demo"; + + /** + * Initializes the Pinecone client, creates the index if necessary, and performs the desired operations. + * + * @param namespace the namespace + * @param docs the list of docs + * @return the initialized Pinecone instance + */ + public static Pinecone initializePineconeIndex(String namespace, List docs) { + var client = PineconeClient.builder().requestTimeout(30).build().init(); + createPineconeIndex(client); + + var embeddings = OpenAIEmbeddings.builder().requestTimeout(60).build().init(); + var pinecone = Pinecone.builder() + .client(client) + .indexName(INDEX_NAME) + .namespace(namespace) + .embeddingFunction(embeddings::embedQuery) + .build().init(); + + var request = new DescribeIndexStatsRequest(); + var response = pinecone.getIndex().describeIndexStats(request); + if (!response.getNamespaces().containsKey(namespace)) { + pinecone.fromDocuments(docs, embeddings); + } + return pinecone; + } + + /** + * If the index does not exist, it creates a new index with the specified name and dimension. + * It also waits until the index is ready before returning. + * + * @param client the PineconeClient instance + */ + public static void createPineconeIndex(PineconeClient client) { + if (!client.listIndexes().contains(INDEX_NAME)) { + // the text-embedding-ada-002 model has an output dimension of 1536. + var request = CreateIndexRequest.builder() + .name(INDEX_NAME) + .dimension(1536) + .build(); + client.createIndex(request); + awaitIndexReady(client); + } + } + + private static void awaitIndexReady(PineconeClient client) { + Awaitility.await() + .atMost(Duration.ofSeconds(120)) + .pollInterval(Duration.ofSeconds(5)) + .until(() -> { + IndexDescription indexDescription = client.describeIndex(INDEX_NAME); + return indexDescription != null && indexDescription.getStatus().isReady(); + }); + } +} diff --git a/pinecone-client/src/main/java/com/hw/pinecone/IndexClient.java b/pinecone-client/src/main/java/com/hw/pinecone/IndexClient.java index 05f4f05e0..6dfe12ce5 100644 --- a/pinecone-client/src/main/java/com/hw/pinecone/IndexClient.java +++ b/pinecone-client/src/main/java/com/hw/pinecone/IndexClient.java @@ -32,6 +32,17 @@ public IndexClient(VectorService vectorService) { this.vectorService = vectorService; } + /** + * The DescribeIndexStats operation returns statistics about the index's contents, + * including the vector count per namespace and the number of dimensions. + * + * @param request the request object for describing index statistics + * @return a DescribeIndexStatsResponse with index statistics + */ + public DescribeIndexStatsResponse describeIndexStats(DescribeIndexStatsRequest request) { + return vectorService.describeIndexStats(request).blockingGet(); + } + /** * The Query operation searches a namespace, using a query vector. * It retrieves the ids of the most similar items in a namespace, along with their similarity scores. diff --git a/pinecone-client/src/main/java/com/hw/pinecone/entity/vector/DescribeIndexStatsRequest.java b/pinecone-client/src/main/java/com/hw/pinecone/entity/vector/DescribeIndexStatsRequest.java new file mode 100644 index 000000000..a45589f59 --- /dev/null +++ b/pinecone-client/src/main/java/com/hw/pinecone/entity/vector/DescribeIndexStatsRequest.java @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.hw.pinecone.entity.vector; + +import com.fasterxml.jackson.annotation.JsonInclude; + +import lombok.AllArgsConstructor; +import lombok.Data; +import lombok.NoArgsConstructor; + +import java.io.Serializable; + +/** + * @author HamaWhite + */ + +@Data +@NoArgsConstructor +@AllArgsConstructor +@JsonInclude(JsonInclude.Include.NON_NULL) +public class DescribeIndexStatsRequest implements Serializable { + + /** + * If this parameter is present, the operation only returns statistics + * for vectors that satisfy the filter. + */ + private Object filter; +} diff --git a/pinecone-client/src/main/java/com/hw/pinecone/entity/vector/DescribeIndexStatsResponse.java b/pinecone-client/src/main/java/com/hw/pinecone/entity/vector/DescribeIndexStatsResponse.java new file mode 100644 index 000000000..cc7990713 --- /dev/null +++ b/pinecone-client/src/main/java/com/hw/pinecone/entity/vector/DescribeIndexStatsResponse.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.hw.pinecone.entity.vector; + +import lombok.Data; + +import java.util.Map; + +/** + * @author HamaWhite + */ +@Data +public class DescribeIndexStatsResponse { + + /** + * A mapping for each namespace in the index from the namespace name to a summary of its contents. + * If a metadata filter expression is present, the summary will reflect only vectors matching that expression. + */ + private Map> namespaces; + + /** + * The dimension of the indexed vectors. + */ + private Integer dimension; + + /** + * The fullness of the index, regardless of whether a metadata filter expression was passed. + * The granularity of this metric is 10%. + */ + private Float indexFullness; + + /** + * The total number of vectors in the index, regardless of whether a metadata filter expression was passed. + */ + private Integer totalVectorCount; +} diff --git a/pinecone-client/src/main/java/com/hw/pinecone/entity/vector/QueryRequest.java b/pinecone-client/src/main/java/com/hw/pinecone/entity/vector/QueryRequest.java index 2763c294a..8a0794abd 100644 --- a/pinecone-client/src/main/java/com/hw/pinecone/entity/vector/QueryRequest.java +++ b/pinecone-client/src/main/java/com/hw/pinecone/entity/vector/QueryRequest.java @@ -54,6 +54,11 @@ public class QueryRequest implements Serializable { @Builder.Default private Integer topK = 10; + /** + * The filter to apply. You can use vector metadata to limit your search. + */ + private Object filter; + /** * The namespace to query. */ diff --git a/pinecone-client/src/main/java/com/hw/pinecone/service/VectorService.java b/pinecone-client/src/main/java/com/hw/pinecone/service/VectorService.java index 7bf74e568..951264456 100644 --- a/pinecone-client/src/main/java/com/hw/pinecone/service/VectorService.java +++ b/pinecone-client/src/main/java/com/hw/pinecone/service/VectorService.java @@ -32,6 +32,16 @@ */ public interface VectorService { + /** + * The DescribeIndexStats operation returns statistics about the index's contents, + * including the vector count per namespace and the number of dimensions. + * + * @param request the request object for describing index statistics + * @return a Single emitting the response with index statistics + */ + @POST("/describe_index_stats") + Single describeIndexStats(@Body DescribeIndexStatsRequest request); + /** * The Query operation searches a namespace, using a query vector. * It retrieves the ids of the most similar items in a namespace, along with their similarity scores. diff --git a/pinecone-client/src/test/java/com/hw/pinecone/PineconeClientTest.java b/pinecone-client/src/test/java/com/hw/pinecone/PineconeClientTest.java index 3448cd1f5..286110dda 100644 --- a/pinecone-client/src/test/java/com/hw/pinecone/PineconeClientTest.java +++ b/pinecone-client/src/test/java/com/hw/pinecone/PineconeClientTest.java @@ -44,19 +44,21 @@ class PineconeClientTest { private static IndexClient index; - private static final String indexName = "index-temp"; + private static final String INDEX_NAME = "index-temp"; + + private static final String NAMESPACE = "namespace-temp"; @BeforeAll static void setup() { client = PineconeClient.builder() - .requestTimeout(16) + .requestTimeout(60) .build() .init(); // Ensures that a temporary index is created. ensureTemporaryIndexCreated(); - index = client.indexClient(indexName); + index = client.indexClient(INDEX_NAME); } /** @@ -65,9 +67,9 @@ static void setup() { * It also waits until the index is ready before returning. */ private static void ensureTemporaryIndexCreated() { - if (!client.listIndexes().contains(indexName)) { + if (!client.listIndexes().contains(INDEX_NAME)) { var request = CreateIndexRequest.builder() - .name(indexName) + .name(INDEX_NAME) .dimension(3) .build(); client.createIndex(request); @@ -81,7 +83,7 @@ private static void awaitIndexReady() { .atMost(Duration.ofSeconds(120)) .pollInterval(Duration.ofSeconds(5)) .until(() -> { - IndexDescription indexDescription = client.describeIndex(indexName); + IndexDescription indexDescription = client.describeIndex(INDEX_NAME); return indexDescription != null && indexDescription.getStatus().isReady(); }); } @@ -89,19 +91,19 @@ private static void awaitIndexReady() { @AfterAll static void cleanup() { // Delete temporary index - client.deleteIndex(indexName); + client.deleteIndex(INDEX_NAME); client.close(); } @Test void testListIndexes() { List indexes = client.listIndexes(); - assertTrue(indexes.contains(indexName)); + assertTrue(indexes.contains(INDEX_NAME)); } @Test void testDescribeIndex() { - IndexDescription indexDescription = client.describeIndex(indexName); + IndexDescription indexDescription = client.describeIndex(INDEX_NAME); assertNotNull(indexDescription); assertNotNull(indexDescription.getDatabase()); assertNotNull(indexDescription.getStatus()); @@ -109,7 +111,7 @@ void testDescribeIndex() { // Assert database information Database database = indexDescription.getDatabase(); assertAll( - () -> assertEquals(indexName, database.getName()), + () -> assertEquals(INDEX_NAME, database.getName()), () -> assertEquals(COSINE, database.getMetric()), () -> assertEquals(3, database.getDimension()), () -> assertEquals(1, database.getReplicas()), @@ -124,7 +126,7 @@ void testDescribeIndex() { () -> assertTrue(status.getWaiting().isEmpty()), () -> assertTrue(status.getCrashed().isEmpty()), () -> { - String host = String.format("%s-%s.svc.%s.pinecone.io", indexName, "b43e233", + String host = String.format("%s-%s.svc.%s.pinecone.io", INDEX_NAME, "b43e233", System.getenv("PINECONE_ENV")); assertEquals(host, status.getHost()); }, @@ -137,14 +139,20 @@ void testDescribeIndex() { void testVectors() { Vector v1 = new Vector("v1", List.of(1F, 3F, 5F)); Vector v2 = new Vector("v2", List.of(5F, 3F, 1F)); - UpsertRequest upsertRequest = new UpsertRequest(List.of(v1, v2)); + UpsertRequest upsertRequest = new UpsertRequest(List.of(v1, v2), NAMESPACE); UpsertResponse upsertResponse = index.upsert(upsertRequest); assertNotNull(upsertResponse, "upsertResponse should not be null"); + DescribeIndexStatsRequest statsRequest = new DescribeIndexStatsRequest(); + DescribeIndexStatsResponse statsResponse = index.describeIndexStats(statsRequest); + assertNotNull(statsResponse, "statsResponse should not be null"); + assertTrue(statsResponse.getNamespaces().containsKey(NAMESPACE)); + QueryRequest queryRequest = QueryRequest.builder() .vector(List.of(1F, 2F, 2F)) .topK(1) + .namespace(NAMESPACE) .build(); QueryResponse queryResponse = index.query(queryRequest); @@ -152,6 +160,7 @@ void testVectors() { FetchRequest fetchRequest = FetchRequest.builder() .ids(List.of("v1", "v2")) + .namespace(NAMESPACE) .build(); FetchResponse fetchResponse = index.fetch(fetchRequest); assertNotNull(fetchResponse, "fetchResponse should not be null");