From c54a9a797b14784986d2876efce5f896f3d5201d Mon Sep 17 00:00:00 2001 From: "lingjue@ubuntu" Date: Tue, 1 Aug 2023 17:39:37 +0800 Subject: [PATCH 1/8] offer default implement for base language model --- .../hw/langchain/base/language/BaseLanguageModel.java | 10 ++++++++-- .../hw/langchain/chat/models/base/BaseChatModel.java | 10 ---------- .../main/java/com/hw/langchain/llms/base/BaseLLM.java | 10 ---------- 3 files changed, 8 insertions(+), 22 deletions(-) diff --git a/langchain-core/src/main/java/com/hw/langchain/base/language/BaseLanguageModel.java b/langchain-core/src/main/java/com/hw/langchain/base/language/BaseLanguageModel.java index 36e6f98d5..ed35fd0e2 100644 --- a/langchain-core/src/main/java/com/hw/langchain/base/language/BaseLanguageModel.java +++ b/langchain-core/src/main/java/com/hw/langchain/base/language/BaseLanguageModel.java @@ -39,7 +39,9 @@ public interface BaseLanguageModel { /** * Predict text from text. */ - String predict(String text); + default String predict(String text) { + return predict(text, null); + } /** * Predict text from text. @@ -49,10 +51,14 @@ public interface BaseLanguageModel { /** * Predict message from messages. */ - BaseMessage predictMessages(List messages); + default BaseMessage predictMessages(List messages) { + return predictMessages(messages, null); + } /** * Predict message from messages. */ BaseMessage predictMessages(List messages, List stop); + + } diff --git a/langchain-core/src/main/java/com/hw/langchain/chat/models/base/BaseChatModel.java b/langchain-core/src/main/java/com/hw/langchain/chat/models/base/BaseChatModel.java index 8109d6e90..f55c9e486 100644 --- a/langchain-core/src/main/java/com/hw/langchain/chat/models/base/BaseChatModel.java +++ b/langchain-core/src/main/java/com/hw/langchain/chat/models/base/BaseChatModel.java @@ -90,11 +90,6 @@ public BaseMessage call(List messages, List stop) { } } - @Override - public String predict(String text) { - return predict(text, null); - } - @Override public String predict(String text, List stop) { List copyStop = stop != null ? List.copyOf(stop) : null; @@ -104,11 +99,6 @@ public String predict(String text, List stop) { return result.getContent(); } - @Override - public BaseMessage predictMessages(List messages) { - return predictMessages(messages, null); - } - @Override public BaseMessage predictMessages(List messages, List stop) { List copyStop = stop != null ? List.copyOf(stop) : null; diff --git a/langchain-core/src/main/java/com/hw/langchain/llms/base/BaseLLM.java b/langchain-core/src/main/java/com/hw/langchain/llms/base/BaseLLM.java index d095b13ad..733167de2 100644 --- a/langchain-core/src/main/java/com/hw/langchain/llms/base/BaseLLM.java +++ b/langchain-core/src/main/java/com/hw/langchain/llms/base/BaseLLM.java @@ -70,21 +70,11 @@ public LLMResult generatePrompt(List prompts, List stop) { return generate(promptStrings, stop); } - @Override - public String predict(String text) { - return predict(text, null); - } - @Override public String predict(String text, List stop) { return call(text, stop); } - @Override - public BaseMessage predictMessages(List messages) { - return predictMessages(messages, null); - } - @Override public BaseMessage predictMessages(List messages, List stop) { return null; From d08c1532e41c6dc9bc119fc9741e67006d8d28b9 Mon Sep 17 00:00:00 2001 From: "lingjue@ubuntu" Date: Tue, 1 Aug 2023 18:40:48 +0800 Subject: [PATCH 2/8] =?UTF-8?q?=E5=AE=9A=E4=B9=89=E5=BC=82=E6=AD=A5?= =?UTF-8?q?=E6=8E=A5=E5=8F=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- langchain-core/pom.xml | 6 ++ .../langchain/agents/agent/AgentExecutor.java | 2 +- .../langchain/chains/api/base/ApiChain.java | 2 +- .../com/hw/langchain/chains/base/Chain.java | 78 ++++++++++++++++++- .../base/BaseCombineDocumentsChain.java | 3 +- .../com/hw/langchain/chains/llm/LLMChain.java | 2 +- .../chains/llm/math/base/LLMMathChain.java | 2 +- .../retrieval/qa/base/BaseRetrievalQA.java | 7 +- .../sql/database/base/SQLDatabaseChain.java | 2 +- .../base/SQLDatabaseSequentialChain.java | 2 +- pom.xml | 8 ++ 11 files changed, 104 insertions(+), 10 deletions(-) diff --git a/langchain-core/pom.xml b/langchain-core/pom.xml index 5b96590c7..5bcdcc233 100644 --- a/langchain-core/pom.xml +++ b/langchain-core/pom.xml @@ -144,6 +144,12 @@ io.netty netty-resolver-dns + + + io.projectreactor + reactor-core + + org.mockito mockito-core diff --git a/langchain-core/src/main/java/com/hw/langchain/agents/agent/AgentExecutor.java b/langchain-core/src/main/java/com/hw/langchain/agents/agent/AgentExecutor.java index cc655b047..4f3427e48 100644 --- a/langchain-core/src/main/java/com/hw/langchain/agents/agent/AgentExecutor.java +++ b/langchain-core/src/main/java/com/hw/langchain/agents/agent/AgentExecutor.java @@ -138,7 +138,7 @@ public Object takeNextStep(Map nameToToolMap, Map innerCall(Map inputs) { + protected Map innerCall(Map inputs) { // Construct a mapping of tool name to tool for easy lookup Map nameToToolMap = tools.stream().collect(Collectors.toMap(BaseTool::getName, tool -> tool)); diff --git a/langchain-core/src/main/java/com/hw/langchain/chains/api/base/ApiChain.java b/langchain-core/src/main/java/com/hw/langchain/chains/api/base/ApiChain.java index 397a77357..fb089f3c0 100644 --- a/langchain-core/src/main/java/com/hw/langchain/chains/api/base/ApiChain.java +++ b/langchain-core/src/main/java/com/hw/langchain/chains/api/base/ApiChain.java @@ -103,7 +103,7 @@ public List outputKeys() { } @Override - public Map innerCall(Map inputs) { + protected Map innerCall(Map inputs) { var question = inputs.get(QUESTION_KEY); String apiUrl = apiRequestChain.predict(Map.of(QUESTION_KEY, question, API_DOCS, apiDocs)); apiUrl = apiUrl.strip(); diff --git a/langchain-core/src/main/java/com/hw/langchain/chains/base/Chain.java b/langchain-core/src/main/java/com/hw/langchain/chains/base/Chain.java index aab019a82..cf0147447 100644 --- a/langchain-core/src/main/java/com/hw/langchain/chains/base/Chain.java +++ b/langchain-core/src/main/java/com/hw/langchain/chains/base/Chain.java @@ -20,6 +20,8 @@ import com.google.common.collect.Maps; import com.hw.langchain.schema.BaseMemory; +import org.apache.commons.lang3.StringUtils; +import reactor.core.publisher.Flux; import java.util.*; @@ -73,7 +75,17 @@ private void validateOutputs(Map outputs) { * @param inputs the inputs to be processed by the chain * @return a map containing the output generated by the chain */ - public abstract Map innerCall(Map inputs); + protected abstract Map innerCall(Map inputs); + + /** + * Runs the logic of this chain and returns the async output. + * + * @param inputs the inputs to be processed by the chain + * @return a map flux containing the output generated event by the chain + */ + protected Flux> ainnerCall(Map inputs) { + throw new UnsupportedOperationException("Not supported yet."); + } /** * Run the logic of this chain and add to output if desired. @@ -104,6 +116,17 @@ public Map call(Map inputs, boolean returnOnlyOu return prepOutputs(inputs, outputs, returnOnlyOutputs); } + public Flux> acall(String input, boolean returnOnlyOutputs) { + Map inputs = prepInputs(input); + return acall(inputs, returnOnlyOutputs); + } + + public Flux> acall(Map inputs, boolean returnOnlyOutputs) { + inputs = prepInputs(inputs); + Flux> outputs = ainnerCall(inputs); + return prepaOutputs(inputs, outputs, returnOnlyOutputs); + } + /** * Validate and prep outputs. */ @@ -123,6 +146,35 @@ private Map prepOutputs(Map inputs, Map> prepaOutputs(Map inputs, Flux> outputs, + boolean returnOnlyOutputs) { + Map collector = Maps.newHashMap(); + return outputs.doOnNext(this::validateOutputs) + .doOnNext(m -> m.forEach((k, v) -> collector.compute(k, (s, old) -> { + if (StringUtils.equals(s, outputKeys().get(0))) { + return old + v; + } else { + return StringUtils.firstNonBlank(old, v); + } + }))).map(m -> { + if (returnOnlyOutputs) { + return m; + } else { + Map result = Maps.newHashMap(); + inputs.forEach((k, v) -> result.put(k, v.toString())); + result.putAll(m); + return result; + } + }).doOnComplete(() -> { + if (memory != null) { + memory.saveContext(inputs, collector); + } + }); + } + /** * Validate and prep inputs. */ @@ -183,4 +235,28 @@ public String run(Map args) { return call(args, false).get(outputKeys().get(0)); } + /** + * Run the chain as text in, text out async + */ + public Flux arun(String args) { + if (outputKeys().size() != 1) { + throw new IllegalArgumentException( + "The `run` method is not supported when there is not exactly one output key. Got " + outputKeys() + + "."); + } + return acall(args, false).map(m -> m.get(outputKeys().get(0))); + } + + /** + * Run the chain as multiple variables, text out async. + */ + public Flux arun(Map args) { + if (outputKeys().size() != 1) { + throw new IllegalArgumentException( + "The `run` method is not supported when there is not exactly one output key. Got " + outputKeys() + + "."); + } + return acall(args, false).map(m -> m.get(outputKeys().get(0))); + } + } diff --git a/langchain-core/src/main/java/com/hw/langchain/chains/combine/documents/base/BaseCombineDocumentsChain.java b/langchain-core/src/main/java/com/hw/langchain/chains/combine/documents/base/BaseCombineDocumentsChain.java index 00f899271..930f08436 100644 --- a/langchain-core/src/main/java/com/hw/langchain/chains/combine/documents/base/BaseCombineDocumentsChain.java +++ b/langchain-core/src/main/java/com/hw/langchain/chains/combine/documents/base/BaseCombineDocumentsChain.java @@ -64,8 +64,7 @@ public Optional promptLength(List docs, Map k public abstract Pair> combineDocs(List docs, Map kwargs); @Override - - public Map innerCall(Map inputs) { + protected Map innerCall(Map inputs) { @SuppressWarnings("unchecked") var docs = (List) inputs.get(inputKey); diff --git a/langchain-core/src/main/java/com/hw/langchain/chains/llm/LLMChain.java b/langchain-core/src/main/java/com/hw/langchain/chains/llm/LLMChain.java index 8382da744..4baeffd95 100644 --- a/langchain-core/src/main/java/com/hw/langchain/chains/llm/LLMChain.java +++ b/langchain-core/src/main/java/com/hw/langchain/chains/llm/LLMChain.java @@ -97,7 +97,7 @@ public List outputKeys() { } @Override - public Map innerCall(Map inputs) { + protected Map innerCall(Map inputs) { LLMResult response = generate(List.of(inputs)); return createOutputs(response).get(0); } diff --git a/langchain-core/src/main/java/com/hw/langchain/chains/llm/math/base/LLMMathChain.java b/langchain-core/src/main/java/com/hw/langchain/chains/llm/math/base/LLMMathChain.java index de6d26db5..726d97f8f 100644 --- a/langchain-core/src/main/java/com/hw/langchain/chains/llm/math/base/LLMMathChain.java +++ b/langchain-core/src/main/java/com/hw/langchain/chains/llm/math/base/LLMMathChain.java @@ -127,7 +127,7 @@ public Map processLLMResult(String llmOutput) { } @Override - public Map innerCall(Map inputs) { + protected Map innerCall(Map inputs) { var kwargs = Map.of("question", inputs.get(inputKey), "stop", List.of("```output")); String llmOutput = llmChain.predict(kwargs); return processLLMResult(llmOutput); diff --git a/langchain-core/src/main/java/com/hw/langchain/chains/retrieval/qa/base/BaseRetrievalQA.java b/langchain-core/src/main/java/com/hw/langchain/chains/retrieval/qa/base/BaseRetrievalQA.java index d7df94fae..cf695927d 100644 --- a/langchain-core/src/main/java/com/hw/langchain/chains/retrieval/qa/base/BaseRetrievalQA.java +++ b/langchain-core/src/main/java/com/hw/langchain/chains/retrieval/qa/base/BaseRetrievalQA.java @@ -24,6 +24,7 @@ import com.hw.langchain.chains.combine.documents.base.BaseCombineDocumentsChain; import com.hw.langchain.chains.query.constructor.JsonUtils; import com.hw.langchain.schema.Document; +import reactor.core.publisher.Flux; import java.util.List; import java.util.Map; @@ -74,7 +75,7 @@ public List outputKeys() { * Run getRelevantText and llm on input query. */ @Override - public Map innerCall(Map inputs) { + protected Map innerCall(Map inputs) { var question = inputs.get(inputKey).toString(); List docs = getDocs(question); @@ -92,4 +93,8 @@ public Map innerCall(Map inputs) { return result; } + @Override + protected Flux> ainnerCall(Map inputs) { + + } } diff --git a/langchain-core/src/main/java/com/hw/langchain/chains/sql/database/base/SQLDatabaseChain.java b/langchain-core/src/main/java/com/hw/langchain/chains/sql/database/base/SQLDatabaseChain.java index 2f26a32ca..1a4055f4d 100644 --- a/langchain-core/src/main/java/com/hw/langchain/chains/sql/database/base/SQLDatabaseChain.java +++ b/langchain-core/src/main/java/com/hw/langchain/chains/sql/database/base/SQLDatabaseChain.java @@ -118,7 +118,7 @@ public List outputKeys() { } @Override - public Map innerCall(Map inputs) { + protected Map innerCall(Map inputs) { String inputText = inputs.get(this.inputKey) + "\nSQLQuery:"; // If not present, then defaults to null which is all tables. var tableNamesToUse = (List) inputs.get("table_names_to_use"); diff --git a/langchain-core/src/main/java/com/hw/langchain/chains/sql/database/base/SQLDatabaseSequentialChain.java b/langchain-core/src/main/java/com/hw/langchain/chains/sql/database/base/SQLDatabaseSequentialChain.java index 025cec545..f1f593258 100644 --- a/langchain-core/src/main/java/com/hw/langchain/chains/sql/database/base/SQLDatabaseSequentialChain.java +++ b/langchain-core/src/main/java/com/hw/langchain/chains/sql/database/base/SQLDatabaseSequentialChain.java @@ -101,7 +101,7 @@ public List outputKeys() { } @Override - public Map innerCall(Map inputs) { + protected Map innerCall(Map inputs) { List tableNameList = sqlChain.getDatabase().getUsableTableNames(); String tableNames = String.join(", ", tableNameList); var llmInputs = Map.of("query", inputs.get(inputKey), diff --git a/pom.xml b/pom.xml index 7fcce1821..25ef45aca 100644 --- a/pom.xml +++ b/pom.xml @@ -47,6 +47,7 @@ 8.0.0.Final 3.17.3 4.1.43.Final + 3.5.8 3.12.4 2.12 @@ -252,6 +253,13 @@ netty-resolver-dns ${netty-resolver.version} + + + io.projectreactor + reactor-core + ${reactor.version} + + org.mockito mockito-core From 15125e4d5302afc232375ed3aefeb875122808b4 Mon Sep 17 00:00:00 2001 From: "lingjue@ubuntu" Date: Tue, 1 Aug 2023 19:06:20 +0800 Subject: [PATCH 3/8] =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E5=BC=82=E6=AD=A5?= =?UTF-8?q?=E7=94=9F=E6=88=90=E6=8E=A5=E5=8F=A3=E5=AE=9A=E4=B9=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../base/language/BaseLanguageModel.java | 6 ++++ .../hw/langchain/schema/AsyncLLMResult.java | 30 +++++++++++++++++++ 2 files changed, 36 insertions(+) create mode 100644 langchain-core/src/main/java/com/hw/langchain/schema/AsyncLLMResult.java diff --git a/langchain-core/src/main/java/com/hw/langchain/base/language/BaseLanguageModel.java b/langchain-core/src/main/java/com/hw/langchain/base/language/BaseLanguageModel.java index ed35fd0e2..4c23ea892 100644 --- a/langchain-core/src/main/java/com/hw/langchain/base/language/BaseLanguageModel.java +++ b/langchain-core/src/main/java/com/hw/langchain/base/language/BaseLanguageModel.java @@ -18,9 +18,11 @@ package com.hw.langchain.base.language; +import com.hw.langchain.schema.AsyncLLMResult; import com.hw.langchain.schema.BaseMessage; import com.hw.langchain.schema.LLMResult; import com.hw.langchain.schema.PromptValue; +import reactor.core.publisher.Flux; import java.util.List; @@ -60,5 +62,9 @@ default BaseMessage predictMessages(List messages) { */ BaseMessage predictMessages(List messages, List stop); + /** + * Take in a list of prompt values and return an Flux<AsyncLLMResult> for every PromptValue. + */ + List> asyncGeneratePrompt(List prompts, List stop); } diff --git a/langchain-core/src/main/java/com/hw/langchain/schema/AsyncLLMResult.java b/langchain-core/src/main/java/com/hw/langchain/schema/AsyncLLMResult.java new file mode 100644 index 000000000..214e50df0 --- /dev/null +++ b/langchain-core/src/main/java/com/hw/langchain/schema/AsyncLLMResult.java @@ -0,0 +1,30 @@ +package com.hw.langchain.schema; + +import lombok.Data; + +import java.util.List; +import java.util.Map; + +/** + * @author lingjue@ubuntu + * @since 8/1/23 7:01 PM + */ +@Data +public class AsyncLLMResult { + + /** + * List of the things generated. This is List> because each input could have multiple generations. + */ + private List generations; + + /** + * For arbitrary LLM provider specific output. + */ + private Map llmOutput; + + public AsyncLLMResult(List generations, Map llmOutput) { + this.generations = generations; + this.llmOutput = llmOutput; + } + +} From 6fcf0f29426f395a1c9fecefe9235caabe7897f1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=87=8C=E7=BB=9D?= Date: Wed, 2 Aug 2023 09:18:56 +0800 Subject: [PATCH 4/8] add sync function definition --- .../base/language/BaseLanguageModel.java | 32 +++++++++++++- .../base/BaseCombineDocumentsChain.java | 20 +++++++++ .../documents/stuff/StuffDocumentsChain.java | 7 +++ .../com/hw/langchain/chains/llm/LLMChain.java | 43 +++++++++++++++++-- .../retrieval/qa/base/BaseRetrievalQA.java | 8 ++++ 5 files changed, 105 insertions(+), 5 deletions(-) diff --git a/langchain-core/src/main/java/com/hw/langchain/base/language/BaseLanguageModel.java b/langchain-core/src/main/java/com/hw/langchain/base/language/BaseLanguageModel.java index 4c23ea892..3b20cd288 100644 --- a/langchain-core/src/main/java/com/hw/langchain/base/language/BaseLanguageModel.java +++ b/langchain-core/src/main/java/com/hw/langchain/base/language/BaseLanguageModel.java @@ -65,6 +65,36 @@ default BaseMessage predictMessages(List messages) { /** * Take in a list of prompt values and return an Flux<AsyncLLMResult> for every PromptValue. */ - List> asyncGeneratePrompt(List prompts, List stop); + default List> asyncGeneratePrompt(List prompts) { + return asyncGeneratePrompt(prompts, null); + } + + /** + * Take in a list of prompt values and return an Flux<AsyncLLMResult> for every PromptValue. + */ + default List> asyncGeneratePrompt(List prompts, List stop) { + throw new UnsupportedOperationException("not supported yet."); + }; + + /** + * Predict text from text async. + */ + default Flux apredict(String text) { + return apredict(text, null); + } + + /** + * Predict text from text async. + */ + default Flux apredict(String text, List stop) { + throw new UnsupportedOperationException("not supported yet."); + } + + /** + * Predict message from messages async. + */ + default Flux apredictMessages(List messages, List stop) { + throw new UnsupportedOperationException("not supported yet."); + } } diff --git a/langchain-core/src/main/java/com/hw/langchain/chains/combine/documents/base/BaseCombineDocumentsChain.java b/langchain-core/src/main/java/com/hw/langchain/chains/combine/documents/base/BaseCombineDocumentsChain.java index 930f08436..ca59e698f 100644 --- a/langchain-core/src/main/java/com/hw/langchain/chains/combine/documents/base/BaseCombineDocumentsChain.java +++ b/langchain-core/src/main/java/com/hw/langchain/chains/combine/documents/base/BaseCombineDocumentsChain.java @@ -23,6 +23,7 @@ import com.hw.langchain.schema.Document; import org.apache.commons.lang3.tuple.Pair; +import reactor.core.publisher.Flux; import java.util.HashMap; import java.util.List; @@ -62,6 +63,10 @@ public Optional promptLength(List docs, Map k * Combine documents into a single string. */ public abstract Pair> combineDocs(List docs, Map kwargs); + /** + * Combine documents into a single string async. + */ + public abstract Flux>> acombineDocs(List docs, Map kwargs); @Override protected Map innerCall(Map inputs) { @@ -75,4 +80,19 @@ protected Map innerCall(Map inputs) { extraReturnDict.put(outputKey, result.getLeft()); return extraReturnDict; } + + @Override + protected Flux> ainnerCall(Map inputs) { + @SuppressWarnings("unchecked") + var docs = (List) inputs.get(inputKey); + + Map otherKeys = Maps.filterKeys(inputs, key -> !key.equals(inputKey)); + var result = this.acombineDocs(docs, otherKeys); + + return result.map(pair -> { + var extraReturnDict = new HashMap<>(pair.getRight()); + extraReturnDict.put(outputKey, pair.getLeft()); + return extraReturnDict; + }); + } } diff --git a/langchain-core/src/main/java/com/hw/langchain/chains/combine/documents/stuff/StuffDocumentsChain.java b/langchain-core/src/main/java/com/hw/langchain/chains/combine/documents/stuff/StuffDocumentsChain.java index 17f54ff59..66947d2ae 100644 --- a/langchain-core/src/main/java/com/hw/langchain/chains/combine/documents/stuff/StuffDocumentsChain.java +++ b/langchain-core/src/main/java/com/hw/langchain/chains/combine/documents/stuff/StuffDocumentsChain.java @@ -25,6 +25,7 @@ import com.hw.langchain.schema.Document; import org.apache.commons.lang3.tuple.Pair; +import reactor.core.publisher.Flux; import java.util.List; import java.util.Map; @@ -116,6 +117,12 @@ public Pair> combineDocs(List docs, Map>> acombineDocs(List docs, Map kwargs) { + var inputs = getInputs(docs, kwargs); + return llmChain.apredict(inputs).map(s -> Pair.of(s, Map.of())); + } + @Override public String chainType() { return "stuff_documents_chain"; diff --git a/langchain-core/src/main/java/com/hw/langchain/chains/llm/LLMChain.java b/langchain-core/src/main/java/com/hw/langchain/chains/llm/LLMChain.java index 4baeffd95..4891cad77 100644 --- a/langchain-core/src/main/java/com/hw/langchain/chains/llm/LLMChain.java +++ b/langchain-core/src/main/java/com/hw/langchain/chains/llm/LLMChain.java @@ -21,13 +21,11 @@ import com.hw.langchain.base.language.BaseLanguageModel; import com.hw.langchain.chains.base.Chain; import com.hw.langchain.prompts.base.BasePromptTemplate; -import com.hw.langchain.schema.BaseLLMOutputParser; -import com.hw.langchain.schema.LLMResult; -import com.hw.langchain.schema.NoOpOutputParser; -import com.hw.langchain.schema.PromptValue; +import com.hw.langchain.schema.*; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; import java.util.ArrayList; import java.util.HashMap; @@ -102,6 +100,12 @@ protected Map innerCall(Map inputs) { return createOutputs(response).get(0); } + @Override + protected Flux> ainnerCall(Map inputs) { + var response = agenerate(List.of(inputs)); + return response.get(0).map(this::createAsyncOutputs); + } + /** * Generate LLM result from inputs. */ @@ -111,6 +115,15 @@ private LLMResult generate(List> inputList) { return llm.generatePrompt(prompts, stop); } + /** + * Generate LLM result from inputs async. + */ + private List> agenerate(List> inputList) { + List stop = prepStop(inputList); + List prompts = prepPrompts(inputList); + return llm.asyncGeneratePrompt(prompts, stop); + } + /** * Prepare prompts from inputs. */ @@ -154,6 +167,17 @@ private List> createOutputs(LLMResult llmResult) { return result; } + /** + * Create outputs from response async. + */ + private Map createAsyncOutputs(AsyncLLMResult llmResult) { + Map result = Map.of(outputKey, outputParser.parseResult(llmResult.getGenerations()), "full_generation", llmResult.getGenerations().toString()); + if (returnFinalOnly) { + result = Map.of(outputKey, result.get(outputKey)); + } + return result; + } + /** * Format prompt with kwargs and pass to LLM. * @@ -165,6 +189,17 @@ public String predict(Map kwargs) { return resultMap.get(outputKey); } + /** + * Format prompt with kwargs and pass to LLM async. + * + * @param kwargs Keys to pass to prompt template. + * @return Completion from LLM. + */ + public Flux apredict(Map kwargs) { + var flux = acall(kwargs, false); + return flux.map(m -> m.get(outputKey)); + } + /** * Call predict and then parse the results. */ diff --git a/langchain-core/src/main/java/com/hw/langchain/chains/retrieval/qa/base/BaseRetrievalQA.java b/langchain-core/src/main/java/com/hw/langchain/chains/retrieval/qa/base/BaseRetrievalQA.java index cf695927d..d54f5b8ac 100644 --- a/langchain-core/src/main/java/com/hw/langchain/chains/retrieval/qa/base/BaseRetrievalQA.java +++ b/langchain-core/src/main/java/com/hw/langchain/chains/retrieval/qa/base/BaseRetrievalQA.java @@ -95,6 +95,14 @@ protected Map innerCall(Map inputs) { @Override protected Flux> ainnerCall(Map inputs) { + var question = inputs.get(inputKey).toString(); + List docs = getDocs(question); + inputs.put("input_documents", docs); + if (!inputs.containsKey("question")) { + inputs.put("question", question); + } + Flux answer = combineDocumentsChain.arun(inputs); + return answer.map(s -> Map.of(outputKey, s)); } } From cc46f51daee1fecdb04ad20d7d676e2addaf63f6 Mon Sep 17 00:00:00 2001 From: "lingjue@ubuntu" Date: Mon, 11 Sep 2023 10:27:58 +0800 Subject: [PATCH 5/8] rebase dev --- .java-version | 1 + .../base/language/BaseLanguageModel.java | 1 + .../com/hw/langchain/chains/base/Chain.java | 4 ++- .../base/BaseCombineDocumentsChain.java | 4 ++- .../documents/stuff/StuffDocumentsChain.java | 1 + .../com/hw/langchain/chains/llm/LLMChain.java | 4 ++- .../retrieval/qa/base/BaseRetrievalQA.java | 1 + .../hw/langchain/chat/models/base/LLM.java | 7 ++++ .../com/hw/langchain/llms/base/BaseLLM.java | 19 ++++++++++ .../com/hw/langchain/llms/ollama/Ollama.java | 7 ++++ .../hw/langchain/llms/openai/BaseOpenAI.java | 7 ++++ .../hw/langchain/llms/openai/OpenAIChat.java | 7 ++++ .../hw/langchain/schema/AsyncLLMResult.java | 18 ++++++++++ langchain-server/pom.xml | 36 +++++++++++++++++++ langchain-web/pom.xml | 28 +++++++++++++++ 15 files changed, 142 insertions(+), 3 deletions(-) create mode 100644 .java-version create mode 100644 langchain-server/pom.xml create mode 100644 langchain-web/pom.xml diff --git a/.java-version b/.java-version new file mode 100644 index 000000000..98d9bcb75 --- /dev/null +++ b/.java-version @@ -0,0 +1 @@ +17 diff --git a/langchain-core/src/main/java/com/hw/langchain/base/language/BaseLanguageModel.java b/langchain-core/src/main/java/com/hw/langchain/base/language/BaseLanguageModel.java index 3b20cd288..79684370d 100644 --- a/langchain-core/src/main/java/com/hw/langchain/base/language/BaseLanguageModel.java +++ b/langchain-core/src/main/java/com/hw/langchain/base/language/BaseLanguageModel.java @@ -22,6 +22,7 @@ import com.hw.langchain.schema.BaseMessage; import com.hw.langchain.schema.LLMResult; import com.hw.langchain.schema.PromptValue; + import reactor.core.publisher.Flux; import java.util.List; diff --git a/langchain-core/src/main/java/com/hw/langchain/chains/base/Chain.java b/langchain-core/src/main/java/com/hw/langchain/chains/base/Chain.java index cf0147447..91b80ccd3 100644 --- a/langchain-core/src/main/java/com/hw/langchain/chains/base/Chain.java +++ b/langchain-core/src/main/java/com/hw/langchain/chains/base/Chain.java @@ -20,7 +20,9 @@ import com.google.common.collect.Maps; import com.hw.langchain.schema.BaseMemory; + import org.apache.commons.lang3.StringUtils; + import reactor.core.publisher.Flux; import java.util.*; @@ -150,7 +152,7 @@ private Map prepOutputs(Map inputs, Map> prepaOutputs(Map inputs, Flux> outputs, - boolean returnOnlyOutputs) { + boolean returnOnlyOutputs) { Map collector = Maps.newHashMap(); return outputs.doOnNext(this::validateOutputs) .doOnNext(m -> m.forEach((k, v) -> collector.compute(k, (s, old) -> { diff --git a/langchain-core/src/main/java/com/hw/langchain/chains/combine/documents/base/BaseCombineDocumentsChain.java b/langchain-core/src/main/java/com/hw/langchain/chains/combine/documents/base/BaseCombineDocumentsChain.java index ca59e698f..e9c8d82da 100644 --- a/langchain-core/src/main/java/com/hw/langchain/chains/combine/documents/base/BaseCombineDocumentsChain.java +++ b/langchain-core/src/main/java/com/hw/langchain/chains/combine/documents/base/BaseCombineDocumentsChain.java @@ -23,6 +23,7 @@ import com.hw.langchain.schema.Document; import org.apache.commons.lang3.tuple.Pair; + import reactor.core.publisher.Flux; import java.util.HashMap; @@ -66,7 +67,8 @@ public Optional promptLength(List docs, Map k /** * Combine documents into a single string async. */ - public abstract Flux>> acombineDocs(List docs, Map kwargs); + public abstract Flux>> acombineDocs(List docs, + Map kwargs); @Override protected Map innerCall(Map inputs) { diff --git a/langchain-core/src/main/java/com/hw/langchain/chains/combine/documents/stuff/StuffDocumentsChain.java b/langchain-core/src/main/java/com/hw/langchain/chains/combine/documents/stuff/StuffDocumentsChain.java index 66947d2ae..ba8cbf1a9 100644 --- a/langchain-core/src/main/java/com/hw/langchain/chains/combine/documents/stuff/StuffDocumentsChain.java +++ b/langchain-core/src/main/java/com/hw/langchain/chains/combine/documents/stuff/StuffDocumentsChain.java @@ -25,6 +25,7 @@ import com.hw.langchain.schema.Document; import org.apache.commons.lang3.tuple.Pair; + import reactor.core.publisher.Flux; import java.util.List; diff --git a/langchain-core/src/main/java/com/hw/langchain/chains/llm/LLMChain.java b/langchain-core/src/main/java/com/hw/langchain/chains/llm/LLMChain.java index 4891cad77..38af05c16 100644 --- a/langchain-core/src/main/java/com/hw/langchain/chains/llm/LLMChain.java +++ b/langchain-core/src/main/java/com/hw/langchain/chains/llm/LLMChain.java @@ -25,6 +25,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; + import reactor.core.publisher.Flux; import java.util.ArrayList; @@ -171,7 +172,8 @@ private List> createOutputs(LLMResult llmResult) { * Create outputs from response async. */ private Map createAsyncOutputs(AsyncLLMResult llmResult) { - Map result = Map.of(outputKey, outputParser.parseResult(llmResult.getGenerations()), "full_generation", llmResult.getGenerations().toString()); + Map result = Map.of(outputKey, outputParser.parseResult(llmResult.getGenerations()), + "full_generation", llmResult.getGenerations().toString()); if (returnFinalOnly) { result = Map.of(outputKey, result.get(outputKey)); } diff --git a/langchain-core/src/main/java/com/hw/langchain/chains/retrieval/qa/base/BaseRetrievalQA.java b/langchain-core/src/main/java/com/hw/langchain/chains/retrieval/qa/base/BaseRetrievalQA.java index d54f5b8ac..7a23b38cd 100644 --- a/langchain-core/src/main/java/com/hw/langchain/chains/retrieval/qa/base/BaseRetrievalQA.java +++ b/langchain-core/src/main/java/com/hw/langchain/chains/retrieval/qa/base/BaseRetrievalQA.java @@ -24,6 +24,7 @@ import com.hw.langchain.chains.combine.documents.base.BaseCombineDocumentsChain; import com.hw.langchain.chains.query.constructor.JsonUtils; import com.hw.langchain.schema.Document; + import reactor.core.publisher.Flux; import java.util.List; diff --git a/langchain-core/src/main/java/com/hw/langchain/chat/models/base/LLM.java b/langchain-core/src/main/java/com/hw/langchain/chat/models/base/LLM.java index 52d9fa980..9f39a8f9e 100644 --- a/langchain-core/src/main/java/com/hw/langchain/chat/models/base/LLM.java +++ b/langchain-core/src/main/java/com/hw/langchain/chat/models/base/LLM.java @@ -19,10 +19,12 @@ package com.hw.langchain.chat.models.base; import com.hw.langchain.llms.base.BaseLLM; +import com.hw.langchain.schema.AsyncLLMResult; import com.hw.langchain.schema.Generation; import com.hw.langchain.schema.LLMResult; import lombok.experimental.SuperBuilder; +import reactor.core.publisher.Flux; import java.util.List; @@ -54,4 +56,9 @@ protected LLMResult innerGenerate(List prompts, List stop) { return new LLMResult(generations); } + + @Override + protected Flux _agenerate(List prompts, List stop) { + throw new UnsupportedOperationException("not supported yet."); + } } diff --git a/langchain-core/src/main/java/com/hw/langchain/llms/base/BaseLLM.java b/langchain-core/src/main/java/com/hw/langchain/llms/base/BaseLLM.java index 733167de2..f258eb823 100644 --- a/langchain-core/src/main/java/com/hw/langchain/llms/base/BaseLLM.java +++ b/langchain-core/src/main/java/com/hw/langchain/llms/base/BaseLLM.java @@ -19,11 +19,13 @@ package com.hw.langchain.llms.base; import com.hw.langchain.base.language.BaseLanguageModel; +import com.hw.langchain.schema.AsyncLLMResult; import com.hw.langchain.schema.BaseMessage; import com.hw.langchain.schema.LLMResult; import com.hw.langchain.schema.PromptValue; import lombok.experimental.SuperBuilder; +import reactor.core.publisher.Flux; import java.util.List; @@ -43,6 +45,10 @@ public abstract class BaseLLM implements BaseLanguageModel { * Run the LLM on the given prompts. */ protected abstract LLMResult innerGenerate(List prompts, List stop); + /** + * Run the LLM on the given prompts async. + */ + protected abstract Flux _agenerate(List prompts, List stop); /** * Check Cache and run the LLM on the given prompt and input. @@ -70,11 +76,24 @@ public LLMResult generatePrompt(List prompts, List stop) { return generate(promptStrings, stop); } + @Override + public List> asyncGeneratePrompt(List prompts, List stop) { + List promptStrings = prompts.stream() + .map(PromptValue::toString) + .toList(); + return promptStrings.stream().map(s -> _agenerate(List.of(s), stop)).toList(); + } + @Override public String predict(String text, List stop) { return call(text, stop); } + @Override + public Flux apredict(String text, List stop) { + return _agenerate(List.of(text), stop).map(result -> result.getGenerations().get(0).getText()); + } + @Override public BaseMessage predictMessages(List messages, List stop) { return null; diff --git a/langchain-core/src/main/java/com/hw/langchain/llms/ollama/Ollama.java b/langchain-core/src/main/java/com/hw/langchain/llms/ollama/Ollama.java index fa83f8e0f..a58d12509 100644 --- a/langchain-core/src/main/java/com/hw/langchain/llms/ollama/Ollama.java +++ b/langchain-core/src/main/java/com/hw/langchain/llms/ollama/Ollama.java @@ -23,6 +23,7 @@ import com.hw.langchain.chains.query.constructor.JsonUtils; import com.hw.langchain.llms.base.BaseLLM; import com.hw.langchain.requests.TextRequestsWrapper; +import com.hw.langchain.schema.AsyncLLMResult; import com.hw.langchain.schema.GenerationChunk; import com.hw.langchain.schema.LLMResult; @@ -30,6 +31,7 @@ import lombok.Builder; import lombok.experimental.SuperBuilder; +import reactor.core.publisher.Flux; import java.util.ArrayList; import java.util.List; @@ -206,6 +208,11 @@ protected LLMResult innerGenerate(List prompts, List stop) { return new LLMResult(generations); } + @Override + protected Flux _agenerate(List prompts, List stop) { + throw new UnsupportedOperationException("not supported yet."); + } + /** * Convert a stream response to a generation chunk. * diff --git a/langchain-core/src/main/java/com/hw/langchain/llms/openai/BaseOpenAI.java b/langchain-core/src/main/java/com/hw/langchain/llms/openai/BaseOpenAI.java index ef52a20fb..93244b095 100644 --- a/langchain-core/src/main/java/com/hw/langchain/llms/openai/BaseOpenAI.java +++ b/langchain-core/src/main/java/com/hw/langchain/llms/openai/BaseOpenAI.java @@ -19,6 +19,7 @@ package com.hw.langchain.llms.openai; import com.hw.langchain.llms.base.BaseLLM; +import com.hw.langchain.schema.AsyncLLMResult; import com.hw.langchain.schema.Generation; import com.hw.langchain.schema.LLMResult; import com.hw.openai.OpenAiClient; @@ -29,6 +30,7 @@ import lombok.Builder; import lombok.experimental.SuperBuilder; import okhttp3.Interceptor; +import reactor.core.publisher.Flux; import java.util.*; @@ -212,6 +214,11 @@ protected LLMResult innerGenerate(List prompts, List stop) { return createLLMResult(choices, prompts, Map.of()); } + @Override + protected Flux _agenerate(List prompts, List stop) { + throw new UnsupportedOperationException("not supported yet."); + } + /** * Create the LLMResult from the choices and prompts. */ diff --git a/langchain-core/src/main/java/com/hw/langchain/llms/openai/OpenAIChat.java b/langchain-core/src/main/java/com/hw/langchain/llms/openai/OpenAIChat.java index 4a6cd4226..629807495 100644 --- a/langchain-core/src/main/java/com/hw/langchain/llms/openai/OpenAIChat.java +++ b/langchain-core/src/main/java/com/hw/langchain/llms/openai/OpenAIChat.java @@ -19,6 +19,7 @@ package com.hw.langchain.llms.openai; import com.hw.langchain.llms.base.BaseLLM; +import com.hw.langchain.schema.AsyncLLMResult; import com.hw.langchain.schema.Generation; import com.hw.langchain.schema.LLMResult; import com.hw.langchain.utils.Utils; @@ -29,6 +30,7 @@ import lombok.Builder; import lombok.experimental.SuperBuilder; +import reactor.core.publisher.Flux; import java.util.*; @@ -201,6 +203,11 @@ protected LLMResult innerGenerate(List prompts, List stop) { return new LLMResult(generations, llmOutput); } + @Override + protected Flux _agenerate(List prompts, List stop) { + throw new UnsupportedOperationException("not supported yet."); + } + private List getChatMessages(List prompts) { checkArgument(prompts.size() == 1, "OpenAIChat currently only supports single prompt, got %s", prompts); List messages = new ArrayList<>(prefixMessages); diff --git a/langchain-core/src/main/java/com/hw/langchain/schema/AsyncLLMResult.java b/langchain-core/src/main/java/com/hw/langchain/schema/AsyncLLMResult.java index 214e50df0..5e3d09d9a 100644 --- a/langchain-core/src/main/java/com/hw/langchain/schema/AsyncLLMResult.java +++ b/langchain-core/src/main/java/com/hw/langchain/schema/AsyncLLMResult.java @@ -1,3 +1,21 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package com.hw.langchain.schema; import lombok.Data; diff --git a/langchain-server/pom.xml b/langchain-server/pom.xml new file mode 100644 index 000000000..dfe57c4b7 --- /dev/null +++ b/langchain-server/pom.xml @@ -0,0 +1,36 @@ + + + 4.0.0 + + io.github.hamawhitegg + langchain-java + 0.1.11-SNAPSHOT + + + langchain-server + + + true + + + + + + org.apache.maven.plugins + maven-compiler-plugin + + + com.diffplug.spotless + spotless-maven-plugin + + + org.apache.maven.plugins + maven-deploy-plugin + + true + + + + + diff --git a/langchain-web/pom.xml b/langchain-web/pom.xml new file mode 100644 index 000000000..23146b536 --- /dev/null +++ b/langchain-web/pom.xml @@ -0,0 +1,28 @@ + + + 4.0.0 + + io.github.hamawhitegg + langchain-java + 0.1.11-SNAPSHOT + + + langchain-web + + + true + + + + + + org.apache.maven.plugins + maven-deploy-plugin + + true + + + + + From 79bf29d7f2e868fb2d12c7cc00a6e9905f7c38be Mon Sep 17 00:00:00 2001 From: "lingjue@ubuntu" Date: Wed, 2 Aug 2023 14:46:51 +0800 Subject: [PATCH 6/8] format code by spotless --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index 98936c327..1374a694f 100644 --- a/.gitignore +++ b/.gitignore @@ -27,6 +27,7 @@ logPath_IS_UNDEFINED target # other ignore +.java-version *.log *.tmp Thumbs.db From 8ea9d24846d6b787251e07f1b60e28bb5e6c727a Mon Sep 17 00:00:00 2001 From: "lingjue@ubuntu" Date: Wed, 2 Aug 2023 16:25:34 +0800 Subject: [PATCH 7/8] support -Dnexus-staging-maven-plugin.executions=false to deploy to my private nexus for experience and preview --- pom.xml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pom.xml b/pom.xml index 25ef45aca..533f3651f 100644 --- a/pom.xml +++ b/pom.xml @@ -61,6 +61,7 @@ 3.0.1 3.0.0 1.6.13 + true ${target.java.version} ${target.java.version} UTF-8 @@ -336,7 +337,7 @@ org.sonatype.plugins nexus-staging-maven-plugin ${nexus-staging-maven-plugin.version} - true + ${nexus-staging-maven-plugin.executions} ossrh https://s01.oss.sonatype.org/ From c0d9ff99a0ffb9537617a0323a1a1cba6469e126 Mon Sep 17 00:00:00 2001 From: "lingjue@ubuntu" Date: Mon, 11 Sep 2023 10:29:07 +0800 Subject: [PATCH 8/8] rebase dev --- langchain-server/pom.xml | 2 +- langchain-web/pom.xml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/langchain-server/pom.xml b/langchain-server/pom.xml index dfe57c4b7..0a7d44892 100644 --- a/langchain-server/pom.xml +++ b/langchain-server/pom.xml @@ -5,7 +5,7 @@ io.github.hamawhitegg langchain-java - 0.1.11-SNAPSHOT + 0.1.12-SNAPSHOT langchain-server diff --git a/langchain-web/pom.xml b/langchain-web/pom.xml index 23146b536..9d961c0d8 100644 --- a/langchain-web/pom.xml +++ b/langchain-web/pom.xml @@ -5,7 +5,7 @@ io.github.hamawhitegg langchain-java - 0.1.11-SNAPSHOT + 0.1.12-SNAPSHOT langchain-web