Skip to content

Commit

Permalink
chatGPT支持自定义REST接口
Browse files Browse the repository at this point in the history
  • Loading branch information
moji committed May 27, 2023
1 parent ce2e297 commit 3559857
Showing 8 changed files with 365 additions and 68 deletions.
Original file line number Diff line number Diff line change
@@ -42,7 +42,8 @@ public class ChatGptConfig {
private String restAiUrl;

/**
* 自定义AI接口HTTP请求方法:POST/GET
* Rest接口是否流式输出
* 非必填,默认值为TRUE
*/
private String restAiHttpMethod;
private Boolean restAiStream = Boolean.TRUE;
}
Original file line number Diff line number Diff line change
@@ -6,10 +6,14 @@
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;

import com.alibaba.dbhub.server.domain.api.enums.AiSqlSourceEnum;
import com.alibaba.dbhub.server.domain.api.model.Config;
import com.alibaba.dbhub.server.domain.api.model.DataSource;
import com.alibaba.dbhub.server.domain.api.param.TableQueryParam;
import com.alibaba.dbhub.server.domain.api.service.ConfigService;
import com.alibaba.dbhub.server.domain.api.service.DataSourceService;
import com.alibaba.dbhub.server.domain.api.service.TableService;
import com.alibaba.dbhub.server.domain.support.enums.DbTypeEnum;
@@ -25,6 +29,8 @@
import com.alibaba.dbhub.server.web.api.controller.ai.enums.PromptType;
import com.alibaba.dbhub.server.web.api.controller.ai.listener.OpenAIEventSourceListener;
import com.alibaba.dbhub.server.web.api.controller.ai.request.ChatQueryRequest;
import com.alibaba.dbhub.server.web.api.controller.ai.rest.client.RestAIClient;
import com.alibaba.dbhub.server.web.api.util.ApplicationContextUtil;
import com.alibaba.dbhub.server.web.api.util.OpenAIClient;

