Skip to content

Commit

Permalink
support load limited messages from chat memory
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangxiaojiawow committed Jul 22, 2023
1 parent 791b52c commit 8cc8a5d
Show file tree
Hide file tree
Showing 10 changed files with 344 additions and 3 deletions.
5 changes: 4 additions & 1 deletion langchain-core/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,10 @@
<groupId>io.netty</groupId>
<artifactId>netty-resolver-dns</artifactId>
</dependency>

<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-core</artifactId>
</dependency>
</dependencies>

<build>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.hw.langchain.memory.buffer;

import com.hw.langchain.exception.LangChainException;
import com.hw.langchain.schema.BaseMessage;

import java.util.List;

/**
* @author zhangxiaojia002
* @date 2023/7/22 12:24 下午
**/
public class BufferWindowMemoryMessagePicker implements MemoryMessagePicker {

// k represents the number of dialog round of chat messages we want to add into gpt request
private int k = 5;
public BufferWindowMemoryMessagePicker(int k) {
if (k <= 0) {
throw new LangChainException("buffer window size musts larger than zero");
}
this.k = k;
}

@Override
public List<BaseMessage> pickMemoryMessage(List<BaseMessage> memoryMessages) {
// select the last 2 * k messages from memoryMessages List
if (memoryMessages.size() > 2 * k) {
return memoryMessages.subList(memoryMessages.size() - 2 * k, memoryMessages.size());
}
return memoryMessages;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.hw.langchain.memory.buffer;

import com.hw.langchain.schema.BaseChatMessageHistory;
import com.hw.langchain.schema.BaseMessage;

import java.util.List;

/**
* @author zhangxiaojia002
* @date 2023/7/22 1:07 下午
**/
public class ConversationBufferLimitedMemory extends ConversationBufferMemory {

private final MemoryMessagePicker memoryMessagePicker;

public ConversationBufferLimitedMemory(MemoryMessagePicker memoryMessagePicker) {
super();
this.memoryMessagePicker = memoryMessagePicker;
}

public ConversationBufferLimitedMemory(MemoryMessagePicker memoryMessagePicker, boolean returnMessage) {
super(returnMessage);
this.memoryMessagePicker = memoryMessagePicker;
}

public ConversationBufferLimitedMemory(MemoryMessagePicker memoryMessagePicker, boolean returnMessage,
BaseChatMessageHistory chatMessageHistory) {
super(returnMessage, chatMessageHistory);
this.memoryMessagePicker = memoryMessagePicker;
}

@Override
protected List<BaseMessage> getMemoryMessages() {
return memoryMessagePicker.pickMemoryMessage(super.getMemoryMessages());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
package com.hw.langchain.memory.buffer;

import com.hw.langchain.memory.chat.memory.BaseChatMemory;
import com.hw.langchain.schema.BaseChatMessageHistory;
import com.hw.langchain.schema.BaseMessage;

import java.util.List;
import java.util.Map;
Expand All @@ -45,17 +47,26 @@ public ConversationBufferMemory(boolean returnMessages) {
this.returnMessages = returnMessages;
}

public ConversationBufferMemory(boolean returnMessages, BaseChatMessageHistory messageHistory) {
super(messageHistory);
this.returnMessages = returnMessages;
}

/**
* String buffer of memory.
*/
public Object buffer() {
if (returnMessages) {
return chatMemory.getMessages();
return getMemoryMessages();
} else {
return getBufferString(chatMemory.getMessages(), humanPrefix, aiPrefix);
return getBufferString(getMemoryMessages(), humanPrefix, aiPrefix);
}
}

protected List<BaseMessage> getMemoryMessages() {
return chatMemory.getMessages();
}

/**
* Will always return list of memory variables.
*/
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.hw.langchain.memory.buffer;

import com.hw.langchain.schema.BaseChatMessageHistory;

/**
* @author zhangxiaojia002
* @date 2023/7/22 1:12 下午
**/
public class ConversationBufferWindowMemory extends ConversationBufferLimitedMemory {

public ConversationBufferWindowMemory(int k) {
super(new BufferWindowMemoryMessagePicker(k));
}

public ConversationBufferWindowMemory(int k, boolean returnMessages) {
super(new BufferWindowMemoryMessagePicker(k), returnMessages);
}

public ConversationBufferWindowMemory(int k, boolean returnMessages, BaseChatMessageHistory chatMessageHistory) {
super(new BufferWindowMemoryMessagePicker(k), returnMessages, chatMessageHistory);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.hw.langchain.memory.buffer;

import com.hw.langchain.schema.BaseMessage;

import java.util.List;

/**
* Since the token we input to LLM is limited, the history messages we add into gpt request is also limited. <br/>
* <p/>
* This interface defines the behavior to pick up limited history memory.<br/>
* @author zhangxiaojia002
* @date 2023/7/22 12:14 下午
**/
public interface MemoryMessagePicker {

/**
* pick up memory messages as needed
* @param memoryMessages all history memory messages
* @return memory messages picked
*/
List<BaseMessage> pickMemoryMessage(List<BaseMessage> memoryMessages);
}
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,14 @@ public abstract class BaseChatMemory extends BaseMemory {

protected boolean returnMessages;

protected BaseChatMemory() {

}

protected BaseChatMemory(BaseChatMessageHistory chatMemory) {
this.chatMemory = chatMemory;
}

private Pair<String, String> getInputOutput(Map<String, Object> inputs, Map<String, String> outputs) {
String promptInputKey;
if (inputKey == null) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.hw.langchain.memory.buffer;

import com.hw.langchain.schema.BaseMessage;
import com.hw.langchain.schema.HumanMessage;

import org.junit.jupiter.api.Test;

import java.util.ArrayList;
import java.util.List;

import static org.junit.jupiter.api.Assertions.*;

/**
* @author zhangxiaojia002
* @date 2023/7/22 12:41 下午
**/
class BufferWindowMemoryMessagePickerTest {

// when 2*K <= memoryMessages.size(), return all memoryMessages
@Test
void returnAllMessageWhenListLengthLessOrEquals2K() {
int k = 2;
BufferWindowMemoryMessagePicker bufferWindowMemoryMessagePicker = new BufferWindowMemoryMessagePicker(k);
assertEquals(2 * k,
bufferWindowMemoryMessagePicker.pickMemoryMessage(
generateRandomHumanMessages(2 * k)).size());
assertEquals(2 * k - 1,
bufferWindowMemoryMessagePicker.pickMemoryMessage(
generateRandomHumanMessages(2 * k - 1)).size());
}

// when 2*K > memoryMessage.size(), return last 2*k memoryMessages;
@Test
void returnLast2KMemoryMessagesWhenListLengthLargerThan2K() {
int k = 2;
BufferWindowMemoryMessagePicker bufferWindowMemoryMessagePicker = new BufferWindowMemoryMessagePicker(k);
assertEquals(2 * k,
bufferWindowMemoryMessagePicker.pickMemoryMessage(
generateRandomHumanMessages(2 * k + 1)).size());
assertEquals(2 * k,
bufferWindowMemoryMessagePicker.pickMemoryMessage(
generateRandomHumanMessages(2 * k + 2)).size());

List<BaseMessage> messageList = generateRandomHumanMessages(2 * k + 2);
assertEquals(messageList.subList(2, messageList.size()),
bufferWindowMemoryMessagePicker.pickMemoryMessage(messageList));
}

// generate random k Human messages
private List<BaseMessage> generateRandomHumanMessages(int k) {
List<BaseMessage> baseMessageList = new ArrayList<>();
for (int i = 0; i < k; i++) {
baseMessageList.add(new HumanMessage(String.valueOf(i)));
}
return baseMessageList;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.hw.langchain.memory.buffer;

import com.hw.langchain.memory.chat.memory.BaseChatMemory;
import com.hw.langchain.schema.AIMessage;
import com.hw.langchain.schema.BaseChatMessageHistory;
import com.hw.langchain.schema.HumanMessage;

import org.junit.jupiter.api.Test;
import org.mockito.Mockito;

import java.util.List;
import java.util.Map;

import static org.junit.jupiter.api.Assertions.*;
import static org.mockito.Mockito.when;

/**
* @author zhangxiaojia002
* @date 2023/7/22 1:25 下午
**/
class ConversationBufferWindowMemoryTest {

@Test
void whenChatMessageHistoryHasManyMessagesOnlyReturnMax2KMessages() {
BaseChatMessageHistory baseChatMessageHistory = Mockito.mock(BaseChatMessageHistory.class);
BaseChatMemory conversationBufferWindowMemory =
new ConversationBufferWindowMemory(1, true, baseChatMessageHistory);

when(baseChatMessageHistory.getMessages()).thenReturn(
List.of(new HumanMessage("hi"), new AIMessage("hi"),
new HumanMessage("what are you doing"), new AIMessage("I'm thinking")));

assertEquals(2,
((List<?>) conversationBufferWindowMemory.loadMemoryVariables(Map.of()).get("history")).size());
}
}
Loading

0 comments on commit 8cc8a5d

Please sign in to comment.