From fa261866809b2b23e239c06010672a6cad838feb Mon Sep 17 00:00:00 2001 From: Danny Avila Date: Mon, 19 Feb 2024 22:47:39 -0500 Subject: [PATCH] =?UTF-8?q?=F0=9F=9B=A1=EF=B8=8F=20feat:=20Model=20Validat?= =?UTF-8?q?ion=20Middleware=20(#1841)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * refactor: add ViolationTypes enum and add new violation for illegal model requests * feat: validateModel middleware to protect the backend against illicit requests for unlisted models --- .env.example | 2 + api/cache/getLogStores.js | 7 +- api/server/middleware/index.js | 2 + api/server/middleware/uploadLimiters.js | 4 +- api/server/middleware/validateModel.js | 50 ++++ api/server/routes/ask/anthropic.js | 14 +- api/server/routes/ask/custom.js | 14 +- api/server/routes/ask/google.js | 14 +- api/server/routes/ask/gptPlugins.js | 371 ++++++++++++------------ api/server/routes/ask/openAI.js | 14 +- api/server/routes/assistants/chat.js | 3 +- api/server/routes/edit/anthropic.js | 14 +- api/server/routes/edit/custom.js | 14 +- api/server/routes/edit/google.js | 14 +- api/server/routes/edit/gptPlugins.js | 335 ++++++++++----------- api/server/routes/edit/openAI.js | 14 +- docs/features/mod_system.md | 5 + docs/install/configuration/dotenv.md | 3 + packages/data-provider/src/config.ts | 12 +- 19 files changed, 534 insertions(+), 372 deletions(-) create mode 100644 api/server/middleware/validateModel.js diff --git a/.env.example b/.env.example index cade9c9a3e97..4e8b53edcd02 100644 --- a/.env.example +++ b/.env.example @@ -239,6 +239,8 @@ LIMIT_MESSAGE_USER=false MESSAGE_USER_MAX=40 MESSAGE_USER_WINDOW=1 +ILLEGAL_MODEL_REQ_SCORE=5 + #========================# # Balance # #========================# diff --git a/api/cache/getLogStores.js b/api/cache/getLogStores.js index 1e614cde5ffb..c46166cc0fb1 100644 --- a/api/cache/getLogStores.js +++ b/api/cache/getLogStores.js @@ -1,5 +1,5 @@ const Keyv = require('keyv'); -const { CacheKeys } = require('librechat-data-provider'); +const { CacheKeys, ViolationTypes } = require('librechat-data-provider'); const { logFile, violationFile } = require('./keyvFiles'); const { math, isEnabled } = require('~/server/utils'); const keyvRedis = require('./keyvRedis'); @@ -49,7 +49,10 @@ const namespaces = { message_limit: createViolationInstance('message_limit'), token_balance: createViolationInstance('token_balance'), registrations: createViolationInstance('registrations'), - [CacheKeys.FILE_UPLOAD_LIMIT]: createViolationInstance(CacheKeys.FILE_UPLOAD_LIMIT), + [ViolationTypes.FILE_UPLOAD_LIMIT]: createViolationInstance(ViolationTypes.FILE_UPLOAD_LIMIT), + [ViolationTypes.ILLEGAL_MODEL_REQUEST]: createViolationInstance( + ViolationTypes.ILLEGAL_MODEL_REQUEST, + ), logins: createViolationInstance('logins'), [CacheKeys.ABORT_KEYS]: abortKeys, [CacheKeys.TOKEN_CONFIG]: tokenConfig, diff --git a/api/server/middleware/index.js b/api/server/middleware/index.js index 5b257c9a4d67..d6a1c175cd57 100644 --- a/api/server/middleware/index.js +++ b/api/server/middleware/index.js @@ -3,6 +3,7 @@ const checkBan = require('./checkBan'); const uaParser = require('./uaParser'); const setHeaders = require('./setHeaders'); const loginLimiter = require('./loginLimiter'); +const validateModel = require('./validateModel'); const requireJwtAuth = require('./requireJwtAuth'); const uploadLimiters = require('./uploadLimiters'); const registerLimiter = require('./registerLimiter'); @@ -32,6 +33,7 @@ module.exports = { validateMessageReq, buildEndpointOption, validateRegistration, + validateModel, moderateText, noIndex, }; diff --git a/api/server/middleware/uploadLimiters.js b/api/server/middleware/uploadLimiters.js index 80544a580f96..71af164fde47 100644 --- a/api/server/middleware/uploadLimiters.js +++ b/api/server/middleware/uploadLimiters.js @@ -1,5 +1,5 @@ const rateLimit = require('express-rate-limit'); -const { CacheKeys } = require('librechat-data-provider'); +const { ViolationTypes } = require('librechat-data-provider'); const logViolation = require('~/cache/logViolation'); const getEnvironmentVariables = () => { @@ -35,7 +35,7 @@ const createFileUploadHandler = (ip = true) => { } = getEnvironmentVariables(); return async (req, res) => { - const type = CacheKeys.FILE_UPLOAD_LIMIT; + const type = ViolationTypes.FILE_UPLOAD_LIMIT; const errorMessage = { type, max: ip ? fileUploadIpMax : fileUploadUserMax, diff --git a/api/server/middleware/validateModel.js b/api/server/middleware/validateModel.js new file mode 100644 index 000000000000..553097f99e2b --- /dev/null +++ b/api/server/middleware/validateModel.js @@ -0,0 +1,50 @@ +const { EModelEndpoint, CacheKeys, ViolationTypes } = require('librechat-data-provider'); +const { logViolation, getLogStores } = require('~/cache'); +const { handleError } = require('~/server/utils'); + +/** + * Validates the model of the request. + * + * @async + * @param {Express.Request} req - The Express request object. + * @param {Express.Response} res - The Express response object. + * @param {Function} next - The Express next function. + */ +const validateModel = async (req, res, next) => { + const { model, endpoint } = req.body; + if (!model) { + return handleError(res, { text: 'Model not provided' }); + } + + const cache = getLogStores(CacheKeys.CONFIG_STORE); + const modelsConfig = await cache.get(CacheKeys.MODELS_CONFIG); + if (!modelsConfig) { + return handleError(res, { text: 'Models not loaded' }); + } + + const availableModels = modelsConfig[endpoint]; + if (!availableModels) { + return handleError(res, { text: 'Endpoint models not loaded' }); + } + + let validModel = !!availableModels.find((availableModel) => availableModel === model); + if (endpoint === EModelEndpoint.gptPlugins) { + validModel = validModel && availableModels.includes(req.body.agentOptions?.model); + } + + if (validModel) { + return next(); + } + + const { ILLEGAL_MODEL_REQ_SCORE: score = 5 } = process.env ?? {}; + + const type = ViolationTypes.ILLEGAL_MODEL_REQUEST; + const errorMessage = { + type, + }; + + await logViolation(req, res, type, errorMessage, score); + return handleError(res, { text: 'Illegal model request' }); +}; + +module.exports = validateModel; diff --git a/api/server/routes/ask/anthropic.js b/api/server/routes/ask/anthropic.js index e0ea0f9857fe..093d64c8de0b 100644 --- a/api/server/routes/ask/anthropic.js +++ b/api/server/routes/ask/anthropic.js @@ -4,6 +4,7 @@ const { initializeClient } = require('~/server/services/Endpoints/anthropic'); const { setHeaders, handleAbort, + validateModel, validateEndpoint, buildEndpointOption, } = require('~/server/middleware'); @@ -12,8 +13,15 @@ const router = express.Router(); router.post('/abort', handleAbort()); -router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req, res, next) => { - await AskController(req, res, next, initializeClient); -}); +router.post( + '/', + validateEndpoint, + validateModel, + buildEndpointOption, + setHeaders, + async (req, res, next) => { + await AskController(req, res, next, initializeClient); + }, +); module.exports = router; diff --git a/api/server/routes/ask/custom.js b/api/server/routes/ask/custom.js index ef979bf0000c..668a9902cb92 100644 --- a/api/server/routes/ask/custom.js +++ b/api/server/routes/ask/custom.js @@ -5,6 +5,7 @@ const { addTitle } = require('~/server/services/Endpoints/openAI'); const { handleAbort, setHeaders, + validateModel, validateEndpoint, buildEndpointOption, } = require('~/server/middleware'); @@ -13,8 +14,15 @@ const router = express.Router(); router.post('/abort', handleAbort()); -router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req, res, next) => { - await AskController(req, res, next, initializeClient, addTitle); -}); +router.post( + '/', + validateEndpoint, + validateModel, + buildEndpointOption, + setHeaders, + async (req, res, next) => { + await AskController(req, res, next, initializeClient, addTitle); + }, +); module.exports = router; diff --git a/api/server/routes/ask/google.js b/api/server/routes/ask/google.js index 78c648495ff8..b5425d67649c 100644 --- a/api/server/routes/ask/google.js +++ b/api/server/routes/ask/google.js @@ -4,6 +4,7 @@ const { initializeClient } = require('~/server/services/Endpoints/google'); const { setHeaders, handleAbort, + validateModel, validateEndpoint, buildEndpointOption, } = require('~/server/middleware'); @@ -12,8 +13,15 @@ const router = express.Router(); router.post('/abort', handleAbort()); -router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req, res, next) => { - await AskController(req, res, next, initializeClient); -}); +router.post( + '/', + validateEndpoint, + validateModel, + buildEndpointOption, + setHeaders, + async (req, res, next) => { + await AskController(req, res, next, initializeClient); + }, +); module.exports = router; diff --git a/api/server/routes/ask/gptPlugins.js b/api/server/routes/ask/gptPlugins.js index 80817e5a41e5..a402f8eaf3ac 100644 --- a/api/server/routes/ask/gptPlugins.js +++ b/api/server/routes/ask/gptPlugins.js @@ -11,6 +11,7 @@ const { createAbortController, handleAbortError, setHeaders, + validateModel, validateEndpoint, buildEndpointOption, moderateText, @@ -20,207 +21,217 @@ const { logger } = require('~/config'); router.use(moderateText); router.post('/abort', handleAbort()); -router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req, res) => { - let { - text, - endpointOption, - conversationId, - parentMessageId = null, - overrideParentMessageId = null, - } = req.body; - logger.debug('[/ask/gptPlugins]', { text, conversationId, ...endpointOption }); - let metadata; - let userMessage; - let promptTokens; - let userMessageId; - let responseMessageId; - let lastSavedTimestamp = 0; - let saveDelay = 100; - const sender = getResponseSender({ ...endpointOption, model: endpointOption.modelOptions.model }); - const newConvo = !conversationId; - const user = req.user.id; - - const plugins = []; - - const addMetadata = (data) => (metadata = data); - const getReqData = (data = {}) => { - for (let key in data) { - if (key === 'userMessage') { - userMessage = data[key]; - userMessageId = data[key].messageId; - } else if (key === 'responseMessageId') { - responseMessageId = data[key]; - } else if (key === 'promptTokens') { - promptTokens = data[key]; - } else if (!conversationId && key === 'conversationId') { - conversationId = data[key]; +router.post( + '/', + validateEndpoint, + validateModel, + buildEndpointOption, + setHeaders, + async (req, res) => { + let { + text, + endpointOption, + conversationId, + parentMessageId = null, + overrideParentMessageId = null, + } = req.body; + logger.debug('[/ask/gptPlugins]', { text, conversationId, ...endpointOption }); + let metadata; + let userMessage; + let promptTokens; + let userMessageId; + let responseMessageId; + let lastSavedTimestamp = 0; + let saveDelay = 100; + const sender = getResponseSender({ + ...endpointOption, + model: endpointOption.modelOptions.model, + }); + const newConvo = !conversationId; + const user = req.user.id; + + const plugins = []; + + const addMetadata = (data) => (metadata = data); + const getReqData = (data = {}) => { + for (let key in data) { + if (key === 'userMessage') { + userMessage = data[key]; + userMessageId = data[key].messageId; + } else if (key === 'responseMessageId') { + responseMessageId = data[key]; + } else if (key === 'promptTokens') { + promptTokens = data[key]; + } else if (!conversationId && key === 'conversationId') { + conversationId = data[key]; + } } - } - }; + }; - let streaming = null; - let timer = null; + let streaming = null; + let timer = null; - const { - onProgress: progressCallback, - sendIntermediateMessage, - getPartialText, - } = createOnProgress({ - onProgress: ({ text: partialText }) => { - const currentTimestamp = Date.now(); + const { + onProgress: progressCallback, + sendIntermediateMessage, + getPartialText, + } = createOnProgress({ + onProgress: ({ text: partialText }) => { + const currentTimestamp = Date.now(); + + if (timer) { + clearTimeout(timer); + } + + if (currentTimestamp - lastSavedTimestamp > saveDelay) { + lastSavedTimestamp = currentTimestamp; + saveMessage({ + messageId: responseMessageId, + sender, + conversationId, + parentMessageId: overrideParentMessageId || userMessageId, + text: partialText, + model: endpointOption.modelOptions.model, + unfinished: true, + error: false, + plugins, + user, + }); + } + + if (saveDelay < 500) { + saveDelay = 500; + } + + streaming = new Promise((resolve) => { + timer = setTimeout(() => { + resolve(); + }, 250); + }); + }, + }); - if (timer) { - clearTimeout(timer); - } + const pluginMap = new Map(); + const onAgentAction = async (action, runId) => { + pluginMap.set(runId, action.tool); + sendIntermediateMessage(res, { plugins }); + }; - if (currentTimestamp - lastSavedTimestamp > saveDelay) { - lastSavedTimestamp = currentTimestamp; - saveMessage({ - messageId: responseMessageId, - sender, - conversationId, - parentMessageId: overrideParentMessageId || userMessageId, - text: partialText, - model: endpointOption.modelOptions.model, - unfinished: true, - error: false, - plugins, - user, - }); + const onToolStart = async (tool, input, runId, parentRunId) => { + const pluginName = pluginMap.get(parentRunId); + const latestPlugin = { + runId, + loading: true, + inputs: [input], + latest: pluginName, + outputs: null, + }; + + if (streaming) { + await streaming; } + const extraTokens = ':::plugin:::\n'; + plugins.push(latestPlugin); + sendIntermediateMessage(res, { plugins }, extraTokens); + }; - if (saveDelay < 500) { - saveDelay = 500; + const onToolEnd = async (output, runId) => { + if (streaming) { + await streaming; } - streaming = new Promise((resolve) => { - timer = setTimeout(() => { - resolve(); - }, 250); - }); - }, - }); - - const pluginMap = new Map(); - const onAgentAction = async (action, runId) => { - pluginMap.set(runId, action.tool); - sendIntermediateMessage(res, { plugins }); - }; - - const onToolStart = async (tool, input, runId, parentRunId) => { - const pluginName = pluginMap.get(parentRunId); - const latestPlugin = { - runId, - loading: true, - inputs: [input], - latest: pluginName, - outputs: null, - }; + const pluginIndex = plugins.findIndex((plugin) => plugin.runId === runId); - if (streaming) { - await streaming; - } - const extraTokens = ':::plugin:::\n'; - plugins.push(latestPlugin); - sendIntermediateMessage(res, { plugins }, extraTokens); - }; - - const onToolEnd = async (output, runId) => { - if (streaming) { - await streaming; - } + if (pluginIndex !== -1) { + plugins[pluginIndex].loading = false; + plugins[pluginIndex].outputs = output; + } + }; - const pluginIndex = plugins.findIndex((plugin) => plugin.runId === runId); + const onChainEnd = () => { + saveMessage({ ...userMessage, user }); + sendIntermediateMessage(res, { plugins }); + }; - if (pluginIndex !== -1) { - plugins[pluginIndex].loading = false; - plugins[pluginIndex].outputs = output; - } - }; - - const onChainEnd = () => { - saveMessage({ ...userMessage, user }); - sendIntermediateMessage(res, { plugins }); - }; - - const getAbortData = () => ({ - sender, - conversationId, - messageId: responseMessageId, - parentMessageId: overrideParentMessageId ?? userMessageId, - text: getPartialText(), - plugins: plugins.map((p) => ({ ...p, loading: false })), - userMessage, - promptTokens, - }); - const { abortController, onStart } = createAbortController(req, res, getAbortData); - - try { - endpointOption.tools = await validateTools(user, endpointOption.tools); - const { client } = await initializeClient({ req, res, endpointOption }); - - let response = await client.sendMessage(text, { - user, + const getAbortData = () => ({ + sender, conversationId, - parentMessageId, - overrideParentMessageId, - getReqData, - onAgentAction, - onChainEnd, - onToolStart, - onToolEnd, - onStart, - addMetadata, - getPartialText, - ...endpointOption, - onProgress: progressCallback.call(null, { - res, - text, - parentMessageId: overrideParentMessageId || userMessageId, - plugins, - }), - abortController, + messageId: responseMessageId, + parentMessageId: overrideParentMessageId ?? userMessageId, + text: getPartialText(), + plugins: plugins.map((p) => ({ ...p, loading: false })), + userMessage, + promptTokens, }); + const { abortController, onStart } = createAbortController(req, res, getAbortData); + + try { + endpointOption.tools = await validateTools(user, endpointOption.tools); + const { client } = await initializeClient({ req, res, endpointOption }); + + let response = await client.sendMessage(text, { + user, + conversationId, + parentMessageId, + overrideParentMessageId, + getReqData, + onAgentAction, + onChainEnd, + onToolStart, + onToolEnd, + onStart, + addMetadata, + getPartialText, + ...endpointOption, + onProgress: progressCallback.call(null, { + res, + text, + parentMessageId: overrideParentMessageId || userMessageId, + plugins, + }), + abortController, + }); - if (overrideParentMessageId) { - response.parentMessageId = overrideParentMessageId; - } + if (overrideParentMessageId) { + response.parentMessageId = overrideParentMessageId; + } - if (metadata) { - response = { ...response, ...metadata }; - } + if (metadata) { + response = { ...response, ...metadata }; + } - logger.debug('[/ask/gptPlugins]', response); + logger.debug('[/ask/gptPlugins]', response); - response.plugins = plugins.map((p) => ({ ...p, loading: false })); - await saveMessage({ ...response, user }); + response.plugins = plugins.map((p) => ({ ...p, loading: false })); + await saveMessage({ ...response, user }); - sendMessage(res, { - title: await getConvoTitle(user, conversationId), - final: true, - conversation: await getConvo(user, conversationId), - requestMessage: userMessage, - responseMessage: response, - }); - res.end(); + sendMessage(res, { + title: await getConvoTitle(user, conversationId), + final: true, + conversation: await getConvo(user, conversationId), + requestMessage: userMessage, + responseMessage: response, + }); + res.end(); - if (parentMessageId === Constants.NO_PARENT && newConvo) { - addTitle(req, { - text, - response, - client, + if (parentMessageId === Constants.NO_PARENT && newConvo) { + addTitle(req, { + text, + response, + client, + }); + } + } catch (error) { + const partialText = getPartialText(); + handleAbortError(res, req, error, { + partialText, + conversationId, + sender, + messageId: responseMessageId, + parentMessageId: userMessageId ?? parentMessageId, }); } - } catch (error) { - const partialText = getPartialText(); - handleAbortError(res, req, error, { - partialText, - conversationId, - sender, - messageId: responseMessageId, - parentMessageId: userMessageId ?? parentMessageId, - }); - } -}); + }, +); module.exports = router; diff --git a/api/server/routes/ask/openAI.js b/api/server/routes/ask/openAI.js index 31b3111077fa..5083a08b1041 100644 --- a/api/server/routes/ask/openAI.js +++ b/api/server/routes/ask/openAI.js @@ -4,6 +4,7 @@ const { addTitle, initializeClient } = require('~/server/services/Endpoints/open const { handleAbort, setHeaders, + validateModel, validateEndpoint, buildEndpointOption, moderateText, @@ -13,8 +14,15 @@ const router = express.Router(); router.use(moderateText); router.post('/abort', handleAbort()); -router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req, res, next) => { - await AskController(req, res, next, initializeClient, addTitle); -}); +router.post( + '/', + validateEndpoint, + validateModel, + buildEndpointOption, + setHeaders, + async (req, res, next) => { + await AskController(req, res, next, initializeClient, addTitle); + }, +); module.exports = router; diff --git a/api/server/routes/assistants/chat.js b/api/server/routes/assistants/chat.js index 57a3d4d28032..2df34a9ce3a9 100644 --- a/api/server/routes/assistants/chat.js +++ b/api/server/routes/assistants/chat.js @@ -21,6 +21,7 @@ const router = express.Router(); const { setHeaders, handleAbort, + validateModel, handleAbortError, // validateEndpoint, buildEndpointOption, @@ -36,7 +37,7 @@ router.post('/abort', handleAbort()); * @param {express.Response} res - The response object, used to send back a response. * @returns {void} */ -router.post('/', buildEndpointOption, setHeaders, async (req, res) => { +router.post('/', validateModel, buildEndpointOption, setHeaders, async (req, res) => { logger.debug('[/assistants/chat/] req.body', req.body); const { text, diff --git a/api/server/routes/edit/anthropic.js b/api/server/routes/edit/anthropic.js index 34dd9d6dfac1..c7bf128d7cb4 100644 --- a/api/server/routes/edit/anthropic.js +++ b/api/server/routes/edit/anthropic.js @@ -4,6 +4,7 @@ const { initializeClient } = require('~/server/services/Endpoints/anthropic'); const { setHeaders, handleAbort, + validateModel, validateEndpoint, buildEndpointOption, } = require('~/server/middleware'); @@ -12,8 +13,15 @@ const router = express.Router(); router.post('/abort', handleAbort()); -router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req, res, next) => { - await EditController(req, res, next, initializeClient); -}); +router.post( + '/', + validateEndpoint, + validateModel, + buildEndpointOption, + setHeaders, + async (req, res, next) => { + await EditController(req, res, next, initializeClient); + }, +); module.exports = router; diff --git a/api/server/routes/edit/custom.js b/api/server/routes/edit/custom.js index dd63c96c8f94..0bf97ba18003 100644 --- a/api/server/routes/edit/custom.js +++ b/api/server/routes/edit/custom.js @@ -5,6 +5,7 @@ const { addTitle } = require('~/server/services/Endpoints/openAI'); const { handleAbort, setHeaders, + validateModel, validateEndpoint, buildEndpointOption, } = require('~/server/middleware'); @@ -13,8 +14,15 @@ const router = express.Router(); router.post('/abort', handleAbort()); -router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req, res, next) => { - await EditController(req, res, next, initializeClient, addTitle); -}); +router.post( + '/', + validateEndpoint, + validateModel, + buildEndpointOption, + setHeaders, + async (req, res, next) => { + await EditController(req, res, next, initializeClient, addTitle); + }, +); module.exports = router; diff --git a/api/server/routes/edit/google.js b/api/server/routes/edit/google.js index e4dfbcd14127..7482f11b4c09 100644 --- a/api/server/routes/edit/google.js +++ b/api/server/routes/edit/google.js @@ -4,6 +4,7 @@ const { initializeClient } = require('~/server/services/Endpoints/google'); const { setHeaders, handleAbort, + validateModel, validateEndpoint, buildEndpointOption, } = require('~/server/middleware'); @@ -12,8 +13,15 @@ const router = express.Router(); router.post('/abort', handleAbort()); -router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req, res, next) => { - await EditController(req, res, next, initializeClient); -}); +router.post( + '/', + validateEndpoint, + validateModel, + buildEndpointOption, + setHeaders, + async (req, res, next) => { + await EditController(req, res, next, initializeClient); + }, +); module.exports = router; diff --git a/api/server/routes/edit/gptPlugins.js b/api/server/routes/edit/gptPlugins.js index 8ddf92c25079..33126a73bf2d 100644 --- a/api/server/routes/edit/gptPlugins.js +++ b/api/server/routes/edit/gptPlugins.js @@ -10,6 +10,7 @@ const { createAbortController, handleAbortError, setHeaders, + validateModel, validateEndpoint, buildEndpointOption, moderateText, @@ -19,179 +20,189 @@ const { logger } = require('~/config'); router.use(moderateText); router.post('/abort', handleAbort()); -router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req, res) => { - let { - text, - generation, - endpointOption, - conversationId, - responseMessageId, - isContinued = false, - parentMessageId = null, - overrideParentMessageId = null, - } = req.body; - - logger.debug('[/edit/gptPlugins]', { - text, - generation, - isContinued, - conversationId, - ...endpointOption, - }); - let metadata; - let userMessage; - let promptTokens; - let lastSavedTimestamp = 0; - let saveDelay = 100; - const sender = getResponseSender({ ...endpointOption, model: endpointOption.modelOptions.model }); - const userMessageId = parentMessageId; - const user = req.user.id; - - const plugin = { - loading: true, - inputs: [], - latest: null, - outputs: null, - }; - - const addMetadata = (data) => (metadata = data); - const getReqData = (data = {}) => { - for (let key in data) { - if (key === 'userMessage') { - userMessage = data[key]; - } else if (key === 'responseMessageId') { - responseMessageId = data[key]; - } else if (key === 'promptTokens') { - promptTokens = data[key]; - } - } - }; - - const { - onProgress: progressCallback, - sendIntermediateMessage, - getPartialText, - } = createOnProgress({ - generation, - onProgress: ({ text: partialText }) => { - const currentTimestamp = Date.now(); - - if (plugin.loading === true) { - plugin.loading = false; - } - - if (currentTimestamp - lastSavedTimestamp > saveDelay) { - lastSavedTimestamp = currentTimestamp; - saveMessage({ - messageId: responseMessageId, - sender, - conversationId, - parentMessageId: overrideParentMessageId || userMessageId, - text: partialText, - model: endpointOption.modelOptions.model, - unfinished: true, - isEdited: true, - error: false, - user, - }); - } +router.post( + '/', + validateEndpoint, + validateModel, + buildEndpointOption, + setHeaders, + async (req, res) => { + let { + text, + generation, + endpointOption, + conversationId, + responseMessageId, + isContinued = false, + parentMessageId = null, + overrideParentMessageId = null, + } = req.body; - if (saveDelay < 500) { - saveDelay = 500; - } - }, - }); - - const onAgentAction = (action, start = false) => { - const formattedAction = formatAction(action); - plugin.inputs.push(formattedAction); - plugin.latest = formattedAction.plugin; - if (!start) { - saveMessage({ ...userMessage, user }); - } - sendIntermediateMessage(res, { plugin }); - // logger.debug('PLUGIN ACTION', formattedAction); - }; - - const onChainEnd = (data) => { - let { intermediateSteps: steps } = data; - plugin.outputs = steps && steps[0].action ? formatSteps(steps) : 'An error occurred.'; - plugin.loading = false; - saveMessage({ ...userMessage, user }); - sendIntermediateMessage(res, { plugin }); - // logger.debug('CHAIN END', plugin.outputs); - }; - - const getAbortData = () => ({ - sender, - conversationId, - messageId: responseMessageId, - parentMessageId: overrideParentMessageId ?? userMessageId, - text: getPartialText(), - plugin: { ...plugin, loading: false }, - userMessage, - promptTokens, - }); - const { abortController, onStart } = createAbortController(req, res, getAbortData); - - try { - endpointOption.tools = await validateTools(user, endpointOption.tools); - const { client } = await initializeClient({ req, res, endpointOption }); - - let response = await client.sendMessage(text, { - user, + logger.debug('[/edit/gptPlugins]', { + text, generation, isContinued, - isEdited: true, conversationId, - parentMessageId, - responseMessageId, - overrideParentMessageId, - getReqData, - onAgentAction, - onChainEnd, - onStart, - addMetadata, ...endpointOption, - onProgress: progressCallback.call(null, { - res, - text, - plugin, - parentMessageId: overrideParentMessageId || userMessageId, - }), - abortController, }); + let metadata; + let userMessage; + let promptTokens; + let lastSavedTimestamp = 0; + let saveDelay = 100; + const sender = getResponseSender({ + ...endpointOption, + model: endpointOption.modelOptions.model, + }); + const userMessageId = parentMessageId; + const user = req.user.id; + + const plugin = { + loading: true, + inputs: [], + latest: null, + outputs: null, + }; + + const addMetadata = (data) => (metadata = data); + const getReqData = (data = {}) => { + for (let key in data) { + if (key === 'userMessage') { + userMessage = data[key]; + } else if (key === 'responseMessageId') { + responseMessageId = data[key]; + } else if (key === 'promptTokens') { + promptTokens = data[key]; + } + } + }; - if (overrideParentMessageId) { - response.parentMessageId = overrideParentMessageId; - } - - if (metadata) { - response = { ...response, ...metadata }; - } + const { + onProgress: progressCallback, + sendIntermediateMessage, + getPartialText, + } = createOnProgress({ + generation, + onProgress: ({ text: partialText }) => { + const currentTimestamp = Date.now(); + + if (plugin.loading === true) { + plugin.loading = false; + } + + if (currentTimestamp - lastSavedTimestamp > saveDelay) { + lastSavedTimestamp = currentTimestamp; + saveMessage({ + messageId: responseMessageId, + sender, + conversationId, + parentMessageId: overrideParentMessageId || userMessageId, + text: partialText, + model: endpointOption.modelOptions.model, + unfinished: true, + isEdited: true, + error: false, + user, + }); + } + + if (saveDelay < 500) { + saveDelay = 500; + } + }, + }); - logger.debug('[/edit/gptPlugins] CLIENT RESPONSE', response); - response.plugin = { ...plugin, loading: false }; - await saveMessage({ ...response, user }); + const onAgentAction = (action, start = false) => { + const formattedAction = formatAction(action); + plugin.inputs.push(formattedAction); + plugin.latest = formattedAction.plugin; + if (!start) { + saveMessage({ ...userMessage, user }); + } + sendIntermediateMessage(res, { plugin }); + // logger.debug('PLUGIN ACTION', formattedAction); + }; + + const onChainEnd = (data) => { + let { intermediateSteps: steps } = data; + plugin.outputs = steps && steps[0].action ? formatSteps(steps) : 'An error occurred.'; + plugin.loading = false; + saveMessage({ ...userMessage, user }); + sendIntermediateMessage(res, { plugin }); + // logger.debug('CHAIN END', plugin.outputs); + }; - sendMessage(res, { - title: await getConvoTitle(user, conversationId), - final: true, - conversation: await getConvo(user, conversationId), - requestMessage: userMessage, - responseMessage: response, - }); - res.end(); - } catch (error) { - const partialText = getPartialText(); - handleAbortError(res, req, error, { - partialText, - conversationId, + const getAbortData = () => ({ sender, + conversationId, messageId: responseMessageId, - parentMessageId: userMessageId ?? parentMessageId, + parentMessageId: overrideParentMessageId ?? userMessageId, + text: getPartialText(), + plugin: { ...plugin, loading: false }, + userMessage, + promptTokens, }); - } -}); + const { abortController, onStart } = createAbortController(req, res, getAbortData); + + try { + endpointOption.tools = await validateTools(user, endpointOption.tools); + const { client } = await initializeClient({ req, res, endpointOption }); + + let response = await client.sendMessage(text, { + user, + generation, + isContinued, + isEdited: true, + conversationId, + parentMessageId, + responseMessageId, + overrideParentMessageId, + getReqData, + onAgentAction, + onChainEnd, + onStart, + addMetadata, + ...endpointOption, + onProgress: progressCallback.call(null, { + res, + text, + plugin, + parentMessageId: overrideParentMessageId || userMessageId, + }), + abortController, + }); + + if (overrideParentMessageId) { + response.parentMessageId = overrideParentMessageId; + } + + if (metadata) { + response = { ...response, ...metadata }; + } + + logger.debug('[/edit/gptPlugins] CLIENT RESPONSE', response); + response.plugin = { ...plugin, loading: false }; + await saveMessage({ ...response, user }); + + sendMessage(res, { + title: await getConvoTitle(user, conversationId), + final: true, + conversation: await getConvo(user, conversationId), + requestMessage: userMessage, + responseMessage: response, + }); + res.end(); + } catch (error) { + const partialText = getPartialText(); + handleAbortError(res, req, error, { + partialText, + conversationId, + sender, + messageId: responseMessageId, + parentMessageId: userMessageId ?? parentMessageId, + }); + } + }, +); module.exports = router; diff --git a/api/server/routes/edit/openAI.js b/api/server/routes/edit/openAI.js index e54881148dc6..ae26b235c799 100644 --- a/api/server/routes/edit/openAI.js +++ b/api/server/routes/edit/openAI.js @@ -4,6 +4,7 @@ const { initializeClient } = require('~/server/services/Endpoints/openAI'); const { handleAbort, setHeaders, + validateModel, validateEndpoint, buildEndpointOption, moderateText, @@ -13,8 +14,15 @@ const router = express.Router(); router.use(moderateText); router.post('/abort', handleAbort()); -router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req, res, next) => { - await EditController(req, res, next, initializeClient); -}); +router.post( + '/', + validateEndpoint, + validateModel, + buildEndpointOption, + setHeaders, + async (req, res, next) => { + await EditController(req, res, next, initializeClient); + }, +); module.exports = router; diff --git a/docs/features/mod_system.md b/docs/features/mod_system.md index 099c5cb3a45b..3bef6b11ddd5 100644 --- a/docs/features/mod_system.md +++ b/docs/features/mod_system.md @@ -69,8 +69,13 @@ MESSAGE_IP_WINDOW=1 # in minutes, determines the window of time for MESSAGE_IP_M LIMIT_MESSAGE_USER=false # Whether to limit the amount of messages an IP can send per MESSAGE_USER_WINDOW MESSAGE_USER_MAX=40 # The max amount of messages an IP can send per MESSAGE_USER_WINDOW MESSAGE_USER_WINDOW=1 # in minutes, determines the window of time for MESSAGE_USER_MAX messages + +ILLEGAL_MODEL_REQ_SCORE=5 #Violation score to accrue if a user attempts to use an unlisted model. + ``` +> Note: Illegal model requests are almost always nefarious as it means a 3rd party is attempting to access the server through an automated script. For this, I recommend a relatively high score, no less than 5. + ## OpenAI moderation text ### OPENAI_MODERATION diff --git a/docs/install/configuration/dotenv.md b/docs/install/configuration/dotenv.md index 4b04c4a53ce6..f436e5085acb 100644 --- a/docs/install/configuration/dotenv.md +++ b/docs/install/configuration/dotenv.md @@ -602,8 +602,11 @@ REGISTRATION_VIOLATION_SCORE=1 CONCURRENT_VIOLATION_SCORE=1 MESSAGE_VIOLATION_SCORE=1 NON_BROWSER_VIOLATION_SCORE=20 +ILLEGAL_MODEL_REQ_SCORE=5 ``` +> Note: Non-browser access and Illegal model requests are almost always nefarious as it means a 3rd party is attempting to access the server through an automated script. + #### Login and registration rate limiting. - `LOGIN_MAX`: The max amount of logins allowed per IP per `LOGIN_WINDOW` - `LOGIN_WINDOW`: In minutes, determines the window of time for `LOGIN_MAX` logins diff --git a/packages/data-provider/src/config.ts b/packages/data-provider/src/config.ts index a0fad47a3942..27ef91d69582 100644 --- a/packages/data-provider/src/config.ts +++ b/packages/data-provider/src/config.ts @@ -284,10 +284,20 @@ export enum CacheKeys { * Key for the override config cache. */ OVERRIDE_CONFIG = 'overrideConfig', +} + +/** + * Enum for violation types, used to identify, log, and cache violations. + */ +export enum ViolationTypes { /** - * Key for accessing File Upload Violations (exceeding limit). + * File Upload Violations (exceeding limit). */ FILE_UPLOAD_LIMIT = 'file_upload_limit', + /** + * Illegal Model Request (not available). + */ + ILLEGAL_MODEL_REQUEST = 'illegal_model_request', } /**