Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

✨ feat: 允许用户自行定义 Embedding 模型 #5177

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Prev Previous commit
Next Next commit
feat: 重构文件配置,更新默认设置和相关测试
  • Loading branch information
cookieY committed Dec 27, 2024
commit dd26797d725aca33d94c6efbaf1e408c96d5ea36
4 changes: 1 addition & 3 deletions src/config/knowledge.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@ import { z } from 'zod';
export const getKnowledgeConfig = () => {
return createEnv({
runtimeEnv: {
DEFAULT_FILES_CONFIG: !!process.env.DEFAULT_FILES_CONFIG
? process.env.DEFAULT_FILES_CONFIG
: 'embedding_model=openai/embedding-text-3-small,reranker_model=cohere/rerank-english-v3.0,query_mode=full_text',
DEFAULT_FILES_CONFIG: process.env.DEFAULT_FILES_CONFIG,
UNSTRUCTURED_API_KEY: process.env.UNSTRUCTURED_API_KEY,
UNSTRUCTURED_SERVER_URL: process.env.UNSTRUCTURED_SERVER_URL,
},
Expand Down
28 changes: 22 additions & 6 deletions src/const/settings/knowledge.ts
cookieY marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -1,9 +1,25 @@
import { SystemEmbeddingConfig } from '@/types/knowledgeBase';
import { FilesConfig, FilesConfigItem } from '@/types/user/settings/filesConfig';

import { DEFAULT_SYSTEM_AGENT_ITEM } from './systemAgent';
import {
DEFAULT_EMBEDDING_MODEL,
DEFAULT_PROVIDER,
DEFAULT_RERANK_MODEL,
DEFAULT_RERANK_PROVIDER,
DEFAULT_RERANK_QUERY_MODE,
} from './llm';

export const SYSTEM_EMBEDDING_CONFIG: SystemEmbeddingConfig = {
embedding_model: DEFAULT_SYSTEM_AGENT_ITEM,
query_model: DEFAULT_SYSTEM_AGENT_ITEM,
reranker_model: DEFAULT_SYSTEM_AGENT_ITEM,
export const DEFAULT_FILE_EMBEDDING_MODEL_ITEM: FilesConfigItem = {
model: DEFAULT_EMBEDDING_MODEL,
provider: DEFAULT_PROVIDER,
};

export const DEFAULT_FILE_RERANK_MODEL_ITEM: FilesConfigItem = {
model: DEFAULT_RERANK_MODEL,
provider: DEFAULT_RERANK_PROVIDER,
};

export const DEFAULT_FILES_CONFIG: FilesConfig = {
embedding_model: DEFAULT_FILE_EMBEDDING_MODEL_ITEM,
cookieY marked this conversation as resolved.
Show resolved Hide resolved
query_model: DEFAULT_RERANK_QUERY_MODE,
reranker_model: DEFAULT_FILE_RERANK_MODEL_ITEM,
};
6 changes: 6 additions & 0 deletions src/const/settings/llm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,12 @@ export const DEFAULT_LLM_CONFIG = genUserLLMConfig({
});

export const DEFAULT_MODEL = 'gpt-4o-mini';

export const DEFAULT_EMBEDDING_MODEL = 'text-embedding-3-small';
export const DEFAULT_EMBEDDING_PROVIDER = ModelProvider.OpenAI;

export const DEFAULT_RERANK_MODEL = 'rerank-english-v3.0';
export const DEFAULT_RERANK_PROVIDER = 'cohere';
export const DEFAULT_RERANK_QUERY_MODE = 'full_text';

export const DEFAULT_PROVIDER = ModelProvider.OpenAI;
5 changes: 1 addition & 4 deletions src/database/schemas/ragEvals.ts
cookieY marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import { integer, jsonb, pgTable, text, uuid } from 'drizzle-orm/pg-core';

import { DEFAULT_MODEL } from '@/const/settings';
import { getServerDefaultFilesConfig } from '@/server/globalConfig';
import { EvalEvaluationStatus } from '@/types/eval';

import { timestamps } from './_helpers';
Expand Down Expand Up @@ -61,9 +60,7 @@ export const evalEvaluation = pgTable('rag_eval_evaluations', {
onDelete: 'cascade',
}),
languageModel: text('language_model').$defaultFn(() => DEFAULT_MODEL),
embeddingModel: text('embedding_model').$defaultFn(() =>
getServerDefaultFilesConfig().getEmbeddingModel(),
),
embeddingModel: text('embedding_model'),

userId: text('user_id').references(() => users.id, { onDelete: 'cascade' }),
...timestamps,
Expand Down
3 changes: 1 addition & 2 deletions src/libs/agent-runtime/bedrock/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,10 @@ export class LobeBedrockAI implements LobeRuntimeAI {
*/
async embeddings(payload: EmbeddingsPayload, options?: EmbeddingsOptions): Promise<Embeddings[]> {
const input = Array.isArray(payload.input) ? payload.input : [payload.input];
const promises = input.map((inputText: string, index: number) =>
const promises = input.map((inputText: string) =>
this.invokeEmbeddingModel(
{
dimensions: payload.dimensions,
index: index,
input: inputText,
model: payload.model,
},
Expand Down
3 changes: 1 addition & 2 deletions src/libs/agent-runtime/ollama/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -92,10 +92,9 @@ export class LobeOllamaAI implements LobeRuntimeAI {

async embeddings(payload: EmbeddingsPayload): Promise<Embeddings[]> {
const input = Array.isArray(payload.input) ? payload.input : [payload.input];
const promises = input.map((inputText: string, index: number) =>
const promises = input.map((inputText: string) =>
this.invokeEmbeddingModel({
dimensions: payload.dimensions,
index: index,
input: inputText,
model: payload.model,
}),
Expand Down
1 change: 0 additions & 1 deletion src/libs/agent-runtime/types/embeddings.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ export interface EmbeddingsPayload {
* supported in `text-embedding-3` and later models.
*/
dimensions?: number;
index?: number;
/**
* Input text to embed, encoded as a string or array of tokens. To embed multiple
* inputs in a single request, pass an array of strings .
Expand Down
34 changes: 5 additions & 29 deletions src/libs/agent-runtime/utils/openaiCompatibleFactory/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -275,46 +275,22 @@ export const LobeOpenAICompatibleFactory = <T extends Record<string, any> = any>
.filter(Boolean) as ChatModelCard[];
}

/**
* Due to the limitation of a maximum of 8k tokens per request in
* the openai interface provided by other vendors
* the current chunked input array size exceeds 8k tokens
* Therefore, a single array with multiple concurrent requests is used.
*/
async embeddings(
payload: EmbeddingsPayload,
options?: EmbeddingsOptions,
): Promise<Embeddings[]> {
try {
const input = Array.isArray(payload.input) ? payload.input : [payload.input];
const promises = input.map((inputText: string) =>
this.invokeEmbeddings(
{
dimensions: payload.dimensions,
input: inputText,
model: payload.model,
},
options,
),
const res = await this.client.embeddings.create(
{ ...payload, user: options?.user },
{ headers: options?.headers, signal: options?.signal },
);
const results = await Promise.all(promises);
return results.flat();

return res.data.map((item) => item.embedding);
} catch (error) {
throw this.handleError(error);
}
}

async invokeEmbeddings(
payload: EmbeddingsPayload,
options?: EmbeddingsOptions,
): Promise<Embeddings[]> {
const res = await this.client.embeddings.create(
{ ...payload, user: options?.user },
{ headers: options?.headers, signal: options?.signal },
);
return res.data.map((item) => item.embedding);
}

async textToImage(payload: TextToImagePayload) {
try {
const res = await this.client.images.generate(payload);
Expand Down
4 changes: 2 additions & 2 deletions src/server/globalConfig/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@ import { knowledgeEnv } from '@/config/knowledge';
import { langfuseEnv } from '@/config/langfuse';
import { enableNextAuth } from '@/const/auth';
import { parseSystemAgent } from '@/server/globalConfig/parseSystemAgent';
import { FilesStore } from '@/server/modules/Files';
import { GlobalServerConfig } from '@/types/serverConfig';

import { genServerLLMConfig } from './genServerLLMConfig';
import { parseAgentConfig } from './parseDefaultAgent';
import { parseFilesConfig } from './parseFilesConfig';

export const getServerGlobalConfig = () => {
const { ACCESS_CODES, DEFAULT_AGENT_CONFIG } = getAppConfig();
Expand Down Expand Up @@ -55,5 +55,5 @@ export const getServerDefaultAgentConfig = () => {
};

export const getServerDefaultFilesConfig = () => {
return new FilesStore(parseSystemAgent(knowledgeEnv.DEFAULT_FILES_CONFIG));
return parseFilesConfig(knowledgeEnv.DEFAULT_FILES_CONFIG);
};
17 changes: 17 additions & 0 deletions src/server/globalConfig/parseFilesConfig.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import { describe, expect, it } from 'vitest';

import { parseFilesConfig } from './parseFilesConfig';

describe('parseFilesConfig', () => {
// 测试embeddings配置是否被正确解析
it('parses embeddings configuration correctly', () => {
const envStr =
'embedding_model=openai/embedding-text-3-large,reranker_model=cohere/rerank-english-v3.0,query_model=full_text';
const expected = {
embedding_model: { provider: 'openai', model: 'embedding-text-3-large' },
reranker_model: { provider: 'cohere', model: 'rerank-english-v3.0' },
query_model: 'full_text',
};
expect(parseFilesConfig(envStr)).toEqual(expected);
});
});
47 changes: 47 additions & 0 deletions src/server/globalConfig/parseFilesConfig.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import { DEFAULT_FILES_CONFIG } from '@/const/settings/knowledge';
import { SystemEmbeddingConfig } from '@/types/knowledgeBase';
import { FilesConfig } from '@/types/user/settings/filesConfig';

const protectedKeys = Object.keys(DEFAULT_FILES_CONFIG);

export const parseFilesConfig = (envString: string = ''): SystemEmbeddingConfig => {
if (!envString) return DEFAULT_FILES_CONFIG;
const config: FilesConfig = {} as any;

// 处理全角逗号和多余空格
let envValue = envString.replaceAll(',', ',').trim();

const pairs = envValue.split(',');

for (const pair of pairs) {
const [key, value] = pair.split('=').map((s) => s.trim());

if (key && value) {
const [provider, ...modelParts] = value.split('/');
const model = modelParts.join('/');

if ((!provider || !model) && key !== 'query_model') {
throw new Error('Missing model or provider value');
}

if (key === 'query_model' && value === '') {
throw new Error('Missing query mode value');
}

if (protectedKeys.includes(key)) {
if (key === 'query_model') {
config.query_model = value;
} else {
config[key as keyof FilesConfig] = {
model: model.trim(),
provider: provider.trim(),
} as any;
}
}
} else {
throw new Error('Invalid environment variable format');
}
}

return config;
};
14 changes: 5 additions & 9 deletions src/server/globalConfig/parseSystemAgent.ts
cookieY marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -1,16 +1,12 @@
import { DEFAULT_SYSTEM_AGENT_CONFIG } from '@/const/settings';
import { SYSTEM_EMBEDDING_CONFIG } from '@/const/settings/knowledge';
import { SystemEmbeddingConfig } from '@/types/knowledgeBase';
import { UserSystemAgentConfig } from '@/types/user/settings';

const protectedKeys = Object.keys({ ...DEFAULT_SYSTEM_AGENT_CONFIG, ...SYSTEM_EMBEDDING_CONFIG });
const protectedKeys = Object.keys(DEFAULT_SYSTEM_AGENT_CONFIG);

export interface CommonSystemConfig extends UserSystemAgentConfig, SystemEmbeddingConfig {}

export const parseSystemAgent = (envString: string = ''): Partial<CommonSystemConfig> => {
export const parseSystemAgent = (envString: string = ''): Partial<UserSystemAgentConfig> => {
if (!envString) return {};

const config: Partial<CommonSystemConfig> = {};
const config: Partial<UserSystemAgentConfig> = {};

// 处理全角逗号和多余空格
let envValue = envString.replaceAll(',', ',').trim();
Expand All @@ -24,12 +20,12 @@ export const parseSystemAgent = (envString: string = ''): Partial<CommonSystemCo
const [provider, ...modelParts] = value.split('/');
const model = modelParts.join('/');

if (!provider) {
if (!provider || !model) {
throw new Error('Missing model or provider value');
}

if (protectedKeys.includes(key)) {
config[key as keyof CommonSystemConfig] = {
config[key as keyof UserSystemAgentConfig] = {
enabled: key === 'queryRewrite' ? true : undefined,
model: model.trim(),
provider: provider.trim(),
Expand Down
42 changes: 0 additions & 42 deletions src/server/modules/Files/index.test.ts

This file was deleted.

26 changes: 0 additions & 26 deletions src/server/modules/Files/index.ts
cookieY marked this conversation as resolved.
Outdated
Show resolved Hide resolved

This file was deleted.

4 changes: 2 additions & 2 deletions src/server/routers/async/file.ts
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ export const fileRouter = router({

const asyncTask = await ctx.asyncTaskModel.findById(input.taskId);

const model = getServerDefaultFilesConfig().getEmbeddingModel();
const provider = getServerDefaultFilesConfig().getEmbeddingProvider();
const model = getServerDefaultFilesConfig().embedding_model.model;
cookieY marked this conversation as resolved.
Show resolved Hide resolved
const provider = getServerDefaultFilesConfig().embedding_model.provider;

if (!asyncTask) throw new TRPCError({ code: 'BAD_REQUEST', message: 'Async Task not found' });

Expand Down
8 changes: 4 additions & 4 deletions src/server/routers/lambda/chunk.ts
cookieY marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,8 @@ export const chunkRouter = router({
)
.mutation(async ({ ctx, input }) => {
console.time('embedding');
const model = getServerDefaultFilesConfig().getEmbeddingModel();
const provider = getServerDefaultFilesConfig().getEmbeddingProvider();
const model = getServerDefaultFilesConfig().embedding_model.model;
const provider = getServerDefaultFilesConfig().embedding_model.provider;
const agentRuntime = await initAgentRuntimeWithUserPayload(provider, ctx.jwtPayload);

const embeddings = await agentRuntime.embeddings({
Expand All @@ -127,8 +127,8 @@ export const chunkRouter = router({
.input(SemanticSearchSchema)
.mutation(async ({ ctx, input }) => {
const item = await ctx.messageModel.findMessageQueriesById(input.messageId);
const model = getServerDefaultFilesConfig().getEmbeddingModel();
const provider = getServerDefaultFilesConfig().getEmbeddingProvider();
const model = getServerDefaultFilesConfig().embedding_model.model;
const provider = getServerDefaultFilesConfig().embedding_model.provider;
let embedding: number[];
let ragQueryId: string;
console.log('embeddingProvider:', provider);
cookieY marked this conversation as resolved.
Show resolved Hide resolved
Expand Down
Loading