forked from HamaWhiteGG/langchain-java
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request HamaWhiteGG#52 from zhangxiaojiawow/limited_conver…
…sation_buffer_memory support load limited messages from chat memory
- Loading branch information
Showing
10 changed files
with
344 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
49 changes: 49 additions & 0 deletions
49
...in-core/src/main/java/com/hw/langchain/memory/buffer/BufferWindowMemoryMessagePicker.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package com.hw.langchain.memory.buffer; | ||
|
||
import com.hw.langchain.exception.LangChainException; | ||
import com.hw.langchain.schema.BaseMessage; | ||
|
||
import java.util.List; | ||
|
||
/** | ||
* @author zhangxiaojia002 | ||
* @date 2023/7/22 12:24 下午 | ||
**/ | ||
public class BufferWindowMemoryMessagePicker implements MemoryMessagePicker { | ||
|
||
// k represents the number of dialog round of chat messages we want to add into gpt request | ||
private int k = 5; | ||
public BufferWindowMemoryMessagePicker(int k) { | ||
if (k <= 0) { | ||
throw new LangChainException("buffer window size musts larger than zero"); | ||
} | ||
this.k = k; | ||
} | ||
|
||
@Override | ||
public List<BaseMessage> pickMemoryMessage(List<BaseMessage> memoryMessages) { | ||
// select the last 2 * k messages from memoryMessages List | ||
if (memoryMessages.size() > 2 * k) { | ||
return memoryMessages.subList(memoryMessages.size() - 2 * k, memoryMessages.size()); | ||
} | ||
return memoryMessages; | ||
} | ||
} |
54 changes: 54 additions & 0 deletions
54
...in-core/src/main/java/com/hw/langchain/memory/buffer/ConversationBufferLimitedMemory.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package com.hw.langchain.memory.buffer; | ||
|
||
import com.hw.langchain.schema.BaseChatMessageHistory; | ||
import com.hw.langchain.schema.BaseMessage; | ||
|
||
import java.util.List; | ||
|
||
/** | ||
* @author zhangxiaojia002 | ||
* @date 2023/7/22 1:07 下午 | ||
**/ | ||
public class ConversationBufferLimitedMemory extends ConversationBufferMemory { | ||
|
||
private final MemoryMessagePicker memoryMessagePicker; | ||
|
||
public ConversationBufferLimitedMemory(MemoryMessagePicker memoryMessagePicker) { | ||
super(); | ||
this.memoryMessagePicker = memoryMessagePicker; | ||
} | ||
|
||
public ConversationBufferLimitedMemory(MemoryMessagePicker memoryMessagePicker, boolean returnMessage) { | ||
super(returnMessage); | ||
this.memoryMessagePicker = memoryMessagePicker; | ||
} | ||
|
||
public ConversationBufferLimitedMemory(MemoryMessagePicker memoryMessagePicker, boolean returnMessage, | ||
BaseChatMessageHistory chatMessageHistory) { | ||
super(returnMessage, chatMessageHistory); | ||
this.memoryMessagePicker = memoryMessagePicker; | ||
} | ||
|
||
@Override | ||
protected List<BaseMessage> getMemoryMessages() { | ||
return memoryMessagePicker.pickMemoryMessage(super.getMemoryMessages()); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
40 changes: 40 additions & 0 deletions
40
...ain-core/src/main/java/com/hw/langchain/memory/buffer/ConversationBufferWindowMemory.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package com.hw.langchain.memory.buffer; | ||
|
||
import com.hw.langchain.schema.BaseChatMessageHistory; | ||
|
||
/** | ||
* @author zhangxiaojia002 | ||
* @date 2023/7/22 1:12 下午 | ||
**/ | ||
public class ConversationBufferWindowMemory extends ConversationBufferLimitedMemory { | ||
|
||
public ConversationBufferWindowMemory(int k) { | ||
super(new BufferWindowMemoryMessagePicker(k)); | ||
} | ||
|
||
public ConversationBufferWindowMemory(int k, boolean returnMessages) { | ||
super(new BufferWindowMemoryMessagePicker(k), returnMessages); | ||
} | ||
|
||
public ConversationBufferWindowMemory(int k, boolean returnMessages, BaseChatMessageHistory chatMessageHistory) { | ||
super(new BufferWindowMemoryMessagePicker(k), returnMessages, chatMessageHistory); | ||
} | ||
} |
40 changes: 40 additions & 0 deletions
40
langchain-core/src/main/java/com/hw/langchain/memory/buffer/MemoryMessagePicker.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package com.hw.langchain.memory.buffer; | ||
|
||
import com.hw.langchain.schema.BaseMessage; | ||
|
||
import java.util.List; | ||
|
||
/** | ||
* Since the token we input to LLM is limited, the history messages we add into gpt request is also limited. <br/> | ||
* <p/> | ||
* This interface defines the behavior to pick up limited history memory.<br/> | ||
* @author zhangxiaojia002 | ||
* @date 2023/7/22 12:14 下午 | ||
**/ | ||
public interface MemoryMessagePicker { | ||
|
||
/** | ||
* pick up memory messages as needed | ||
* @param memoryMessages all history memory messages | ||
* @return memory messages picked | ||
*/ | ||
List<BaseMessage> pickMemoryMessage(List<BaseMessage> memoryMessages); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
75 changes: 75 additions & 0 deletions
75
...ore/src/test/java/com/hw/langchain/memory/buffer/BufferWindowMemoryMessagePickerTest.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package com.hw.langchain.memory.buffer; | ||
|
||
import com.hw.langchain.schema.BaseMessage; | ||
import com.hw.langchain.schema.HumanMessage; | ||
|
||
import org.junit.jupiter.api.Test; | ||
|
||
import java.util.ArrayList; | ||
import java.util.List; | ||
|
||
import static org.junit.jupiter.api.Assertions.*; | ||
|
||
/** | ||
* @author zhangxiaojia002 | ||
* @date 2023/7/22 12:41 下午 | ||
**/ | ||
class BufferWindowMemoryMessagePickerTest { | ||
|
||
// when 2*K <= memoryMessages.size(), return all memoryMessages | ||
@Test | ||
void returnAllMessageWhenListLengthLessOrEquals2K() { | ||
int k = 2; | ||
BufferWindowMemoryMessagePicker bufferWindowMemoryMessagePicker = new BufferWindowMemoryMessagePicker(k); | ||
assertEquals(2 * k, | ||
bufferWindowMemoryMessagePicker.pickMemoryMessage( | ||
generateRandomHumanMessages(2 * k)).size()); | ||
assertEquals(2 * k - 1, | ||
bufferWindowMemoryMessagePicker.pickMemoryMessage( | ||
generateRandomHumanMessages(2 * k - 1)).size()); | ||
} | ||
|
||
// when 2*K > memoryMessage.size(), return last 2*k memoryMessages; | ||
@Test | ||
void returnLast2KMemoryMessagesWhenListLengthLargerThan2K() { | ||
int k = 2; | ||
BufferWindowMemoryMessagePicker bufferWindowMemoryMessagePicker = new BufferWindowMemoryMessagePicker(k); | ||
assertEquals(2 * k, | ||
bufferWindowMemoryMessagePicker.pickMemoryMessage( | ||
generateRandomHumanMessages(2 * k + 1)).size()); | ||
assertEquals(2 * k, | ||
bufferWindowMemoryMessagePicker.pickMemoryMessage( | ||
generateRandomHumanMessages(2 * k + 2)).size()); | ||
|
||
List<BaseMessage> messageList = generateRandomHumanMessages(2 * k + 2); | ||
assertEquals(messageList.subList(2, messageList.size()), | ||
bufferWindowMemoryMessagePicker.pickMemoryMessage(messageList)); | ||
} | ||
|
||
// generate random k Human messages | ||
private List<BaseMessage> generateRandomHumanMessages(int k) { | ||
List<BaseMessage> baseMessageList = new ArrayList<>(); | ||
for (int i = 0; i < k; i++) { | ||
baseMessageList.add(new HumanMessage(String.valueOf(i))); | ||
} | ||
return baseMessageList; | ||
} | ||
} |
54 changes: 54 additions & 0 deletions
54
...core/src/test/java/com/hw/langchain/memory/buffer/ConversationBufferWindowMemoryTest.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package com.hw.langchain.memory.buffer; | ||
|
||
import com.hw.langchain.memory.chat.memory.BaseChatMemory; | ||
import com.hw.langchain.schema.AIMessage; | ||
import com.hw.langchain.schema.BaseChatMessageHistory; | ||
import com.hw.langchain.schema.HumanMessage; | ||
|
||
import org.junit.jupiter.api.Test; | ||
import org.mockito.Mockito; | ||
|
||
import java.util.List; | ||
import java.util.Map; | ||
|
||
import static org.junit.jupiter.api.Assertions.*; | ||
import static org.mockito.Mockito.when; | ||
|
||
/** | ||
* @author zhangxiaojia002 | ||
* @date 2023/7/22 1:25 下午 | ||
**/ | ||
class ConversationBufferWindowMemoryTest { | ||
|
||
@Test | ||
void whenChatMessageHistoryHasManyMessagesOnlyReturnMax2KMessages() { | ||
BaseChatMessageHistory baseChatMessageHistory = Mockito.mock(BaseChatMessageHistory.class); | ||
BaseChatMemory conversationBufferWindowMemory = | ||
new ConversationBufferWindowMemory(1, true, baseChatMessageHistory); | ||
|
||
when(baseChatMessageHistory.getMessages()).thenReturn( | ||
List.of(new HumanMessage("hi"), new AIMessage("hi"), | ||
new HumanMessage("what are you doing"), new AIMessage("I'm thinking"))); | ||
|
||
assertEquals(2, | ||
((List<?>) conversationBufferWindowMemory.loadMemoryVariables(Map.of()).get("history")).size()); | ||
} | ||
} |
Oops, something went wrong.