diff --git a/.gitignore b/.gitignore
index 98936c327..1374a694f 100644
--- a/.gitignore
+++ b/.gitignore
@@ -27,6 +27,7 @@ logPath_IS_UNDEFINED
target
# other ignore
+.java-version
*.log
*.tmp
Thumbs.db
diff --git a/.java-version b/.java-version
new file mode 100644
index 000000000..98d9bcb75
--- /dev/null
+++ b/.java-version
@@ -0,0 +1 @@
+17
diff --git a/langchain-core/pom.xml b/langchain-core/pom.xml
index 5b96590c7..5bcdcc233 100644
--- a/langchain-core/pom.xml
+++ b/langchain-core/pom.xml
@@ -144,6 +144,12 @@
io.netty
netty-resolver-dns
+
+
+ io.projectreactor
+ reactor-core
+
+
org.mockito
mockito-core
diff --git a/langchain-core/src/main/java/com/hw/langchain/agents/agent/AgentExecutor.java b/langchain-core/src/main/java/com/hw/langchain/agents/agent/AgentExecutor.java
index cc655b047..4f3427e48 100644
--- a/langchain-core/src/main/java/com/hw/langchain/agents/agent/AgentExecutor.java
+++ b/langchain-core/src/main/java/com/hw/langchain/agents/agent/AgentExecutor.java
@@ -138,7 +138,7 @@ public Object takeNextStep(Map nameToToolMap, Map innerCall(Map inputs) {
+ protected Map innerCall(Map inputs) {
// Construct a mapping of tool name to tool for easy lookup
Map nameToToolMap = tools.stream().collect(Collectors.toMap(BaseTool::getName, tool -> tool));
diff --git a/langchain-core/src/main/java/com/hw/langchain/base/language/BaseLanguageModel.java b/langchain-core/src/main/java/com/hw/langchain/base/language/BaseLanguageModel.java
index 36e6f98d5..79684370d 100644
--- a/langchain-core/src/main/java/com/hw/langchain/base/language/BaseLanguageModel.java
+++ b/langchain-core/src/main/java/com/hw/langchain/base/language/BaseLanguageModel.java
@@ -18,10 +18,13 @@
package com.hw.langchain.base.language;
+import com.hw.langchain.schema.AsyncLLMResult;
import com.hw.langchain.schema.BaseMessage;
import com.hw.langchain.schema.LLMResult;
import com.hw.langchain.schema.PromptValue;
+import reactor.core.publisher.Flux;
+
import java.util.List;
/**
@@ -39,7 +42,9 @@ public interface BaseLanguageModel {
/**
* Predict text from text.
*/
- String predict(String text);
+ default String predict(String text) {
+ return predict(text, null);
+ }
/**
* Predict text from text.
@@ -49,10 +54,48 @@ public interface BaseLanguageModel {
/**
* Predict message from messages.
*/
- BaseMessage predictMessages(List messages);
+ default BaseMessage predictMessages(List messages) {
+ return predictMessages(messages, null);
+ }
/**
* Predict message from messages.
*/
BaseMessage predictMessages(List messages, List stop);
+
+ /**
+ * Take in a list of prompt values and return an Flux<AsyncLLMResult> for every PromptValue.
+ */
+ default List> asyncGeneratePrompt(List prompts) {
+ return asyncGeneratePrompt(prompts, null);
+ }
+
+ /**
+ * Take in a list of prompt values and return an Flux<AsyncLLMResult> for every PromptValue.
+ */
+ default List> asyncGeneratePrompt(List prompts, List stop) {
+ throw new UnsupportedOperationException("not supported yet.");
+ };
+
+ /**
+ * Predict text from text async.
+ */
+ default Flux apredict(String text) {
+ return apredict(text, null);
+ }
+
+ /**
+ * Predict text from text async.
+ */
+ default Flux apredict(String text, List stop) {
+ throw new UnsupportedOperationException("not supported yet.");
+ }
+
+ /**
+ * Predict message from messages async.
+ */
+ default Flux apredictMessages(List messages, List stop) {
+ throw new UnsupportedOperationException("not supported yet.");
+ }
+
}
diff --git a/langchain-core/src/main/java/com/hw/langchain/chains/api/base/ApiChain.java b/langchain-core/src/main/java/com/hw/langchain/chains/api/base/ApiChain.java
index 397a77357..fb089f3c0 100644
--- a/langchain-core/src/main/java/com/hw/langchain/chains/api/base/ApiChain.java
+++ b/langchain-core/src/main/java/com/hw/langchain/chains/api/base/ApiChain.java
@@ -103,7 +103,7 @@ public List outputKeys() {
}
@Override
- public Map innerCall(Map inputs) {
+ protected Map innerCall(Map inputs) {
var question = inputs.get(QUESTION_KEY);
String apiUrl = apiRequestChain.predict(Map.of(QUESTION_KEY, question, API_DOCS, apiDocs));
apiUrl = apiUrl.strip();
diff --git a/langchain-core/src/main/java/com/hw/langchain/chains/base/Chain.java b/langchain-core/src/main/java/com/hw/langchain/chains/base/Chain.java
index aab019a82..91b80ccd3 100644
--- a/langchain-core/src/main/java/com/hw/langchain/chains/base/Chain.java
+++ b/langchain-core/src/main/java/com/hw/langchain/chains/base/Chain.java
@@ -21,6 +21,10 @@
import com.google.common.collect.Maps;
import com.hw.langchain.schema.BaseMemory;
+import org.apache.commons.lang3.StringUtils;
+
+import reactor.core.publisher.Flux;
+
import java.util.*;
/**
@@ -73,7 +77,17 @@ private void validateOutputs(Map outputs) {
* @param inputs the inputs to be processed by the chain
* @return a map containing the output generated by the chain
*/
- public abstract Map innerCall(Map inputs);
+ protected abstract Map innerCall(Map inputs);
+
+ /**
+ * Runs the logic of this chain and returns the async output.
+ *
+ * @param inputs the inputs to be processed by the chain
+ * @return a map flux containing the output generated event by the chain
+ */
+ protected Flux
+
+
+ io.projectreactor
+ reactor-core
+ ${reactor.version}
+
+
org.mockito
mockito-core
@@ -328,7 +337,7 @@
org.sonatype.plugins
nexus-staging-maven-plugin
${nexus-staging-maven-plugin.version}
- true
+ ${nexus-staging-maven-plugin.executions}
ossrh
https://s01.oss.sonatype.org/