Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

✨ feat: 允许用户自行定义 Embedding 模型 #5177

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
5 changes: 4 additions & 1 deletion .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -208,4 +208,7 @@ OPENAI_API_KEY=sk-xxxxxxxxx

# use `openssl rand -base64 32` to generate a key for the encryption of the database
# we use this key to encrypt the user api key
# KEY_VAULTS_SECRET=xxxxx/xxxxxxxxxxxxxx=
#KEY_VAULTS_SECRET=xxxxx/xxxxxxxxxxxxxx=

# Specify the Embedding model and Reranker model(unImplemented)
# DEFAULT_FILES_CONFIG="embedding_model=openai/embedding-text-3-small,reranker_model=cohere/rerank-english-v3.0,query_mode=full_text"
9 changes: 9 additions & 0 deletions docs/self-hosting/advanced/knowledge-base.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,12 @@ Unstructured.io is a powerful document processing tool.
- **Note**: Evaluate processing needs based on document complexity

By correctly configuring and integrating these core components, you can build a powerful and efficient knowledge base system for LobeChat. Each component plays a crucial role in the overall architecture, supporting advanced document management and intelligent retrieval functions.

### 5. Custom Embedding

- **Purpose**: Use different Embedding generate vector representations for semantic search
- **Options**: support model provider list: zhipu/github/openai/bedrock/ollama
- **Deployment Tip**: Used to configure the default Embedding model
```
environment: DEFAULT_FILES_CONFIG=embedding_model=openai/embedding-text-3-small
```
9 changes: 9 additions & 0 deletions docs/self-hosting/advanced/knowledge-base.zh-CN.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,12 @@ Unstructured.io 是一个强大的文档处理工具。
- **注意事项**:评估处理需求,根据文档复杂度决定是否部署

通过正确配置和集成这些核心组件,您可以为 LobeChat 构建一个强大、高效的知识库系统。每个组件都在整体架构中扮演着关键角色,共同支持高级的文档管理和智能检索功能。

### 5. 自定义 Embedding(可选)

