Skip to content

Commit

Permalink
Merge pull request HamaWhiteGG#101 from kael-aiur/dev
Browse files Browse the repository at this point in the history
good
  • Loading branch information
HamaWhiteGG authored Sep 11, 2023
2 parents 4153d69 + c0d9ff9 commit bbc4302
Show file tree
Hide file tree
Showing 24 changed files with 388 additions and 31 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ logPath_IS_UNDEFINED
target

# other ignore
.java-version
*.log
*.tmp
Thumbs.db
Expand Down
1 change: 1 addition & 0 deletions .java-version
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
17
6 changes: 6 additions & 0 deletions langchain-core/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,12 @@
<groupId>io.netty</groupId>
<artifactId>netty-resolver-dns</artifactId>
</dependency>

<dependency>
<groupId>io.projectreactor</groupId>
<artifactId>reactor-core</artifactId>
</dependency>

<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-core</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ public Object takeNextStep(Map<String, BaseTool> nameToToolMap, Map<String, Obje
* Run text through and get agent response.
*/
@Override
public Map<String, String> innerCall(Map<String, Object> inputs) {
protected Map<String, String> innerCall(Map<String, Object> inputs) {
// Construct a mapping of tool name to tool for easy lookup
Map<String, BaseTool> nameToToolMap = tools.stream().collect(Collectors.toMap(BaseTool::getName, tool -> tool));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,13 @@

package com.hw.langchain.base.language;

import com.hw.langchain.schema.AsyncLLMResult;
import com.hw.langchain.schema.BaseMessage;
import com.hw.langchain.schema.LLMResult;
import com.hw.langchain.schema.PromptValue;

import reactor.core.publisher.Flux;

import java.util.List;

/**
Expand All @@ -39,7 +42,9 @@ public interface BaseLanguageModel {
/**
* Predict text from text.
*/
String predict(String text);
default String predict(String text) {
return predict(text, null);
}

/**
* Predict text from text.
Expand All @@ -49,10 +54,48 @@ public interface BaseLanguageModel {
/**
* Predict message from messages.
*/
BaseMessage predictMessages(List<BaseMessage> messages);
default BaseMessage predictMessages(List<BaseMessage> messages) {
return predictMessages(messages, null);
}

/**
* Predict message from messages.
*/
BaseMessage predictMessages(List<BaseMessage> messages, List<String> stop);

/**
* Take in a list of prompt values and return an Flux&lt;AsyncLLMResult&gt; for every PromptValue.
*/
default List<Flux<AsyncLLMResult>> asyncGeneratePrompt(List<PromptValue> prompts) {
return asyncGeneratePrompt(prompts, null);
}

/**
* Take in a list of prompt values and return an Flux&lt;AsyncLLMResult&gt; for every PromptValue.
*/
default List<Flux<AsyncLLMResult>> asyncGeneratePrompt(List<PromptValue> prompts, List<String> stop) {
throw new UnsupportedOperationException("not supported yet.");
};

/**
* Predict text from text async.
*/
default Flux<String> apredict(String text) {
return apredict(text, null);
}

/**
* Predict text from text async.
*/
default Flux<String> apredict(String text, List<String> stop) {
throw new UnsupportedOperationException("not supported yet.");
}

/**
* Predict message from messages async.
*/
default Flux<BaseMessage> apredictMessages(List<BaseMessage> messages, List<String> stop) {
throw new UnsupportedOperationException("not supported yet.");
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ public List<String> outputKeys() {
}

@Override
public Map<String, String> innerCall(Map<String, Object> inputs) {
protected Map<String, String> innerCall(Map<String, Object> inputs) {
var question = inputs.get(QUESTION_KEY);
String apiUrl = apiRequestChain.predict(Map.of(QUESTION_KEY, question, API_DOCS, apiDocs));
apiUrl = apiUrl.strip();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@
import com.google.common.collect.Maps;
import com.hw.langchain.schema.BaseMemory;

import org.apache.commons.lang3.StringUtils;

import reactor.core.publisher.Flux;

import java.util.*;

/**
Expand Down Expand Up @@ -73,7 +77,17 @@ private void validateOutputs(Map<String, String> outputs) {
* @param inputs the inputs to be processed by the chain
* @return a map containing the output generated by the chain
*/
public abstract Map<String, String> innerCall(Map<String, Object> inputs);
protected abstract Map<String, String> innerCall(Map<String, Object> inputs);

/**
* Runs the logic of this chain and returns the async output.
*
* @param inputs the inputs to be processed by the chain
* @return a map flux containing the output generated event by the chain
*/
protected Flux<Map<String, String>> ainnerCall(Map<String, Object> inputs) {
throw new UnsupportedOperationException("Not supported yet.");
}

/**
* Run the logic of this chain and add to output if desired.
Expand Down Expand Up @@ -104,6 +118,17 @@ public Map<String, String> call(Map<String, Object> inputs, boolean returnOnlyOu
return prepOutputs(inputs, outputs, returnOnlyOutputs);
}

public Flux<Map<String, String>> acall(String input, boolean returnOnlyOutputs) {
Map<String, Object> inputs = prepInputs(input);
return acall(inputs, returnOnlyOutputs);
}

public Flux<Map<String, String>> acall(Map<String, Object> inputs, boolean returnOnlyOutputs) {
inputs = prepInputs(inputs);
Flux<Map<String, String>> outputs = ainnerCall(inputs);
return prepaOutputs(inputs, outputs, returnOnlyOutputs);
}

/**
* Validate and prep outputs.
*/
Expand All @@ -123,6 +148,35 @@ private Map<String, String> prepOutputs(Map<String, Object> inputs, Map<String,
}
}

/**
* Validate and async prep outputs.
*/
private Flux<Map<String, String>> prepaOutputs(Map<String, Object> inputs, Flux<Map<String, String>> outputs,
boolean returnOnlyOutputs) {
Map<String, String> collector = Maps.newHashMap();
return outputs.doOnNext(this::validateOutputs)
.doOnNext(m -> m.forEach((k, v) -> collector.compute(k, (s, old) -> {
if (StringUtils.equals(s, outputKeys().get(0))) {
return old + v;
} else {
return StringUtils.firstNonBlank(old, v);
}
}))).map(m -> {
if (returnOnlyOutputs) {
return m;
} else {
Map<String, String> result = Maps.newHashMap();
inputs.forEach((k, v) -> result.put(k, v.toString()));
result.putAll(m);
return result;
}
}).doOnComplete(() -> {
if (memory != null) {
memory.saveContext(inputs, collector);
}
});
}

/**
* Validate and prep inputs.
*/
Expand Down Expand Up @@ -183,4 +237,28 @@ public String run(Map<String, Object> args) {
return call(args, false).get(outputKeys().get(0));
}

/**
* Run the chain as text in, text out async
*/
public Flux<String> arun(String args) {
if (outputKeys().size() != 1) {
throw new IllegalArgumentException(
"The `run` method is not supported when there is not exactly one output key. Got " + outputKeys()
+ ".");
}
return acall(args, false).map(m -> m.get(outputKeys().get(0)));
}

/**
* Run the chain as multiple variables, text out async.
*/
public Flux<String> arun(Map<String, Object> args) {
if (outputKeys().size() != 1) {
throw new IllegalArgumentException(
"The `run` method is not supported when there is not exactly one output key. Got " + outputKeys()
+ ".");
}
return acall(args, false).map(m -> m.get(outputKeys().get(0)));
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@

import org.apache.commons.lang3.tuple.Pair;

import reactor.core.publisher.Flux;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -62,10 +64,14 @@ public Optional<Integer> promptLength(List<Document> docs, Map<String, Object> k
* Combine documents into a single string.
*/
public abstract Pair<String, Map<String, String>> combineDocs(List<Document> docs, Map<String, Object> kwargs);
/**
* Combine documents into a single string async.
*/
public abstract Flux<Pair<String, Map<String, String>>> acombineDocs(List<Document> docs,
Map<String, Object> kwargs);

@Override

public Map<String, String> innerCall(Map<String, Object> inputs) {
protected Map<String, String> innerCall(Map<String, Object> inputs) {
@SuppressWarnings("unchecked")
var docs = (List<Document>) inputs.get(inputKey);

Expand All @@ -76,4 +82,19 @@ public Map<String, String> innerCall(Map<String, Object> inputs) {
extraReturnDict.put(outputKey, result.getLeft());
return extraReturnDict;
}

@Override
protected Flux<Map<String, String>> ainnerCall(Map<String, Object> inputs) {
@SuppressWarnings("unchecked")
var docs = (List<Document>) inputs.get(inputKey);

Map<String, Object> otherKeys = Maps.filterKeys(inputs, key -> !key.equals(inputKey));
var result = this.acombineDocs(docs, otherKeys);

return result.map(pair -> {
var extraReturnDict = new HashMap<>(pair.getRight());
extraReturnDict.put(outputKey, pair.getLeft());
return extraReturnDict;
});
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@

import org.apache.commons.lang3.tuple.Pair;

import reactor.core.publisher.Flux;

import java.util.List;
import java.util.Map;

Expand Down Expand Up @@ -116,6 +118,12 @@ public Pair<String, Map<String, String>> combineDocs(List<Document> docs, Map<St
return Pair.of(llmChain.predict(inputs), Map.of());
}

@Override
public Flux<Pair<String, Map<String, String>>> acombineDocs(List<Document> docs, Map<String, Object> kwargs) {
var inputs = getInputs(docs, kwargs);
return llmChain.apredict(inputs).map(s -> Pair.of(s, Map.of()));
}

@Override
public String chainType() {
return "stuff_documents_chain";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,13 @@
import com.hw.langchain.base.language.BaseLanguageModel;
import com.hw.langchain.chains.base.Chain;
import com.hw.langchain.prompts.base.BasePromptTemplate;
import com.hw.langchain.schema.BaseLLMOutputParser;
import com.hw.langchain.schema.LLMResult;
import com.hw.langchain.schema.NoOpOutputParser;
import com.hw.langchain.schema.PromptValue;
import com.hw.langchain.schema.*;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import reactor.core.publisher.Flux;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
Expand Down Expand Up @@ -97,11 +96,17 @@ public List<String> outputKeys() {
}

@Override
public Map<String, String> innerCall(Map<String, Object> inputs) {
protected Map<String, String> innerCall(Map<String, Object> inputs) {
LLMResult response = generate(List.of(inputs));
return createOutputs(response).get(0);
}

@Override
protected Flux<Map<String, String>> ainnerCall(Map<String, Object> inputs) {
var response = agenerate(List.of(inputs));
return response.get(0).map(this::createAsyncOutputs);
}

/**
* Generate LLM result from inputs.
*/
Expand All @@ -111,6 +116,15 @@ private LLMResult generate(List<Map<String, Object>> inputList) {
return llm.generatePrompt(prompts, stop);
}

/**
* Generate LLM result from inputs async.
*/
private List<Flux<AsyncLLMResult>> agenerate(List<Map<String, Object>> inputList) {
List<String> stop = prepStop(inputList);
List<PromptValue> prompts = prepPrompts(inputList);
return llm.asyncGeneratePrompt(prompts, stop);
}

/**
* Prepare prompts from inputs.
*/
Expand Down Expand Up @@ -154,6 +168,18 @@ private List<Map<String, String>> createOutputs(LLMResult llmResult) {
return result;
}

/**
* Create outputs from response async.
*/
private Map<String, String> createAsyncOutputs(AsyncLLMResult llmResult) {
Map<String, String> result = Map.of(outputKey, outputParser.parseResult(llmResult.getGenerations()),
"full_generation", llmResult.getGenerations().toString());
if (returnFinalOnly) {
result = Map.of(outputKey, result.get(outputKey));
}
return result;
}

/**
* Format prompt with kwargs and pass to LLM.
*
Expand All @@ -165,6 +191,17 @@ public String predict(Map<String, Object> kwargs) {
return resultMap.get(outputKey);
}

/**
* Format prompt with kwargs and pass to LLM async.
*
* @param kwargs Keys to pass to prompt template.
* @return Completion from LLM.
*/
public Flux<String> apredict(Map<String, Object> kwargs) {
var flux = acall(kwargs, false);
return flux.map(m -> m.get(outputKey));
}

/**
* Call predict and then parse the results.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ public Map<String, String> processLLMResult(String llmOutput) {
}

@Override
public Map<String, String> innerCall(Map<String, Object> inputs) {
protected Map<String, String> innerCall(Map<String, Object> inputs) {
var kwargs = Map.of("question", inputs.get(inputKey), "stop", List.of("```output"));
String llmOutput = llmChain.predict(kwargs);
return processLLMResult(llmOutput);
Expand Down
Loading

0 comments on commit bbc4302

Please sign in to comment.