import cn.hutool.core.util.StrUtil;
@@ -113,21 +119,10 @@ public SseEmitter chat(@RequestParam("message") String msg, @RequestHeader Map<S
if (StrUtil.isBlank(uid)) {
throw new BaseException(CommonError.SYS_ERROR);
}
String messageContext = (String)LocalCache.CACHE.get(uid);
List<Message> messages = new ArrayList<>();
if (StrUtil.isNotBlank(messageContext)) {
messages = JSONUtil.toList(messageContext, Message.class);
if (messages.size() >= contextLength) {
messages = messages.subList(1, contextLength);
}
Message currentMessage = Message.builder().content(msg).role(Message.Role.USER).build();
messages.add(currentMessage);
} else {
Message currentMessage = Message.builder().content(msg).role(Message.Role.USER).build();
messages.add(currentMessage);
if (useOpenAI()) {
return chatWithOpenAi(msg, sseEmitter, uid);
}

return chatGpt35(messages, sseEmitter, uid);
return chatWithRestAi(msg, sseEmitter);
}

/**
@@ -154,6 +149,50 @@ public SseEmitter completions(ChatQueryRequest queryRequest, @RequestHeader Map<
throw new BusinessException(CommonErrorEnum.PARAM_ERROR);
}

if (useOpenAI()) {
return chatWithOpenAiSql(queryRequest, sseEmitter, uid);
}
return chatWithRestAi(queryRequest.getMessage(), sseEmitter);
}

/**
* 是否使用OPENAI
*
* @return
*/
private Boolean useOpenAI() {
ConfigService configService = ApplicationContextUtil.getBean(ConfigService.class);
Config config = configService.find(RestAIClient.REST_AI_URL).getData();
if (Objects.nonNull(config) && AiSqlSourceEnum.RESTAI.getCode().equals(config.getContent())) {
return false;
}
return true;
}

/**
* 使用自定义AI接口进行聊天
*
* @param prompt
* @param sseEmitter
* @return
*/
private SseEmitter chatWithRestAi(String prompt, SseEmitter sseEmitter) {
OpenAIEventSourceListener openAIEventSourceListener = new OpenAIEventSourceListener(sseEmitter);
RestAIClient.getInstance().restCompletions(prompt, openAIEventSourceListener, sseEmitter);
return sseEmitter;
}

/**
* 使用OPENAI SQL接口
*
* @param queryRequest
* @param sseEmitter
* @param uid
* @return
* @throws IOException
*/
private SseEmitter chatWithOpenAiSql(ChatQueryRequest queryRequest, SseEmitter sseEmitter, String uid)
throws IOException {
String prompt = buildPrompt(queryRequest);
if (prompt.length() / TOKEN_CONVERT_CHAR_LENGTH > MAX_PROMPT_LENGTH) {
log.error("提示语超出最大长度:{},输入长度:{}, 请重新输入", MAX_PROMPT_LENGTH,
@@ -178,6 +217,33 @@ public SseEmitter completions(ChatQueryRequest queryRequest, @RequestHeader Map<
return chatGpt3(prompt, sseEmitter, uid);
}

/**
* 使用OPENAI聊天相关接口
*
* @param msg
* @param sseEmitter
* @param uid
* @return
* @throws IOException
*/
private SseEmitter chatWithOpenAi(String msg, SseEmitter sseEmitter, String uid) throws IOException {
String messageContext = (String)LocalCache.CACHE.get(uid);
List<Message> messages = new ArrayList<>();
if (StrUtil.isNotBlank(messageContext)) {
messages = JSONUtil.toList(messageContext, Message.class);
if (messages.size() >= contextLength) {
messages = messages.subList(1, contextLength);
}
Message currentMessage = Message.builder().content(msg).role(Message.Role.USER).build();
messages.add(currentMessage);
} else {
Message currentMessage = Message.builder().content(msg).role(Message.Role.USER).build();
messages.add(currentMessage);
}

return chatGpt35(messages, sseEmitter, uid);
}

/**
* 使用GPT3.5模型
*
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
/**
* alibaba.com Inc.
* Copyright (c) 2004-2023 All Rights Reserved.
*/
package com.alibaba.dbhub.server.web.api.controller.ai.rest.client;

import com.alibaba.dbhub.server.domain.api.model.Config;
import com.alibaba.dbhub.server.domain.api.service.ConfigService;
import com.alibaba.dbhub.server.web.api.util.ApplicationContextUtil;

import lombok.extern.slf4j.Slf4j;

/**
* @author moji
* @version : RestAIClient.java
*/
@Slf4j
public class RestAIClient {

/**
* AI SQL选择的接口来源
*/
public static final String AI_SQL_SOURCE = "ai.sql.source";

/**
* 自定义AI接口地址
*/
public static final String REST_AI_URL = "rest.ai.url";

/**
* 自定义AI接口请求方法
*/
public static final String REST_AI_STREAM_OUT = "rest.ai.stream";

private static RestAiStreamClient REST_AI_STREAM_CLIENT;

public static RestAiStreamClient getInstance() {
if (REST_AI_STREAM_CLIENT != null) {
return REST_AI_STREAM_CLIENT;
} else {
return singleton();
}
}

private static RestAiStreamClient singleton() {
if (REST_AI_STREAM_CLIENT == null) {
synchronized (RestAIClient.class) {
if (REST_AI_STREAM_CLIENT == null) {
refresh();
}
}
}
return REST_AI_STREAM_CLIENT;
}

/**
* 刷新客户端
*/
public static void refresh() {
String apiUrl = "";
Boolean stream = Boolean.TRUE;
ConfigService configService = ApplicationContextUtil.getBean(ConfigService.class);
Config apiHostConfig = configService.find(REST_AI_URL).getData();
if (apiHostConfig != null) {
apiUrl = apiHostConfig.getContent();
}
Config config = configService.find(REST_AI_STREAM_OUT).getData();
if (config != null) {
stream = Boolean.valueOf(config.getContent());
}
REST_AI_STREAM_CLIENT = new RestAiStreamClient(apiUrl, stream);
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
package com.alibaba.dbhub.server.web.api.controller.ai.rest.client;

import java.util.Objects;
import java.util.concurrent.TimeUnit;

import com.alibaba.dbhub.server.tools.base.excption.BusinessException;
import com.alibaba.dbhub.server.tools.base.excption.CommonErrorEnum;
import com.alibaba.dbhub.server.web.api.controller.ai.rest.model.RestAiCompletion;

import cn.hutool.http.ContentType;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.unfbx.chatgpt.sse.ConsoleEventSourceListener;
import lombok.Getter;
import lombok.extern.slf4j.Slf4j;
import okhttp3.MediaType;
import okhttp3.OkHttpClient;
import okhttp3.Request;
import okhttp3.RequestBody;
import okhttp3.Response;
import okhttp3.sse.EventSource;
import okhttp3.sse.EventSourceListener;
import okhttp3.sse.EventSources;
import org.apache.commons.lang3.StringUtils;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;

/**
* 自定义AI接口client
* @author moji
*/
@Slf4j
public class RestAiStreamClient {
/**
* rest api url
*/
@Getter
private String apiUrl;

/**
* 是否流式接口
*/
@Getter
private Boolean stream;
/**
* okHttpClient
*/
@Getter
private OkHttpClient okHttpClient;

/**
* 构造实例对象
*
* @param url
*/
public RestAiStreamClient(String url, Boolean stream) {
this.apiUrl = url;
this.stream = stream;
this.okHttpClient = new OkHttpClient
.Builder()
.connectTimeout(10, TimeUnit.SECONDS)
.writeTimeout(50, TimeUnit.SECONDS)
.readTimeout(50, TimeUnit.SECONDS)
.build();
}

/**
* 请求RESTAI接口
*
* @param prompt
* @param sseEmitter
* @param eventSourceListener
*/
public void restCompletions(String prompt,
EventSourceListener eventSourceListener,
SseEmitter sseEmitter) {
RestAiCompletion completion = new RestAiCompletion();
completion.setPrompt(prompt);
if (Objects.isNull(stream) || stream) {
streamCompletions(completion, eventSourceListener);
return;
}
nonStreamCompletions(completion, sseEmitter);
}

/**
* 问答接口 stream 形式
*
* @param completion open ai 参数
* @param eventSourceListener sse监听器
* @see ConsoleEventSourceListener
*/
public void streamCompletions(RestAiCompletion completion, EventSourceListener eventSourceListener) {
if (Objects.isNull(eventSourceListener)) {
log.error("参数异常:EventSourceListener不能为空");
throw new BusinessException(CommonErrorEnum.PARAM_ERROR);
}
if (StringUtils.isBlank(completion.getPrompt())) {
log.error("参数异常:Prompt不能为空");
throw new BusinessException(CommonErrorEnum.PARAM_ERROR);
}
try {
EventSource.Factory factory = EventSources.createFactory(this.okHttpClient);
ObjectMapper mapper = new ObjectMapper();
String requestBody = mapper.writeValueAsString(completion);
Request request = new Request.Builder()
.url(this.apiUrl)
.post(RequestBody.create(MediaType.parse(ContentType.JSON.getValue()), requestBody))
.build();
//创建事件
EventSource eventSource = factory.newEventSource(request, eventSourceListener);
} catch (Exception e) {
log.error("请求参数解析异常", e);
throw new BusinessException(CommonErrorEnum.PARAM_ERROR);
}
}

/**
* 请求非流式输出接口
*
* @param completion
* @param sseEmitter
*/
public void nonStreamCompletions(RestAiCompletion completion, SseEmitter sseEmitter) {
if (StringUtils.isBlank(completion.getPrompt())) {
log.error("参数异常:Prompt不能为空");
throw new BusinessException(CommonErrorEnum.PARAM_ERROR);
}
try {
ObjectMapper mapper = new ObjectMapper();
String requestBody = mapper.writeValueAsString(completion);
Request request = new Request.Builder()
.url(this.apiUrl)
.post(RequestBody.create(MediaType.parse(ContentType.JSON.getValue()), requestBody))
.build();
// 发起请求并获取响应
Response response = this.okHttpClient.newCall(request).execute();

// 处理响应
if (response.isSuccessful()) {
// 获取响应体数据
assert response.body() != null;
String responseBody = response.body().string();
// 处理响应结果
sseEmitter.send(responseBody);
} else {
sseEmitter.send("请求REST接口失败:" + response.message());
}
} catch (Exception e) {
log.error("请求参数解析异常", e);
throw new BusinessException(CommonErrorEnum.PARAM_ERROR);
}
}

}
Loading

0 comments on commit 3559857

Please sign in to comment.