Skip to content

Commit

Permalink
add Retrieval QA
Browse files Browse the repository at this point in the history
  • Loading branch information
HamaWhiteGG committed Jul 4, 2023
1 parent 2376dd9 commit 1b41a05
Show file tree
Hide file tree
Showing 23 changed files with 950 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ public Object takeNextStep(Map<String, BaseTool> nameToToolMap, Map<String, Obje
* Run text through and get agent response.
*/
@Override
public Map<String, String> _call(Map<String, Object> inputs) {
public Map<String, String> innerCall(Map<String, Object> inputs) {
// Construct a mapping of tool name to tool for easy lookup
Map<String, BaseTool> nameToToolMap = tools.stream().collect(Collectors.toMap(BaseTool::getName, tool -> tool));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ private void validateOutputs(Map<String, String> outputs) {
/**
* Run the logic of this chain and return the output.
*/
public abstract Map<String, String> _call(Map<String, Object> inputs);
public abstract Map<String, String> innerCall(Map<String, Object> inputs);

/**
* Run the logic of this chain and add to output if desired.
Expand All @@ -92,7 +92,7 @@ public Map<String, String> call(String input, boolean returnOnlyOutputs) {
*/
public Map<String, String> call(Map<String, Object> inputs, boolean returnOnlyOutputs) {
inputs = prepInputs(inputs);
Map<String, String> outputs = _call(inputs);
Map<String, String> outputs = innerCall(inputs);
return prepOutputs(inputs, outputs, returnOnlyOutputs);
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.hw.langchain.chains.combine.documents.base;

import com.google.common.collect.Maps;
import com.hw.langchain.chains.base.Chain;
import com.hw.langchain.schema.Document;

import org.apache.commons.lang3.tuple.Pair;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;

/**
* Base interface for chains combining documents.
*
* @author HamaWhite
*/
public abstract class BaseCombineDocumentsChain extends Chain {

protected String inputKey = "input_documents";

protected String outputKey = "output_text";

@Override
public List<String> inputKeys() {
return List.of(inputKey);
}

@Override
public List<String> outputKeys() {
return List.of(outputKey);
}

/**
* Return the prompt length given the documents passed in.
* Returns None if the method does not depend on the prompt length.
*/
public Optional<Integer> promptLength(List<Document> docs, Map<String, Object> kwargs) {
return Optional.empty();
}

/**
* Combine documents into a single string.
*/
public abstract Pair<String, Map<String, String>> combineDocs(List<Document> docs, Map<String, Object> kwargs);

@Override

public Map<String, String> innerCall(Map<String, Object> inputs) {
@SuppressWarnings("unchecked")
var docs = (List<Document>) inputs.get(inputKey);

Map<String, Object> otherKeys = Maps.filterKeys(inputs, key -> !key.equals(inputKey));
var result = this.combineDocs(docs, otherKeys);

var extraReturnDict = new HashMap<>(result.getRight());
extraReturnDict.put(outputKey, result.getLeft());
return extraReturnDict;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.hw.langchain.chains.combine.documents.base;

import com.google.common.collect.Maps;
import com.google.common.collect.Sets;
import com.hw.langchain.prompts.base.BasePromptTemplate;
import com.hw.langchain.schema.Document;

import java.util.Map;
import java.util.Set;

/**
* @author HamaWhite
*/
public class BaseUtils {

private BaseUtils() {
}

/**
* Format a document into a string based on a prompt template.
*/
public static String formatDocument(Document doc, BasePromptTemplate prompt) {
Map<String, Object> baseInfo = Maps.newHashMap();
baseInfo.put("page_content", doc.getPageContent());
baseInfo.putAll(doc.getMetadata());

Set<String> missingMetadata = Sets.newHashSet(prompt.getInputVariables());
missingMetadata.removeAll(baseInfo.keySet());

if (!missingMetadata.isEmpty()) {
var requiredMetadata = prompt.getInputVariables().stream()
.filter(iv -> !"page_content".equals(iv))
.toList();

throw new IllegalArgumentException(
"Document prompt requires documents to have metadata variables: " + requiredMetadata
+ ". Received document with missing metadata: " + missingMetadata + ".");
}
Map<String, Object> documentInfo = Maps.filterKeys(baseInfo, prompt.getInputVariables()::contains);
return prompt.format(documentInfo);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.hw.langchain.chains.combine.documents.stuff;

import com.google.common.collect.Maps;
import com.hw.langchain.chains.combine.documents.base.BaseCombineDocumentsChain;
import com.hw.langchain.chains.llm.LLMChain;
import com.hw.langchain.prompts.base.BasePromptTemplate;
import com.hw.langchain.schema.Document;

import org.apache.commons.lang3.tuple.Pair;

import java.util.List;
import java.util.Map;

import static com.hw.langchain.chains.combine.documents.base.BaseUtils.formatDocument;
import static com.hw.langchain.chains.combine.documents.stuff.StuffUtils.getDefaultDocumentPrompt;

/**
* Chain that combines documents by stuffing into context.
*
* @author HamaWhite
*/
public class StuffDocumentsChain extends BaseCombineDocumentsChain {

/**
* LLM wrapper to use after formatting documents.
*/
private final LLMChain llmChain;

/**
* Prompt to use to format each document.
*/
private final BasePromptTemplate documentPrompt;

/**
* The variable name in the llmChain to put the documents in.
* If only one variable in the llmChain, this need not be provided.
*/
private String documentVariableName;

/**
* The string with which to join the formatted documents.
*/
private final String documentSeparator;

public StuffDocumentsChain(LLMChain llmChain, String documentVariableName) {
this(llmChain, getDefaultDocumentPrompt(), documentVariableName, "\n\n");
}

public StuffDocumentsChain(LLMChain llmChain, BasePromptTemplate documentPrompt, String documentVariableName, String documentSeparator) {
this.llmChain = llmChain;
this.documentPrompt = documentPrompt;
this.documentVariableName = documentVariableName;
this.documentSeparator = documentSeparator;

// Get default document variable name, if not provided.
getDefaultDocumentVariableName();
}

/**
* Get default document variable name, if not provided.
*/
private void getDefaultDocumentVariableName() {
List<String> llmChainVariables = llmChain.getPrompt().getInputVariables();
if (documentVariableName == null) {
if (llmChainVariables.size() == 1) {
documentVariableName = llmChainVariables.get(0);
} else {
throw new IllegalArgumentException(
"documentVariableName must be provided if there are multiple llmChainVariables");
}
} else {
if (!llmChainVariables.contains(documentVariableName)) {
throw new IllegalArgumentException("documentVariableName " + documentVariableName
+ " was not found in llmChain inputVariables: " + llmChainVariables);
}
}
}

private Map<String, Object> getInputs(List<Document> docs, Map<String, Object> kwargs) {
// Format each document according to the prompt
List<String> docStrings = docs.stream()
.map(doc -> formatDocument(doc, documentPrompt))
.toList();
// Join the documents together to put them in the prompt.
Map<String, Object> inputs = Maps.filterKeys(kwargs, llmChain.getPrompt().getInputVariables()::contains);
inputs.put(documentVariableName, String.join(documentSeparator, docStrings));
return inputs;
}

/**
* Stuff all documents into one prompt and pass to LLM.
*/
@Override
public Pair<String, Map<String, String>> combineDocs(List<Document> docs, Map<String, Object> kwargs) {
var inputs = getInputs(docs, kwargs);
// Call predict on the LLM.
return Pair.of(llmChain.predict(inputs), Map.of());
}

@Override
public String chainType() {
return "stuff_documents_chain";
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.hw.langchain.chains.combine.documents.stuff;

import com.hw.langchain.prompts.prompt.PromptTemplate;

import java.util.List;

/**
* @author HamaWhite
*/
public class StuffUtils {

private StuffUtils() {
}

/**
* Get the default document prompt.
*
* @return The default document prompt.
*/
public static PromptTemplate getDefaultDocumentPrompt() {
return new PromptTemplate(List.of("page_content"), "{page_content}");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ public List<String> outputKeys() {
}

@Override
public Map<String, String> _call(Map<String, Object> inputs) {
public Map<String, String> innerCall(Map<String, Object> inputs) {
LLMResult response = generate(List.of(inputs));
return createOutputs(response).get(0);
}
Expand Down Expand Up @@ -175,4 +175,9 @@ public <T> T predictAndParse(Map<String, Object> kwargs) {
}
return (T) result;
}

public BasePromptTemplate getPrompt() {
return prompt;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ public Map<String, String> processLLMResult(String llmOutput) {
}

@Override
public Map<String, String> _call(Map<String, Object> inputs) {
public Map<String, String> innerCall(Map<String, Object> inputs) {
var kwargs = Map.of("question", inputs.get(inputKey), "stop", List.of("```output"));
String llmOutput = llmChain.predict(kwargs);
return processLLMResult(llmOutput);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.hw.langchain.chains.prompt.selector;

import com.hw.langchain.base.language.BaseLanguageModel;
import com.hw.langchain.prompts.base.BasePromptTemplate;

/**
* @author HamaWhite
*/
public abstract class BasePromptSelector {

/**
* Get default prompt for a language model.
*
* @param llm The BaseLanguageModel object.
* @return The BasePromptTemplate object representing the default prompt.
*/
public abstract BasePromptTemplate getPrompt(BaseLanguageModel llm);
}
Loading

0 comments on commit 1b41a05

Please sign in to comment.