forked from HamaWhiteGG/langchain-java
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request HamaWhiteGG#29 from HamaWhiteGG/dev
add Retrieval QA
- Loading branch information
Showing
23 changed files
with
951 additions
and
19 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
79 changes: 79 additions & 0 deletions
79
...c/main/java/com/hw/langchain/chains/combine/documents/base/BaseCombineDocumentsChain.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package com.hw.langchain.chains.combine.documents.base; | ||
|
||
import com.google.common.collect.Maps; | ||
import com.hw.langchain.chains.base.Chain; | ||
import com.hw.langchain.schema.Document; | ||
|
||
import org.apache.commons.lang3.tuple.Pair; | ||
|
||
import java.util.HashMap; | ||
import java.util.List; | ||
import java.util.Map; | ||
import java.util.Optional; | ||
|
||
/** | ||
* Base interface for chains combining documents. | ||
* | ||
* @author HamaWhite | ||
*/ | ||
public abstract class BaseCombineDocumentsChain extends Chain { | ||
|
||
protected String inputKey = "input_documents"; | ||
|
||
protected String outputKey = "output_text"; | ||
|
||
@Override | ||
public List<String> inputKeys() { | ||
return List.of(inputKey); | ||
} | ||
|
||
@Override | ||
public List<String> outputKeys() { | ||
return List.of(outputKey); | ||
} | ||
|
||
/** | ||
* Return the prompt length given the documents passed in. | ||
* Returns None if the method does not depend on the prompt length. | ||
*/ | ||
public Optional<Integer> promptLength(List<Document> docs, Map<String, Object> kwargs) { | ||
return Optional.empty(); | ||
} | ||
|
||
/** | ||
* Combine documents into a single string. | ||
*/ | ||
public abstract Pair<String, Map<String, String>> combineDocs(List<Document> docs, Map<String, Object> kwargs); | ||
|
||
@Override | ||
|
||
public Map<String, String> innerCall(Map<String, Object> inputs) { | ||
@SuppressWarnings("unchecked") | ||
var docs = (List<Document>) inputs.get(inputKey); | ||
|
||
Map<String, Object> otherKeys = Maps.filterKeys(inputs, key -> !key.equals(inputKey)); | ||
var result = this.combineDocs(docs, otherKeys); | ||
|
||
var extraReturnDict = new HashMap<>(result.getRight()); | ||
extraReturnDict.put(outputKey, result.getLeft()); | ||
return extraReturnDict; | ||
} | ||
} |
60 changes: 60 additions & 0 deletions
60
langchain-core/src/main/java/com/hw/langchain/chains/combine/documents/base/BaseUtils.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package com.hw.langchain.chains.combine.documents.base; | ||
|
||
import com.google.common.collect.Maps; | ||
import com.google.common.collect.Sets; | ||
import com.hw.langchain.prompts.base.BasePromptTemplate; | ||
import com.hw.langchain.schema.Document; | ||
|
||
import java.util.Map; | ||
import java.util.Set; | ||
|
||
/** | ||
* @author HamaWhite | ||
*/ | ||
public class BaseUtils { | ||
|
||
private BaseUtils() { | ||
} | ||
|
||
/** | ||
* Format a document into a string based on a prompt template. | ||
*/ | ||
public static String formatDocument(Document doc, BasePromptTemplate prompt) { | ||
Map<String, Object> baseInfo = Maps.newHashMap(); | ||
baseInfo.put("page_content", doc.getPageContent()); | ||
baseInfo.putAll(doc.getMetadata()); | ||
|
||
Set<String> missingMetadata = Sets.newHashSet(prompt.getInputVariables()); | ||
missingMetadata.removeAll(baseInfo.keySet()); | ||
|
||
if (!missingMetadata.isEmpty()) { | ||
var requiredMetadata = prompt.getInputVariables().stream() | ||
.filter(iv -> !"page_content".equals(iv)) | ||
.toList(); | ||
|
||
throw new IllegalArgumentException( | ||
"Document prompt requires documents to have metadata variables: " + requiredMetadata | ||
+ ". Received document with missing metadata: " + missingMetadata + "."); | ||
} | ||
Map<String, Object> documentInfo = Maps.filterKeys(baseInfo, prompt.getInputVariables()::contains); | ||
return prompt.format(documentInfo); | ||
} | ||
} |
123 changes: 123 additions & 0 deletions
123
...re/src/main/java/com/hw/langchain/chains/combine/documents/stuff/StuffDocumentsChain.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package com.hw.langchain.chains.combine.documents.stuff; | ||
|
||
import com.google.common.collect.Maps; | ||
import com.hw.langchain.chains.combine.documents.base.BaseCombineDocumentsChain; | ||
import com.hw.langchain.chains.llm.LLMChain; | ||
import com.hw.langchain.prompts.base.BasePromptTemplate; | ||
import com.hw.langchain.schema.Document; | ||
|
||
import org.apache.commons.lang3.tuple.Pair; | ||
|
||
import java.util.List; | ||
import java.util.Map; | ||
|
||
import static com.hw.langchain.chains.combine.documents.base.BaseUtils.formatDocument; | ||
import static com.hw.langchain.chains.combine.documents.stuff.StuffUtils.getDefaultDocumentPrompt; | ||
|
||
/** | ||
* Chain that combines documents by stuffing into context. | ||
* | ||
* @author HamaWhite | ||
*/ | ||
public class StuffDocumentsChain extends BaseCombineDocumentsChain { | ||
|
||
/** | ||
* LLM wrapper to use after formatting documents. | ||
*/ | ||
private final LLMChain llmChain; | ||
|
||
/** | ||
* Prompt to use to format each document. | ||
*/ | ||
private final BasePromptTemplate documentPrompt; | ||
|
||
/** | ||
* The variable name in the llmChain to put the documents in. | ||
* If only one variable in the llmChain, this need not be provided. | ||
*/ | ||
private String documentVariableName; | ||
|
||
/** | ||
* The string with which to join the formatted documents. | ||
*/ | ||
private final String documentSeparator; | ||
|
||
public StuffDocumentsChain(LLMChain llmChain, String documentVariableName) { | ||
this(llmChain, getDefaultDocumentPrompt(), documentVariableName, "\n\n"); | ||
} | ||
|
||
public StuffDocumentsChain(LLMChain llmChain, BasePromptTemplate documentPrompt, String documentVariableName, | ||
String documentSeparator) { | ||
this.llmChain = llmChain; | ||
this.documentPrompt = documentPrompt; | ||
this.documentVariableName = documentVariableName; | ||
this.documentSeparator = documentSeparator; | ||
|
||
// Get default document variable name, if not provided. | ||
getDefaultDocumentVariableName(); | ||
} | ||
|
||
/** | ||
* Get default document variable name, if not provided. | ||
*/ | ||
private void getDefaultDocumentVariableName() { | ||
List<String> llmChainVariables = llmChain.getPrompt().getInputVariables(); | ||
if (documentVariableName == null) { | ||
if (llmChainVariables.size() == 1) { | ||
documentVariableName = llmChainVariables.get(0); | ||
} else { | ||
throw new IllegalArgumentException( | ||
"documentVariableName must be provided if there are multiple llmChainVariables"); | ||
} | ||
} else { | ||
if (!llmChainVariables.contains(documentVariableName)) { | ||
throw new IllegalArgumentException("documentVariableName " + documentVariableName | ||
+ " was not found in llmChain inputVariables: " + llmChainVariables); | ||
} | ||
} | ||
} | ||
|
||
private Map<String, Object> getInputs(List<Document> docs, Map<String, Object> kwargs) { | ||
// Format each document according to the prompt | ||
List<String> docStrings = docs.stream() | ||
.map(doc -> formatDocument(doc, documentPrompt)) | ||
.toList(); | ||
// Join the documents together to put them in the prompt. | ||
Map<String, Object> inputs = Maps.filterKeys(kwargs, llmChain.getPrompt().getInputVariables()::contains); | ||
inputs.put(documentVariableName, String.join(documentSeparator, docStrings)); | ||
return inputs; | ||
} | ||
|
||
/** | ||
* Stuff all documents into one prompt and pass to LLM. | ||
*/ | ||
@Override | ||
public Pair<String, Map<String, String>> combineDocs(List<Document> docs, Map<String, Object> kwargs) { | ||
var inputs = getInputs(docs, kwargs); | ||
// Call predict on the LLM. | ||
return Pair.of(llmChain.predict(inputs), Map.of()); | ||
} | ||
|
||
@Override | ||
public String chainType() { | ||
return "stuff_documents_chain"; | ||
} | ||
} |
41 changes: 41 additions & 0 deletions
41
langchain-core/src/main/java/com/hw/langchain/chains/combine/documents/stuff/StuffUtils.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package com.hw.langchain.chains.combine.documents.stuff; | ||
|
||
import com.hw.langchain.prompts.prompt.PromptTemplate; | ||
|
||
import java.util.List; | ||
|
||
/** | ||
* @author HamaWhite | ||
*/ | ||
public class StuffUtils { | ||
|
||
private StuffUtils() { | ||
} | ||
|
||
/** | ||
* Get the default document prompt. | ||
* | ||
* @return The default document prompt. | ||
*/ | ||
public static PromptTemplate getDefaultDocumentPrompt() { | ||
return new PromptTemplate(List.of("page_content"), "{page_content}"); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
36 changes: 36 additions & 0 deletions
36
langchain-core/src/main/java/com/hw/langchain/chains/prompt/selector/BasePromptSelector.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package com.hw.langchain.chains.prompt.selector; | ||
|
||
import com.hw.langchain.base.language.BaseLanguageModel; | ||
import com.hw.langchain.prompts.base.BasePromptTemplate; | ||
|
||
/** | ||
* @author HamaWhite | ||
*/ | ||
public abstract class BasePromptSelector { | ||
|
||
/** | ||
* Get default prompt for a language model. | ||
* | ||
* @param llm The BaseLanguageModel object. | ||
* @return The BasePromptTemplate object representing the default prompt. | ||
*/ | ||
public abstract BasePromptTemplate getPrompt(BaseLanguageModel llm); | ||
} |
Oops, something went wrong.