diff --git a/.env.example b/.env.example index ec6de216f5b2..5c8754675c1a 100644 --- a/.env.example +++ b/.env.example @@ -208,4 +208,7 @@ OPENAI_API_KEY=sk-xxxxxxxxx # use `openssl rand -base64 32` to generate a key for the encryption of the database # we use this key to encrypt the user api key -# KEY_VAULTS_SECRET=xxxxx/xxxxxxxxxxxxxx= +#KEY_VAULTS_SECRET=xxxxx/xxxxxxxxxxxxxx= + +# Specify the Embedding model and Reranker model(unImplemented) +# DEFAULT_FILES_CONFIG="embedding_model=openai/embedding-text-3-small,reranker_model=cohere/rerank-english-v3.0,query_mode=full_text" diff --git a/docs/self-hosting/advanced/knowledge-base.mdx b/docs/self-hosting/advanced/knowledge-base.mdx index f1bbd2f24c0a..6b54f28e08ca 100644 --- a/docs/self-hosting/advanced/knowledge-base.mdx +++ b/docs/self-hosting/advanced/knowledge-base.mdx @@ -62,3 +62,12 @@ Unstructured.io is a powerful document processing tool. - **Note**: Evaluate processing needs based on document complexity By correctly configuring and integrating these core components, you can build a powerful and efficient knowledge base system for LobeChat. Each component plays a crucial role in the overall architecture, supporting advanced document management and intelligent retrieval functions. + +### 5. Custom Embedding + +- **Purpose**: Use different Embedding generate vector representations for semantic search +- **Options**: support model provider list: zhipu/github/openai/bedrock/ollama +- **Deployment Tip**: Used to configure the default Embedding model +``` +environment: DEFAULT_FILES_CONFIG=embedding_model=openai/embedding-text-3-small +``` diff --git a/docs/self-hosting/advanced/knowledge-base.zh-CN.mdx b/docs/self-hosting/advanced/knowledge-base.zh-CN.mdx index 59cecf36dfce..7504b089bf17 100644 --- a/docs/self-hosting/advanced/knowledge-base.zh-CN.mdx +++ b/docs/self-hosting/advanced/knowledge-base.zh-CN.mdx @@ -60,3 +60,12 @@ Unstructured.io 是一个强大的文档处理工具。 - **注意事项**:评估处理需求,根据文档复杂度决定是否部署 通过正确配置和集成这些核心组件,您可以为 LobeChat 构建一个强大、高效的知识库系统。每个组件都在整体架构中扮演着关键角色,共同支持高级的文档管理和智能检索功能。 + +### 5. 自定义 Embedding(可选) + +- **用途**: 使用不同的嵌入模型(Embedding)生成文本的向量表示,用于语义搜索 +- **选项**: 支持的模型提供商:zhipu/github/openai/bedrock/ollama +- **部署建议**: 使用环境变量配置默认嵌入模型 +``` +environment: DEFAULT_FILES_CONFIG=embedding_model=openai/embedding-text-3-small +``` diff --git a/src/config/knowledge.ts b/src/config/knowledge.ts index 9b38768f1df7..02f79a5af19e 100644 --- a/src/config/knowledge.ts +++ b/src/config/knowledge.ts @@ -4,10 +4,12 @@ import { z } from 'zod'; export const getKnowledgeConfig = () => { return createEnv({ runtimeEnv: { + DEFAULT_FILES_CONFIG: process.env.DEFAULT_FILES_CONFIG, UNSTRUCTURED_API_KEY: process.env.UNSTRUCTURED_API_KEY, UNSTRUCTURED_SERVER_URL: process.env.UNSTRUCTURED_SERVER_URL, }, server: { + DEFAULT_FILES_CONFIG: z.string().optional(), UNSTRUCTURED_API_KEY: z.string().optional(), UNSTRUCTURED_SERVER_URL: z.string().optional(), }, diff --git a/src/const/settings/knowledge.ts b/src/const/settings/knowledge.ts new file mode 100644 index 000000000000..7593358525ce --- /dev/null +++ b/src/const/settings/knowledge.ts @@ -0,0 +1,25 @@ +import { FilesConfig, FilesConfigItem } from '@/types/user/settings/filesConfig'; + +import { + DEFAULT_EMBEDDING_MODEL, + DEFAULT_PROVIDER, + DEFAULT_RERANK_MODEL, + DEFAULT_RERANK_PROVIDER, + DEFAULT_RERANK_QUERY_MODE, +} from './llm'; + +export const DEFAULT_FILE_EMBEDDING_MODEL_ITEM: FilesConfigItem = { + model: DEFAULT_EMBEDDING_MODEL, + provider: DEFAULT_PROVIDER, +}; + +export const DEFAULT_FILE_RERANK_MODEL_ITEM: FilesConfigItem = { + model: DEFAULT_RERANK_MODEL, + provider: DEFAULT_RERANK_PROVIDER, +}; + +export const DEFAULT_FILES_CONFIG: FilesConfig = { + embeddingModel: DEFAULT_FILE_EMBEDDING_MODEL_ITEM, + queryModel: DEFAULT_RERANK_QUERY_MODE, + rerankerModel: DEFAULT_FILE_RERANK_MODEL_ITEM, +}; diff --git a/src/const/settings/llm.ts b/src/const/settings/llm.ts index d7a75d72a780..da2abdf43d41 100644 --- a/src/const/settings/llm.ts +++ b/src/const/settings/llm.ts @@ -12,6 +12,12 @@ export const DEFAULT_LLM_CONFIG = genUserLLMConfig({ }); export const DEFAULT_MODEL = 'gpt-4o-mini'; + export const DEFAULT_EMBEDDING_MODEL = 'text-embedding-3-small'; +export const DEFAULT_EMBEDDING_PROVIDER = ModelProvider.OpenAI; + +export const DEFAULT_RERANK_MODEL = 'rerank-english-v3.0'; +export const DEFAULT_RERANK_PROVIDER = 'cohere'; +export const DEFAULT_RERANK_QUERY_MODE = 'full_text'; export const DEFAULT_PROVIDER = ModelProvider.OpenAI; diff --git a/src/database/schemas/ragEvals.ts b/src/database/schemas/ragEvals.ts index 6f1f44301bdb..fa04353a8c2f 100644 --- a/src/database/schemas/ragEvals.ts +++ b/src/database/schemas/ragEvals.ts @@ -1,7 +1,7 @@ /* eslint-disable sort-keys-fix/sort-keys-fix */ import { integer, jsonb, pgTable, text, uuid } from 'drizzle-orm/pg-core'; -import { DEFAULT_EMBEDDING_MODEL, DEFAULT_MODEL } from '@/const/settings'; +import { DEFAULT_MODEL } from '@/const/settings'; import { EvalEvaluationStatus } from '@/types/eval'; import { timestamps } from './_helpers'; @@ -60,7 +60,7 @@ export const evalEvaluation = pgTable('rag_eval_evaluations', { onDelete: 'cascade', }), languageModel: text('language_model').$defaultFn(() => DEFAULT_MODEL), - embeddingModel: text('embedding_model').$defaultFn(() => DEFAULT_EMBEDDING_MODEL), + embeddingModel: text('embedding_model'), userId: text('user_id').references(() => users.id, { onDelete: 'cascade' }), ...timestamps, diff --git a/src/libs/agent-runtime/bedrock/index.ts b/src/libs/agent-runtime/bedrock/index.ts index a0668e273cc0..2e8ddeec0290 100644 --- a/src/libs/agent-runtime/bedrock/index.ts +++ b/src/libs/agent-runtime/bedrock/index.ts @@ -1,12 +1,20 @@ import { BedrockRuntimeClient, + InvokeModelCommand, InvokeModelWithResponseStreamCommand, } from '@aws-sdk/client-bedrock-runtime'; import { experimental_buildLlama2Prompt } from 'ai/prompts'; import { LobeRuntimeAI } from '../BaseAI'; import { AgentRuntimeErrorType } from '../error'; -import { ChatCompetitionOptions, ChatStreamPayload, ModelProvider } from '../types'; +import { + ChatCompetitionOptions, + ChatStreamPayload, + Embeddings, + EmbeddingsOptions, + EmbeddingsPayload, + ModelProvider, +} from '../types'; import { buildAnthropicMessages, buildAnthropicTools } from '../utils/anthropicHelpers'; import { AgentRuntimeError } from '../utils/createError'; import { debugStream } from '../utils/debugStream'; @@ -32,9 +40,7 @@ export class LobeBedrockAI implements LobeRuntimeAI { constructor({ region, accessKeyId, accessKeySecret, sessionToken }: LobeBedrockAIParams = {}) { if (!(accessKeyId && accessKeySecret)) throw AgentRuntimeError.createError(AgentRuntimeErrorType.InvalidBedrockCredentials); - this.region = region ?? 'us-east-1'; - this.client = new BedrockRuntimeClient({ credentials: { accessKeyId: accessKeyId, @@ -50,6 +56,61 @@ export class LobeBedrockAI implements LobeRuntimeAI { return this.invokeClaudeModel(payload, options); } + /** + * Supports the Amazon Titan Text models series. + * Cohere Embed models are not supported + * because the current text size per request + * exceeds the maximum 2048 characters limit + * for a single request for this series of models. + * [bedrock embed guide] https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-embed.html + */ + async embeddings(payload: EmbeddingsPayload, options?: EmbeddingsOptions): Promise { + const input = Array.isArray(payload.input) ? payload.input : [payload.input]; + const promises = input.map((inputText: string) => + this.invokeEmbeddingModel( + { + dimensions: payload.dimensions, + input: inputText, + model: payload.model, + }, + options, + ), + ); + return Promise.all(promises); + } + + private invokeEmbeddingModel = async ( + payload: EmbeddingsPayload, + options?: EmbeddingsOptions, + ): Promise => { + const command = new InvokeModelCommand({ + accept: 'application/json', + body: JSON.stringify({ + dimensions: payload.dimensions, + inputText: payload.input, + normalize: true, + }), + contentType: 'application/json', + modelId: payload.model, + }); + try { + const res = await this.client.send(command, { abortSignal: options?.signal }); + const responseBody = JSON.parse(new TextDecoder().decode(res.body)); + return responseBody.embedding; + } catch (e) { + const err = e as Error & { $metadata: any }; + throw AgentRuntimeError.chat({ + error: { + body: err.$metadata, + message: err.message, + type: err.name, + }, + errorType: AgentRuntimeErrorType.ProviderBizError, + provider: ModelProvider.Bedrock, + region: this.region, + }); + } + }; private invokeClaudeModel = async ( payload: ChatStreamPayload, diff --git a/src/libs/agent-runtime/ollama/index.ts b/src/libs/agent-runtime/ollama/index.ts index 47b6023caf64..f61ccd90a061 100644 --- a/src/libs/agent-runtime/ollama/index.ts +++ b/src/libs/agent-runtime/ollama/index.ts @@ -6,7 +6,13 @@ import { ChatModelCard } from '@/types/llm'; import { LobeRuntimeAI } from '../BaseAI'; import { AgentRuntimeErrorType } from '../error'; -import { ChatCompetitionOptions, ChatStreamPayload, ModelProvider } from '../types'; +import { + ChatCompetitionOptions, + ChatStreamPayload, + Embeddings, + EmbeddingsPayload, + ModelProvider, +} from '../types'; import { AgentRuntimeError } from '../utils/createError'; import { debugStream } from '../utils/debugStream'; import { StreamingResponse } from '../utils/response'; @@ -84,6 +90,18 @@ export class LobeOllamaAI implements LobeRuntimeAI { } } + async embeddings(payload: EmbeddingsPayload): Promise { + const input = Array.isArray(payload.input) ? payload.input : [payload.input]; + const promises = input.map((inputText: string) => + this.invokeEmbeddingModel({ + dimensions: payload.dimensions, + input: inputText, + model: payload.model, + }), + ); + return await Promise.all(promises); + } + async models(): Promise { const list = await this.client.list(); return list.models.map((model) => ({ @@ -91,6 +109,24 @@ export class LobeOllamaAI implements LobeRuntimeAI { })); } + private invokeEmbeddingModel = async (payload: EmbeddingsPayload): Promise => { + try { + const responseBody = await this.client.embeddings({ + model: payload.model, + prompt: payload.input as string, + }); + return responseBody.embedding; + } catch (error) { + const e = error as { message: string; name: string; status_code: number }; + + throw AgentRuntimeError.chat({ + error: { message: e.message, name: e.name, status_code: e.status_code }, + errorType: AgentRuntimeErrorType.OllamaBizError, + provider: ModelProvider.Ollama, + }); + } + }; + private buildOllamaMessages(messages: OpenAIChatMessage[]) { return messages.map((message) => this.convertContentToOllamaMessage(message)); } diff --git a/src/server/globalConfig/index.ts b/src/server/globalConfig/index.ts index b2af9f616315..9869618decd1 100644 --- a/src/server/globalConfig/index.ts +++ b/src/server/globalConfig/index.ts @@ -1,6 +1,7 @@ import { appEnv, getAppConfig } from '@/config/app'; import { authEnv } from '@/config/auth'; import { fileEnv } from '@/config/file'; +import { knowledgeEnv } from '@/config/knowledge'; import { langfuseEnv } from '@/config/langfuse'; import { enableNextAuth } from '@/const/auth'; import { parseSystemAgent } from '@/server/globalConfig/parseSystemAgent'; @@ -9,6 +10,7 @@ import { GlobalServerConfig } from '@/types/serverConfig'; import { genServerLLMConfig } from './_deprecated'; import { genServerAiProvidersConfig } from './genServerAiProviderConfig'; import { parseAgentConfig } from './parseDefaultAgent'; +import { parseFilesConfig } from './parseFilesConfig'; export const getServerGlobalConfig = () => { const { ACCESS_CODES, DEFAULT_AGENT_CONFIG } = getAppConfig(); @@ -73,3 +75,7 @@ export const getServerDefaultAgentConfig = () => { return parseAgentConfig(DEFAULT_AGENT_CONFIG) || {}; }; + +export const getServerDefaultFilesConfig = () => { + return parseFilesConfig(knowledgeEnv.DEFAULT_FILES_CONFIG); +}; diff --git a/src/server/globalConfig/parseFilesConfig.test.ts b/src/server/globalConfig/parseFilesConfig.test.ts new file mode 100644 index 000000000000..c001ead1748c --- /dev/null +++ b/src/server/globalConfig/parseFilesConfig.test.ts @@ -0,0 +1,17 @@ +import { describe, expect, it } from 'vitest'; + +import { parseFilesConfig } from './parseFilesConfig'; + +describe('parseFilesConfig', () => { + // 测试embeddings配置是否被正确解析 + it('parses embeddings configuration correctly', () => { + const envStr = + 'embedding_model=openai/embedding-text-3-large,reranker_model=cohere/rerank-english-v3.0,query_model=full_text'; + const expected = { + embeddingModel: { provider: 'openai', model: 'embedding-text-3-large' }, + rerankerModel: { provider: 'cohere', model: 'rerank-english-v3.0' }, + queryModel: 'full_text', + }; + expect(parseFilesConfig(envStr)).toEqual(expected); + }); +}); diff --git a/src/server/globalConfig/parseFilesConfig.ts b/src/server/globalConfig/parseFilesConfig.ts new file mode 100644 index 000000000000..7175b3206448 --- /dev/null +++ b/src/server/globalConfig/parseFilesConfig.ts @@ -0,0 +1,57 @@ +import { DEFAULT_FILES_CONFIG } from '@/const/settings/knowledge'; +import { SystemEmbeddingConfig } from '@/types/knowledgeBase'; +import { FilesConfig } from '@/types/user/settings/filesConfig'; + +const protectedKeys = Object.keys({ + embedding_model: null, + query_model: null, + reranker_model: null, +}); + +export const parseFilesConfig = (envString: string = ''): SystemEmbeddingConfig => { + if (!envString) return DEFAULT_FILES_CONFIG; + const config: FilesConfig = {} as any; + + // 处理全角逗号和多余空格 + let envValue = envString.replaceAll(',', ',').trim(); + + const pairs = envValue.split(','); + + for (const pair of pairs) { + const [key, value] = pair.split('=').map((s) => s.trim()); + + if (key && value) { + const [provider, ...modelParts] = value.split('/'); + const model = modelParts.join('/'); + + if ((!provider || !model) && key !== 'query_model') { + throw new Error('Missing model or provider value'); + } + + if (key === 'query_model' && value === '') { + throw new Error('Missing query mode value'); + } + + if (protectedKeys.includes(key)) { + switch (key) { + case 'embedding_model': { + config.embeddingModel = { model: model.trim(), provider: provider.trim() }; + break; + } + case 'reranker_model': { + config.rerankerModel = { model: model.trim(), provider: provider.trim() }; + break; + } + case 'query_model': { + config.queryModel = value; + break; + } + } + } + } else { + throw new Error('Invalid environment variable format'); + } + } + + return config; +}; diff --git a/src/server/routers/async/file.ts b/src/server/routers/async/file.ts index dfe3b346087f..e67621b4d26a 100644 --- a/src/server/routers/async/file.ts +++ b/src/server/routers/async/file.ts @@ -5,15 +5,15 @@ import { z } from 'zod'; import { serverDBEnv } from '@/config/db'; import { fileEnv } from '@/config/file'; -import { DEFAULT_EMBEDDING_MODEL } from '@/const/settings'; +import { DEFAULT_FILE_EMBEDDING_MODEL_ITEM } from '@/const/settings/knowledge'; import { NewChunkItem, NewEmbeddingsItem } from '@/database/schemas'; import { serverDB } from '@/database/server'; import { ASYNC_TASK_TIMEOUT, AsyncTaskModel } from '@/database/server/models/asyncTask'; import { ChunkModel } from '@/database/server/models/chunk'; import { EmbeddingModel } from '@/database/server/models/embedding'; import { FileModel } from '@/database/server/models/file'; -import { ModelProvider } from '@/libs/agent-runtime'; import { asyncAuthedProcedure, asyncRouter as router } from '@/libs/trpc/async'; +import { getServerDefaultFilesConfig } from '@/server/globalConfig'; import { initAgentRuntimeWithUserPayload } from '@/server/modules/AgentRuntime'; import { S3 } from '@/server/modules/S3'; import { ChunkService } from '@/server/services/chunk'; @@ -44,7 +44,6 @@ export const fileRouter = router({ .input( z.object({ fileId: z.string(), - model: z.string().default(DEFAULT_EMBEDDING_MODEL), taskId: z.string(), }), ) @@ -57,6 +56,9 @@ export const fileRouter = router({ const asyncTask = await ctx.asyncTaskModel.findById(input.taskId); + const { model, provider } = + getServerDefaultFilesConfig().embeddingModel || DEFAULT_FILE_EMBEDDING_MODEL_ITEM; + if (!asyncTask) throw new TRPCError({ code: 'BAD_REQUEST', message: 'Async Task not found' }); try { @@ -84,13 +86,12 @@ export const fileRouter = router({ const chunks = await ctx.chunkModel.getChunksTextByFileId(input.fileId); const requestArray = chunk(chunks, CHUNK_SIZE); - try { await pMap( requestArray, async (chunks, index) => { const agentRuntime = await initAgentRuntimeWithUserPayload( - ModelProvider.OpenAI, + provider, ctx.jwtPayload, ); @@ -98,11 +99,10 @@ export const fileRouter = router({ console.log(`执行第 ${number} 个任务`); console.time(`任务[${number}]: embeddings`); - const embeddings = await agentRuntime.embeddings({ dimensions: 1024, input: chunks.map((c) => c.text), - model: input.model, + model, }); console.timeEnd(`任务[${number}]: embeddings`); @@ -111,7 +111,7 @@ export const fileRouter = router({ chunkId: chunks[idx].id, embeddings: e, fileId: input.fileId, - model: input.model, + model, })) || []; console.time(`任务[${number}]: insert db`); diff --git a/src/server/routers/lambda/chunk.ts b/src/server/routers/lambda/chunk.ts index c663c74d31e6..e83f0653a8c3 100644 --- a/src/server/routers/lambda/chunk.ts +++ b/src/server/routers/lambda/chunk.ts @@ -1,7 +1,7 @@ import { inArray } from 'drizzle-orm/expressions'; import { z } from 'zod'; -import { DEFAULT_EMBEDDING_MODEL } from '@/const/settings'; +import { DEFAULT_FILE_EMBEDDING_MODEL_ITEM } from '@/const/settings/knowledge'; import { knowledgeBaseFiles } from '@/database/schemas'; import { serverDB } from '@/database/server'; import { AsyncTaskModel } from '@/database/server/models/asyncTask'; @@ -9,9 +9,9 @@ import { ChunkModel } from '@/database/server/models/chunk'; import { EmbeddingModel } from '@/database/server/models/embedding'; import { FileModel } from '@/database/server/models/file'; import { MessageModel } from '@/database/server/models/message'; -import { ModelProvider } from '@/libs/agent-runtime'; import { authedProcedure, router } from '@/libs/trpc'; import { keyVaults } from '@/libs/trpc/middleware/keyVaults'; +import { getServerDefaultFilesConfig } from '@/server/globalConfig'; import { initAgentRuntimeWithUserPayload } from '@/server/modules/AgentRuntime'; import { ChunkService } from '@/server/services/chunk'; import { SemanticSearchSchema } from '@/types/rag'; @@ -101,21 +101,18 @@ export const chunkRouter = router({ .input( z.object({ fileIds: z.array(z.string()).optional(), - model: z.string().default(DEFAULT_EMBEDDING_MODEL), query: z.string(), }), ) .mutation(async ({ ctx, input }) => { - console.time('embedding'); - const agentRuntime = await initAgentRuntimeWithUserPayload( - ModelProvider.OpenAI, - ctx.jwtPayload, - ); + const { model, provider } = + getServerDefaultFilesConfig().embeddingModel || DEFAULT_FILE_EMBEDDING_MODEL_ITEM; + const agentRuntime = await initAgentRuntimeWithUserPayload(provider, ctx.jwtPayload); const embeddings = await agentRuntime.embeddings({ dimensions: 1024, input: input.query, - model: input.model, + model, }); console.timeEnd('embedding'); @@ -130,27 +127,25 @@ export const chunkRouter = router({ .input(SemanticSearchSchema) .mutation(async ({ ctx, input }) => { const item = await ctx.messageModel.findMessageQueriesById(input.messageId); + const { model, provider } = + getServerDefaultFilesConfig().embeddingModel || DEFAULT_FILE_EMBEDDING_MODEL_ITEM; let embedding: number[]; let ragQueryId: string; - // if there is no message rag or it's embeddings, then we need to create one if (!item || !item.embeddings) { // TODO: need to support customize - const agentRuntime = await initAgentRuntimeWithUserPayload( - ModelProvider.OpenAI, - ctx.jwtPayload, - ); + const agentRuntime = await initAgentRuntimeWithUserPayload(provider, ctx.jwtPayload); const embeddings = await agentRuntime.embeddings({ dimensions: 1024, input: input.rewriteQuery, - model: input.model || DEFAULT_EMBEDDING_MODEL, + model, }); embedding = embeddings![0]; const embeddingsId = await ctx.embeddingModel.create({ embeddings: embedding, - model: input.model, + model, }); const result = await ctx.messageModel.createMessageQuery({ @@ -182,6 +177,7 @@ export const chunkRouter = router({ fileIds: finalFileIds, query: input.rewriteQuery, }); + // TODO: need to rerank the chunks console.timeEnd('semanticSearch'); return { chunks, queryId: ragQueryId }; diff --git a/src/types/knowledgeBase/index.ts b/src/types/knowledgeBase/index.ts index 16e2a271391c..20355071c9ff 100644 --- a/src/types/knowledgeBase/index.ts +++ b/src/types/knowledgeBase/index.ts @@ -1,3 +1,5 @@ +import { FilesConfigItem } from '../user/settings/filesConfig'; + export enum KnowledgeBaseTabs { Files = 'files', Settings = 'Settings', @@ -43,3 +45,9 @@ export interface KnowledgeItem { name: string; type: KnowledgeType; } + +export interface SystemEmbeddingConfig { + embeddingModel: FilesConfigItem; + queryModel: string; + rerankerModel: FilesConfigItem; +} diff --git a/src/types/user/settings/filesConfig.ts b/src/types/user/settings/filesConfig.ts new file mode 100644 index 000000000000..ecfb3b6e3290 --- /dev/null +++ b/src/types/user/settings/filesConfig.ts @@ -0,0 +1,9 @@ +export interface FilesConfigItem { + model: string; + provider: string; +} +export interface FilesConfig { + embeddingModel: FilesConfigItem; + queryModel: string; + rerankerModel: FilesConfigItem; +}