Skip to content

Commit

Permalink
♻️ refactor: refactor the agent runtime payload (lobehub#5250)
Browse files Browse the repository at this point in the history
* refactor the code

* fix locale

* refactor endpoint to baseURL

* add tests

* fix tests

* fix tests

* Update auth.ts

* update snapshot

* fix tests

* fix tests
  • Loading branch information
arvinxx authored Dec 31, 2024
1 parent 21b6610 commit e420ab3
Show file tree
Hide file tree
Showing 27 changed files with 628 additions and 81 deletions.
1 change: 0 additions & 1 deletion next.config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ const nextConfig: NextConfig = {
'@icons-pack/react-simple-icons',
'@lobehub/ui',
'gpt-tokenizer',
'chroma-js',
],
webVitalsAttribution: ['CLS', 'LCP'],
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@ import { Checkbox, Form, FormInstance, Input } from 'antd';
import { memo, useEffect } from 'react';
import { useTranslation } from 'react-i18next';

import MaxTokenSlider from '@/components/MaxTokenSlider';
import { useIsMobile } from '@/hooks/useIsMobile';
import { ChatModelCard } from '@/types/llm';

import MaxTokenSlider from './MaxTokenSlider';

interface ModelConfigFormProps {
initialValues?: ChatModelCard;
onFormInstanceReady: (instance: FormInstance) => void;
Expand Down Expand Up @@ -66,7 +65,10 @@ const ModelConfigForm = memo<ModelConfigFormProps>(
>
<Input placeholder={t('llm.customModelCards.modelConfig.displayName.placeholder')} />
</Form.Item>
<Form.Item label={t('llm.customModelCards.modelConfig.tokens.title')} name={'contextWindowTokens'}>
<Form.Item
label={t('llm.customModelCards.modelConfig.tokens.title')}
name={'contextWindowTokens'}
>
<MaxTokenSlider />
</Form.Item>
<Form.Item
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@ import { useTranslation } from 'react-i18next';
import { Flexbox } from 'react-layout-kit';
import useMergeState from 'use-merge-value';

import { useServerConfigStore } from '@/store/serverConfig';
import { serverConfigSelectors } from '@/store/serverConfig/selectors';
import { useIsMobile } from '@/hooks/useIsMobile';

const Kibi = 1024;

Expand All @@ -20,7 +19,7 @@ interface MaxTokenSliderProps {
}

const MaxTokenSlider = memo<MaxTokenSliderProps>(({ value, onChange, defaultValue }) => {
const { t } = useTranslation('setting');
const { t } = useTranslation('components');

const [token, setTokens] = useMergeState(0, {
defaultValue,
Expand All @@ -45,7 +44,7 @@ const MaxTokenSlider = memo<MaxTokenSliderProps>(({ value, onChange, defaultValu
setPowValue(exponent(value / Kibi));
};

const isMobile = useServerConfigStore(serverConfigSelectors.isMobile);
const isMobile = useIsMobile();

const marks = useMemo(() => {
return {
Expand Down Expand Up @@ -74,7 +73,7 @@ const MaxTokenSlider = memo<MaxTokenSliderProps>(({ value, onChange, defaultValu
tooltip={{
formatter: (x) => {
if (typeof x === 'undefined') return;
if (x === 0) return t('llm.customModelCards.modelConfig.tokens.unlimited');
if (x === 0) return t('MaxTokenSlider.unlimited');

let value = getRealValue(x);
if (value < 125) return value.toFixed(0) + 'K';
Expand Down
9 changes: 6 additions & 3 deletions src/components/ModelSelect/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import { FC, memo } from 'react';
import { useTranslation } from 'react-i18next';
import { Center, Flexbox } from 'react-layout-kit';

import { ModelAbilities } from '@/types/aiModel';
import { ChatModelCard } from '@/types/llm';
import { formatTokenNumber } from '@/utils/format';

Expand Down Expand Up @@ -57,8 +58,10 @@ const useStyles = createStyles(({ css, token }) => ({
`,
}));

interface ModelInfoTagsProps extends ChatModelCard {
interface ModelInfoTagsProps extends ModelAbilities {
contextWindowTokens?: number | null;
directionReverse?: boolean;
isCustom?: boolean;
placement?: 'top' | 'right';
}

Expand Down Expand Up @@ -102,7 +105,7 @@ export const ModelInfoTags = memo<ModelInfoTagsProps>(
</div>
</Tooltip>
)}
{model.contextWindowTokens !== undefined && (
{typeof model.contextWindowTokens === 'number' && (
<Tooltip
overlayStyle={{ maxWidth: 'unset', pointerEvents: 'none' }}
placement={placement}
Expand All @@ -117,7 +120,7 @@ export const ModelInfoTags = memo<ModelInfoTagsProps>(
{model.contextWindowTokens === 0 ? (
<Infinity size={17} strokeWidth={1.6} />
) : (
formatTokenNumber(model.contextWindowTokens)
formatTokenNumber(model.contextWindowTokens as number)
)}
</Center>
</Tooltip>
Expand Down
10 changes: 9 additions & 1 deletion src/components/NProgress/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,15 @@ import { memo } from 'react';

const NProgress = memo(() => {
const theme = useTheme();
return <NextTopLoader color={theme.colorText} height={2} shadow={false} showSpinner={false} />;
return (
<NextTopLoader
color={theme.colorText}
height={2}
shadow={false}
showSpinner={false}
zIndex={1000}
/>
);
});

export default NProgress;
2 changes: 1 addition & 1 deletion src/const/auth.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ export interface JWTPayload {
/**
* Represents the endpoint of provider
*/
endpoint?: string;
baseURL?: string;

azureApiVersion?: string;

Expand Down
11 changes: 11 additions & 0 deletions src/database/server/models/__tests__/user.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,17 @@ describe('UserModel', () => {
});
});

describe('getUserSettings', () => {
it('should get user settings', async () => {
await serverDB.insert(users).values({ id: userId });
await serverDB.insert(userSettings).values({ id: userId, general: { language: 'en-US' } });

const data = await userModel.getUserSettings();

expect(data).toMatchObject({ id: userId, general: { language: 'en-US' } });
});
});

describe('deleteSetting', () => {
it('should delete user settings', async () => {
await serverDB.insert(users).values({ id: userId });
Expand Down
4 changes: 4 additions & 0 deletions src/database/server/models/user.ts
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,10 @@ export class UserModel {
};
};

getUserSettings = async () => {
return this.db.query.userSettings.findFirst({ where: eq(userSettings.id, this.userId) });
};

updateUser = async (value: Partial<UserItem>) => {
return this.db
.update(users)
Expand Down
20 changes: 10 additions & 10 deletions src/libs/agent-runtime/AgentRuntime.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,8 @@ describe('AgentRuntime', () => {
describe('Azure OpenAI provider', () => {
it('should initialize correctly', async () => {
const jwtPayload = {
apikey: 'user-azure-key',
endpoint: 'user-azure-endpoint',
apiKey: 'user-azure-key',
baseURL: 'user-azure-endpoint',
apiVersion: '2024-06-01',
};

Expand All @@ -90,8 +90,8 @@ describe('AgentRuntime', () => {
});
it('should initialize with azureOpenAIParams correctly', async () => {
const jwtPayload = {
apikey: 'user-openai-key',
endpoint: 'user-endpoint',
apiKey: 'user-openai-key',
baseURL: 'user-endpoint',
apiVersion: 'custom-version',
};

Expand All @@ -106,8 +106,8 @@ describe('AgentRuntime', () => {

it('should initialize with AzureAI correctly', async () => {
const jwtPayload = {
apikey: 'user-azure-key',
endpoint: 'user-azure-endpoint',
apiKey: 'user-azure-key',
baseURL: 'user-azure-endpoint',
};
const runtime = await AgentRuntime.initializeWithProviderOptions(ModelProvider.Azure, {
azure: jwtPayload,
Expand Down Expand Up @@ -171,7 +171,7 @@ describe('AgentRuntime', () => {

describe('Ollama provider', () => {
it('should initialize correctly', async () => {
const jwtPayload: JWTPayload = { endpoint: 'user-ollama-url' };
const jwtPayload: JWTPayload = { baseURL: 'https://user-ollama-url' };
const runtime = await AgentRuntime.initializeWithProviderOptions(ModelProvider.Ollama, {
ollama: jwtPayload,
});
Expand Down Expand Up @@ -255,7 +255,7 @@ describe('AgentRuntime', () => {

describe('AgentRuntime chat method', () => {
it('should run correctly', async () => {
const jwtPayload: JWTPayload = { apiKey: 'user-openai-key', endpoint: 'user-endpoint' };
const jwtPayload: JWTPayload = { apiKey: 'user-openai-key', baseURL: 'user-endpoint' };
const runtime = await AgentRuntime.initializeWithProviderOptions(ModelProvider.OpenAI, {
openai: jwtPayload,
});
Expand All @@ -271,7 +271,7 @@ describe('AgentRuntime', () => {
await runtime.chat(payload);
});
it('should handle options correctly', async () => {
const jwtPayload: JWTPayload = { apiKey: 'user-openai-key', endpoint: 'user-endpoint' };
const jwtPayload: JWTPayload = { apiKey: 'user-openai-key', baseURL: 'user-endpoint' };
const runtime = await AgentRuntime.initializeWithProviderOptions(ModelProvider.OpenAI, {
openai: jwtPayload,
});
Expand Down Expand Up @@ -300,7 +300,7 @@ describe('AgentRuntime', () => {
});

describe('callback', async () => {
const jwtPayload: JWTPayload = { apiKey: 'user-openai-key', endpoint: 'user-endpoint' };
const jwtPayload: JWTPayload = { apiKey: 'user-openai-key', baseURL: 'user-endpoint' };
const runtime = await AgentRuntime.initializeWithProviderOptions(ModelProvider.OpenAI, {
openai: jwtPayload,
});
Expand Down
6 changes: 3 additions & 3 deletions src/libs/agent-runtime/AgentRuntime.ts
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ class AgentRuntime {
ai21: Partial<ClientOptions>;
ai360: Partial<ClientOptions>;
anthropic: Partial<ClientOptions>;
azure: { apiVersion?: string; apikey?: string; endpoint?: string };
azure: { apiKey?: string; apiVersion?: string; baseURL?: string };
baichuan: Partial<ClientOptions>;
bedrock: Partial<LobeBedrockAIParams>;
cloudflare: Partial<LobeCloudflareParams>;
Expand Down Expand Up @@ -180,8 +180,8 @@ class AgentRuntime {

case ModelProvider.Azure: {
runtimeModel = new LobeAzureOpenAI(
params.azure?.endpoint,
params.azure?.apikey,
params.azure?.baseURL,
params.azure?.apiKey,
params.azure?.apiVersion,
);
break;
Expand Down
5 changes: 4 additions & 1 deletion src/libs/agent-runtime/ollama/index.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,10 @@ describe('LobeOllamaAI', () => {
try {
new LobeOllamaAI({ baseURL: 'invalid-url' });
} catch (e) {
expect(e).toEqual(AgentRuntimeError.createError(AgentRuntimeErrorType.InvalidOllamaArgs));
expect(e).toEqual({
error: new TypeError('Invalid URL'),
errorType: 'InvalidOllamaArgs',
});
}
});
});
Expand Down
4 changes: 2 additions & 2 deletions src/libs/agent-runtime/ollama/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ export class LobeOllamaAI implements LobeRuntimeAI {
constructor({ baseURL }: ClientOptions = {}) {
try {
if (baseURL) new URL(baseURL);
} catch {
throw AgentRuntimeError.createError(AgentRuntimeErrorType.InvalidOllamaArgs);
} catch (e) {
throw AgentRuntimeError.createError(AgentRuntimeErrorType.InvalidOllamaArgs, e);
}

this.client = new Ollama(!baseURL ? undefined : { host: baseURL });
Expand Down
10 changes: 10 additions & 0 deletions src/libs/agent-runtime/openai/__snapshots__/index.test.ts.snap
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ exports[`LobeOpenAI > models > should get models 1`] = `
"input": 0.5,
"output": 1.5,
},
"releasedAt": "2023-02-28",
},
{
"id": "gpt-3.5-turbo-16k",
Expand All @@ -35,6 +36,7 @@ exports[`LobeOpenAI > models > should get models 1`] = `
"input": 10,
"output": 30,
},
"releasedAt": "2024-01-23",
},
{
"contextWindowTokens": 128000,
Expand All @@ -46,6 +48,7 @@ exports[`LobeOpenAI > models > should get models 1`] = `
"input": 10,
"output": 30,
},
"releasedAt": "2024-01-23",
},
{
"contextWindowTokens": 4096,
Expand All @@ -56,6 +59,7 @@ exports[`LobeOpenAI > models > should get models 1`] = `
"input": 1.5,
"output": 2,
},
"releasedAt": "2023-08-24",
},
{
"id": "gpt-3.5-turbo-0301",
Expand All @@ -73,6 +77,7 @@ exports[`LobeOpenAI > models > should get models 1`] = `
"input": 1,
"output": 2,
},
"releasedAt": "2023-11-02",
},
{
"contextWindowTokens": 128000,
Expand All @@ -84,13 +89,15 @@ exports[`LobeOpenAI > models > should get models 1`] = `
"input": 10,
"output": 30,
},
"releasedAt": "2023-11-02",
},
{
"contextWindowTokens": 128000,
"deploymentName": "gpt-4-vision",
"description": "GPT-4 视觉预览版,专为图像分析和处理任务设计。",
"displayName": "GPT 4 Turbo with Vision Preview",
"id": "gpt-4-vision-preview",
"releasedAt": "2023-11-02",
"vision": true,
},
{
Expand All @@ -103,6 +110,7 @@ exports[`LobeOpenAI > models > should get models 1`] = `
"input": 30,
"output": 60,
},
"releasedAt": "2023-06-27",
},
{
"contextWindowTokens": 16385,
Expand All @@ -114,6 +122,7 @@ exports[`LobeOpenAI > models > should get models 1`] = `
"input": 0.5,
"output": 1.5,
},
"releasedAt": "2024-01-23",
},
{
"contextWindowTokens": 8192,
Expand All @@ -125,6 +134,7 @@ exports[`LobeOpenAI > models > should get models 1`] = `
"input": 30,
"output": 60,
},
"releasedAt": "2023-06-12",
},
]
`;
Loading

0 comments on commit e420ab3

Please sign in to comment.