forked from alibaba/ali-dbhub
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
moji
committed
May 27, 2023
1 parent
ce2e297
commit 3559857
Showing
8 changed files
with
365 additions
and
68 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
74 changes: 74 additions & 0 deletions
74
...rc/main/java/com/alibaba/dbhub/server/web/api/controller/ai/rest/client/RestAIClient.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
/** | ||
* alibaba.com Inc. | ||
* Copyright (c) 2004-2023 All Rights Reserved. | ||
*/ | ||
package com.alibaba.dbhub.server.web.api.controller.ai.rest.client; | ||
|
||
import com.alibaba.dbhub.server.domain.api.model.Config; | ||
import com.alibaba.dbhub.server.domain.api.service.ConfigService; | ||
import com.alibaba.dbhub.server.web.api.util.ApplicationContextUtil; | ||
|
||
import lombok.extern.slf4j.Slf4j; | ||
|
||
/** | ||
* @author moji | ||
* @version : RestAIClient.java | ||
*/ | ||
@Slf4j | ||
public class RestAIClient { | ||
|
||
/** | ||
* AI SQL选择的接口来源 | ||
*/ | ||
public static final String AI_SQL_SOURCE = "ai.sql.source"; | ||
|
||
/** | ||
* 自定义AI接口地址 | ||
*/ | ||
public static final String REST_AI_URL = "rest.ai.url"; | ||
|
||
/** | ||
* 自定义AI接口请求方法 | ||
*/ | ||
public static final String REST_AI_STREAM_OUT = "rest.ai.stream"; | ||
|
||
private static RestAiStreamClient REST_AI_STREAM_CLIENT; | ||
|
||
public static RestAiStreamClient getInstance() { | ||
if (REST_AI_STREAM_CLIENT != null) { | ||
return REST_AI_STREAM_CLIENT; | ||
} else { | ||
return singleton(); | ||
} | ||
} | ||
|
||
private static RestAiStreamClient singleton() { | ||
if (REST_AI_STREAM_CLIENT == null) { | ||
synchronized (RestAIClient.class) { | ||
if (REST_AI_STREAM_CLIENT == null) { | ||
refresh(); | ||
} | ||
} | ||
} | ||
return REST_AI_STREAM_CLIENT; | ||
} | ||
|
||
/** | ||
* 刷新客户端 | ||
*/ | ||
public static void refresh() { | ||
String apiUrl = ""; | ||
Boolean stream = Boolean.TRUE; | ||
ConfigService configService = ApplicationContextUtil.getBean(ConfigService.class); | ||
Config apiHostConfig = configService.find(REST_AI_URL).getData(); | ||
if (apiHostConfig != null) { | ||
apiUrl = apiHostConfig.getContent(); | ||
} | ||
Config config = configService.find(REST_AI_STREAM_OUT).getData(); | ||
if (config != null) { | ||
stream = Boolean.valueOf(config.getContent()); | ||
} | ||
REST_AI_STREAM_CLIENT = new RestAiStreamClient(apiUrl, stream); | ||
} | ||
|
||
} |
153 changes: 153 additions & 0 deletions
153
...n/java/com/alibaba/dbhub/server/web/api/controller/ai/rest/client/RestAiStreamClient.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,153 @@ | ||
package com.alibaba.dbhub.server.web.api.controller.ai.rest.client; | ||
|
||
import java.util.Objects; | ||
import java.util.concurrent.TimeUnit; | ||
|
||
import com.alibaba.dbhub.server.tools.base.excption.BusinessException; | ||
import com.alibaba.dbhub.server.tools.base.excption.CommonErrorEnum; | ||
import com.alibaba.dbhub.server.web.api.controller.ai.rest.model.RestAiCompletion; | ||
|
||
import cn.hutool.http.ContentType; | ||
import com.fasterxml.jackson.databind.ObjectMapper; | ||
import com.unfbx.chatgpt.sse.ConsoleEventSourceListener; | ||
import lombok.Getter; | ||
import lombok.extern.slf4j.Slf4j; | ||
import okhttp3.MediaType; | ||
import okhttp3.OkHttpClient; | ||
import okhttp3.Request; | ||
import okhttp3.RequestBody; | ||
import okhttp3.Response; | ||
import okhttp3.sse.EventSource; | ||
import okhttp3.sse.EventSourceListener; | ||
import okhttp3.sse.EventSources; | ||
import org.apache.commons.lang3.StringUtils; | ||
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; | ||
|
||
/** | ||
* 自定义AI接口client | ||
* @author moji | ||
*/ | ||
@Slf4j | ||
public class RestAiStreamClient { | ||
/** | ||
* rest api url | ||
*/ | ||
@Getter | ||
private String apiUrl; | ||
|
||
/** | ||
* 是否流式接口 | ||
*/ | ||
@Getter | ||
private Boolean stream; | ||
/** | ||
* okHttpClient | ||
*/ | ||
@Getter | ||
private OkHttpClient okHttpClient; | ||
|
||
/** | ||
* 构造实例对象 | ||
* | ||
* @param url | ||
*/ | ||
public RestAiStreamClient(String url, Boolean stream) { | ||
this.apiUrl = url; | ||
this.stream = stream; | ||
this.okHttpClient = new OkHttpClient | ||
.Builder() | ||
.connectTimeout(10, TimeUnit.SECONDS) | ||
.writeTimeout(50, TimeUnit.SECONDS) | ||
.readTimeout(50, TimeUnit.SECONDS) | ||
.build(); | ||
} | ||
|
||
/** | ||
* 请求RESTAI接口 | ||
* | ||
* @param prompt | ||
* @param sseEmitter | ||
* @param eventSourceListener | ||
*/ | ||
public void restCompletions(String prompt, | ||
EventSourceListener eventSourceListener, | ||
SseEmitter sseEmitter) { | ||
RestAiCompletion completion = new RestAiCompletion(); | ||
completion.setPrompt(prompt); | ||
if (Objects.isNull(stream) || stream) { | ||
streamCompletions(completion, eventSourceListener); | ||
return; | ||
} | ||
nonStreamCompletions(completion, sseEmitter); | ||
} | ||
|
||
/** | ||
* 问答接口 stream 形式 | ||
* | ||
* @param completion open ai 参数 | ||
* @param eventSourceListener sse监听器 | ||
* @see ConsoleEventSourceListener | ||
*/ | ||
public void streamCompletions(RestAiCompletion completion, EventSourceListener eventSourceListener) { | ||
if (Objects.isNull(eventSourceListener)) { | ||
log.error("参数异常:EventSourceListener不能为空"); | ||
throw new BusinessException(CommonErrorEnum.PARAM_ERROR); | ||
} | ||
if (StringUtils.isBlank(completion.getPrompt())) { | ||
log.error("参数异常:Prompt不能为空"); | ||
throw new BusinessException(CommonErrorEnum.PARAM_ERROR); | ||
} | ||
try { | ||
EventSource.Factory factory = EventSources.createFactory(this.okHttpClient); | ||
ObjectMapper mapper = new ObjectMapper(); | ||
String requestBody = mapper.writeValueAsString(completion); | ||
Request request = new Request.Builder() | ||
.url(this.apiUrl) | ||
.post(RequestBody.create(MediaType.parse(ContentType.JSON.getValue()), requestBody)) | ||
.build(); | ||
//创建事件 | ||
EventSource eventSource = factory.newEventSource(request, eventSourceListener); | ||
} catch (Exception e) { | ||
log.error("请求参数解析异常", e); | ||
throw new BusinessException(CommonErrorEnum.PARAM_ERROR); | ||
} | ||
} | ||
|
||
/** | ||
* 请求非流式输出接口 | ||
* | ||
* @param completion | ||
* @param sseEmitter | ||
*/ | ||
public void nonStreamCompletions(RestAiCompletion completion, SseEmitter sseEmitter) { | ||
if (StringUtils.isBlank(completion.getPrompt())) { | ||
log.error("参数异常:Prompt不能为空"); | ||
throw new BusinessException(CommonErrorEnum.PARAM_ERROR); | ||
} | ||
try { | ||
ObjectMapper mapper = new ObjectMapper(); | ||
String requestBody = mapper.writeValueAsString(completion); | ||
Request request = new Request.Builder() | ||
.url(this.apiUrl) | ||
.post(RequestBody.create(MediaType.parse(ContentType.JSON.getValue()), requestBody)) | ||
.build(); | ||
// 发起请求并获取响应 | ||
Response response = this.okHttpClient.newCall(request).execute(); | ||
|
||
// 处理响应 | ||
if (response.isSuccessful()) { | ||
// 获取响应体数据 | ||
assert response.body() != null; | ||
String responseBody = response.body().string(); | ||
// 处理响应结果 | ||
sseEmitter.send(responseBody); | ||
} else { | ||
sseEmitter.send("请求REST接口失败:" + response.message()); | ||
} | ||
} catch (Exception e) { | ||
log.error("请求参数解析异常", e); | ||
throw new BusinessException(CommonErrorEnum.PARAM_ERROR); | ||
} | ||
} | ||
|
||
} |
Oops, something went wrong.