Skip to content

Commit

Permalink
add database based chat history message memory
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangxiaojiawow committed Jul 21, 2023
1 parent 3a267e8 commit 24be2ec
Show file tree
Hide file tree
Showing 6 changed files with 185 additions and 2 deletions.
16 changes: 16 additions & 0 deletions langchain-core/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,22 @@
<groupId>org.awaitility</groupId>
<artifactId>awaitility</artifactId>
</dependency>

<dependency>
<groupId>org.redisson</groupId>
<artifactId>redisson</artifactId>
</dependency>

<dependency>
<groupId>io.netty</groupId>
<artifactId>netty-resolver</artifactId>
</dependency>

<dependency>
<groupId>io.netty</groupId>
<artifactId>netty-resolver-dns</artifactId>
</dependency>

</dependencies>

<build>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@
package com.hw.langchain.chains.query.constructor;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.core.util.DefaultIndenter;
import com.fasterxml.jackson.core.util.DefaultPrettyPrinter;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.ObjectWriter;
import com.fasterxml.jackson.databind.SerializationFeature;
import com.hw.langchain.exception.LangChainException;
import com.sun.jna.platform.win32.Netapi32Util.User;

/**
* @author HamaWhite
Expand All @@ -52,14 +52,26 @@ public static String toJsonStringWithIndent(Object object, int indent) {
}
}

public static <T> T convertFromJsonStr(String jsonStr, Class<T> clazz) {
public static String toJsonStringWithIndent(Object object) {
return toJsonStringWithIndent(object, 4);
}

public static <T> T convertFromJsonStr(String jsonStr, Class<T> clazz) {
try {
return OBJECT_MAPPER.readValue(jsonStr, clazz);
} catch (JsonProcessingException e) {
throw new LangChainException("Failed to deserialize json str", e);
}
}

public static <T> T convertFromJsonStr(String jsonStr, TypeReference<T> typeReference) {
try {
return OBJECT_MAPPER.readValue(jsonStr, typeReference);
} catch (JsonProcessingException e) {
throw new LangChainException("Failed to deserialize json str", e);
}
}

private static DefaultPrettyPrinter getPrettyPrinter(int indent) {
DefaultPrettyPrinter printer = new DefaultPrettyPrinter();
printer.indentArraysWith(DefaultPrettyPrinter.FixedSpaceIndenter.instance);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package com.hw.langchain.memory.chat.message.histories.database;

import com.hw.langchain.schema.BaseMessage;
import java.util.List;

/**
* interface for database supported chat message repository;
* @author zhangxiaojia002
* @date 2023/7/20 9:50 下午
**/
public interface ChatMessageRepository {

/**
* load all history chat message of given sessionId
* @param sessionId
* @return
*/
List<BaseMessage> loadMessage(String sessionId);

void saveMessage(String sessionId, BaseMessage baseMessage);

void clearSessionChatMessage(String sessionId);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package com.hw.langchain.memory.chat.message.histories.database;

import com.hw.langchain.schema.BaseChatMessageHistory;
import com.hw.langchain.schema.BaseMessage;
import java.util.List;

/**
* database based chat message history;
* @author zhangxiaojia002
* @date 2023/7/20 9:53 下午
**/
public class DataBaseChatMessageHistory extends BaseChatMessageHistory {
private final String sessionId;
private final ChatMessageRepository chatMessageRepository;

public DataBaseChatMessageHistory(String sessionId, ChatMessageRepository chatMessageRepository) {
this.sessionId = sessionId;
this.chatMessageRepository = chatMessageRepository;
}

@Override
public void addMessage(BaseMessage message) {
chatMessageRepository.saveMessage(sessionId, message);
}

@Override
public void clear() {
chatMessageRepository.clearSessionChatMessage(sessionId);
}

@Override
public List<BaseMessage> getMessages() {
return chatMessageRepository.loadMessage(sessionId);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
package com.hw.langchain.memory.chat.message.histories.database;

import com.fasterxml.jackson.core.type.TypeReference;
import com.hw.langchain.chains.query.constructor.JsonUtils;
import com.hw.langchain.schema.BaseMessage;
import java.time.Duration;
import java.time.temporal.ChronoUnit;
import java.util.List;
import java.util.Map;
import lombok.Builder;
import lombok.experimental.Tolerate;
import org.redisson.api.RQueue;
import org.redisson.api.RedissonClient;

/**
* @author zhangxiaojia002
* @date 2023/7/20 9:59 下午
**/
@Builder
public class RedisChatMessageRepository implements ChatMessageRepository {

private RedissonClient redissonClient;
private String keyPrefix = "message_store";
private Integer ttlSeconds;

@Tolerate
public RedisChatMessageRepository(RedissonClient redissonClient) {
this.redissonClient = redissonClient;
}

@Tolerate
public RedisChatMessageRepository(RedissonClient redissonClient, int ttlSeconds) {
this.redissonClient = redissonClient;
this.ttlSeconds = ttlSeconds;
}

/**
* Construct the record key to use
*
* @return
*/
private String key(String sessionId) {
if (this.keyPrefix == null) {
return sessionId;
}
return this.keyPrefix + sessionId;
}

@Override
public List<BaseMessage> loadMessage(String sessionId) {
RQueue<String> messageQueue = redissonClient.getQueue(key(sessionId));
List<String> messageJSonStrList = messageQueue.readAll();
return messageJSonStrList.stream().map(x -> {
Map<String, Object> data =
JsonUtils.convertFromJsonStr(x, new TypeReference<>() {
});
return BaseMessage.fromMap(data);
}).toList();
}

@Override
public void saveMessage(String sessionId, BaseMessage baseMessage) {
RQueue<String> messageQueue = redissonClient.getQueue(key(sessionId));
messageQueue.add(JsonUtils.toJsonStringWithIndent(baseMessage.toMap()));
if (this.ttlSeconds != null) {
messageQueue.expire(Duration.of(ttlSeconds, ChronoUnit.SECONDS));
}
}

@Override
public void clearSessionChatMessage(String sessionId) {
RQueue<String> messageQueue = redissonClient.getQueue(key(sessionId));
messageQueue.delete();
}
}
22 changes: 22 additions & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@
<jackson-annotation.version>2.15.1</jackson-annotation.version>
<commons-collections4.version>4.4</commons-collections4.version>
<hibernate-validator.version>8.0.0.Final</hibernate-validator.version>
<redisson.version>3.17.3</redisson.version>
<netty-resolver.version>4.1.43.Final</netty-resolver.version>

<target.java.version>17</target.java.version>
<maven-compiler-plugin.version>3.8.0</maven-compiler-plugin.version>
Expand Down Expand Up @@ -211,6 +213,26 @@
<scope>test</scope>
</dependency>

<dependency>
<groupId>org.redisson</groupId>
<artifactId>redisson</artifactId>
<version>${redisson.version}</version>
</dependency>

<dependency>
<groupId>io.netty</groupId>
<artifactId>netty-resolver</artifactId>
<version>${netty-resolver.version}</version>
</dependency>

<!-- https://mvnrepository.com/artifact/io.netty/netty-resolver-dns -->
<dependency>
<groupId>io.netty</groupId>
<artifactId>netty-resolver-dns</artifactId>
<version>${netty-resolver.version}</version>
</dependency>


</dependencies>
</dependencyManagement>

Expand Down

0 comments on commit 24be2ec

Please sign in to comment.