Skip to content

Commit

Permalink
refactor: llm client api
Browse files Browse the repository at this point in the history
  • Loading branch information
Yidadaa committed May 14, 2023
1 parent 6da3aab commit bd90caa
Show file tree
Hide file tree
Showing 8 changed files with 279 additions and 22 deletions.
109 changes: 109 additions & 0 deletions app/client/api.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import { fetchEventSource } from "@microsoft/fetch-event-source";
import { ACCESS_CODE_PREFIX } from "../constant";
import { ModelType, useAccessStore } from "../store";
import { ChatGPTApi } from "./platforms/openai";

export enum MessageRole {
System = "system",
User = "user",
Assistant = "assistant",
}

export const Models = ["gpt-3.5-turbo", "gpt-4"] as const;
export type ChatModel = ModelType;

export interface Message {
role: MessageRole;
content: string;
}

export interface LLMConfig {
temperature?: number;
topP?: number;
stream?: boolean;
presencePenalty?: number;
frequencyPenalty?: number;
}

export interface ChatOptions {
messages: Message[];
model: ChatModel;
config: LLMConfig;

onUpdate: (message: string, chunk: string) => void;
onFinish: (message: string) => void;
onError: (err: Error) => void;
onUnAuth: () => void;
}

export interface LLMUsage {
used: number;
total: number;
}

export abstract class LLMApi {
abstract chat(options: ChatOptions): Promise<void>;
abstract usage(): Promise<LLMUsage>;
}

export class ClientApi {
public llm: LLMApi;

constructor() {
this.llm = new ChatGPTApi();
}

headers() {
const accessStore = useAccessStore.getState();
let headers: Record<string, string> = {};

const makeBearer = (token: string) => `Bearer ${token.trim()}`;
const validString = (x: string) => x && x.length > 0;

// use user's api key first
if (validString(accessStore.token)) {
headers.Authorization = makeBearer(accessStore.token);
} else if (
accessStore.enabledAccessControl() &&
validString(accessStore.accessCode)
) {
headers.Authorization = makeBearer(
ACCESS_CODE_PREFIX + accessStore.accessCode,
);
}

return headers;
}

config() {}

prompts() {}

masks() {}
}

export const api = new ClientApi();

export function getHeaders() {
const accessStore = useAccessStore.getState();
let headers: Record<string, string> = {
"Content-Type": "application/json",
};

const makeBearer = (token: string) => `Bearer ${token.trim()}`;
const validString = (x: string) => x && x.length > 0;

// use user's api key first
if (validString(accessStore.token)) {
headers.Authorization = makeBearer(accessStore.token);
} else if (
accessStore.enabledAccessControl() &&
validString(accessStore.accessCode)
) {
headers.Authorization = makeBearer(
ACCESS_CODE_PREFIX + accessStore.accessCode,
);
}

return headers;
}
37 changes: 37 additions & 0 deletions app/client/controller.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
// To store message streaming controller
export const ChatControllerPool = {
controllers: {} as Record<string, AbortController>,

addController(
sessionIndex: number,
messageId: number,
controller: AbortController,
) {
const key = this.key(sessionIndex, messageId);
this.controllers[key] = controller;
return key;
},

stop(sessionIndex: number, messageId: number) {
const key = this.key(sessionIndex, messageId);
const controller = this.controllers[key];
controller?.abort();
},

stopAll() {
Object.values(this.controllers).forEach((v) => v.abort());
},

hasPending() {
return Object.values(this.controllers).length > 0;
},

remove(sessionIndex: number, messageId: number) {
const key = this.key(sessionIndex, messageId);
delete this.controllers[key];
},

key(sessionIndex: number, messageIndex: number) {
return `${sessionIndex},${messageIndex}`;
},
};
124 changes: 124 additions & 0 deletions app/client/platforms/openai.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
import { REQUEST_TIMEOUT_MS } from "@/app/constant";
import { useAccessStore, useAppConfig, useChatStore } from "@/app/store";
import {
EventStreamContentType,
fetchEventSource,
} from "@microsoft/fetch-event-source";
import { ChatOptions, LLMApi, LLMUsage } from "../api";

