Skip to content

Commit

Permalink
Merge pull request HamaWhiteGG#49 from zhangxiaojiawow/support-databa…
Browse files Browse the repository at this point in the history
…se-memory

Support database memory
  • Loading branch information
HamaWhiteGG authored Jul 21, 2023
2 parents 624713a + 39ae25a commit 5956141
Show file tree
Hide file tree
Showing 13 changed files with 361 additions and 3 deletions.
16 changes: 16 additions & 0 deletions langchain-core/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,22 @@
<groupId>org.awaitility</groupId>
<artifactId>awaitility</artifactId>
</dependency>

<dependency>
<groupId>org.redisson</groupId>
<artifactId>redisson</artifactId>
</dependency>

<dependency>
<groupId>io.netty</groupId>
<artifactId>netty-resolver</artifactId>
</dependency>

<dependency>
<groupId>io.netty</groupId>
<artifactId>netty-resolver-dns</artifactId>
</dependency>

</dependencies>

<build>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
package com.hw.langchain.chains.query.constructor;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.core.util.DefaultIndenter;
import com.fasterxml.jackson.core.util.DefaultPrettyPrinter;
import com.fasterxml.jackson.databind.ObjectMapper;
Expand All @@ -31,9 +32,6 @@
*/
public class JsonUtils {

private JsonUtils() {
}

private static final ObjectMapper OBJECT_MAPPER = createObjectMapper();

private static ObjectMapper createObjectMapper() {
Expand All @@ -51,6 +49,26 @@ public static String toJsonStringWithIndent(Object object, int indent) {
}
}

public static String toJsonStringWithIndent(Object object) {
return toJsonStringWithIndent(object, 4);
}

public static <T> T convertFromJsonStr(String jsonStr, Class<T> clazz) {
try {
return OBJECT_MAPPER.readValue(jsonStr, clazz);
} catch (JsonProcessingException e) {
throw new LangChainException("Failed to deserialize json str", e);
}
}

public static <T> T convertFromJsonStr(String jsonStr, TypeReference<T> typeReference) {
try {
return OBJECT_MAPPER.readValue(jsonStr, typeReference);
} catch (JsonProcessingException e) {
throw new LangChainException("Failed to deserialize json str", e);
}
}

private static DefaultPrettyPrinter getPrettyPrinter(int indent) {
DefaultPrettyPrinter printer = new DefaultPrettyPrinter();
printer.indentArraysWith(DefaultPrettyPrinter.FixedSpaceIndenter.instance);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.hw.langchain.memory.chat.message.histories.database;

import com.hw.langchain.schema.BaseMessage;

import java.util.List;

/**
* interface for database supported chat message repository;
*
* @author zhangxiaojia002
* @date 2023/7/20 9:50 下午
**/
public interface ChatMessageRepository {

/**
* load all history chat message of given sessionId
*
* @param sessionId
* @return
*/
List<BaseMessage> loadMessage(String sessionId);

void saveMessage(String sessionId, BaseMessage baseMessage);

void clearSessionChatMessage(String sessionId);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.hw.langchain.memory.chat.message.histories.database;

import com.hw.langchain.schema.BaseChatMessageHistory;
import com.hw.langchain.schema.BaseMessage;

import java.util.List;

/**
* database based chat message history;
*
* @author zhangxiaojia002
* @date 2023/7/20 9:53 下午
**/
public class DataBaseChatMessageHistory extends BaseChatMessageHistory {

private final String sessionId;
private final ChatMessageRepository chatMessageRepository;

public DataBaseChatMessageHistory(String sessionId, ChatMessageRepository chatMessageRepository) {
this.sessionId = sessionId;
this.chatMessageRepository = chatMessageRepository;
}

@Override
public void addMessage(BaseMessage message) {
chatMessageRepository.saveMessage(sessionId, message);
}

@Override
public void clear() {
chatMessageRepository.clearSessionChatMessage(sessionId);
}

@Override
public List<BaseMessage> getMessages() {
return chatMessageRepository.loadMessage(sessionId);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.hw.langchain.memory.chat.message.histories.database.redis;

import com.hw.langchain.memory.chat.message.histories.database.DataBaseChatMessageHistory;
import com.hw.langchain.schema.BaseChatMessageHistory;
import com.hw.langchain.schema.BaseMessage;

import org.redisson.api.RedissonClient;

import java.util.List;

/**
* a simple wrapper for DataBaseChatMessageHistory with redisChatMessageRepository;
*
* @author zhangxiaojia002
* @date 2023/7/21 10:49 上午
**/
public class RedisChatMessageHistory extends BaseChatMessageHistory {

private DataBaseChatMessageHistory dataBaseChatMessageHistory;

public RedisChatMessageHistory(String sessionId, RedissonClient redissonClient, int ttl) {
RedisChatMessageRepository redisChatMessageRepository = new RedisChatMessageRepository(redissonClient, ttl);
dataBaseChatMessageHistory = new DataBaseChatMessageHistory(sessionId, redisChatMessageRepository);
}

public RedisChatMessageHistory(String sessionId, RedissonClient redissonClient) {
RedisChatMessageRepository redisChatMessageRepository = new RedisChatMessageRepository(redissonClient);
dataBaseChatMessageHistory = new DataBaseChatMessageHistory(sessionId, redisChatMessageRepository);
}

@Override
public void addMessage(BaseMessage message) {
dataBaseChatMessageHistory.addMessage(message);
}

@Override
public void clear() {
dataBaseChatMessageHistory.clear();
}

@Override
public List<BaseMessage> getMessages() {
return dataBaseChatMessageHistory.getMessages();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.hw.langchain.memory.chat.message.histories.database.redis;

import com.fasterxml.jackson.core.type.TypeReference;
import com.hw.langchain.chains.query.constructor.JsonUtils;
import com.hw.langchain.memory.chat.message.histories.database.ChatMessageRepository;
import com.hw.langchain.schema.BaseMessage;

import org.redisson.api.RQueue;
import org.redisson.api.RedissonClient;

import lombok.Builder;
import lombok.experimental.Tolerate;

import java.time.Duration;
import java.time.temporal.ChronoUnit;
import java.util.List;
import java.util.Map;

/**
* @author zhangxiaojia002
* @date 2023/7/20 9:59 下午
**/
@Builder
public class RedisChatMessageRepository implements ChatMessageRepository {

private RedissonClient redissonClient;
private String keyPrefix = "message_store";
private Integer ttlSeconds;

@Tolerate
public RedisChatMessageRepository(RedissonClient redissonClient) {
this.redissonClient = redissonClient;
}

@Tolerate
public RedisChatMessageRepository(RedissonClient redissonClient, int ttlSeconds) {
this.redissonClient = redissonClient;
this.ttlSeconds = ttlSeconds;
}

/**
* Construct the record key to use
*
* @return
*/
private String key(String sessionId) {
if (this.keyPrefix == null) {
return sessionId;
}
return this.keyPrefix + sessionId;
}

@Override
public List<BaseMessage> loadMessage(String sessionId) {
RQueue<String> messageQueue = redissonClient.getQueue(key(sessionId));
List<String> messageJSonStrList = messageQueue.readAll();
return messageJSonStrList.stream().map(x -> {
Map<String, Object> data =
JsonUtils.convertFromJsonStr(x, new TypeReference<>() {
});
return BaseMessage.fromMap(data);
}).toList();
}

@Override
public void saveMessage(String sessionId, BaseMessage baseMessage) {
RQueue<String> messageQueue = redissonClient.getQueue(key(sessionId));
messageQueue.add(JsonUtils.toJsonStringWithIndent(baseMessage.toMap()));
if (this.ttlSeconds != null) {
messageQueue.expire(Duration.of(ttlSeconds, ChronoUnit.SECONDS));
}
}

@Override
public void clearSessionChatMessage(String sessionId) {
RQueue<String> messageQueue = redissonClient.getQueue(key(sessionId));
messageQueue.delete();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,13 @@

package com.hw.langchain.schema;

import lombok.NoArgsConstructor;

/**
* Type of message that is spoken by the AI.
* @author HamaWhite
*/
@NoArgsConstructor
public class AIMessage extends BaseMessage {

public AIMessage(String content) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@

package com.hw.langchain.schema;

import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.module.SimpleModule;
import com.hw.langchain.chains.query.constructor.JsonUtils;
import com.hw.langchain.exception.LangChainException;

import lombok.Data;
import lombok.NoArgsConstructor;

Expand Down Expand Up @@ -46,4 +52,27 @@ protected BaseMessage(String content) {
* Type of the message, used for serialization.
*/
public abstract String type();

public Map<String, Object> toMap() {
ObjectMapper objectMapper = new ObjectMapper();
SimpleModule module = new SimpleModule();
objectMapper.registerModule(module);
Map<String, Object> map = objectMapper.convertValue(this, new TypeReference<>() {
});
return Map.of("type", type(), "data", map);
}

public static BaseMessage fromMap(Map<String, Object> message) {
String type = (String) message.get("type");
Object data = message.get("data");
String jsonStr = JsonUtils.toJsonStringWithIndent(data, 0);
return switch (type) {
case "ai" -> JsonUtils.convertFromJsonStr(jsonStr, AIMessage.class);
case "human" -> JsonUtils.convertFromJsonStr(jsonStr, HumanMessage.class);
case "system" -> JsonUtils.convertFromJsonStr(jsonStr, SystemMessage.class);
case "chat" -> JsonUtils.convertFromJsonStr(jsonStr, ChatMessage.class);
case "function" -> JsonUtils.convertFromJsonStr(jsonStr, FunctionMessage.class);
default -> throw new LangChainException(String.format("Got unexpected message type:%s", type));
};
}
}
Loading

0 comments on commit 5956141

Please sign in to comment.