Skip to content

Commit

Permalink
feat(chat): add rate limit to agnent chat
Browse files Browse the repository at this point in the history
  • Loading branch information
OdapX authored and gmpetrov committed Sep 3, 2023
1 parent c8cc88b commit 4352085
Show file tree
Hide file tree
Showing 8 changed files with 265 additions and 6 deletions.
18 changes: 18 additions & 0 deletions components/ChatBoxFrame.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@ function ChatBoxFrame(props: { initConfig?: AgentInterfaceConfig }) {
handleAbort,
} = useChat({
endpoint: `/api/agents/${router.query?.agentId}/query`,

channel: ConversationChannel.website,
agentId,
});

const primaryColor =
Expand All @@ -51,6 +53,22 @@ function ChatBoxFrame(props: { initConfig?: AgentInterfaceConfig }) {
return pickColorBasedOnBgColor(primaryColor, '#ffffff', '#000000');
}, [primaryColor]);

// TODO: find why onSuccess is not working
// useSWR<Agent>(`${API_URL}/api/agents/${agentId}`, fetcher, {
// onSuccess: (data) => {
// const agentConfig = data?.interfaceConfig as AgentInterfaceConfig;

// setAgent(data);
// setConfig({
// ...defaultChatBubbleConfig,
// ...agentConfig,
// });
// },
// onError: (err) => {
// console.error(err);
// },
// });

