Skip to content

Commit

Permalink
add test cases for memory of chat models
Browse files Browse the repository at this point in the history
  • Loading branch information
HamaWhiteGG committed Jun 24, 2023
1 parent 1f87329 commit ebfd774
Show file tree
Hide file tree
Showing 6 changed files with 122 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -41,16 +41,14 @@ public class ConversationChain extends LLMChain {
protected String inputKey = "input";

public ConversationChain(BaseLanguageModel llm) {
super(llm, PROMPT, "response");
// Default memory store.
this.memory = new ConversationBufferMemory();

validatePromptInputVariables();
this(llm, PROMPT, new ConversationBufferMemory());
}

public ConversationChain(BaseLanguageModel llm, BasePromptTemplate prompt, BaseMemory memory) {
super(llm, prompt);
super(llm, prompt, "response");
this.memory = memory;

validatePromptInputVariables();
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@
import com.hw.langchain.base.language.BaseLanguageModel;
import com.hw.langchain.chains.base.Chain;
import com.hw.langchain.prompts.base.BasePromptTemplate;
import com.hw.langchain.schema.BaseLLMOutputParser;
import com.hw.langchain.schema.LLMResult;
import com.hw.langchain.schema.NoOpOutputParser;
import com.hw.langchain.schema.PromptValue;

import org.slf4j.Logger;
Expand Down Expand Up @@ -50,6 +52,18 @@ public class LLMChain extends Chain {

protected String outputKey = "text";

/**
* Output parser to use.
* Defaults to one that takes the most likely string but does not change it.
*/
protected BaseLLMOutputParser<String> outputParser = new NoOpOutputParser();

/**
* Whether to return only the final parsed result. Defaults to true.
* If false, will return a bunch of extra information about the generation.
*/
protected boolean returnFinalOnly = true;

public LLMChain(BaseLanguageModel llm, BasePromptTemplate prompt) {
this.llm = llm;
this.prompt = prompt;
Expand Down Expand Up @@ -126,10 +140,18 @@ private List<String> prepStop(List<Map<String, Object>> inputList) {
/**
* Create outputs from response.
*/
private List<Map<String, String>> createOutputs(LLMResult response) {
return response.getGenerations().stream()
.map(generationList -> Map.of(outputKey, generationList.get(0).getText()))
private List<Map<String, String>> createOutputs(LLMResult llmResult) {
var result = llmResult.getGenerations().stream()
.map(generation -> Map.of(outputKey, outputParser.parseResult(generation),
"full_generation", generation.toString()))
.toList();

if (returnFinalOnly) {
result = result.stream()
.map(r -> Map.of(outputKey, r.get(outputKey)))
.toList();
}
return result;
}

/**
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.hw.langchain.schema;

import java.io.Serializable;
import java.util.List;

/**
* @author HamaWhite
*/
public abstract class BaseLLMOutputParser<T> implements Serializable {

/**
* Parse LLM Result.
*/
public abstract T parseResult(List<? extends Generation> result);
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,20 @@

package com.hw.langchain.schema;

import java.util.List;

/**
* Class to parse the output of an LLM call.
* <p>
* Output parsers help structure language model responses.
*
* @author HamaWhite
*/
public abstract class BaseOutputParser<T> {
public abstract class BaseOutputParser<T> extends BaseLLMOutputParser<T> {

@Override
public T parseResult(List<? extends Generation> result) {
return parse(result.get(0).getText());
}

/**
* Parse the output of an LLM call.
Expand Down Expand Up @@ -55,5 +61,7 @@ public Object parseWithPrompt(String completion, PromptValue prompt) {
*
* @return format instructions
*/
public abstract String getFormatInstructions();
public String getFormatInstructions() {
throw new UnsupportedOperationException("Method getFormatInstructions() is not implemented.");
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.hw.langchain.schema;

/**
* Output parser that just returns the text as is.
*
* @author HamaWhite
*/
public class NoOpOutputParser extends BaseOutputParser<String> {

@Override
public String parse(String text) {
return text;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -67,18 +67,28 @@ void testConversationChainWithChatModel() {
var prompt = ChatPromptTemplate.fromMessages(List.of(
SystemMessagePromptTemplate.fromTemplate(
"The following is a friendly conversation between a human and an AI. The AI is talkative and " +
"provides lots of specific details from its context. If the AI does not know the answer to a "
+
"question, it truthfully says it does not know."),
"provides lots of specific details from its context. If the AI does not know the " +
"answer to a question, it truthfully says it does not know."),
new MessagesPlaceholder("history"),
HumanMessagePromptTemplate.fromTemplate("{input}")));

var chat = ChatOpenAI.builder().temperature(0).build().init();
var memory = new ConversationBufferMemory(true);
var conversation = new ConversationChain(chat, prompt, memory);

conversation.predict(Map.of("input", "Hi there!"));
conversation.predict(Map.of("input", "I'm doing well! Just having a conversation with an AI."));
conversation.predict(Map.of("input", "Tell me about yourself."));
var output1 = conversation.predict(Map.of("input", "Hi there!"));
// Hello! How can I assist you today?
LOG.info("output1: \n{}", output1);
assertNotNull(output1, "output1 should not be null");

var output2 = conversation.predict(Map.of("input", "I'm doing well! Just having a conversation with an AI."));
// That sounds like fun! I'm happy to chat with you. What would you like to talk about?
LOG.info("output2: \n{}", output2);
assertNotNull(output2, "output2 should not be null");

var output3 = conversation.predict(Map.of("input", "Tell me about yourself."));
// Sure! I am an AI language model created by OpenAI. I was trained on a large dataset ...
LOG.info("output3: \n{}", output3);
assertNotNull(output3, "output3 should not be null");
}
}

0 comments on commit ebfd774

Please sign in to comment.