Skip to content

Commit

Permalink
🛡️ feat: Model Validation Middleware (danny-avila#1841)
Browse files Browse the repository at this point in the history
* refactor: add ViolationTypes enum and add new violation for illegal model requests

* feat: validateModel middleware to protect the backend against illicit requests for unlisted models
  • Loading branch information
danny-avila authored and zpdsherlock committed Feb 24, 2024
1 parent 3e7aa3e commit fa26186
Show file tree
Hide file tree
Showing 19 changed files with 534 additions and 372 deletions.
2 changes: 2 additions & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,8 @@ LIMIT_MESSAGE_USER=false
MESSAGE_USER_MAX=40
MESSAGE_USER_WINDOW=1

ILLEGAL_MODEL_REQ_SCORE=5

#========================#
# Balance #
#========================#
Expand Down
7 changes: 5 additions & 2 deletions api/cache/getLogStores.js
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
const Keyv = require('keyv');
const { CacheKeys } = require('librechat-data-provider');
const { CacheKeys, ViolationTypes } = require('librechat-data-provider');
const { logFile, violationFile } = require('./keyvFiles');
const { math, isEnabled } = require('~/server/utils');
const keyvRedis = require('./keyvRedis');
Expand Down Expand Up @@ -49,7 +49,10 @@ const namespaces = {
message_limit: createViolationInstance('message_limit'),
token_balance: createViolationInstance('token_balance'),
registrations: createViolationInstance('registrations'),
[CacheKeys.FILE_UPLOAD_LIMIT]: createViolationInstance(CacheKeys.FILE_UPLOAD_LIMIT),
[ViolationTypes.FILE_UPLOAD_LIMIT]: createViolationInstance(ViolationTypes.FILE_UPLOAD_LIMIT),
[ViolationTypes.ILLEGAL_MODEL_REQUEST]: createViolationInstance(
ViolationTypes.ILLEGAL_MODEL_REQUEST,
),
logins: createViolationInstance('logins'),
[CacheKeys.ABORT_KEYS]: abortKeys,
[CacheKeys.TOKEN_CONFIG]: tokenConfig,
Expand Down
2 changes: 2 additions & 0 deletions api/server/middleware/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ const checkBan = require('./checkBan');
const uaParser = require('./uaParser');
const setHeaders = require('./setHeaders');
const loginLimiter = require('./loginLimiter');
const validateModel = require('./validateModel');
const requireJwtAuth = require('./requireJwtAuth');
const uploadLimiters = require('./uploadLimiters');
const registerLimiter = require('./registerLimiter');
Expand Down Expand Up @@ -32,6 +33,7 @@ module.exports = {
validateMessageReq,
buildEndpointOption,
validateRegistration,
validateModel,
moderateText,
noIndex,
};
4 changes: 2 additions & 2 deletions api/server/middleware/uploadLimiters.js
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
const rateLimit = require('express-rate-limit');
const { CacheKeys } = require('librechat-data-provider');
const { ViolationTypes } = require('librechat-data-provider');
const logViolation = require('~/cache/logViolation');

const getEnvironmentVariables = () => {
Expand Down Expand Up @@ -35,7 +35,7 @@ const createFileUploadHandler = (ip = true) => {
} = getEnvironmentVariables();

return async (req, res) => {
const type = CacheKeys.FILE_UPLOAD_LIMIT;
const type = ViolationTypes.FILE_UPLOAD_LIMIT;
const errorMessage = {
type,
max: ip ? fileUploadIpMax : fileUploadUserMax,
Expand Down
50 changes: 50 additions & 0 deletions api/server/middleware/validateModel.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
const { EModelEndpoint, CacheKeys, ViolationTypes } = require('librechat-data-provider');
const { logViolation, getLogStores } = require('~/cache');
const { handleError } = require('~/server/utils');

/**
* Validates the model of the request.
*
* @async
* @param {Express.Request} req - The Express request object.
* @param {Express.Response} res - The Express response object.
* @param {Function} next - The Express next function.
*/
const validateModel = async (req, res, next) => {
const { model, endpoint } = req.body;
if (!model) {
return handleError(res, { text: 'Model not provided' });
}

const cache = getLogStores(CacheKeys.CONFIG_STORE);
const modelsConfig = await cache.get(CacheKeys.MODELS_CONFIG);
if (!modelsConfig) {
return handleError(res, { text: 'Models not loaded' });
}

const availableModels = modelsConfig[endpoint];
if (!availableModels) {
return handleError(res, { text: 'Endpoint models not loaded' });
}

let validModel = !!availableModels.find((availableModel) => availableModel === model);
if (endpoint === EModelEndpoint.gptPlugins) {
validModel = validModel && availableModels.includes(req.body.agentOptions?.model);
}

if (validModel) {
return next();
}

const { ILLEGAL_MODEL_REQ_SCORE: score = 5 } = process.env ?? {};

const type = ViolationTypes.ILLEGAL_MODEL_REQUEST;
const errorMessage = {
type,
};

await logViolation(req, res, type, errorMessage, score);
return handleError(res, { text: 'Illegal model request' });
};

module.exports = validateModel;
14 changes: 11 additions & 3 deletions api/server/routes/ask/anthropic.js
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ const { initializeClient } = require('~/server/services/Endpoints/anthropic');
const {
setHeaders,
handleAbort,
validateModel,
validateEndpoint,
buildEndpointOption,
} = require('~/server/middleware');
Expand All @@ -12,8 +13,15 @@ const router = express.Router();

router.post('/abort', handleAbort());

router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req, res, next) => {
await AskController(req, res, next, initializeClient);
});
router.post(
'/',
validateEndpoint,
validateModel,
buildEndpointOption,
setHeaders,
async (req, res, next) => {
await AskController(req, res, next, initializeClient);
},
);

module.exports = router;
14 changes: 11 additions & 3 deletions api/server/routes/ask/custom.js
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ const { addTitle } = require('~/server/services/Endpoints/openAI');
const {
handleAbort,
setHeaders,
validateModel,
validateEndpoint,
buildEndpointOption,
} = require('~/server/middleware');
Expand All @@ -13,8 +14,15 @@ const router = express.Router();

router.post('/abort', handleAbort());

router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req, res, next) => {
await AskController(req, res, next, initializeClient, addTitle);
});
router.post(
'/',
validateEndpoint,
validateModel,
buildEndpointOption,
setHeaders,
async (req, res, next) => {
await AskController(req, res, next, initializeClient, addTitle);
},
);

module.exports = router;
14 changes: 11 additions & 3 deletions api/server/routes/ask/google.js
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ const { initializeClient } = require('~/server/services/Endpoints/google');
const {
setHeaders,
handleAbort,
validateModel,
validateEndpoint,
buildEndpointOption,
} = require('~/server/middleware');
Expand All @@ -12,8 +13,15 @@ const router = express.Router();

router.post('/abort', handleAbort());

router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req, res, next) => {
await AskController(req, res, next, initializeClient);
});
router.post(
'/',
validateEndpoint,
validateModel,
buildEndpointOption,
setHeaders,
async (req, res, next) => {
await AskController(req, res, next, initializeClient);
},
);

module.exports = router;
Loading

0 comments on commit fa26186

Please sign in to comment.