From 1b41a053705bdb25ec7510988125eed5d9147759 Mon Sep 17 00:00:00 2001 From: HamaWhite Date: Wed, 5 Jul 2023 00:54:52 +0800 Subject: [PATCH 1/2] add Retrieval QA --- .../langchain/agents/agent/AgentExecutor.java | 2 +- .../com/hw/langchain/chains/base/Chain.java | 4 +- .../base/BaseCombineDocumentsChain.java | 79 ++++++++++++ .../combine/documents/base/BaseUtils.java | 60 +++++++++ .../documents/stuff/StuffDocumentsChain.java | 122 ++++++++++++++++++ .../combine/documents/stuff/StuffUtils.java | 41 ++++++ .../com/hw/langchain/chains/llm/LLMChain.java | 7 +- .../chains/llm/math/base/LLMMathChain.java | 2 +- .../prompt/selector/BasePromptSelector.java | 36 ++++++ .../selector/ConditionalPromptSelector.java | 55 ++++++++ .../prompt/selector/PromptSelectorUtils.java | 52 ++++++++ .../chains/question/answering/ChainType.java | 55 ++++++++ .../question/answering/StuffPrompt.java | 68 ++++++++++ .../chains/question/answering/init/Init.java | 62 +++++++++ .../retrieval/qa/base/BaseRetrievalQA.java | 90 +++++++++++++ .../chains/retrieval/qa/base/RetrievalQA.java | 62 +++++++++ .../chains/retrieval/qa/promt/Prompt.java | 44 +++++++ .../sql/database/base/SQLDatabaseChain.java | 2 +- .../base/SQLDatabaseSequentialChain.java | 2 +- .../vectorstores/base/VectorStore.java | 6 + .../vectorstores/pinecone/Pinecone.java | 5 +- .../retrieval/qa/base/RetrievalQATest.java | 90 +++++++++++++ .../vectorstores/pinecone/PineconeTest.java | 23 ++-- 23 files changed, 950 insertions(+), 19 deletions(-) create mode 100644 langchain-core/src/main/java/com/hw/langchain/chains/combine/documents/base/BaseCombineDocumentsChain.java create mode 100644 langchain-core/src/main/java/com/hw/langchain/chains/combine/documents/base/BaseUtils.java create mode 100644 langchain-core/src/main/java/com/hw/langchain/chains/combine/documents/stuff/StuffDocumentsChain.java create mode 100644 langchain-core/src/main/java/com/hw/langchain/chains/combine/documents/stuff/StuffUtils.java create mode 100644 langchain-core/src/main/java/com/hw/langchain/chains/prompt/selector/BasePromptSelector.java create mode 100644 langchain-core/src/main/java/com/hw/langchain/chains/prompt/selector/ConditionalPromptSelector.java create mode 100644 langchain-core/src/main/java/com/hw/langchain/chains/prompt/selector/PromptSelectorUtils.java create mode 100644 langchain-core/src/main/java/com/hw/langchain/chains/question/answering/ChainType.java create mode 100644 langchain-core/src/main/java/com/hw/langchain/chains/question/answering/StuffPrompt.java create mode 100644 langchain-core/src/main/java/com/hw/langchain/chains/question/answering/init/Init.java create mode 100644 langchain-core/src/main/java/com/hw/langchain/chains/retrieval/qa/base/BaseRetrievalQA.java create mode 100644 langchain-core/src/main/java/com/hw/langchain/chains/retrieval/qa/base/RetrievalQA.java create mode 100644 langchain-core/src/main/java/com/hw/langchain/chains/retrieval/qa/promt/Prompt.java create mode 100644 langchain-core/src/test/java/com/hw/langchain/chains/retrieval/qa/base/RetrievalQATest.java diff --git a/langchain-core/src/main/java/com/hw/langchain/agents/agent/AgentExecutor.java b/langchain-core/src/main/java/com/hw/langchain/agents/agent/AgentExecutor.java index 455b1cdcf..d6b1be8f8 100644 --- a/langchain-core/src/main/java/com/hw/langchain/agents/agent/AgentExecutor.java +++ b/langchain-core/src/main/java/com/hw/langchain/agents/agent/AgentExecutor.java @@ -138,7 +138,7 @@ public Object takeNextStep(Map nameToToolMap, Map _call(Map inputs) { + public Map innerCall(Map inputs) { // Construct a mapping of tool name to tool for easy lookup Map nameToToolMap = tools.stream().collect(Collectors.toMap(BaseTool::getName, tool -> tool)); diff --git a/langchain-core/src/main/java/com/hw/langchain/chains/base/Chain.java b/langchain-core/src/main/java/com/hw/langchain/chains/base/Chain.java index 55e194f63..714cb9e16 100644 --- a/langchain-core/src/main/java/com/hw/langchain/chains/base/Chain.java +++ b/langchain-core/src/main/java/com/hw/langchain/chains/base/Chain.java @@ -65,7 +65,7 @@ private void validateOutputs(Map outputs) { /** * Run the logic of this chain and return the output. */ - public abstract Map _call(Map inputs); + public abstract Map innerCall(Map inputs); /** * Run the logic of this chain and add to output if desired. @@ -92,7 +92,7 @@ public Map call(String input, boolean returnOnlyOutputs) { */ public Map call(Map inputs, boolean returnOnlyOutputs) { inputs = prepInputs(inputs); - Map outputs = _call(inputs); + Map outputs = innerCall(inputs); return prepOutputs(inputs, outputs, returnOnlyOutputs); } diff --git a/langchain-core/src/main/java/com/hw/langchain/chains/combine/documents/base/BaseCombineDocumentsChain.java b/langchain-core/src/main/java/com/hw/langchain/chains/combine/documents/base/BaseCombineDocumentsChain.java new file mode 100644 index 000000000..00f899271 --- /dev/null +++ b/langchain-core/src/main/java/com/hw/langchain/chains/combine/documents/base/BaseCombineDocumentsChain.java @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.hw.langchain.chains.combine.documents.base; + +import com.google.common.collect.Maps; +import com.hw.langchain.chains.base.Chain; +import com.hw.langchain.schema.Document; + +import org.apache.commons.lang3.tuple.Pair; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +/** + * Base interface for chains combining documents. + * + * @author HamaWhite + */ +public abstract class BaseCombineDocumentsChain extends Chain { + + protected String inputKey = "input_documents"; + + protected String outputKey = "output_text"; + + @Override + public List inputKeys() { + return List.of(inputKey); + } + + @Override + public List outputKeys() { + return List.of(outputKey); + } + + /** + * Return the prompt length given the documents passed in. + * Returns None if the method does not depend on the prompt length. + */ + public Optional promptLength(List docs, Map kwargs) { + return Optional.empty(); + } + + /** + * Combine documents into a single string. + */ + public abstract Pair> combineDocs(List docs, Map kwargs); + + @Override + + public Map innerCall(Map inputs) { + @SuppressWarnings("unchecked") + var docs = (List) inputs.get(inputKey); + + Map otherKeys = Maps.filterKeys(inputs, key -> !key.equals(inputKey)); + var result = this.combineDocs(docs, otherKeys); + + var extraReturnDict = new HashMap<>(result.getRight()); + extraReturnDict.put(outputKey, result.getLeft()); + return extraReturnDict; + } +} diff --git a/langchain-core/src/main/java/com/hw/langchain/chains/combine/documents/base/BaseUtils.java b/langchain-core/src/main/java/com/hw/langchain/chains/combine/documents/base/BaseUtils.java new file mode 100644 index 000000000..b5eafdeec --- /dev/null +++ b/langchain-core/src/main/java/com/hw/langchain/chains/combine/documents/base/BaseUtils.java @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.hw.langchain.chains.combine.documents.base; + +import com.google.common.collect.Maps; +import com.google.common.collect.Sets; +import com.hw.langchain.prompts.base.BasePromptTemplate; +import com.hw.langchain.schema.Document; + +import java.util.Map; +import java.util.Set; + +/** + * @author HamaWhite + */ +public class BaseUtils { + + private BaseUtils() { + } + + /** + * Format a document into a string based on a prompt template. + */ + public static String formatDocument(Document doc, BasePromptTemplate prompt) { + Map baseInfo = Maps.newHashMap(); + baseInfo.put("page_content", doc.getPageContent()); + baseInfo.putAll(doc.getMetadata()); + + Set missingMetadata = Sets.newHashSet(prompt.getInputVariables()); + missingMetadata.removeAll(baseInfo.keySet()); + + if (!missingMetadata.isEmpty()) { + var requiredMetadata = prompt.getInputVariables().stream() + .filter(iv -> !"page_content".equals(iv)) + .toList(); + + throw new IllegalArgumentException( + "Document prompt requires documents to have metadata variables: " + requiredMetadata + + ". Received document with missing metadata: " + missingMetadata + "."); + } + Map documentInfo = Maps.filterKeys(baseInfo, prompt.getInputVariables()::contains); + return prompt.format(documentInfo); + } +} diff --git a/langchain-core/src/main/java/com/hw/langchain/chains/combine/documents/stuff/StuffDocumentsChain.java b/langchain-core/src/main/java/com/hw/langchain/chains/combine/documents/stuff/StuffDocumentsChain.java new file mode 100644 index 000000000..663911c08 --- /dev/null +++ b/langchain-core/src/main/java/com/hw/langchain/chains/combine/documents/stuff/StuffDocumentsChain.java @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.hw.langchain.chains.combine.documents.stuff; + +import com.google.common.collect.Maps; +import com.hw.langchain.chains.combine.documents.base.BaseCombineDocumentsChain; +import com.hw.langchain.chains.llm.LLMChain; +import com.hw.langchain.prompts.base.BasePromptTemplate; +import com.hw.langchain.schema.Document; + +import org.apache.commons.lang3.tuple.Pair; + +import java.util.List; +import java.util.Map; + +import static com.hw.langchain.chains.combine.documents.base.BaseUtils.formatDocument; +import static com.hw.langchain.chains.combine.documents.stuff.StuffUtils.getDefaultDocumentPrompt; + +/** + * Chain that combines documents by stuffing into context. + * + * @author HamaWhite + */ +public class StuffDocumentsChain extends BaseCombineDocumentsChain { + + /** + * LLM wrapper to use after formatting documents. + */ + private final LLMChain llmChain; + + /** + * Prompt to use to format each document. + */ + private final BasePromptTemplate documentPrompt; + + /** + * The variable name in the llmChain to put the documents in. + * If only one variable in the llmChain, this need not be provided. + */ + private String documentVariableName; + + /** + * The string with which to join the formatted documents. + */ + private final String documentSeparator; + + public StuffDocumentsChain(LLMChain llmChain, String documentVariableName) { + this(llmChain, getDefaultDocumentPrompt(), documentVariableName, "\n\n"); + } + + public StuffDocumentsChain(LLMChain llmChain, BasePromptTemplate documentPrompt, String documentVariableName, String documentSeparator) { + this.llmChain = llmChain; + this.documentPrompt = documentPrompt; + this.documentVariableName = documentVariableName; + this.documentSeparator = documentSeparator; + + // Get default document variable name, if not provided. + getDefaultDocumentVariableName(); + } + + /** + * Get default document variable name, if not provided. + */ + private void getDefaultDocumentVariableName() { + List llmChainVariables = llmChain.getPrompt().getInputVariables(); + if (documentVariableName == null) { + if (llmChainVariables.size() == 1) { + documentVariableName = llmChainVariables.get(0); + } else { + throw new IllegalArgumentException( + "documentVariableName must be provided if there are multiple llmChainVariables"); + } + } else { + if (!llmChainVariables.contains(documentVariableName)) { + throw new IllegalArgumentException("documentVariableName " + documentVariableName + + " was not found in llmChain inputVariables: " + llmChainVariables); + } + } + } + + private Map getInputs(List docs, Map kwargs) { + // Format each document according to the prompt + List docStrings = docs.stream() + .map(doc -> formatDocument(doc, documentPrompt)) + .toList(); + // Join the documents together to put them in the prompt. + Map inputs = Maps.filterKeys(kwargs, llmChain.getPrompt().getInputVariables()::contains); + inputs.put(documentVariableName, String.join(documentSeparator, docStrings)); + return inputs; + } + + /** + * Stuff all documents into one prompt and pass to LLM. + */ + @Override + public Pair> combineDocs(List docs, Map kwargs) { + var inputs = getInputs(docs, kwargs); + // Call predict on the LLM. + return Pair.of(llmChain.predict(inputs), Map.of()); + } + + @Override + public String chainType() { + return "stuff_documents_chain"; + } +} diff --git a/langchain-core/src/main/java/com/hw/langchain/chains/combine/documents/stuff/StuffUtils.java b/langchain-core/src/main/java/com/hw/langchain/chains/combine/documents/stuff/StuffUtils.java new file mode 100644 index 000000000..22a825f95 --- /dev/null +++ b/langchain-core/src/main/java/com/hw/langchain/chains/combine/documents/stuff/StuffUtils.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.hw.langchain.chains.combine.documents.stuff; + +import com.hw.langchain.prompts.prompt.PromptTemplate; + +import java.util.List; + +/** + * @author HamaWhite + */ +public class StuffUtils { + + private StuffUtils() { + } + + /** + * Get the default document prompt. + * + * @return The default document prompt. + */ + public static PromptTemplate getDefaultDocumentPrompt() { + return new PromptTemplate(List.of("page_content"), "{page_content}"); + } +} diff --git a/langchain-core/src/main/java/com/hw/langchain/chains/llm/LLMChain.java b/langchain-core/src/main/java/com/hw/langchain/chains/llm/LLMChain.java index 30265dc31..8382da744 100644 --- a/langchain-core/src/main/java/com/hw/langchain/chains/llm/LLMChain.java +++ b/langchain-core/src/main/java/com/hw/langchain/chains/llm/LLMChain.java @@ -97,7 +97,7 @@ public List outputKeys() { } @Override - public Map _call(Map inputs) { + public Map innerCall(Map inputs) { LLMResult response = generate(List.of(inputs)); return createOutputs(response).get(0); } @@ -175,4 +175,9 @@ public T predictAndParse(Map kwargs) { } return (T) result; } + + public BasePromptTemplate getPrompt() { + return prompt; + } + } diff --git a/langchain-core/src/main/java/com/hw/langchain/chains/llm/math/base/LLMMathChain.java b/langchain-core/src/main/java/com/hw/langchain/chains/llm/math/base/LLMMathChain.java index 4569c533c..de6d26db5 100644 --- a/langchain-core/src/main/java/com/hw/langchain/chains/llm/math/base/LLMMathChain.java +++ b/langchain-core/src/main/java/com/hw/langchain/chains/llm/math/base/LLMMathChain.java @@ -127,7 +127,7 @@ public Map processLLMResult(String llmOutput) { } @Override - public Map _call(Map inputs) { + public Map innerCall(Map inputs) { var kwargs = Map.of("question", inputs.get(inputKey), "stop", List.of("```output")); String llmOutput = llmChain.predict(kwargs); return processLLMResult(llmOutput); diff --git a/langchain-core/src/main/java/com/hw/langchain/chains/prompt/selector/BasePromptSelector.java b/langchain-core/src/main/java/com/hw/langchain/chains/prompt/selector/BasePromptSelector.java new file mode 100644 index 000000000..4907b2c46 --- /dev/null +++ b/langchain-core/src/main/java/com/hw/langchain/chains/prompt/selector/BasePromptSelector.java @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.hw.langchain.chains.prompt.selector; + +import com.hw.langchain.base.language.BaseLanguageModel; +import com.hw.langchain.prompts.base.BasePromptTemplate; + +/** + * @author HamaWhite + */ +public abstract class BasePromptSelector { + + /** + * Get default prompt for a language model. + * + * @param llm The BaseLanguageModel object. + * @return The BasePromptTemplate object representing the default prompt. + */ + public abstract BasePromptTemplate getPrompt(BaseLanguageModel llm); +} diff --git a/langchain-core/src/main/java/com/hw/langchain/chains/prompt/selector/ConditionalPromptSelector.java b/langchain-core/src/main/java/com/hw/langchain/chains/prompt/selector/ConditionalPromptSelector.java new file mode 100644 index 000000000..210382966 --- /dev/null +++ b/langchain-core/src/main/java/com/hw/langchain/chains/prompt/selector/ConditionalPromptSelector.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.hw.langchain.chains.prompt.selector; + +import com.hw.langchain.base.language.BaseLanguageModel; +import com.hw.langchain.prompts.base.BasePromptTemplate; + +import org.apache.commons.lang3.tuple.Pair; + +import java.util.List; +import java.util.function.Predicate; + +/** + * Prompt collection that goes through conditionals. + * + * @author HamaWhite + */ +public class ConditionalPromptSelector extends BasePromptSelector { + + private final BasePromptTemplate defaultPrompt; + + private final List, BasePromptTemplate>> conditionals; + + public ConditionalPromptSelector(BasePromptTemplate defaultPrompt, + List, BasePromptTemplate>> conditionals) { + this.defaultPrompt = defaultPrompt; + this.conditionals = conditionals; + } + + @Override + public BasePromptTemplate getPrompt(BaseLanguageModel llm) { + for (var condition : conditionals) { + if (condition.getLeft().test(llm)) { + return condition.getRight(); + } + } + return defaultPrompt; + } +} diff --git a/langchain-core/src/main/java/com/hw/langchain/chains/prompt/selector/PromptSelectorUtils.java b/langchain-core/src/main/java/com/hw/langchain/chains/prompt/selector/PromptSelectorUtils.java new file mode 100644 index 000000000..8ff1866ef --- /dev/null +++ b/langchain-core/src/main/java/com/hw/langchain/chains/prompt/selector/PromptSelectorUtils.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.hw.langchain.chains.prompt.selector; + +import com.hw.langchain.base.language.BaseLanguageModel; +import com.hw.langchain.chat.models.base.BaseChatModel; +import com.hw.langchain.llms.base.BaseLLM; + +/** + * @author HamaWhite + */ +public class PromptSelectorUtils { + + private PromptSelectorUtils() { + } + + /** + * Check if the language model is a LLM. + * + * @param llm The language model to check. + * @return true if the language model is a BaseLLM model, false otherwise. + */ + public static boolean isLLM(BaseLanguageModel llm) { + return llm instanceof BaseLLM; + } + + /** + * Check if the language model is a chat model. + * + * @param llm The language model to check. + * @return true if the language model is a BaseChatModel model, false otherwise. + */ + public static boolean isChatModel(BaseLanguageModel llm) { + return llm instanceof BaseChatModel; + } +} diff --git a/langchain-core/src/main/java/com/hw/langchain/chains/question/answering/ChainType.java b/langchain-core/src/main/java/com/hw/langchain/chains/question/answering/ChainType.java new file mode 100644 index 000000000..f3d2baa99 --- /dev/null +++ b/langchain-core/src/main/java/com/hw/langchain/chains/question/answering/ChainType.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.hw.langchain.chains.question.answering; + +/** + * @author HamaWhite + */ +public enum ChainType { + + /** + * Chain type for "stuff". + */ + STUFF("stuff"), + + /** + * Chain type for "map_reduce". + */ + MAP_REDUCE("map_reduce"), + + /** + * Chain type for "refine". + */ + REFINE("refine"), + + /** + * Chain type for "map_rerank". + */ + MAP_RERANK("map_rerank"); + + private final String value; + + ChainType(String value) { + this.value = value; + } + + public String getValue() { + return value; + } +} diff --git a/langchain-core/src/main/java/com/hw/langchain/chains/question/answering/StuffPrompt.java b/langchain-core/src/main/java/com/hw/langchain/chains/question/answering/StuffPrompt.java new file mode 100644 index 000000000..b5ac98770 --- /dev/null +++ b/langchain-core/src/main/java/com/hw/langchain/chains/question/answering/StuffPrompt.java @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.hw.langchain.chains.question.answering; + +import com.hw.langchain.chains.prompt.selector.BasePromptSelector; +import com.hw.langchain.chains.prompt.selector.ConditionalPromptSelector; +import com.hw.langchain.chains.prompt.selector.PromptSelectorUtils; +import com.hw.langchain.prompts.base.BasePromptTemplate; +import com.hw.langchain.prompts.chat.ChatPromptTemplate; +import com.hw.langchain.prompts.chat.HumanMessagePromptTemplate; +import com.hw.langchain.prompts.chat.SystemMessagePromptTemplate; +import com.hw.langchain.prompts.prompt.PromptTemplate; + +import org.apache.commons.lang3.tuple.Pair; + +import java.util.List; + +/** + * @author HamaWhite + */ +public class StuffPrompt { + + private StuffPrompt(){ + } + + private static final String PROMPT_TEMPLATE = + """ + Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer. + + {context} + + Question: {question} + Helpful Answer:"""; + + public static final PromptTemplate PROMPT = new PromptTemplate(List.of("context", "question"), PROMPT_TEMPLATE); + + private static final String SYSTEM_TEMPLATE = """ + Use the following pieces of context to answer the users question. + If you don't know the answer, just say that you don't know, don't try to make up an answer. + ---------------- + {context}"""; + + private static final List MESSAGES = List.of( + SystemMessagePromptTemplate.fromTemplate(SYSTEM_TEMPLATE), + HumanMessagePromptTemplate.fromTemplate("{question}")); + + private static final BasePromptTemplate CHAT_PROMPT = ChatPromptTemplate.fromMessages(MESSAGES); + + public static final BasePromptSelector PROMPT_SELECTOR = + new ConditionalPromptSelector(PROMPT, List.of(Pair.of(PromptSelectorUtils::isChatModel, CHAT_PROMPT))); + +} diff --git a/langchain-core/src/main/java/com/hw/langchain/chains/question/answering/init/Init.java b/langchain-core/src/main/java/com/hw/langchain/chains/question/answering/init/Init.java new file mode 100644 index 000000000..f6c0c0900 --- /dev/null +++ b/langchain-core/src/main/java/com/hw/langchain/chains/question/answering/init/Init.java @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.hw.langchain.chains.question.answering.init; + +import com.hw.langchain.base.language.BaseLanguageModel; +import com.hw.langchain.chains.combine.documents.base.BaseCombineDocumentsChain; +import com.hw.langchain.chains.combine.documents.stuff.StuffDocumentsChain; +import com.hw.langchain.chains.llm.LLMChain; +import com.hw.langchain.chains.question.answering.ChainType; +import com.hw.langchain.prompts.base.BasePromptTemplate; + +import java.util.Map; +import java.util.function.Function; + +import static com.hw.langchain.chains.question.answering.ChainType.STUFF; +import static com.hw.langchain.chains.question.answering.StuffPrompt.PROMPT_SELECTOR; + +/** + * @author HamaWhite + */ +public class Init { + + private Init() { + } + + private static final Map> LOADER_MAPPING = Map.of( + STUFF, Init::loadStuffChain); + + public static StuffDocumentsChain loadStuffChain(BaseLanguageModel llm) { + return loadStuffChain(llm, PROMPT_SELECTOR.getPrompt(llm), "context"); + } + + public static StuffDocumentsChain loadStuffChain(BaseLanguageModel llm, BasePromptTemplate prompt, + String documentVariableName) { + LLMChain llmChain = new LLMChain(llm, prompt); + return new StuffDocumentsChain(llmChain, documentVariableName); + } + + public static BaseCombineDocumentsChain loadQaChain(BaseLanguageModel llm) { + return loadQaChain(llm, STUFF); + } + + public static BaseCombineDocumentsChain loadQaChain(BaseLanguageModel llm, ChainType chainType) { + return LOADER_MAPPING.get(chainType).apply(llm); + } +} diff --git a/langchain-core/src/main/java/com/hw/langchain/chains/retrieval/qa/base/BaseRetrievalQA.java b/langchain-core/src/main/java/com/hw/langchain/chains/retrieval/qa/base/BaseRetrievalQA.java new file mode 100644 index 000000000..1e8fe2037 --- /dev/null +++ b/langchain-core/src/main/java/com/hw/langchain/chains/retrieval/qa/base/BaseRetrievalQA.java @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.hw.langchain.chains.retrieval.qa.base; + +import com.google.common.collect.Lists; +import com.google.common.collect.Maps; +import com.hw.langchain.chains.base.Chain; +import com.hw.langchain.chains.combine.documents.base.BaseCombineDocumentsChain; +import com.hw.langchain.schema.Document; + +import java.util.List; +import java.util.Map; + +/** + * @author HamaWhite + */ +public abstract class BaseRetrievalQA extends Chain { + + /** + * Chain to use to combine the documents. + */ + private final BaseCombineDocumentsChain combineDocumentsChain; + + private final String inputKey = "query"; + + private final String outputKey = "result"; + + /** + * Return the source documents. + */ + private boolean returnSourceDocuments; + + protected BaseRetrievalQA(BaseCombineDocumentsChain combineDocumentsChain) { + this.combineDocumentsChain = combineDocumentsChain; + } + + @Override + public List inputKeys() { + return List.of(inputKey); + } + + @Override + public List outputKeys() { + List outputKeys = Lists.newArrayList(outputKey); + if (returnSourceDocuments) { + outputKeys.add("source_documents"); + } + return outputKeys; + } + + /** + * Get documents to do question answering over. + */ + public abstract List getDocs(String question); + + /** + * Run getRelevantText and llm on input query. + */ + @Override + public Map innerCall(Map inputs) { + var question = inputs.get(inputKey).toString(); + + List docs = getDocs(question); + String answer = combineDocumentsChain.run(Map.of("input_documents", docs, "question", question)); + + Map result = Maps.newHashMap(); + result.put(outputKey, answer); + if (this.returnSourceDocuments) { + result.put("source_documents", docs.toString()); + } + return result; + } + +} diff --git a/langchain-core/src/main/java/com/hw/langchain/chains/retrieval/qa/base/RetrievalQA.java b/langchain-core/src/main/java/com/hw/langchain/chains/retrieval/qa/base/RetrievalQA.java new file mode 100644 index 000000000..bdd665f79 --- /dev/null +++ b/langchain-core/src/main/java/com/hw/langchain/chains/retrieval/qa/base/RetrievalQA.java @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.hw.langchain.chains.retrieval.qa.base; + +import com.hw.langchain.base.language.BaseLanguageModel; +import com.hw.langchain.chains.combine.documents.base.BaseCombineDocumentsChain; +import com.hw.langchain.chains.question.answering.ChainType; +import com.hw.langchain.schema.BaseRetriever; +import com.hw.langchain.schema.Document; + +import java.util.List; + +import static com.hw.langchain.chains.question.answering.init.Init.loadQaChain; + +/** + * Chain for question-answering against an index. + * + * @author HamaWhite + */ +public class RetrievalQA extends BaseRetrievalQA { + + private final BaseRetriever retriever; + + public RetrievalQA(BaseCombineDocumentsChain combineDocumentsChain, BaseRetriever retriever) { + super(combineDocumentsChain); + this.retriever = retriever; + } + + /** + * Load chain from chain type. + */ + public static BaseRetrievalQA fromChainType(BaseLanguageModel llm, ChainType chainType, BaseRetriever retriever) { + BaseCombineDocumentsChain combineDocumentsChain = loadQaChain(llm, chainType); + return new RetrievalQA(combineDocumentsChain, retriever); + } + + @Override + public List getDocs(String question) { + return retriever.getRelevantDocuments(question); + } + + @Override + public String chainType() { + return "retrieval_qa"; + } +} diff --git a/langchain-core/src/main/java/com/hw/langchain/chains/retrieval/qa/promt/Prompt.java b/langchain-core/src/main/java/com/hw/langchain/chains/retrieval/qa/promt/Prompt.java new file mode 100644 index 000000000..68f7a7148 --- /dev/null +++ b/langchain-core/src/main/java/com/hw/langchain/chains/retrieval/qa/promt/Prompt.java @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.hw.langchain.chains.retrieval.qa.promt; + +import com.hw.langchain.prompts.prompt.PromptTemplate; + +import java.util.List; + +/** + * @author HamaWhite + */ +public class Prompt { + + private Prompt() { + } + + private static final String TEMPLATE = + """ + Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer. + + {context} + + Question: {question} + Helpful Answer:"""; + + public static final PromptTemplate PROMPT_TEMPLATE = new PromptTemplate(List.of("context", "question"), TEMPLATE); + +} diff --git a/langchain-core/src/main/java/com/hw/langchain/chains/sql/database/base/SQLDatabaseChain.java b/langchain-core/src/main/java/com/hw/langchain/chains/sql/database/base/SQLDatabaseChain.java index dd4eb74c8..2f26a32ca 100644 --- a/langchain-core/src/main/java/com/hw/langchain/chains/sql/database/base/SQLDatabaseChain.java +++ b/langchain-core/src/main/java/com/hw/langchain/chains/sql/database/base/SQLDatabaseChain.java @@ -118,7 +118,7 @@ public List outputKeys() { } @Override - public Map _call(Map inputs) { + public Map innerCall(Map inputs) { String inputText = inputs.get(this.inputKey) + "\nSQLQuery:"; // If not present, then defaults to null which is all tables. var tableNamesToUse = (List) inputs.get("table_names_to_use"); diff --git a/langchain-core/src/main/java/com/hw/langchain/chains/sql/database/base/SQLDatabaseSequentialChain.java b/langchain-core/src/main/java/com/hw/langchain/chains/sql/database/base/SQLDatabaseSequentialChain.java index af0baa795..025cec545 100644 --- a/langchain-core/src/main/java/com/hw/langchain/chains/sql/database/base/SQLDatabaseSequentialChain.java +++ b/langchain-core/src/main/java/com/hw/langchain/chains/sql/database/base/SQLDatabaseSequentialChain.java @@ -101,7 +101,7 @@ public List outputKeys() { } @Override - public Map _call(Map inputs) { + public Map innerCall(Map inputs) { List tableNameList = sqlChain.getDatabase().getUsableTableNames(); String tableNames = String.join(", ", tableNameList); var llmInputs = Map.of("query", inputs.get(inputKey), diff --git a/langchain-core/src/main/java/com/hw/langchain/vectorstores/base/VectorStore.java b/langchain-core/src/main/java/com/hw/langchain/vectorstores/base/VectorStore.java index 19d159897..17a9471e4 100644 --- a/langchain-core/src/main/java/com/hw/langchain/vectorstores/base/VectorStore.java +++ b/langchain-core/src/main/java/com/hw/langchain/vectorstores/base/VectorStore.java @@ -28,6 +28,8 @@ import java.util.List; import java.util.Map; +import static com.hw.langchain.vectorstores.base.SearchType.SIMILARITY; + /** * @author HamaWhite */ @@ -176,6 +178,10 @@ public int fromDocuments(List documents, Embeddings embedding) { */ public abstract int fromTexts(List texts, Embeddings embedding, List> metadatas); + public VectorStoreRetriever asRetriever() { + return asRetriever(SIMILARITY); + } + public VectorStoreRetriever asRetriever(SearchType searchType) { return new VectorStoreRetriever(this, searchType); } diff --git a/langchain-core/src/main/java/com/hw/langchain/vectorstores/pinecone/Pinecone.java b/langchain-core/src/main/java/com/hw/langchain/vectorstores/pinecone/Pinecone.java index 5094249ae..2f5eccd75 100644 --- a/langchain-core/src/main/java/com/hw/langchain/vectorstores/pinecone/Pinecone.java +++ b/langchain-core/src/main/java/com/hw/langchain/vectorstores/pinecone/Pinecone.java @@ -18,6 +18,7 @@ package com.hw.langchain.vectorstores.pinecone; +import com.google.common.collect.Maps; import com.hw.langchain.embeddings.base.Embeddings; import com.hw.langchain.schema.Document; import com.hw.langchain.vectorstores.base.VectorStore; @@ -117,7 +118,7 @@ private List> similaritySearchWithScore(String query, int for (var res : results.getMatches()) { var metadata = res.getMetadata(); if (metadata.containsKey(textKey)) { - var text = metadata.get(textKey).toString(); + var text = metadata.remove(textKey).toString(); Document document = new Document(text, metadata); docs.add(Pair.of(document, res.getScore())); } else { @@ -220,7 +221,7 @@ private List> createMetadata(List linesBatch, List()); + metadata.add(Maps.newHashMap()); } } for (int j = 0; j < linesBatch.size(); j++) { diff --git a/langchain-core/src/test/java/com/hw/langchain/chains/retrieval/qa/base/RetrievalQATest.java b/langchain-core/src/test/java/com/hw/langchain/chains/retrieval/qa/base/RetrievalQATest.java new file mode 100644 index 000000000..3a8f4ab1b --- /dev/null +++ b/langchain-core/src/test/java/com/hw/langchain/chains/retrieval/qa/base/RetrievalQATest.java @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.hw.langchain.chains.retrieval.qa.base; + +import com.hw.langchain.embeddings.openai.OpenAIEmbeddings; +import com.hw.langchain.llms.openai.OpenAI; +import com.hw.langchain.vectorstores.pinecone.Pinecone; +import com.hw.pinecone.PineconeClient; +import com.hw.langchain.vectorstores.pinecone.PineconeTest; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; + +import static com.hw.langchain.chains.question.answering.ChainType.STUFF; +import static com.hw.langchain.vectorstores.pinecone.PineconeTest.INDEX_NAME; +import static org.junit.jupiter.api.Assertions.*; + +/** + * Retrieval QA + * + * @author HamaWhite + */ +@Disabled("Test requires costly OpenAI and Pinecone calls, can be run manually.") +class RetrievalQATest { + + private OpenAIEmbeddings embeddings; + + private PineconeClient client; + + @BeforeEach + void setup() { + client = PineconeClient.builder() + .requestTimeout(30) + .build() + .init(); + + embeddings = OpenAIEmbeddings.builder() + .model("text-embedding-ada-002") + .requestTimeout(60) + .build() + .init(); + } + + private Pinecone createPinecone() { + return Pinecone.builder() + .client(client) + .indexName(INDEX_NAME) + .embeddingFunction(embeddings::embedQuery) + .build() + .init(); + } + + /** + * Please run the {@link PineconeTest#testFromDocuments()} to write the text data into the Pinecone index. + */ + @Test + void testRetrievalQAFromPinecone() { + var pinecone = createPinecone(); + + var llm = OpenAI.builder().temperature(0).requestTimeout(30).build().init(); + var qa = RetrievalQA.fromChainType(llm, STUFF, pinecone.asRetriever()); + + String query = "What did the president say about Ketanji Brown Jackson"; + var actual = qa.run(query); + + var expected = " The president said that Ketanji Brown Jackson is one of the nation's top legal minds, a " + + "former top litigator in private practice, a former federal public defender, and from a family of " + + "public school educators and police officers. He also said that she is a consensus builder and has " + + "received a broad range of support from the Fraternal Order of Police to former judges appointed by " + + "Democrats and Republicans."; + assertEquals(expected, actual); + } +} \ No newline at end of file diff --git a/langchain-core/src/test/java/com/hw/langchain/vectorstores/pinecone/PineconeTest.java b/langchain-core/src/test/java/com/hw/langchain/vectorstores/pinecone/PineconeTest.java index 92ceb28e5..d60ef7673 100644 --- a/langchain-core/src/test/java/com/hw/langchain/vectorstores/pinecone/PineconeTest.java +++ b/langchain-core/src/test/java/com/hw/langchain/vectorstores/pinecone/PineconeTest.java @@ -26,9 +26,7 @@ import com.hw.pinecone.entity.index.IndexDescription; import org.awaitility.Awaitility; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Disabled; -import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.*; import java.time.Duration; @@ -41,10 +39,11 @@ * * @author HamaWhite */ +@TestMethodOrder(MethodOrderer.OrderAnnotation.class) @Disabled("Test requires costly OpenAI and Pinecone calls, can be run manually.") -class PineconeTest { +public class PineconeTest { - private final String indexName = "langchain-demo"; + public static final String INDEX_NAME = "langchain-demo"; private final String query = "What did the president say about Ketanji Brown Jackson"; @@ -69,7 +68,7 @@ void setup() { private Pinecone createPinecone() { return Pinecone.builder() .client(client) - .indexName(indexName) + .indexName(INDEX_NAME) .embeddingFunction(embeddings::embedQuery) .build() .init(); @@ -81,9 +80,9 @@ private Pinecone createPinecone() { * It also waits until the index is ready before returning. */ private void ensureIndexCreated() { - if (!client.listIndexes().contains(indexName)) { + if (!client.listIndexes().contains(INDEX_NAME)) { var request = CreateIndexRequest.builder() - .name(indexName) + .name(INDEX_NAME) .dimension(1536) .build(); client.createIndex(request); @@ -97,13 +96,14 @@ private void awaitIndexReady() { .atMost(Duration.ofSeconds(120)) .pollInterval(Duration.ofSeconds(5)) .until(() -> { - IndexDescription indexDescription = client.describeIndex(indexName); + IndexDescription indexDescription = client.describeIndex(INDEX_NAME); return indexDescription != null && indexDescription.getStatus().isReady(); }); } @Test - void testFromDocuments() { + @Order(1) + public void testFromDocuments() { String filePath = "../docs/extras/modules/state_of_the_union.txt"; var loader = new TextLoader(filePath); var documents = loader.load(); @@ -119,6 +119,7 @@ void testFromDocuments() { } @Test + @Order(2) void testSimilaritySearch() { var pinecone = createPinecone(); var docs = pinecone.similaritySearch(query); @@ -138,6 +139,7 @@ void testSimilaritySearch() { } @Test + @Order(3) void testGetRelevantDocuments() { var pinecone = createPinecone(); var retriever = pinecone.asRetriever(MMR); @@ -147,6 +149,7 @@ void testGetRelevantDocuments() { } @Test + @Order(4) void testMaxMarginalRelevanceSearch() { var pinecone = createPinecone(); var foundDocs = pinecone.maxMarginalRelevanceSearch(query, 2, 10, 0.5f); From 7e0ee7c93e1ec5b3cabbb8fa719d1e0d8b67d7d4 Mon Sep 17 00:00:00 2001 From: HamaWhite Date: Wed, 5 Jul 2023 00:56:30 +0800 Subject: [PATCH 2/2] add Retrieval QA --- .../chains/combine/documents/stuff/StuffDocumentsChain.java | 3 ++- .../com/hw/langchain/chains/question/answering/ChainType.java | 2 +- .../hw/langchain/chains/question/answering/StuffPrompt.java | 2 +- .../hw/langchain/chains/retrieval/qa/base/RetrievalQATest.java | 2 +- 4 files changed, 5 insertions(+), 4 deletions(-) diff --git a/langchain-core/src/main/java/com/hw/langchain/chains/combine/documents/stuff/StuffDocumentsChain.java b/langchain-core/src/main/java/com/hw/langchain/chains/combine/documents/stuff/StuffDocumentsChain.java index 663911c08..17f54ff59 100644 --- a/langchain-core/src/main/java/com/hw/langchain/chains/combine/documents/stuff/StuffDocumentsChain.java +++ b/langchain-core/src/main/java/com/hw/langchain/chains/combine/documents/stuff/StuffDocumentsChain.java @@ -64,7 +64,8 @@ public StuffDocumentsChain(LLMChain llmChain, String documentVariableName) { this(llmChain, getDefaultDocumentPrompt(), documentVariableName, "\n\n"); } - public StuffDocumentsChain(LLMChain llmChain, BasePromptTemplate documentPrompt, String documentVariableName, String documentSeparator) { + public StuffDocumentsChain(LLMChain llmChain, BasePromptTemplate documentPrompt, String documentVariableName, + String documentSeparator) { this.llmChain = llmChain; this.documentPrompt = documentPrompt; this.documentVariableName = documentVariableName; diff --git a/langchain-core/src/main/java/com/hw/langchain/chains/question/answering/ChainType.java b/langchain-core/src/main/java/com/hw/langchain/chains/question/answering/ChainType.java index f3d2baa99..6bc2dd26d 100644 --- a/langchain-core/src/main/java/com/hw/langchain/chains/question/answering/ChainType.java +++ b/langchain-core/src/main/java/com/hw/langchain/chains/question/answering/ChainType.java @@ -37,7 +37,7 @@ public enum ChainType { * Chain type for "refine". */ REFINE("refine"), - + /** * Chain type for "map_rerank". */ diff --git a/langchain-core/src/main/java/com/hw/langchain/chains/question/answering/StuffPrompt.java b/langchain-core/src/main/java/com/hw/langchain/chains/question/answering/StuffPrompt.java index b5ac98770..2012fa519 100644 --- a/langchain-core/src/main/java/com/hw/langchain/chains/question/answering/StuffPrompt.java +++ b/langchain-core/src/main/java/com/hw/langchain/chains/question/answering/StuffPrompt.java @@ -36,7 +36,7 @@ */ public class StuffPrompt { - private StuffPrompt(){ + private StuffPrompt() { } private static final String PROMPT_TEMPLATE = diff --git a/langchain-core/src/test/java/com/hw/langchain/chains/retrieval/qa/base/RetrievalQATest.java b/langchain-core/src/test/java/com/hw/langchain/chains/retrieval/qa/base/RetrievalQATest.java index 3a8f4ab1b..9b7641d16 100644 --- a/langchain-core/src/test/java/com/hw/langchain/chains/retrieval/qa/base/RetrievalQATest.java +++ b/langchain-core/src/test/java/com/hw/langchain/chains/retrieval/qa/base/RetrievalQATest.java @@ -21,8 +21,8 @@ import com.hw.langchain.embeddings.openai.OpenAIEmbeddings; import com.hw.langchain.llms.openai.OpenAI; import com.hw.langchain.vectorstores.pinecone.Pinecone; -import com.hw.pinecone.PineconeClient; import com.hw.langchain.vectorstores.pinecone.PineconeTest; +import com.hw.pinecone.PineconeClient; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Disabled;