- **用途**: 使用不同的嵌入模型(Embedding)生成文本的向量表示,用于语义搜索
- **选项**: 支持的模型提供商:zhipu/github/openai/bedrock/ollama
- **部署建议**: 使用环境变量配置默认嵌入模型
```
environment: DEFAULT_FILES_CONFIG=embedding_model=openai/embedding-text-3-small
```
2 changes: 2 additions & 0 deletions src/config/knowledge.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@ import { z } from 'zod';
export const getKnowledgeConfig = () => {
return createEnv({
runtimeEnv: {
DEFAULT_FILES_CONFIG: process.env.DEFAULT_FILES_CONFIG,
UNSTRUCTURED_API_KEY: process.env.UNSTRUCTURED_API_KEY,
UNSTRUCTURED_SERVER_URL: process.env.UNSTRUCTURED_SERVER_URL,
},
server: {
DEFAULT_FILES_CONFIG: z.string().optional(),
UNSTRUCTURED_API_KEY: z.string().optional(),
UNSTRUCTURED_SERVER_URL: z.string().optional(),
},
Expand Down
25 changes: 25 additions & 0 deletions src/const/settings/knowledge.ts
cookieY marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import { FilesConfig, FilesConfigItem } from '@/types/user/settings/filesConfig';

import {
DEFAULT_EMBEDDING_MODEL,
DEFAULT_PROVIDER,
DEFAULT_RERANK_MODEL,
DEFAULT_RERANK_PROVIDER,
DEFAULT_RERANK_QUERY_MODE,
} from './llm';

export const DEFAULT_FILE_EMBEDDING_MODEL_ITEM: FilesConfigItem = {
model: DEFAULT_EMBEDDING_MODEL,
provider: DEFAULT_PROVIDER,
};

export const DEFAULT_FILE_RERANK_MODEL_ITEM: FilesConfigItem = {
model: DEFAULT_RERANK_MODEL,
provider: DEFAULT_RERANK_PROVIDER,
};

export const DEFAULT_FILES_CONFIG: FilesConfig = {
embeddingModel: DEFAULT_FILE_EMBEDDING_MODEL_ITEM,
queryModel: DEFAULT_RERANK_QUERY_MODE,
rerankerModel: DEFAULT_FILE_RERANK_MODEL_ITEM,
};
6 changes: 6 additions & 0 deletions src/const/settings/llm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,12 @@ export const DEFAULT_LLM_CONFIG = genUserLLMConfig({
});

export const DEFAULT_MODEL = 'gpt-4o-mini';

export const DEFAULT_EMBEDDING_MODEL = 'text-embedding-3-small';
export const DEFAULT_EMBEDDING_PROVIDER = ModelProvider.OpenAI;

export const DEFAULT_RERANK_MODEL = 'rerank-english-v3.0';
export const DEFAULT_RERANK_PROVIDER = 'cohere';
export const DEFAULT_RERANK_QUERY_MODE = 'full_text';

export const DEFAULT_PROVIDER = ModelProvider.OpenAI;
4 changes: 2 additions & 2 deletions src/database/schemas/ragEvals.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/* eslint-disable sort-keys-fix/sort-keys-fix */
import { integer, jsonb, pgTable, text, uuid } from 'drizzle-orm/pg-core';

import { DEFAULT_EMBEDDING_MODEL, DEFAULT_MODEL } from '@/const/settings';
import { DEFAULT_MODEL } from '@/const/settings';
import { EvalEvaluationStatus } from '@/types/eval';

import { timestamps } from './_helpers';
Expand Down Expand Up @@ -60,7 +60,7 @@ export const evalEvaluation = pgTable('rag_eval_evaluations', {
onDelete: 'cascade',
}),
languageModel: text('language_model').$defaultFn(() => DEFAULT_MODEL),
embeddingModel: text('embedding_model').$defaultFn(() => DEFAULT_EMBEDDING_MODEL),
embeddingModel: text('embedding_model'),

userId: text('user_id').references(() => users.id, { onDelete: 'cascade' }),
...timestamps,
Expand Down
67 changes: 64 additions & 3 deletions src/libs/agent-runtime/bedrock/index.ts
Original file line number Diff line number Diff line change
@@ -1,12 +1,20 @@
import {
BedrockRuntimeClient,
InvokeModelCommand,
InvokeModelWithResponseStreamCommand,
} from '@aws-sdk/client-bedrock-runtime';
import { experimental_buildLlama2Prompt } from 'ai/prompts';

import { LobeRuntimeAI } from '../BaseAI';
import { AgentRuntimeErrorType } from '../error';
import { ChatCompetitionOptions, ChatStreamPayload, ModelProvider } from '../types';
import {
ChatCompetitionOptions,
ChatStreamPayload,
Embeddings,
EmbeddingsOptions,
EmbeddingsPayload,
ModelProvider,
} from '../types';
import { buildAnthropicMessages, buildAnthropicTools } from '../utils/anthropicHelpers';
import { AgentRuntimeError } from '../utils/createError';
import { debugStream } from '../utils/debugStream';
Expand All @@ -32,9 +40,7 @@ export class LobeBedrockAI implements LobeRuntimeAI {
constructor({ region, accessKeyId, accessKeySecret, sessionToken }: LobeBedrockAIParams = {}) {
if (!(accessKeyId && accessKeySecret))
throw AgentRuntimeError.createError(AgentRuntimeErrorType.InvalidBedrockCredentials);

this.region = region ?? 'us-east-1';

this.client = new BedrockRuntimeClient({
credentials: {
accessKeyId: accessKeyId,
Expand All @@ -50,6 +56,61 @@ export class LobeBedrockAI implements LobeRuntimeAI {

return this.invokeClaudeModel(payload, options);
}
/**
* Supports the Amazon Titan Text models series.
* Cohere Embed models are not supported
* because the current text size per request
* exceeds the maximum 2048 characters limit
* for a single request for this series of models.
* [bedrock embed guide] https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-embed.html
*/
async embeddings(payload: EmbeddingsPayload, options?: EmbeddingsOptions): Promise<Embeddings[]> {
const input = Array.isArray(payload.input) ? payload.input : [payload.input];
const promises = input.map((inputText: string) =>
this.invokeEmbeddingModel(
{
dimensions: payload.dimensions,
input: inputText,
model: payload.model,
},
options,
),
);
return Promise.all(promises);
}

private invokeEmbeddingModel = async (
payload: EmbeddingsPayload,
options?: EmbeddingsOptions,
): Promise<Embeddings> => {
const command = new InvokeModelCommand({
accept: 'application/json',
body: JSON.stringify({
dimensions: payload.dimensions,
inputText: payload.input,
normalize: true,
}),
contentType: 'application/json',
modelId: payload.model,
});
try {
const res = await this.client.send(command, { abortSignal: options?.signal });
const responseBody = JSON.parse(new TextDecoder().decode(res.body));
return responseBody.embedding;
} catch (e) {
const err = e as Error & { $metadata: any };
throw AgentRuntimeError.chat({
error: {
body: err.$metadata,
message: err.message,
type: err.name,
},
errorType: AgentRuntimeErrorType.ProviderBizError,
provider: ModelProvider.Bedrock,
region: this.region,
});
}
};

private invokeClaudeModel = async (
payload: ChatStreamPayload,
Expand Down
38 changes: 37 additions & 1 deletion src/libs/agent-runtime/ollama/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,13 @@ import { ChatModelCard } from '@/types/llm';

import { LobeRuntimeAI } from '../BaseAI';
import { AgentRuntimeErrorType } from '../error';
import { ChatCompetitionOptions, ChatStreamPayload, ModelProvider } from '../types';
import {
ChatCompetitionOptions,
ChatStreamPayload,
Embeddings,
EmbeddingsPayload,
ModelProvider,
} from '../types';
import { AgentRuntimeError } from '../utils/createError';
import { debugStream } from '../utils/debugStream';
import { StreamingResponse } from '../utils/response';
Expand Down Expand Up @@ -84,13 +90,43 @@ export class LobeOllamaAI implements LobeRuntimeAI {
}
}

async embeddings(payload: EmbeddingsPayload): Promise<Embeddings[]> {
const input = Array.isArray(payload.input) ? payload.input : [payload.input];
const promises = input.map((inputText: string) =>
this.invokeEmbeddingModel({
dimensions: payload.dimensions,
input: inputText,
model: payload.model,
}),
);
return await Promise.all(promises);
}

async models(): Promise<ChatModelCard[]> {
const list = await this.client.list();
return list.models.map((model) => ({
id: model.name,
}));
}

private invokeEmbeddingModel = async (payload: EmbeddingsPayload): Promise<Embeddings> => {
try {
const responseBody = await this.client.embeddings({
model: payload.model,
prompt: payload.input as string,
});
return responseBody.embedding;
} catch (error) {
const e = error as { message: string; name: string; status_code: number };

throw AgentRuntimeError.chat({
error: { message: e.message, name: e.name, status_code: e.status_code },
errorType: AgentRuntimeErrorType.OllamaBizError,
provider: ModelProvider.Ollama,
});
}
};

private buildOllamaMessages(messages: OpenAIChatMessage[]) {
return messages.map((message) => this.convertContentToOllamaMessage(message));
}
Expand Down
6 changes: 6 additions & 0 deletions src/server/globalConfig/index.ts
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import { appEnv, getAppConfig } from '@/config/app';
import { authEnv } from '@/config/auth';
import { fileEnv } from '@/config/file';
import { knowledgeEnv } from '@/config/knowledge';
import { langfuseEnv } from '@/config/langfuse';
import { enableNextAuth } from '@/const/auth';
import { parseSystemAgent } from '@/server/globalConfig/parseSystemAgent';
import { GlobalServerConfig } from '@/types/serverConfig';

import { genServerLLMConfig } from './genServerLLMConfig';
import { parseAgentConfig } from './parseDefaultAgent';
import { parseFilesConfig } from './parseFilesConfig';

export const getServerGlobalConfig = () => {
const { ACCESS_CODES, DEFAULT_AGENT_CONFIG } = getAppConfig();
Expand Down Expand Up @@ -51,3 +53,7 @@ export const getServerDefaultAgentConfig = () => {

return parseAgentConfig(DEFAULT_AGENT_CONFIG) || {};
};

export const getServerDefaultFilesConfig = () => {
return parseFilesConfig(knowledgeEnv.DEFAULT_FILES_CONFIG);
};
17 changes: 17 additions & 0 deletions src/server/globalConfig/parseFilesConfig.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import { describe, expect, it } from 'vitest';

import { parseFilesConfig } from './parseFilesConfig';

describe('parseFilesConfig', () => {
// 测试embeddings配置是否被正确解析
it('parses embeddings configuration correctly', () => {
const envStr =
'embedding_model=openai/embedding-text-3-large,reranker_model=cohere/rerank-english-v3.0,query_model=full_text';
const expected = {
embeddingModel: { provider: 'openai', model: 'embedding-text-3-large' },
rerankerModel: { provider: 'cohere', model: 'rerank-english-v3.0' },
queryModel: 'full_text',
};
expect(parseFilesConfig(envStr)).toEqual(expected);
});
});
57 changes: 57 additions & 0 deletions src/server/globalConfig/parseFilesConfig.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import { DEFAULT_FILES_CONFIG } from '@/const/settings/knowledge';
import { SystemEmbeddingConfig } from '@/types/knowledgeBase';
import { FilesConfig } from '@/types/user/settings/filesConfig';

const protectedKeys = Object.keys({
embedding_model: null,
query_model: null,
reranker_model: null,
});

export const parseFilesConfig = (envString: string = ''): SystemEmbeddingConfig => {
if (!envString) return DEFAULT_FILES_CONFIG;
const config: FilesConfig = {} as any;

// 处理全角逗号和多余空格
let envValue = envString.replaceAll(',', ',').trim();

const pairs = envValue.split(',');

for (const pair of pairs) {
const [key, value] = pair.split('=').map((s) => s.trim());

if (key && value) {
const [provider, ...modelParts] = value.split('/');
const model = modelParts.join('/');

if ((!provider || !model) && key !== 'query_model') {
throw new Error('Missing model or provider value');
}

if (key === 'query_model' && value === '') {
throw new Error('Missing query mode value');
}

if (protectedKeys.includes(key)) {
switch (key) {
case 'embedding_model': {
config.embeddingModel = { model: model.trim(), provider: provider.trim() };
break;
}
case 'reranker_model': {
config.rerankerModel = { model: model.trim(), provider: provider.trim() };
break;
}
case 'query_model': {
config.queryModel = value;
break;
}
}
}
} else {
throw new Error('Invalid environment variable format');
}
}

return config;
};
Loading
Loading