forked from HamaWhiteGG/langchain-java
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request HamaWhiteGG#51 from HamaWhiteGG/dev
Support database memory
- Loading branch information
Showing
14 changed files
with
373 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
55 changes: 55 additions & 0 deletions
55
...n/java/com/hw/langchain/memory/chat/message/histories/database/ChatMessageRepository.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package com.hw.langchain.memory.chat.message.histories.database; | ||
|
||
import com.hw.langchain.schema.BaseMessage; | ||
|
||
import java.util.List; | ||
|
||
/** | ||
* interface for database supported chat message repository; | ||
* | ||
* @author zhangxiaojia002 | ||
* @date 2023/7/20 9:50 下午 | ||
**/ | ||
public interface ChatMessageRepository { | ||
|
||
/** | ||
* Loads all historical chat messages for the given sessionId. | ||
* | ||
* @param sessionId The unique identifier of the chat session. | ||
* @return A List of BaseMessage containing the chat message history for the session. | ||
*/ | ||
List<BaseMessage> loadMessage(String sessionId); | ||
|
||
/** | ||
* Saves a chat message to the specified sessionId. | ||
* | ||
* @param sessionId The unique identifier of the chat session. | ||
* @param baseMessage The BaseMessage to be saved. | ||
*/ | ||
void saveMessage(String sessionId, BaseMessage baseMessage); | ||
|
||
/** | ||
* Clears all chat messages for the specified sessionId. | ||
* | ||
* @param sessionId The unique identifier of the chat session. | ||
*/ | ||
void clearSessionChatMessage(String sessionId); | ||
} |
56 changes: 56 additions & 0 deletions
56
...a/com/hw/langchain/memory/chat/message/histories/database/DataBaseChatMessageHistory.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package com.hw.langchain.memory.chat.message.histories.database; | ||
|
||
import com.hw.langchain.schema.BaseChatMessageHistory; | ||
import com.hw.langchain.schema.BaseMessage; | ||
|
||
import java.util.List; | ||
|
||
/** | ||
* database based chat message history; | ||
* | ||
* @author zhangxiaojia002 | ||
* @date 2023/7/20 9:53 下午 | ||
**/ | ||
public class DataBaseChatMessageHistory extends BaseChatMessageHistory { | ||
|
||
private final String sessionId; | ||
private final ChatMessageRepository chatMessageRepository; | ||
|
||
public DataBaseChatMessageHistory(String sessionId, ChatMessageRepository chatMessageRepository) { | ||
this.sessionId = sessionId; | ||
this.chatMessageRepository = chatMessageRepository; | ||
} | ||
|
||
@Override | ||
public void addMessage(BaseMessage message) { | ||
chatMessageRepository.saveMessage(sessionId, message); | ||
} | ||
|
||
@Override | ||
public void clear() { | ||
chatMessageRepository.clearSessionChatMessage(sessionId); | ||
} | ||
|
||
@Override | ||
public List<BaseMessage> getMessages() { | ||
return chatMessageRepository.loadMessage(sessionId); | ||
} | ||
} |
63 changes: 63 additions & 0 deletions
63
...om/hw/langchain/memory/chat/message/histories/database/redis/RedisChatMessageHistory.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package com.hw.langchain.memory.chat.message.histories.database.redis; | ||
|
||
import com.hw.langchain.memory.chat.message.histories.database.DataBaseChatMessageHistory; | ||
import com.hw.langchain.schema.BaseChatMessageHistory; | ||
import com.hw.langchain.schema.BaseMessage; | ||
|
||
import org.redisson.api.RedissonClient; | ||
|
||
import java.util.List; | ||
|
||
/** | ||
* a simple wrapper for DataBaseChatMessageHistory with redisChatMessageRepository; | ||
* | ||
* @author zhangxiaojia002 | ||
* @date 2023/7/21 10:49 上午 | ||
**/ | ||
public class RedisChatMessageHistory extends BaseChatMessageHistory { | ||
|
||
private DataBaseChatMessageHistory dataBaseChatMessageHistory; | ||
|
||
public RedisChatMessageHistory(String sessionId, RedissonClient redissonClient, int ttl) { | ||
RedisChatMessageRepository redisChatMessageRepository = new RedisChatMessageRepository(redissonClient, ttl); | ||
dataBaseChatMessageHistory = new DataBaseChatMessageHistory(sessionId, redisChatMessageRepository); | ||
} | ||
|
||
public RedisChatMessageHistory(String sessionId, RedissonClient redissonClient) { | ||
RedisChatMessageRepository redisChatMessageRepository = new RedisChatMessageRepository(redissonClient); | ||
dataBaseChatMessageHistory = new DataBaseChatMessageHistory(sessionId, redisChatMessageRepository); | ||
} | ||
|
||
@Override | ||
public void addMessage(BaseMessage message) { | ||
dataBaseChatMessageHistory.addMessage(message); | ||
} | ||
|
||
@Override | ||
public void clear() { | ||
dataBaseChatMessageHistory.clear(); | ||
} | ||
|
||
@Override | ||
public List<BaseMessage> getMessages() { | ||
return dataBaseChatMessageHistory.getMessages(); | ||
} | ||
} |
97 changes: 97 additions & 0 deletions
97
...hw/langchain/memory/chat/message/histories/database/redis/RedisChatMessageRepository.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package com.hw.langchain.memory.chat.message.histories.database.redis; | ||
|
||
import com.fasterxml.jackson.core.type.TypeReference; | ||
import com.hw.langchain.chains.query.constructor.JsonUtils; | ||
import com.hw.langchain.memory.chat.message.histories.database.ChatMessageRepository; | ||
import com.hw.langchain.schema.BaseMessage; | ||
|
||
import org.redisson.api.RQueue; | ||
import org.redisson.api.RedissonClient; | ||
|
||
import lombok.Builder; | ||
import lombok.experimental.Tolerate; | ||
|
||
import java.time.Duration; | ||
import java.time.temporal.ChronoUnit; | ||
import java.util.List; | ||
import java.util.Map; | ||
|
||
/** | ||
* @author zhangxiaojia002 | ||
* @date 2023/7/20 9:59 下午 | ||
**/ | ||
@Builder | ||
public class RedisChatMessageRepository implements ChatMessageRepository { | ||
|
||
private RedissonClient redissonClient; | ||
private String keyPrefix = "message_store"; | ||
private Integer ttlSeconds; | ||
|
||
@Tolerate | ||
public RedisChatMessageRepository(RedissonClient redissonClient) { | ||
this.redissonClient = redissonClient; | ||
} | ||
|
||
@Tolerate | ||
public RedisChatMessageRepository(RedissonClient redissonClient, int ttlSeconds) { | ||
this.redissonClient = redissonClient; | ||
this.ttlSeconds = ttlSeconds; | ||
} | ||
|
||
/** | ||
* Construct the record key to use | ||
* | ||
* @return key | ||
*/ | ||
private String key(String sessionId) { | ||
if (this.keyPrefix == null) { | ||
return sessionId; | ||
} | ||
return this.keyPrefix + sessionId; | ||
} | ||
|
||
@Override | ||
public List<BaseMessage> loadMessage(String sessionId) { | ||
RQueue<String> messageQueue = redissonClient.getQueue(key(sessionId)); | ||
List<String> messageJSonStrList = messageQueue.readAll(); | ||
return messageJSonStrList.stream().map(x -> { | ||
Map<String, Object> data = | ||
JsonUtils.convertFromJsonStr(x, new TypeReference<>() { | ||
}); | ||
return BaseMessage.fromMap(data); | ||
}).toList(); | ||
} | ||
|
||
@Override | ||
public void saveMessage(String sessionId, BaseMessage baseMessage) { | ||
RQueue<String> messageQueue = redissonClient.getQueue(key(sessionId)); | ||
messageQueue.add(JsonUtils.toJsonStringWithIndent(baseMessage.toMap())); | ||
if (this.ttlSeconds != null) { | ||
messageQueue.expire(Duration.of(ttlSeconds, ChronoUnit.SECONDS)); | ||
} | ||
} | ||
|
||
@Override | ||
public void clearSessionChatMessage(String sessionId) { | ||
RQueue<String> messageQueue = redissonClient.getQueue(key(sessionId)); | ||
messageQueue.delete(); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.