Skip to content

Commit

Permalink
🌿 feat: Multi-response Streaming (#3191)
Browse files Browse the repository at this point in the history
* chore: comment back handlePlusCommand

* chore: ignore .git dir

* refactor: pass newConversation to `useSelectMention`

refactor: pass newConversation to Mention component

refactor: useChatFunctions for modular use of `ask` and `regenerate`

refactor: set latest message only for the first index in useChatFunctions

refactor: pass setLatestMessage to useChatFunctions

refactor: Pass setSubmission to useChatFunctions for submission handling

refactor: consolidate event handlers to separate hook from useSSE

WIP: additional response handlers

feat: responsive added convo, clears on new chat/navigating to chat, assistants excluded

feat: Add conversationByKeySelector to select any conversation by index

WIP: handle second submission with messages paired to root

* style: surface-primary-contrast

* refactor: remove unnecessary console.log statement in useChatFunctions

* refactor: Consolidate imports in ChatForm and Input hooks

* refactor: compositional usage of useSSE for multiple streams

* WIP: set latest 'multi' message

* WIP: first pass, added response streaming

* pass: performant multi-message stream

* fix: styling and message render

* second pass: modular, performant multi-stream

* fix: align parentMessageId of multiMessage

* refactor: move resetting latestMultiMessage

* chore: update footer text in Chat component

* fix: stop button styling

* fix: handle abortMessage request for multi-response

* clear messages but bug with latest message reset present

* fix: add delay for additional message generation

* fix: access LAST_CONVO_SETUP by index

* style: add div to prevent layout shift before hover buttons render

* chore: Update Message component styling for card messages

* chore: move hook use order

* fix: abort middleware using unsent field from req.body

* feat: support multi-response stream from initial message

* refactor: buildTree function to improve readability and remove unused code

* feat: add logger for frontend dev

* refactor: use depth to track if message is really last in its branch

* fix(buildTree): default export

* fix: share parent message Id and avoid duplication error for multi-response streams

* fix: prevent addedConvo reset to response convo

* feat: allow setting multi message as latest message to control which to respond to

* chore: wrap setSiblingIdxRev with useCallback

* chore: styling and allow editing messages

* style: styling fixes

* feat: Add "AddMultiConvo" component to Chat Header

* feat: prevent clearing added convos on endpoint, preset, mention, or modelSpec switch

* fix: message styling fixes, mainly related to code blocks

* fix: stop button visibility logic

* fix: Handle edge case in abortMiddleware for non-existant `abortControllers`

* refactor: optimize/memoize icons

* chore(GoogleClient): change info to debug logs

* style: active message styling

* style: prevent layout shift due to placeholder row

* chore: remove unused code

* fix: Update BaseClient to handle optional request body properties

* fix(ci): `onStart` now accepts 2 args, the 2nd being responseMessageId

* chore: bump data-provider
  • Loading branch information
danny-avila committed Aug 5, 2024
1 parent a939e76 commit 4f61c4f
Show file tree
Hide file tree
Showing 72 changed files with 2,685 additions and 1,314 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ logs
pids
*.pid
*.seed
.git

# Directory for instrumented libs generated by jscoverage/JSCover
lib-cov
Expand Down Expand Up @@ -45,6 +46,7 @@ api/node_modules/
client/node_modules/
bower_components/
*.d.ts
!vite-env.d.ts

# Floobits
.floo
Expand Down
42 changes: 38 additions & 4 deletions api/app/clients/BaseClient.js
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ class BaseClient {
day: 'numeric',
});
this.fetch = this.fetch.bind(this);
/** @type {boolean} */
this.skipSaveConvo = false;
/** @type {boolean} */
this.skipSaveUserMessage = false;
}