const handleFetchAgent = async () => {
try {
const res = await fetch(`${API_URL}/api/agents/${agentId}`);
Expand Down
23 changes: 21 additions & 2 deletions components/ChatBubble.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ function App(props: { agentId: string; initConfig?: AgentInterfaceConfig }) {
endpoint: `${API_URL}/api/agents/${props.agentId}/query`,
channel: 'website',
// channel: ConversationChannel.website // not working with bundler parcel,
agentId: props?.agentId,
});

const textColor = useMemo(() => {
Expand All @@ -94,6 +95,22 @@ function App(props: { agentId: string; initConfig?: AgentInterfaceConfig }) {
);
}, [state.config.primaryColor]);

// TODO: find why onSuccess is not working
// useSWR<Agent>(`${API_URL}/api/agents/${agentId}`, fetcher, {
// onSuccess: (data) => {
// const agentConfig = data?.interfaceConfig as AgentInterfaceConfig;

// setAgent(data);
// setConfig({
// ...defaultChatBubbleConfig,
// ...agentConfig,
// });
// },
// onError: (err) => {
// console.error(err);
// },
// });

const handleFetchAgent = async () => {
try {
const res = await fetch(`${API_URL}/api/agents/${props.agentId}`);
Expand All @@ -115,8 +132,10 @@ function App(props: { agentId: string; initConfig?: AgentInterfaceConfig }) {
};

useEffect(() => {
handleFetchAgent();
}, []);
if (props.agentId) {
handleFetchAgent();
}
}, [props.agentId]);

useEffect(() => {
if (state.config?.initialMessage) {
Expand Down
14 changes: 13 additions & 1 deletion components/Input.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,19 @@ const Input = forwardRef((props: Props, ref) => {

return (
<FormControl error={!!errorMsg}>
{!props.hidden && label && <FormLabel>{label}</FormLabel>}
{!props.hidden && label && (
<FormLabel
sx={{
...(props.disabled
? {
color: 'text.disabled',
}
: {}),
}}
>
{label}
</FormLabel>
)}

<BaseInput
ref={ref as any}
Expand Down
96 changes: 96 additions & 0 deletions components/RateLimitForm.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import { zodResolver } from '@hookform/resolvers/zod';
import { Button, Checkbox, FormLabel, Stack, Typography } from '@mui/joy';
import { useForm } from 'react-hook-form';
import { z } from 'zod';

import Input from '@app/components/Input';
import { AgentInterfaceConfig } from '@app/types/models';

const rateLimitSchema = AgentInterfaceConfig.pick({
rateLimit: true,
});

export type RateLimitFields = z.infer<typeof rateLimitSchema>;

interface Props extends RateLimitFields {
onSubmit(args: RateLimitFields): Promise<void>;
}

const RateLimitForm: React.FC<Props> = ({ onSubmit, rateLimit }) => {
const { register, control, handleSubmit, watch } = useForm<RateLimitFields>({
resolver: zodResolver(rateLimitSchema),
defaultValues: {
rateLimit,
},
});

const isRateLimitEnabled = watch('rateLimit.enabled');

return (
<form onSubmit={handleSubmit(onSubmit)}>
<FormLabel>Rate Limit</FormLabel>
<Typography
level="body3"
sx={{
mb: 2,
}}
>
Limit the number of messages sent from one device on the Chat Bubble,
iFrame and Standalone integrations.
</Typography>

<Stack gap={2}>
<div className="flex space-x-4">
<Checkbox
size="lg"
{...register('rateLimit.enabled')}
defaultChecked={isRateLimitEnabled}
/>
<div className="flex flex-col">
<FormLabel>Enable Rate Limit</FormLabel>
<Typography level="body3">
X messages max every Y seconds
</Typography>
</div>
</div>

<Stack gap={2} pl={4}>
<Input
control={control as any}
label="Max number of queries"
disabled={!isRateLimitEnabled}
placeholder="10"
{...register('rateLimit.maxQueries')}
/>
<Input
control={control as any}
label="Interval (in seconds)"
disabled={!isRateLimitEnabled}
placeholder="60"
{...register('rateLimit.interval')}
/>
<Input
control={control as any}
label="Rate Limit Reached Message"
placeholder="Usage limit reached"
disabled={!isRateLimitEnabled}
{...register('rateLimit.limitReachedMessage')}
/>
</Stack>
</Stack>

<div style={{ display: 'flex', justifyContent: 'flex-end' }}>
<Button
type="submit"
variant="solid"
color="primary"
sx={{ ml: 2, mt: 2 }} // Adjust the margin as needed
>
Save
</Button>
</div>
</form>
);
};

export default RateLimitForm;
23 changes: 22 additions & 1 deletion hooks/useChat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import {
} from '@microsoft/fetch-event-source';
import type { ConversationChannel, Prisma } from '@prisma/client';
import { useCallback, useEffect } from 'react';
import { useSWRConfig } from 'swr';
import useSWRInfinite from 'swr/infinite';

import { getConversation } from '@app/pages/api/conversations/[conversationId]';
Expand All @@ -14,6 +13,7 @@ import type { ChatResponse, EvalAnswer } from '@app/types/dtos';
import { ApiError, ApiErrorType } from '@app/utils/api-error';
import { fetcher } from '@app/utils/swr-fetcher';

import useRateLimit from './useRateLimit';
import useStateReducer from './useStateReducer';

const API_URL = process.env.NEXT_PUBLIC_DASHBOARD_URL;
Expand All @@ -24,6 +24,8 @@ type Props = {
queryBody?: any;
datasourceId?: string;
localStorageConversationIdKey?: string;
// TODO: Remove when rate limit implemented from backend
agentId?: string;
};

export const handleEvalAnswer = async (props: {
Expand Down Expand Up @@ -64,6 +66,12 @@ const useChat = ({ endpoint, channel, queryBody, ...otherProps }: Props) => {
handleAbort: undefined as any,
});

// TODO: Remove when rate limit implemented from backend
const { isRateExceeded, rateExceededMessage, handleIncrementRateLimitCount } =
useRateLimit({
agentId: otherProps.agentId,
});

const getConversationQuery = useSWRInfinite<
Prisma.PromiseReturnType<typeof getConversation>
>(
Expand Down Expand Up @@ -124,6 +132,16 @@ const useChat = ({ endpoint, channel, queryBody, ...otherProps }: Props) => {
return;
}

if (isRateExceeded) {
setState({
history: [
...state.history,
{ from: 'agent', message: rateExceededMessage },
] as any,
});
return;
}

const ctrl = new AbortController();
const history = [...state.history, { from: 'human', message }];
const nextIndex = history.length;
Expand Down Expand Up @@ -288,8 +306,11 @@ const useChat = ({ endpoint, channel, queryBody, ...otherProps }: Props) => {
}
},
});

handleIncrementRateLimitCount?.();
} catch (err) {
console.error('err', err);

if (err instanceof ApiError) {
if (err?.message) {
error = err?.message;
Expand Down
55 changes: 55 additions & 0 deletions hooks/useRateLimit.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import type { Agent } from '@prisma/client';
import { useCallback, useEffect, useState } from 'react';
import useSWR from 'swr';

import { AgentInterfaceConfig } from '@app/types/models';
import { fetcher } from '@app/utils/swr-fetcher';

const API_URL = process.env.NEXT_PUBLIC_DASHBOARD_URL;

interface RateResponse {
isRateExceeded: boolean;
rateExceededMessage?: string;
handleIncrementRateLimitCount: () => any;
}

const useRateLimit = ({ agentId }: { agentId?: string }): RateResponse => {
const [isRateExceeded, setIsRateExceeded] = useState(false);

const getAgentQuery = useSWR<Agent>(
agentId ? `${API_URL}/api/agents/${agentId}` : null,
fetcher
);

const config = getAgentQuery?.data?.interfaceConfig as AgentInterfaceConfig;
const rateLimit = config?.rateLimit?.maxQueries || 0;

const handleIncrementRateLimitCount = useCallback(() => {
let currentRateCount = Number(localStorage.getItem('rateLimitCount')) || 0;
localStorage.setItem('rateLimitCount', `${++currentRateCount}`);

if (currentRateCount >= rateLimit) {
setIsRateExceeded(true);
}
}, [rateLimit]);

useEffect(() => {
if (!config?.rateLimit?.interval) return;

const interval = setInterval(() => {
localStorage.setItem('rateLimitCount', '0');
setIsRateExceeded(false);
}, config?.rateLimit?.interval * 1000);

return () => clearInterval(interval);
}, [config]);

return {
isRateExceeded: config?.rateLimit?.enabled ? isRateExceeded : false,
handleIncrementRateLimitCount,
rateExceededMessage:
config?.rateLimit?.limitReachedMessage || 'Usage limit reached',
};
};

export default useRateLimit;
28 changes: 26 additions & 2 deletions pages/agents/[agentId]/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,14 @@ import ChatBox from '@app/components/ChatBox';
import ChatBubble from '@app/components/ChatBubble';
import ConversationList from '@app/components/ConversationList';
import Layout from '@app/components/Layout';
import RateLimitForm, { RateLimitFields } from '@app/components/RateLimitForm';
import UsageLimitModal from '@app/components/UsageLimitModal';
import useChat from '@app/hooks/useChat';
import useStateReducer from '@app/hooks/useStateReducer';
import { upsertAgent } from '@app/pages/api/agents';
import { getAgent } from '@app/pages/api/agents/[id]';
import { RouteNames } from '@app/types';
import { AgentInterfaceConfig } from '@app/types/models';
import agentToolFormat from '@app/utils/agent-tool-format';
import { fetcher, postFetcher } from '@app/utils/swr-fetcher';
import { withAuth } from '@app/utils/withAuth';
Expand Down Expand Up @@ -163,6 +165,23 @@ export default function AgentPage() {
}
};

const handleSubmitRateLimit = async (values: RateLimitFields) => {
await toast.promise(
upsertAgentMutation.trigger({
...getAgentQuery?.data,
interfaceConfig: {
...(getAgentQuery?.data?.interfaceConfig as any),
rateLimit: values.rateLimit,
},
} as any),
{
loading: 'Updating...',
success: 'Updated!',
error: 'Something went wrong',
}
);
};

const handleChangeTab = (tab: string) => {
router.query.tab = tab;
router.replace(router);
Expand All @@ -188,7 +207,7 @@ export default function AgentPage() {
}

const agent = getAgentQuery?.data;

const agentConfig = agent.interfaceConfig as AgentInterfaceConfig;
return (
<Box
component="main"
Expand Down Expand Up @@ -660,6 +679,12 @@ export default function AgentPage() {

<Divider sx={{ my: 4 }} />

<RateLimitForm
onSubmit={handleSubmitRateLimit}
rateLimit={agentConfig.rateLimit}
/>

<Divider sx={{ my: 4 }} />
<FormControl sx={{ gap: 1 }}>
<FormLabel>Agent ID</FormLabel>
<Typography level="body3" mb={2}>
Expand Down Expand Up @@ -702,7 +727,6 @@ export default function AgentPage() {
</FormControl>

<Divider sx={{ my: 4 }} />

<FormControl sx={{ gap: 1 }}>
<FormLabel>Delete Agent</FormLabel>
<Typography level="body3">
Expand Down
Loading

0 comments on commit 4352085

Please sign in to comment.