Skip to content

Commit

Permalink
Merge pull request HamaWhiteGG#106 from HamaWhiteGG/dev
Browse files Browse the repository at this point in the history
support stream response
  • Loading branch information
HamaWhiteGG authored Sep 21, 2023
2 parents 85f2d23 + 0d6c47d commit 5e4d57f
Show file tree
Hide file tree
Showing 54 changed files with 1,506 additions and 210 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ The following example can view in the [langchain-example](langchain-examples/src

## 3. Integrations
### 3.1 LLMs
- [OpenAI](langchain-examples/src/main/java/com/hw/langchain/examples/llms/OpenAIExample.java)
- [OpenAI](langchain-examples/src/main/java/com/hw/langchain/examples/llms/OpenAIExample.java), (support [stream](langchain-examples/src/main/java/com/hw/langchain/examples/llms/StreamOpenAIExample.java))
- [Azure OpenAI](openai-client/src/test/java/com/hw/openai/AzureOpenAiClientTest.java)
- [ChatGLM2](langchain-examples/src/main/java/com/hw/langchain/examples/llms/ChatGLMExample.java)
- [Ollama](langchain-examples/src/main/java/com/hw/langchain/examples/llms/OllamaExample.java)
Expand Down
16 changes: 16 additions & 0 deletions langchain-core/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,22 @@
<groupId>io.netty</groupId>
<artifactId>netty-resolver-dns</artifactId>
</dependency>

<dependency>
<groupId>io.projectreactor</groupId>
<artifactId>reactor-core</artifactId>
</dependency>

<dependency>
<groupId>io.projectreactor.addons</groupId>
<artifactId>reactor-adapter</artifactId>
</dependency>

<dependency>
<groupId>io.reactivex.rxjava2</groupId>
<artifactId>rxjava</artifactId>
</dependency>

<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-core</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ public Object takeNextStep(Map<String, BaseTool> nameToToolMap, Map<String, Obje
* Run text through and get agent response.
*/
@Override
public Map<String, String> innerCall(Map<String, Object> inputs) {
protected Map<String, String> innerCall(Map<String, Object> inputs) {
// Construct a mapping of tool name to tool for easy lookup
Map<String, BaseTool> nameToToolMap = tools.stream().collect(Collectors.toMap(BaseTool::getName, tool -> tool));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,13 @@

package com.hw.langchain.base.language;

import com.hw.langchain.schema.AsyncLLMResult;
import com.hw.langchain.schema.BaseMessage;
import com.hw.langchain.schema.LLMResult;
import com.hw.langchain.schema.PromptValue;

import reactor.core.publisher.Flux;

import java.util.List;

/**
Expand All @@ -39,7 +42,9 @@ public interface BaseLanguageModel {
/**
* Predict text from text.
*/
String predict(String text);
default String predict(String text) {
return predict(text, null);
}

/**
* Predict text from text.
Expand All @@ -49,10 +54,48 @@ public interface BaseLanguageModel {
/**
* Predict message from messages.
*/
BaseMessage predictMessages(List<BaseMessage> messages);
default BaseMessage predictMessages(List<BaseMessage> messages) {
return predictMessages(messages, null);
}

/**
* Predict message from messages.
*/
BaseMessage predictMessages(List<BaseMessage> messages, List<String> stop);

/**
* Take in a list of prompt values and return an Flux<AsyncLLMResult> for every PromptValue.
*/
default List<Flux<AsyncLLMResult>> asyncGeneratePrompt(List<PromptValue> prompts) {
return asyncGeneratePrompt(prompts, null);
}

/**
* Take in a list of prompt values and return an Flux<AsyncLLMResult> for every PromptValue.
*/
default List<Flux<AsyncLLMResult>> asyncGeneratePrompt(List<PromptValue> prompts, List<String> stop) {
throw new UnsupportedOperationException("not supported yet.");
}

/**
* Predict text from text async.
*/
default Flux<String> asyncPredict(String text) {
return asyncPredict(text, null);
}

/**
* Predict text from text async.
*/
default Flux<String> asyncPredict(String text, List<String> stop) {
throw new UnsupportedOperationException("not supported yet.");
}

/**
* Predict message from messages async.
*/
default Flux<BaseMessage> asyncPredictMessages(List<BaseMessage> messages, List<String> stop) {
throw new UnsupportedOperationException("not supported yet.");
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ public List<String> outputKeys() {
}

@Override
public Map<String, String> innerCall(Map<String, Object> inputs) {
protected Map<String, String> innerCall(Map<String, Object> inputs) {
var question = inputs.get(QUESTION_KEY);
String apiUrl = apiRequestChain.predict(Map.of(QUESTION_KEY, question, API_DOCS, apiDocs));
apiUrl = apiUrl.strip();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@
import com.google.common.collect.Maps;
import com.hw.langchain.schema.BaseMemory;

import org.apache.commons.lang3.StringUtils;

import reactor.core.publisher.Flux;

import java.util.*;

/**
Expand Down Expand Up @@ -73,7 +77,17 @@ private void validateOutputs(Map<String, String> outputs) {
* @param inputs the inputs to be processed by the chain
* @return a map containing the output generated by the chain
*/
public abstract Map<String, String> innerCall(Map<String, Object> inputs);
protected abstract Map<String, String> innerCall(Map<String, Object> inputs);

/**
* Runs the logic of this chain and returns the async output.
*
* @param inputs the inputs to be processed by the chain
* @return a map flux containing the output generated event by the chain
*/
protected Flux<Map<String, String>> asyncInnerCall(Map<String, Object> inputs) {
throw new UnsupportedOperationException("Not supported yet.");
}

/**
* Run the logic of this chain and add to output if desired.
Expand Down Expand Up @@ -104,6 +118,17 @@ public Map<String, String> call(Map<String, Object> inputs, boolean returnOnlyOu
return prepOutputs(inputs, outputs, returnOnlyOutputs);
}

public Flux<Map<String, String>> asyncCall(Object input, boolean returnOnlyOutputs) {
Map<String, Object> inputs = prepInputs(input);
return asyncCall(inputs, returnOnlyOutputs);
}

public Flux<Map<String, String>> asyncCall(Map<String, Object> inputs, boolean returnOnlyOutputs) {
inputs = prepInputs(inputs);
Flux<Map<String, String>> outputs = asyncInnerCall(inputs);
return asyncPrepOutputs(inputs, outputs, returnOnlyOutputs);
}

/**
* Validate and prep outputs.
*/
Expand All @@ -123,6 +148,35 @@ private Map<String, String> prepOutputs(Map<String, Object> inputs, Map<String,
}
}

/**
* Validate and async prep outputs.
*/
private Flux<Map<String, String>> asyncPrepOutputs(Map<String, Object> inputs, Flux<Map<String, String>> outputs,
boolean returnOnlyOutputs) {
Map<String, String> collector = Maps.newHashMap();
return outputs.doOnNext(this::validateOutputs)
.doOnNext(m -> m.forEach((k, v) -> collector.compute(k, (s, old) -> {
if (StringUtils.equals(s, outputKeys().get(0))) {
return old + v;
} else {
return StringUtils.firstNonBlank(old, v);
}
}))).map(m -> {
if (returnOnlyOutputs) {
return m;
} else {
Map<String, String> result = Maps.newHashMap();
inputs.forEach((k, v) -> result.put(k, v.toString()));
result.putAll(m);
return result;
}
}).doOnComplete(() -> {
if (memory != null) {
memory.saveContext(inputs, collector);
}
});
}

/**
* Validate and prep inputs.
*/
Expand Down Expand Up @@ -183,4 +237,28 @@ public String run(Map<String, Object> args) {
return call(args, false).get(outputKeys().get(0));
}

/**
* Run the chain as text in, text out async
*/
public Flux<String> asyncRun(Object args) {
if (outputKeys().size() != 1) {
throw new IllegalArgumentException(
"The `run` method is not supported when there is not exactly one output key. Got " + outputKeys()
+ ".");
}
return asyncCall(args, false).map(m -> m.get(outputKeys().get(0)));
}

/**
* Run the chain as multiple variables, text out async.
*/
public Flux<String> asyncRun(Map<String, Object> args) {
if (outputKeys().size() != 1) {
throw new IllegalArgumentException(
"The `run` method is not supported when there is not exactly one output key. Got " + outputKeys()
+ ".");
}
return asyncCall(args, false).map(m -> m.get(outputKeys().get(0)));
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@

import org.apache.commons.lang3.tuple.Pair;

import reactor.core.publisher.Flux;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -63,9 +65,14 @@ public Optional<Integer> promptLength(List<Document> docs, Map<String, Object> k
*/
public abstract Pair<String, Map<String, String>> combineDocs(List<Document> docs, Map<String, Object> kwargs);

@Override
/**
* Combine documents into a single string async.
*/
public abstract Flux<Pair<String, Map<String, String>>> asyncCombineDocs(List<Document> docs,
Map<String, Object> kwargs);

public Map<String, String> innerCall(Map<String, Object> inputs) {
@Override
protected Map<String, String> innerCall(Map<String, Object> inputs) {
@SuppressWarnings("unchecked")
var docs = (List<Document>) inputs.get(inputKey);

Expand All @@ -76,4 +83,19 @@ public Map<String, String> innerCall(Map<String, Object> inputs) {
extraReturnDict.put(outputKey, result.getLeft());
return extraReturnDict;
}

@Override
protected Flux<Map<String, String>> asyncInnerCall(Map<String, Object> inputs) {
@SuppressWarnings("unchecked")
var docs = (List<Document>) inputs.get(inputKey);

Map<String, Object> otherKeys = Maps.filterKeys(inputs, key -> !key.equals(inputKey));
var result = this.asyncCombineDocs(docs, otherKeys);

return result.map(pair -> {
var extraReturnDict = new HashMap<>(pair.getRight());
extraReturnDict.put(outputKey, pair.getLeft());
return extraReturnDict;
});
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@

import org.apache.commons.lang3.tuple.Pair;

import reactor.core.publisher.Flux;

import java.util.List;
import java.util.Map;

Expand Down Expand Up @@ -116,6 +118,12 @@ public Pair<String, Map<String, String>> combineDocs(List<Document> docs, Map<St
return Pair.of(llmChain.predict(inputs), Map.of());
}

@Override
public Flux<Pair<String, Map<String, String>>> asyncCombineDocs(List<Document> docs, Map<String, Object> kwargs) {
var inputs = getInputs(docs, kwargs);
return llmChain.asyncPredict(inputs).map(s -> Pair.of(s, Map.of()));
}

@Override
public String chainType() {
return "stuff_documents_chain";
Expand Down
Loading

0 comments on commit 5e4d57f

Please sign in to comment.