setOptions() {
Expand Down Expand Up @@ -84,19 +88,45 @@ class BaseClient {
await stream.processTextStream(onProgress);
}

/**
* @returns {[string|undefined, string|undefined]}
*/
processOverideIds() {
/** @type {Record<string, string | undefined>} */
let { overrideConvoId, overrideUserMessageId } = this.options?.req?.body ?? {};
if (overrideConvoId) {
const [conversationId, index] = overrideConvoId.split(Constants.COMMON_DIVIDER);
overrideConvoId = conversationId;
if (index !== '0') {
this.skipSaveConvo = true;
}
}
if (overrideUserMessageId) {
const [userMessageId, index] = overrideUserMessageId.split(Constants.COMMON_DIVIDER);
overrideUserMessageId = userMessageId;
if (index !== '0') {
this.skipSaveUserMessage = true;
}
}

return [overrideConvoId, overrideUserMessageId];
}

async setMessageOptions(opts = {}) {
if (opts && opts.replaceOptions) {
this.setOptions(opts);
}

const [overrideConvoId, overrideUserMessageId] = this.processOverideIds();
const { isEdited, isContinued } = opts;
const user = opts.user ?? null;
this.user = user;
const saveOptions = this.getSaveOptions();
this.abortController = opts.abortController ?? new AbortController();
const conversationId = opts.conversationId ?? crypto.randomUUID();
const conversationId = overrideConvoId ?? opts.conversationId ?? crypto.randomUUID();
const parentMessageId = opts.parentMessageId ?? Constants.NO_PARENT;
const userMessageId = opts.overrideParentMessageId ?? crypto.randomUUID();
const userMessageId =
overrideUserMessageId ?? opts.overrideParentMessageId ?? crypto.randomUUID();
let responseMessageId = opts.responseMessageId ?? crypto.randomUUID();
let head = isEdited ? responseMessageId : parentMessageId;
this.currentMessages = (await this.loadHistory(conversationId, head)) ?? [];
Expand Down Expand Up @@ -160,7 +190,7 @@ class BaseClient {
}

if (typeof opts?.onStart === 'function') {
opts.onStart(userMessage);
opts.onStart(userMessage, responseMessageId);
}

return {
Expand Down Expand Up @@ -450,7 +480,7 @@ class BaseClient {
this.handleTokenCountMap(tokenCountMap);
}

if (!isEdited) {
if (!isEdited && !this.skipSaveUserMessage) {
await this.saveMessageToDatabase(userMessage, saveOptions, user);
}

Expand Down Expand Up @@ -569,6 +599,10 @@ class BaseClient {
unfinished: false,
user,
});

if (this.skipSaveConvo) {
return;
}
await saveConvo(user, {
conversationId: message.conversationId,
endpoint: this.options.endpoint,
Expand Down
8 changes: 4 additions & 4 deletions api/app/clients/GoogleClient.js
Original file line number Diff line number Diff line change
Expand Up @@ -737,7 +737,7 @@ class GoogleClient extends BaseClient {

let clientOptions = { ...parameters, maxRetries: 2 };

logger.info('Initialized title client options');
logger.debug('Initialized title client options');

if (this.project_id) {
clientOptions['authOptions'] = {
Expand All @@ -764,7 +764,7 @@ class GoogleClient extends BaseClient {

const modelName = clientOptions.modelName ?? clientOptions.model ?? '';
if (modelName?.includes('1.5') && !this.project_id) {
logger.info('Identified titling model as 1.5 version');
logger.debug('Identified titling model as 1.5 version');
/** @type {GenerativeModel} */
const client = model;
const requestOptions = {
Expand All @@ -790,7 +790,7 @@ class GoogleClient extends BaseClient {

return reply;
} else {
logger.info('Beginning titling');
logger.debug('Beginning titling');
const safetySettings = _payload.safetySettings;

const titleResponse = await model.invoke(messages, {
Expand Down Expand Up @@ -840,7 +840,7 @@ class GoogleClient extends BaseClient {
} catch (e) {
logger.error('[GoogleClient] There was an issue generating the title', e);
}
logger.info(`Title response: ${title}`);
logger.debug(`Title response: ${title}`);
return title;
}

Expand Down
5 changes: 4 additions & 1 deletion api/app/clients/PluginsClient.js
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,10 @@ class PluginsClient extends OpenAIClient {
if (payload) {
this.currentMessages = payload;
}
await this.saveMessageToDatabase(userMessage, saveOptions, user);

if (!this.skipSaveUserMessage) {
await this.saveMessageToDatabase(userMessage, saveOptions, user);
}

if (isEnabled(process.env.CHECK_BALANCE)) {
await checkBalance({
Expand Down
6 changes: 5 additions & 1 deletion api/app/clients/specs/BaseClient.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -576,7 +576,11 @@ describe('BaseClient', () => {
const onStart = jest.fn();
const opts = { onStart };
await TestClient.sendMessage('Hello, world!', opts);
expect(onStart).toHaveBeenCalledWith(expect.objectContaining({ text: 'Hello, world!' }));

expect(onStart).toHaveBeenCalledWith(
expect.objectContaining({ text: 'Hello, world!' }),
expect.any(String),
);
});

test('saveMessageToDatabase is called with the correct arguments', async () => {
Expand Down
6 changes: 4 additions & 2 deletions api/server/controllers/AskController.js
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ const AskController = async (req, res, next, initializeClient, addTitle) => {
promptTokens,
});

const { abortController, onStart } = createAbortController(req, res, getAbortData);
const { abortController, onStart } = createAbortController(req, res, getAbortData, getReqData);

res.on('close', () => {
logger.debug('[AskController] Request closed');
Expand Down Expand Up @@ -144,7 +144,9 @@ const AskController = async (req, res, next, initializeClient, addTitle) => {
await saveMessage({ ...response, user });
}

await saveMessage(userMessage);
if (!client.skipSaveUserMessage) {
await saveMessage(userMessage);
}

if (addTitle && parentMessageId === Constants.NO_PARENT && newConvo) {
addTitle(req, {
Expand Down
2 changes: 1 addition & 1 deletion api/server/controllers/EditController.js
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ const EditController = async (req, res, next, initializeClient) => {
promptTokens,
});

const { abortController, onStart } = createAbortController(req, res, getAbortData);
const { abortController, onStart } = createAbortController(req, res, getAbortData, getReqData);

res.on('close', () => {
logger.debug('[EditController] Request closed');
Expand Down
43 changes: 34 additions & 9 deletions api/server/middleware/abortMiddleware.js
Original file line number Diff line number Diff line change
Expand Up @@ -9,23 +9,28 @@ const { abortRun } = require('./abortRun');
const { logger } = require('~/config');

async function abortMessage(req, res) {
let { abortKey, conversationId, endpoint } = req.body;

if (!abortKey && conversationId) {
abortKey = conversationId;
}
let { abortKey, endpoint } = req.body;

if (isAssistantsEndpoint(endpoint)) {
return await abortRun(req, res);
}

const conversationId = abortKey?.split(':')?.[0] ?? req.user.id;

if (!abortControllers.has(abortKey) && abortControllers.has(conversationId)) {
abortKey = conversationId;
}

if (!abortControllers.has(abortKey) && !res.headersSent) {
return res.status(204).send({ message: 'Request not found' });
}

const { abortController } = abortControllers.get(abortKey);
const { abortController } = abortControllers.get(abortKey) ?? {};
if (!abortController) {
return res.status(204).send({ message: 'Request not found' });
}
const finalEvent = await abortController.abortCompletion();
logger.debug('[abortMessage] Aborted request', { abortKey });
logger.info('[abortMessage] Aborted request', { abortKey });
abortControllers.delete(abortKey);

if (res.headersSent && finalEvent) {
Expand All @@ -50,12 +55,32 @@ const handleAbort = () => {
};
};

const createAbortController = (req, res, getAbortData) => {
const createAbortController = (req, res, getAbortData, getReqData) => {
const abortController = new AbortController();
const { endpointOption } = req.body;
const onStart = (userMessage) => {

abortController.getAbortData = function () {
return getAbortData();
};

/**
* @param {TMessage} userMessage
* @param {string} responseMessageId
*/
const onStart = (userMessage, responseMessageId) => {
sendMessage(res, { message: userMessage, created: true });
const abortKey = userMessage?.conversationId ?? req.user.id;
const prevRequest = abortControllers.get(abortKey);
if (prevRequest && prevRequest?.abortController) {
const data = prevRequest.abortController.getAbortData();
getReqData({ userMessage: data?.userMessage });
const addedAbortKey = `${abortKey}:${responseMessageId}`;
abortControllers.set(addedAbortKey, { abortController, ...endpointOption });
res.on('finish', function () {
abortControllers.delete(addedAbortKey);
});
return;
}
abortControllers.set(abortKey, { abortController, ...endpointOption });

res.on('finish', function () {
Expand Down
22 changes: 12 additions & 10 deletions api/server/routes/ask/gptPlugins.js
Original file line number Diff line number Diff line change
Expand Up @@ -148,15 +148,6 @@ router.post(
}
};

const onChainEnd = () => {
saveMessage({ ...userMessage, user });
sendIntermediateMessage(res, {
plugins,
parentMessageId: userMessage.messageId,
messageId: responseMessageId,
});
};

const getAbortData = () => ({
sender,
conversationId,
Expand All @@ -167,12 +158,23 @@ router.post(
userMessage,
promptTokens,
});
const { abortController, onStart } = createAbortController(req, res, getAbortData);
const { abortController, onStart } = createAbortController(req, res, getAbortData, getReqData);

try {
endpointOption.tools = await validateTools(user, endpointOption.tools);
const { client } = await initializeClient({ req, res, endpointOption });

const onChainEnd = () => {
if (!client.skipSaveUserMessage) {
saveMessage({ ...userMessage, user });
}
sendIntermediateMessage(res, {
plugins,
parentMessageId: userMessage.messageId,
messageId: responseMessageId,
});
};

let response = await client.sendMessage(text, {
user,
conversationId,
Expand Down
32 changes: 16 additions & 16 deletions api/server/routes/edit/gptPlugins.js
Original file line number Diff line number Diff line change
Expand Up @@ -103,21 +103,6 @@ router.post(
},
});

const onAgentAction = (action, start = false) => {
const formattedAction = formatAction(action);
plugin.inputs.push(formattedAction);
plugin.latest = formattedAction.plugin;
if (!start) {
saveMessage({ ...userMessage, user });
}
sendIntermediateMessage(res, {
plugin,
parentMessageId: userMessage.messageId,
messageId: responseMessageId,
});
// logger.debug('PLUGIN ACTION', formattedAction);
};

const onChainEnd = (data) => {
let { intermediateSteps: steps } = data;
plugin.outputs = steps && steps[0].action ? formatSteps(steps) : 'An error occurred.';
Expand All @@ -141,12 +126,27 @@ router.post(
userMessage,
promptTokens,
});
const { abortController, onStart } = createAbortController(req, res, getAbortData);
const { abortController, onStart } = createAbortController(req, res, getAbortData, getReqData);

try {
endpointOption.tools = await validateTools(user, endpointOption.tools);
const { client } = await initializeClient({ req, res, endpointOption });

const onAgentAction = (action, start = false) => {
const formattedAction = formatAction(action);
plugin.inputs.push(formattedAction);
plugin.latest = formattedAction.plugin;
if (!start && !client.skipSaveUserMessage) {
saveMessage({ ...userMessage, user });
}
sendIntermediateMessage(res, {
plugin,
parentMessageId: userMessage.messageId,
messageId: responseMessageId,
});
// logger.debug('PLUGIN ACTION', formattedAction);
};

let response = await client.sendMessage(text, {
user,
generation,
Expand Down
6 changes: 6 additions & 0 deletions client/src/Providers/AddedChatContext.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
import { createContext, useContext } from 'react';
import useAddedResponse from '~/hooks/Chat/useAddedResponse';
type TAddedChatContext = ReturnType<typeof useAddedResponse>;

export const AddedChatContext = createContext<TAddedChatContext>({} as TAddedChatContext);
export const useAddedChatContext = () => useContext(AddedChatContext);
2 changes: 1 addition & 1 deletion client/src/Providers/ChatContext.tsx
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { createContext, useContext } from 'react';
import useChatHelpers from '~/hooks/useChatHelpers';
import useChatHelpers from '~/hooks/Chat/useChatHelpers';
type TChatContext = ReturnType<typeof useChatHelpers>;

export const ChatContext = createContext<TChatContext>({} as TChatContext);
Expand Down
1 change: 1 addition & 0 deletions client/src/Providers/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ export * from './ShareContext';
export * from './ToastContext';
export * from './SearchContext';
export * from './FileMapContext';
export * from './AddedChatContext';
export * from './ChatFormContext';
export * from './DashboardContext';
export * from './AssistantsContext';
Expand Down
Loading

0 comments on commit 4f61c4f

Please sign in to comment.