export class ChatGPTApi implements LLMApi {
public ChatPath = "v1/chat/completions";

path(path: string): string {
const openaiUrl = useAccessStore.getState().openaiUrl;
if (openaiUrl.endsWith("/")) openaiUrl.slice(0, openaiUrl.length - 1);
return [openaiUrl, path].join("/");
}

extractMessage(res: any) {
return res.choices?.at(0)?.message?.content ?? "";
}

async chat(options: ChatOptions) {
const messages = options.messages.map((v) => ({
role: v.role,
content: v.content,
}));

const modelConfig = {
...useAppConfig.getState().modelConfig,
...useChatStore.getState().currentSession().mask.modelConfig,
...{
model: options.model,
},
};

const requestPayload = {
messages,
stream: options.config.stream,
model: modelConfig.model,
temperature: modelConfig.temperature,
presence_penalty: modelConfig.presence_penalty,
};

console.log("[Request] openai payload: ", requestPayload);

const shouldStream = !!options.config.stream;
const controller = new AbortController();

try {
const chatPath = this.path(this.ChatPath);
const chatPayload = {
method: "POST",
body: JSON.stringify(requestPayload),
signal: controller.signal,
};

// make a fetch request
const reqestTimeoutId = setTimeout(
() => controller.abort(),
REQUEST_TIMEOUT_MS,
);
if (shouldStream) {
let responseText = "";

fetchEventSource(chatPath, {
...chatPayload,
async onopen(res) {
if (
res.ok &&
res.headers.get("Content-Type") === EventStreamContentType
) {
return;
}

if (res.status === 401) {
// TODO: Unauthorized 401
responseText += "\n\n";
} else if (res.status !== 200) {
console.error("[Request] response", res);
throw new Error("[Request] server error");
}
},
onmessage: (ev) => {
if (ev.data === "[DONE]") {
return options.onFinish(responseText);
}
try {
const resJson = JSON.parse(ev.data);
const message = this.extractMessage(resJson);
responseText += message;
options.onUpdate(responseText, message);
} catch (e) {
console.error("[Request] stream error", e);
options.onError(e as Error);
}
},
onclose() {
options.onError(new Error("stream closed unexpected"));
},
onerror(err) {
options.onError(err);
},
});
} else {
const res = await fetch(chatPath, chatPayload);

const resJson = await res.json();
const message = this.extractMessage(resJson);
options.onFinish(message);
}

clearTimeout(reqestTimeoutId);
} catch (e) {
console.log("[Request] failed to make a chat reqeust", e);
options.onError(e as Error);
}
}
async usage() {
return {
used: 0,
total: 0,
} as LLMUsage;
}
}
2 changes: 2 additions & 0 deletions app/constant.ts
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,5 @@ export const NARROW_SIDEBAR_WIDTH = 100;
export const ACCESS_CODE_PREFIX = "ak-";

export const LAST_INPUT_KEY = "last-input";

export const REQUEST_TIMEOUT_MS = 60000;
22 changes: 0 additions & 22 deletions app/requests.ts
Original file line number Diff line number Diff line change
Expand Up @@ -43,28 +43,6 @@ const makeRequestParam = (
};
};

export function getHeaders() {
const accessStore = useAccessStore.getState();
let headers: Record<string, string> = {};

const makeBearer = (token: string) => `Bearer ${token.trim()}`;
const validString = (x: string) => x && x.length > 0;

// use user's api key first
if (validString(accessStore.token)) {
headers.Authorization = makeBearer(accessStore.token);
} else if (
accessStore.enabledAccessControl() &&
validString(accessStore.accessCode)
) {
headers.Authorization = makeBearer(
ACCESS_CODE_PREFIX + accessStore.accessCode,
);
}

return headers;
}

export function requestOpenaiClient(path: string) {
const openaiUrl = useAccessStore.getState().openaiUrl;
return (body: any, method = "POST") =>
Expand Down
1 change: 1 addition & 0 deletions app/store/chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import { showToast } from "../components/ui-lib";
import { ModelType } from "./config";
import { createEmptyMask, Mask } from "./mask";
import { StoreKey } from "../constant";
import { api } from "../client/api";

export type Message = ChatCompletionResponseMessage & {
date: string;
Expand Down
1 change: 1 addition & 0 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
},
"dependencies": {
"@hello-pangea/dnd": "^16.2.0",
"@microsoft/fetch-event-source": "^2.0.1",
"@svgr/webpack": "^6.5.1",
"@vercel/analytics": "^0.1.11",
"emoji-picker-react": "^4.4.7",
Expand Down
5 changes: 5 additions & 0 deletions yarn.lock
Original file line number Diff line number Diff line change
Expand Up @@ -1111,6 +1111,11 @@
dependencies:
"@types/react" ">=16.0.0"

"@microsoft/fetch-event-source@^2.0.1":
version "2.0.1"
resolved "https://registry.npmmirror.com/@microsoft/fetch-event-source/-/fetch-event-source-2.0.1.tgz#9ceecc94b49fbaa15666e38ae8587f64acce007d"
integrity sha512-W6CLUJ2eBMw3Rec70qrsEW0jOm/3twwJv21mrmj2yORiaVmVYGS4sSS5yUwvQc1ZlDLYGPnClVWmUUMagKNsfA==

"@next/env@13.3.1-canary.8":
version "13.3.1-canary.8"
resolved "https://registry.yarnpkg.com/@next/env/-/env-13.3.1-canary.8.tgz#9f5cf57999e4f4b59ef6407924803a247cc4e451"
Expand Down

0 comments on commit bd90caa

Please sign in to comment.