Skip to content

Commit

Permalink
🤖 feat: Add titling to Google client (#2983)
Browse files Browse the repository at this point in the history
* feat: Add titling to Google client

* feat: Add titling to Google client

* PR feedback changes
  • Loading branch information
mungewrath authored Jun 22, 2024
1 parent aac01df commit b5081bf
Show file tree
Hide file tree
Showing 5 changed files with 193 additions and 5 deletions.
2 changes: 2 additions & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,8 @@ GOOGLE_KEY=user_provided
# Vertex AI
# GOOGLE_MODELS=gemini-1.5-flash-preview-0514,gemini-1.5-pro-preview-0514,gemini-1.0-pro-vision-001,gemini-1.0-pro-002,gemini-1.0-pro-001,gemini-pro-vision,gemini-1.0-pro

# GOOGLE_TITLE_MODEL=gemini-pro

# Google Gemini Safety Settings
# NOTE (Vertex AI): You do not have access to the BLOCK_NONE setting by default.
# To use this restricted HarmBlockThreshold setting, you will need to either:
Expand Down
131 changes: 129 additions & 2 deletions api/app/clients/GoogleClient.js
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,15 @@ const {
AuthKeys,
} = require('librechat-data-provider');
const { encodeAndFormat } = require('~/server/services/Files/images');
const { formatMessage, createContextHandlers } = require('./prompts');
const { getModelMaxTokens } = require('~/utils');
const BaseClient = require('./BaseClient');
const { logger } = require('~/config');
const {
formatMessage,
createContextHandlers,
titleInstruction,
truncateText,
} = require('./prompts');
const BaseClient = require('./BaseClient');

const loc = 'us-central1';
const publisher = 'google';
Expand Down Expand Up @@ -591,12 +596,16 @@ class GoogleClient extends BaseClient {
createLLM(clientOptions) {
const model = clientOptions.modelName ?? clientOptions.model;
if (this.project_id && this.isTextModel) {
logger.debug('Creating Google VertexAI client');
return new GoogleVertexAI(clientOptions);
} else if (this.project_id && this.isChatModel) {
logger.debug('Creating Chat Google VertexAI client');
return new ChatGoogleVertexAI(clientOptions);
} else if (this.project_id) {
logger.debug('Creating VertexAI client');
return new ChatVertexAI(clientOptions);
} else if (model.includes('1.5')) {
logger.debug('Creating GenAI client');
return new GenAI(this.apiKey).getGenerativeModel(
{
...clientOptions,
Expand All @@ -606,6 +615,7 @@ class GoogleClient extends BaseClient {
);
}

logger.debug('Creating Chat Google Generative AI client');
return new ChatGoogleGenerativeAI({ ...clientOptions, apiKey: this.apiKey });
}

Expand Down Expand Up @@ -717,6 +727,123 @@ class GoogleClient extends BaseClient {
return reply;
}

/**
* Stripped-down logic for generating a title. This uses the non-streaming APIs, since the user does not see titles streaming
*/
async titleChatCompletion(_payload, options = {}) {
const { abortController } = options;
const { parameters, instances } = _payload;
const { messages: _messages, examples: _examples } = instances?.[0] ?? {};

let clientOptions = { ...parameters, maxRetries: 2 };

logger.info('Initialized title client options');

if (this.project_id) {
clientOptions['authOptions'] = {
credentials: {
...this.serviceKey,
},
projectId: this.project_id,
};
}

if (!parameters) {
clientOptions = { ...clientOptions, ...this.modelOptions };
}

if (this.isGenerativeModel && !this.project_id) {
clientOptions.modelName = clientOptions.model;
delete clientOptions.model;
}

const model = this.createLLM(clientOptions);

let reply = '';
const messages = this.isTextModel ? _payload.trim() : _messages;

const modelName = clientOptions.modelName ?? clientOptions.model ?? '';
if (modelName?.includes('1.5') && !this.project_id) {
logger.info('Identified titling model as 1.5 version');
/** @type {GenerativeModel} */
const client = model;
const requestOptions = {
contents: _payload,
};

if (this.options?.promptPrefix?.length) {
requestOptions.systemInstruction = {
parts: [
{
text: this.options.promptPrefix,
},
],
};
}

const safetySettings = _payload.safetySettings;
requestOptions.safetySettings = safetySettings;

const result = await client.generateContent(requestOptions);

reply = result.response?.text();

return reply;
} else {
logger.info('Beginning titling');
const safetySettings = _payload.safetySettings;

const titleResponse = await model.invoke(messages, {
signal: abortController.signal,
timeout: 7000,
safetySettings: safetySettings,
});

reply = titleResponse.content;

return reply;
}
}

async titleConvo({ text, responseText = '' }) {
let title = 'New Chat';
const convo = `||>User:
"${truncateText(text)}"
||>Response:
"${JSON.stringify(truncateText(responseText))}"`;

let { prompt: payload } = await this.buildMessages([
{
text: `Please generate ${titleInstruction}
${convo}
||>Title:`,
isCreatedByUser: true,
author: this.userLabel,
},
]);

if (this.isVisionModel) {
logger.warn(
`Current vision model does not support titling without an attachment; falling back to default model ${settings.model.default}`,
);

payload.parameters = { ...payload.parameters, model: settings.model.default };
}

try {
title = await this.titleChatCompletion(payload, {
abortController: new AbortController(),
onProgress: () => {},
});
} catch (e) {
logger.error('[GoogleClient] There was an issue generating the title', e);
}
logger.info(`Title response: ${title}`);
return title;
}

getSaveOptions() {
return {
promptPrefix: this.options.promptPrefix,
Expand Down
4 changes: 2 additions & 2 deletions api/server/routes/ask/google.js
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
const express = require('express');
const AskController = require('~/server/controllers/AskController');
const { initializeClient } = require('~/server/services/Endpoints/google');
const { initializeClient, addTitle } = require('~/server/services/Endpoints/google');
const {
setHeaders,
handleAbort,
Expand All @@ -20,7 +20,7 @@ router.post(
buildEndpointOption,
setHeaders,
async (req, res, next) => {
await AskController(req, res, next, initializeClient);
await AskController(req, res, next, initializeClient, addTitle);
},
);

Expand Down
58 changes: 58 additions & 0 deletions api/server/services/Endpoints/google/addTitle.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
const { CacheKeys, Constants } = require('librechat-data-provider');
const getLogStores = require('~/cache/getLogStores');
const { isEnabled } = require('~/server/utils');
const { saveConvo } = require('~/models');
const { logger } = require('~/config');
const initializeClient = require('./initializeClient');

const addTitle = async (req, { text, response, client }) => {
const { TITLE_CONVO = 'true' } = process.env ?? {};
if (!isEnabled(TITLE_CONVO)) {
return;
}

if (client.options.titleConvo === false) {
return;
}

const DEFAULT_TITLE_MODEL = 'gemini-pro';
const { GOOGLE_TITLE_MODEL } = process.env ?? {};

let model = GOOGLE_TITLE_MODEL ?? DEFAULT_TITLE_MODEL;

if (GOOGLE_TITLE_MODEL === Constants.CURRENT_MODEL) {
model = client.options?.modelOptions.model;

if (client.isVisionModel) {
logger.warn(
`current_model was specified for Google title request, but the model ${model} cannot process a text-only conversation. Falling back to ${DEFAULT_TITLE_MODEL}`,
);

model = DEFAULT_TITLE_MODEL;
}
}

const titleEndpointOptions = {
...client.options,
modelOptions: { ...client.options?.modelOptions, model: model },
attachments: undefined, // After a response, this is set to an empty array which results in an error during setOptions
};

const { client: titleClient } = await initializeClient({
req,
res: response,
endpointOption: titleEndpointOptions,
});

const titleCache = getLogStores(CacheKeys.GEN_TITLE);
const key = `${req.user.id}-${response.conversationId}`;

const title = await titleClient.titleConvo({ text, responseText: response?.text });
await titleCache.set(key, title, 120000);
await saveConvo(req.user.id, {
conversationId: response.conversationId,
title,
});
};

module.exports = addTitle;
3 changes: 2 additions & 1 deletion api/server/services/Endpoints/google/index.js
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
const addTitle = require('./addTitle');
const buildOptions = require('./buildOptions');
const initializeClient = require('./initializeClient');

module.exports = {
// addTitle, // todo
addTitle,
buildOptions,
initializeClient,
};

0 comments on commit b5081bf

Please sign in to comment.