diff --git a/langchain-core/src/main/java/com/hw/langchain/chains/conversation/base/ConversationChain.java b/langchain-core/src/main/java/com/hw/langchain/chains/conversation/base/ConversationChain.java index 16e476d63..7726d1fa9 100644 --- a/langchain-core/src/main/java/com/hw/langchain/chains/conversation/base/ConversationChain.java +++ b/langchain-core/src/main/java/com/hw/langchain/chains/conversation/base/ConversationChain.java @@ -41,16 +41,14 @@ public class ConversationChain extends LLMChain { protected String inputKey = "input"; public ConversationChain(BaseLanguageModel llm) { - super(llm, PROMPT, "response"); - // Default memory store. - this.memory = new ConversationBufferMemory(); - - validatePromptInputVariables(); + this(llm, PROMPT, new ConversationBufferMemory()); } public ConversationChain(BaseLanguageModel llm, BasePromptTemplate prompt, BaseMemory memory) { - super(llm, prompt); + super(llm, prompt, "response"); this.memory = memory; + + validatePromptInputVariables(); } /** diff --git a/langchain-core/src/main/java/com/hw/langchain/chains/llm/LLMChain.java b/langchain-core/src/main/java/com/hw/langchain/chains/llm/LLMChain.java index f09becf9b..30265dc31 100644 --- a/langchain-core/src/main/java/com/hw/langchain/chains/llm/LLMChain.java +++ b/langchain-core/src/main/java/com/hw/langchain/chains/llm/LLMChain.java @@ -21,7 +21,9 @@ import com.hw.langchain.base.language.BaseLanguageModel; import com.hw.langchain.chains.base.Chain; import com.hw.langchain.prompts.base.BasePromptTemplate; +import com.hw.langchain.schema.BaseLLMOutputParser; import com.hw.langchain.schema.LLMResult; +import com.hw.langchain.schema.NoOpOutputParser; import com.hw.langchain.schema.PromptValue; import org.slf4j.Logger; @@ -50,6 +52,18 @@ public class LLMChain extends Chain { protected String outputKey = "text"; + /** + * Output parser to use. + * Defaults to one that takes the most likely string but does not change it. + */ + protected BaseLLMOutputParser outputParser = new NoOpOutputParser(); + + /** + * Whether to return only the final parsed result. Defaults to true. + * If false, will return a bunch of extra information about the generation. + */ + protected boolean returnFinalOnly = true; + public LLMChain(BaseLanguageModel llm, BasePromptTemplate prompt) { this.llm = llm; this.prompt = prompt; @@ -126,10 +140,18 @@ private List prepStop(List> inputList) { /** * Create outputs from response. */ - private List> createOutputs(LLMResult response) { - return response.getGenerations().stream() - .map(generationList -> Map.of(outputKey, generationList.get(0).getText())) + private List> createOutputs(LLMResult llmResult) { + var result = llmResult.getGenerations().stream() + .map(generation -> Map.of(outputKey, outputParser.parseResult(generation), + "full_generation", generation.toString())) .toList(); + + if (returnFinalOnly) { + result = result.stream() + .map(r -> Map.of(outputKey, r.get(outputKey))) + .toList(); + } + return result; } /** diff --git a/langchain-core/src/main/java/com/hw/langchain/schema/BaseLLMOutputParser.java b/langchain-core/src/main/java/com/hw/langchain/schema/BaseLLMOutputParser.java new file mode 100644 index 000000000..eac7c7865 --- /dev/null +++ b/langchain-core/src/main/java/com/hw/langchain/schema/BaseLLMOutputParser.java @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.hw.langchain.schema; + +import java.io.Serializable; +import java.util.List; + +/** + * @author HamaWhite + */ +public abstract class BaseLLMOutputParser implements Serializable { + + /** + * Parse LLM Result. + */ + public abstract T parseResult(List result); +} \ No newline at end of file diff --git a/langchain-core/src/main/java/com/hw/langchain/schema/BaseOutputParser.java b/langchain-core/src/main/java/com/hw/langchain/schema/BaseOutputParser.java index bd25127f6..f5a8cbdba 100644 --- a/langchain-core/src/main/java/com/hw/langchain/schema/BaseOutputParser.java +++ b/langchain-core/src/main/java/com/hw/langchain/schema/BaseOutputParser.java @@ -18,14 +18,20 @@ package com.hw.langchain.schema; +import java.util.List; + /** * Class to parse the output of an LLM call. - *

* Output parsers help structure language model responses. * * @author HamaWhite */ -public abstract class BaseOutputParser { +public abstract class BaseOutputParser extends BaseLLMOutputParser { + + @Override + public T parseResult(List result) { + return parse(result.get(0).getText()); + } /** * Parse the output of an LLM call. @@ -55,5 +61,7 @@ public Object parseWithPrompt(String completion, PromptValue prompt) { * * @return format instructions */ - public abstract String getFormatInstructions(); + public String getFormatInstructions() { + throw new UnsupportedOperationException("Method getFormatInstructions() is not implemented."); + } } diff --git a/langchain-core/src/main/java/com/hw/langchain/schema/NoOpOutputParser.java b/langchain-core/src/main/java/com/hw/langchain/schema/NoOpOutputParser.java new file mode 100644 index 000000000..19b4f70cc --- /dev/null +++ b/langchain-core/src/main/java/com/hw/langchain/schema/NoOpOutputParser.java @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.hw.langchain.schema; + +/** + * Output parser that just returns the text as is. + * + * @author HamaWhite + */ +public class NoOpOutputParser extends BaseOutputParser { + + @Override + public String parse(String text) { + return text; + } + +} diff --git a/langchain-core/src/test/java/com/hw/langchain/chains/conversation/base/ConversationChainTest.java b/langchain-core/src/test/java/com/hw/langchain/chains/conversation/base/ConversationChainTest.java index e58514d2f..e344f76e4 100644 --- a/langchain-core/src/test/java/com/hw/langchain/chains/conversation/base/ConversationChainTest.java +++ b/langchain-core/src/test/java/com/hw/langchain/chains/conversation/base/ConversationChainTest.java @@ -67,9 +67,8 @@ void testConversationChainWithChatModel() { var prompt = ChatPromptTemplate.fromMessages(List.of( SystemMessagePromptTemplate.fromTemplate( "The following is a friendly conversation between a human and an AI. The AI is talkative and " + - "provides lots of specific details from its context. If the AI does not know the answer to a " - + - "question, it truthfully says it does not know."), + "provides lots of specific details from its context. If the AI does not know the " + + "answer to a question, it truthfully says it does not know."), new MessagesPlaceholder("history"), HumanMessagePromptTemplate.fromTemplate("{input}"))); @@ -77,8 +76,19 @@ void testConversationChainWithChatModel() { var memory = new ConversationBufferMemory(true); var conversation = new ConversationChain(chat, prompt, memory); - conversation.predict(Map.of("input", "Hi there!")); - conversation.predict(Map.of("input", "I'm doing well! Just having a conversation with an AI.")); - conversation.predict(Map.of("input", "Tell me about yourself.")); + var output1 = conversation.predict(Map.of("input", "Hi there!")); + // Hello! How can I assist you today? + LOG.info("output1: \n{}", output1); + assertNotNull(output1, "output1 should not be null"); + + var output2 = conversation.predict(Map.of("input", "I'm doing well! Just having a conversation with an AI.")); + // That sounds like fun! I'm happy to chat with you. What would you like to talk about? + LOG.info("output2: \n{}", output2); + assertNotNull(output2, "output2 should not be null"); + + var output3 = conversation.predict(Map.of("input", "Tell me about yourself.")); + // Sure! I am an AI language model created by OpenAI. I was trained on a large dataset ... + LOG.info("output3: \n{}", output3); + assertNotNull(output3, "output3 should not be null"); } } \ No newline at end of file