Skip to content

Commit

Permalink
feat(api): initial Redis support; fix(SearchBar): proper debounce (da…
Browse files Browse the repository at this point in the history
…nny-avila#1039)

* refactor: use keyv for search caching with 1 min expirations

* feat: keyvRedis; chore: bump keyv, bun.lockb, add jsconfig for vscode file resolution

* feat: api/search redis support

* refactor(redis) use ioredis cluster for keyv
fix(OpenID): when redis is configured, use redis memory store for express-session

* fix: revert using uri for keyvredis

* fix(SearchBar): properly debounce search queries, fix weird render behaviors

* refactor: add authentication to search endpoint and show error messages in results

* feat: redis support for violation logs

* fix(logViolation): ensure a number is always being stored in cache

* feat(concurrentLimiter): uses clearPendingReq, clears pendingReq on abort, redis support

* fix(api/search/enable): query only when authenticated

* feat(ModelService): redis support

* feat(checkBan): redis support

* refactor(api/search): consolidate keyv logic

* fix(ci): add default empty value for REDIS_URI

* refactor(keyvRedis): use condition to initialize keyvRedis assignment

* refactor(connectDb): handle disconnected state (should create a new conn)

* fix(ci/e2e): handle case where cleanUp did not successfully run

* fix(getDefaultEndpoint): return endpoint from localStorage if defined and endpointsConfig is default

* ci(e2e): remove afterAll messages as startup/cleanUp will clear messages

* ci(e2e): remove teardown for CI until further notice

* chore: bump playwright/test

* ci(e2e): reinstate teardown as CI issue is specific to github env

* fix(ci): click settings menu trigger by testid
  • Loading branch information
danny-avila authored Oct 11, 2023
1 parent 59235c9 commit 86be4b8
Show file tree
Hide file tree
Showing 15 changed files with 172 additions and 92 deletions.
45 changes: 32 additions & 13 deletions cache/clearPendingReq.js
Original file line number Diff line number Diff line change
@@ -1,29 +1,48 @@
const Keyv = require('keyv');
const { pendingReqFile } = require('./keyvFiles');
const { LIMIT_CONCURRENT_MESSAGES } = process.env ?? {};

const keyv = new Keyv({ store: pendingReqFile, namespace: 'pendingRequests' });
const getLogStores = require('./getLogStores');
const { isEnabled } = require('../server/utils');
const { USE_REDIS, LIMIT_CONCURRENT_MESSAGES } = process.env ?? {};
const ttl = 1000 * 60 * 1;

/**
* Clear pending requests from the cache.
* Clear or decrement pending requests from the cache.
* Checks the environmental variable LIMIT_CONCURRENT_MESSAGES;
* if the rule is enabled ('true'), pending requests in the cache are cleared.
* if the rule is enabled ('true'), it either decrements the count of pending requests
* or deletes the key if the count is less than or equal to 1.
*
* @module clearPendingReq
* @requires keyv
* @requires keyvFiles
* @requires ./getLogStores
* @requires ../server/utils
* @requires process
*
* @async
* @function
* @returns {Promise<void>} A promise that either clears 'pendingRequests' from store or resolves with no value.
* @param {Object} params - The parameters object.
* @param {string} params.userId - The user ID for which the pending requests are to be cleared or decremented.
* @param {Object} [params.cache] - An optional cache object to use. If not provided, a default cache will be fetched using getLogStores.
* @returns {Promise<void>} A promise that either decrements the 'pendingRequests' count, deletes the key from the store, or resolves with no value.
*/
const clearPendingReq = async () => {
if (LIMIT_CONCURRENT_MESSAGES?.toLowerCase() !== 'true') {
const clearPendingReq = async ({ userId, cache: _cache }) => {
if (!userId) {
return;
} else if (!isEnabled(LIMIT_CONCURRENT_MESSAGES)) {
return;
}

const namespace = 'pending_req';
const cache = _cache ?? getLogStores(namespace);

if (!cache) {
return;
}

await keyv.clear();
const key = `${USE_REDIS ? namespace : ''}:${userId ?? ''}`;
const currentReq = +((await cache.get(key)) ?? 0);

if (currentReq && currentReq >= 1) {
await cache.set(key, currentReq - 1, ttl);
} else {
await cache.delete(key);
}
};

module.exports = clearPendingReq;
40 changes: 25 additions & 15 deletions cache/getLogStores.js
Original file line number Diff line number Diff line change
@@ -1,26 +1,37 @@
const Keyv = require('keyv');
const keyvMongo = require('./keyvMongo');
const { math } = require('../server/utils');
const keyvRedis = require('./keyvRedis');
const { math, isEnabled } = require('../server/utils');
const { logFile, violationFile } = require('./keyvFiles');
const { BAN_DURATION } = process.env ?? {};
const { BAN_DURATION, USE_REDIS } = process.env ?? {};

const duration = math(BAN_DURATION, 7200000);

const createViolationInstance = (namespace) => {
const config = isEnabled(USE_REDIS) ? { store: keyvRedis } : { store: violationFile, namespace };
return new Keyv(config);
};

// Serve cache from memory so no need to clear it on startup/exit
const pending_req = isEnabled(USE_REDIS)
? new Keyv({ store: keyvRedis })
: new Keyv({ namespace: 'pending_req' });

const namespaces = {
ban: new Keyv({ store: keyvMongo, ttl: duration, namespace: 'bans' }),
pending_req,
ban: new Keyv({ store: keyvMongo, namespace: 'bans', duration }),
general: new Keyv({ store: logFile, namespace: 'violations' }),
concurrent: new Keyv({ store: violationFile, namespace: 'concurrent' }),
non_browser: new Keyv({ store: violationFile, namespace: 'non_browser' }),
message_limit: new Keyv({ store: violationFile, namespace: 'message_limit' }),
token_balance: new Keyv({ store: violationFile, namespace: 'token_balance' }),
registrations: new Keyv({ store: violationFile, namespace: 'registrations' }),
logins: new Keyv({ store: violationFile, namespace: 'logins' }),
concurrent: createViolationInstance('concurrent'),
non_browser: createViolationInstance('non_browser'),
message_limit: createViolationInstance('message_limit'),
token_balance: createViolationInstance('token_balance'),
registrations: createViolationInstance('registrations'),
logins: createViolationInstance('logins'),
};

/**
* Returns either the logs of violations specified by type if a type is provided
* or it returns the general log if no type is specified. If an invalid type is passed,
* an error will be thrown.
* Returns the keyv cache specified by type.
* If an invalid type is passed, an error will be thrown.
*
* @module getLogStores
* @requires keyv - a simple key-value storage that allows you to easily switch out storage adapters.
Expand All @@ -31,11 +42,10 @@ const namespaces = {
* @throws Will throw an error if an invalid violation type is passed.
*/
const getLogStores = (type) => {
if (!type) {
if (!type || !namespaces[type]) {
throw new Error(`Invalid store type: ${type}`);
}
const logs = namespaces[type];
return logs;
return namespaces[type];
};

module.exports = getLogStores;
3 changes: 1 addition & 2 deletions cache/index.js
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
const keyvFiles = require('./keyvFiles');
const getLogStores = require('./getLogStores');
const logViolation = require('./logViolation');
const clearPendingReq = require('./clearPendingReq');

module.exports = { ...keyvFiles, getLogStores, logViolation, clearPendingReq };
module.exports = { ...keyvFiles, getLogStores, logViolation };
14 changes: 14 additions & 0 deletions cache/keyvRedis.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
const KeyvRedis = require('@keyv/redis');

const { REDIS_URI } = process.env;

let keyvRedis;

if (REDIS_URI) {
keyvRedis = new KeyvRedis(REDIS_URI, { useRedisSets: false });
keyvRedis.on('error', (err) => console.error('KeyvRedis connection error:', err));
} else {
// console.log('REDIS_URI not provided. Redis module will not be initialized.');
}

module.exports = keyvRedis;
12 changes: 7 additions & 5 deletions cache/logViolation.js
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
const getLogStores = require('./getLogStores');
const banViolation = require('./banViolation');
const { isEnabled } = require('../server/utils');

/**
* Logs the violation.
Expand All @@ -17,21 +18,22 @@ const logViolation = async (req, res, type, errorMessage, score = 1) => {
}
const logs = getLogStores('general');
const violationLogs = getLogStores(type);
const key = isEnabled(process.env.USE_REDIS) ? `${type}:${userId}` : userId;

const userViolations = (await violationLogs.get(userId)) ?? 0;
const violationCount = userViolations + score;
await violationLogs.set(userId, violationCount);
const userViolations = (await violationLogs.get(key)) ?? 0;
const violationCount = +userViolations + +score;
await violationLogs.set(key, violationCount);

errorMessage.user_id = userId;
errorMessage.prev_count = userViolations;
errorMessage.violation_count = violationCount;
errorMessage.date = new Date().toISOString();

await banViolation(req, res, errorMessage);
const userLogs = (await logs.get(userId)) ?? [];
const userLogs = (await logs.get(key)) ?? [];
userLogs.push(errorMessage);
delete errorMessage.user_id;
await logs.set(userId, userLogs);
await logs.set(key, userLogs);
};

module.exports = logViolation;
4 changes: 4 additions & 0 deletions cache/redis.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
const Redis = require('ioredis');
const { REDIS_URI } = process.env ?? {};
const redis = new Redis.Cluster(REDIS_URI);
module.exports = redis;
13 changes: 13 additions & 0 deletions jsconfig.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
{
"compilerOptions": {
"target": "ES6",
"module": "CommonJS",
// "checkJs": true, // Report errors in JavaScript files
"baseUrl": "./",
"paths": {
"*": ["*", "node_modules/*"],
"~/*": ["./*"]
}
},
"exclude": ["node_modules"]
}
5 changes: 3 additions & 2 deletions lib/db/connectDb.js
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,12 @@ if (!cached) {
}

async function connectDb() {
if (cached.conn) {
if (cached.conn && cached.conn?._readyState === 1) {
return cached.conn;
}

if (!cached.promise) {
const disconnected = cached.conn && cached.conn?._readyState !== 1;
if (!cached.promise || disconnected) {
const opts = {
useNewUrlParser: true,
useUnifiedTopology: true,
Expand Down
5 changes: 4 additions & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,13 @@
"@anthropic-ai/sdk": "^0.5.4",
"@azure/search-documents": "^11.3.2",
"@keyv/mongo": "^2.1.8",
"@keyv/redis": "^2.8.0",
"@waylaidwanderer/chatgpt-api": "^1.37.2",
"axios": "^1.3.4",
"bcryptjs": "^2.4.3",
"cheerio": "^1.0.0-rc.12",
"cohere-ai": "^6.0.0",
"connect-redis": "^7.1.0",
"cookie": "^0.5.0",
"cors": "^2.8.5",
"dotenv": "^16.0.3",
Expand All @@ -39,10 +41,11 @@
"googleapis": "^118.0.0",
"handlebars": "^4.7.7",
"html": "^1.0.0",
"ioredis": "^5.3.2",
"jose": "^4.15.2",
"js-yaml": "^4.1.0",
"jsonwebtoken": "^9.0.0",
"keyv": "^4.5.3",
"keyv": "^4.5.4",
"keyv-file": "^0.2.0",
"langchain": "^0.0.153",
"lodash": "^4.17.21",
Expand Down
6 changes: 5 additions & 1 deletion server/middleware/abortMiddleware.js
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
const { sendMessage, sendError, countTokens, isEnabled } = require('../utils');
const { saveMessage, getConvo, getConvoTitle } = require('../../models');
const { sendMessage, sendError, countTokens } = require('../utils');
const clearPendingReq = require('../../cache/clearPendingReq');
const spendTokens = require('../../models/spendTokens');
const abortControllers = require('./abortControllers');

Expand All @@ -20,6 +21,9 @@ async function abortMessage(req, res) {
const handleAbort = () => {
return async (req, res) => {
try {
if (isEnabled(process.env.LIMIT_CONCURRENT_MESSAGES)) {
await clearPendingReq({ userId: req.user.id });
}
return await abortMessage(req, res);
} catch (err) {
console.error(err);
Expand Down
19 changes: 12 additions & 7 deletions server/middleware/checkBan.js
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,11 @@ const uap = require('ua-parser-js');
const { getLogStores } = require('../../cache');
const denyRequest = require('./denyRequest');
const { isEnabled, removePorts } = require('../utils');
const keyvRedis = require('../../cache/keyvRedis');

const banCache = new Keyv({ namespace: 'bans', ttl: 0 });
const banCache = isEnabled(process.env.USE_REDIS)
? new Keyv({ store: keyvRedis })
: new Keyv({ namespace: 'bans', ttl: 0 });
const message = 'Your account has been temporarily banned due to violations of our service.';

/**
Expand Down Expand Up @@ -50,9 +53,11 @@ const checkBan = async (req, res, next = () => {}) => {

req.ip = removePorts(req);
const userId = req.user?.id ?? req.user?._id ?? null;
const ipKey = isEnabled(process.env.USE_REDIS) ? `ban_cache:ip:${req.ip}` : req.ip;
const userKey = isEnabled(process.env.USE_REDIS) ? `ban_cache:user:${userId}` : userId;

const cachedIPBan = await banCache.get(req.ip);
const cachedUserBan = await banCache.get(userId);
const cachedIPBan = await banCache.get(ipKey);
const cachedUserBan = await banCache.get(userKey);
const cachedBan = cachedIPBan || cachedUserBan;

if (cachedBan) {
Expand All @@ -78,13 +83,13 @@ const checkBan = async (req, res, next = () => {}) => {
const timeLeft = Number(isBanned.expiresAt) - Date.now();

if (timeLeft <= 0) {
await banLogs.delete(req.ip);
await banLogs.delete(userId);
await banLogs.delete(ipKey);
await banLogs.delete(userKey);
return next();
}

banCache.set(req.ip, isBanned, timeLeft);
banCache.set(userId, isBanned, timeLeft);
banCache.set(ipKey, isBanned, timeLeft);
banCache.set(userKey, isBanned, timeLeft);
req.banned = true;
return await banResponse(req, res);
};
Expand Down
Loading

0 comments on commit 86be4b8

Please sign in to comment.