Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

✨ feat: support to customize Embedding model with env #5177

Merged
merged 11 commits into from
Jan 15, 2025
Next Next commit
feat: 添加嵌入模型配置支持,更新相关文档和测试
  • Loading branch information
cookieY committed Dec 25, 2024
commit f8f64e070a01c901cdaebf20fa23b73075d8d380
3 changes: 3 additions & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -198,3 +198,6 @@ OPENAI_API_KEY=sk-xxxxxxxxx
# use `openssl rand -base64 32` to generate a key for the encryption of the database
# we use this key to encrypt the user api key
#KEY_VAULTS_SECRET=xxxxx/xxxxxxxxxxxxxx=

# Specify the Embedding model and Reranker model(unImplemented)
# DEFAULT_FILES_CONFIG="embedding_model=openai/embedding-text-3-small,reranker_model=cohere/rerank-english-v3.0,query_mode=full_text"
9 changes: 9 additions & 0 deletions docs/self-hosting/advanced/knowledge-base.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,12 @@ Unstructured.io is a powerful document processing tool.
- **Note**: Evaluate processing needs based on document complexity

By correctly configuring and integrating these core components, you can build a powerful and efficient knowledge base system for LobeChat. Each component plays a crucial role in the overall architecture, supporting advanced document management and intelligent retrieval functions.

### 5. Custom Embedding

- **Purpose**: Use different Embedding generate vector representations for semantic search
- **Options**: support model provider list: zhipu/github/openai/bedrock/ollama
- **Deployment Tip**: Used to configure the default Embedding model
```
environment: DEFAULT_FILES_CONFIG=embedding_model=openai/embedding-text-3-small
```
9 changes: 9 additions & 0 deletions docs/self-hosting/advanced/knowledge-base.zh-CN.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,12 @@ Unstructured.io 是一个强大的文档处理工具。
- **注意事项**:评估处理需求,根据文档复杂度决定是否部署

通过正确配置和集成这些核心组件,您可以为 LobeChat 构建一个强大、高效的知识库系统。每个组件都在整体架构中扮演着关键角色,共同支持高级的文档管理和智能检索功能。

### 5. 自定义 Embedding(可选)

- **用途**: 使用不同的嵌入模型(Embedding)生成文本的向量表示,用于语义搜索
- **选项**: 支持的模型提供商:zhipu/github/openai/bedrock/ollama
- **部署建议**: 使用环境变量配置默认嵌入模型
```
environment: DEFAULT_FILES_CONFIG=embedding_model=openai/embedding-text-3-small
```
4 changes: 4 additions & 0 deletions src/config/knowledge.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,14 @@ import { z } from 'zod';
export const getKnowledgeConfig = () => {
return createEnv({
runtimeEnv: {
DEFAULT_FILES_CONFIG: !!process.env.DEFAULT_FILES_CONFIG
? process.env.DEFAULT_FILES_CONFIG
: 'embedding_model=openai/embedding-text-3-small,reranker_model=cohere/rerank-english-v3.0,query_mode=full_text',
cookieY marked this conversation as resolved.
Show resolved Hide resolved
UNSTRUCTURED_API_KEY: process.env.UNSTRUCTURED_API_KEY,
UNSTRUCTURED_SERVER_URL: process.env.UNSTRUCTURED_SERVER_URL,
},
server: {
DEFAULT_FILES_CONFIG: z.string().optional(),
UNSTRUCTURED_API_KEY: z.string().optional(),
UNSTRUCTURED_SERVER_URL: z.string().optional(),
},
Expand Down
9 changes: 9 additions & 0 deletions src/const/settings/knowledge.ts
cookieY marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import { SystemEmbeddingConfig } from '@/types/knowledgeBase';

import { DEFAULT_SYSTEM_AGENT_ITEM } from './systemAgent';

export const SYSTEM_EMBEDDING_CONFIG: SystemEmbeddingConfig = {
embedding_model: DEFAULT_SYSTEM_AGENT_ITEM,
query_model: DEFAULT_SYSTEM_AGENT_ITEM,
reranker_model: DEFAULT_SYSTEM_AGENT_ITEM,
};
7 changes: 5 additions & 2 deletions src/database/schemas/ragEvals.ts
cookieY marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
/* eslint-disable sort-keys-fix/sort-keys-fix */
import { integer, jsonb, pgTable, text, uuid } from 'drizzle-orm/pg-core';

import { DEFAULT_EMBEDDING_MODEL, DEFAULT_MODEL } from '@/const/settings';
import { DEFAULT_MODEL } from '@/const/settings';
import { getServerDefaultFilesConfig } from '@/server/globalConfig';
import { EvalEvaluationStatus } from '@/types/eval';

import { timestamps } from './_helpers';
Expand Down Expand Up @@ -60,7 +61,9 @@ export const evalEvaluation = pgTable('rag_eval_evaluations', {
onDelete: 'cascade',
}),
languageModel: text('language_model').$defaultFn(() => DEFAULT_MODEL),
embeddingModel: text('embedding_model').$defaultFn(() => DEFAULT_EMBEDDING_MODEL),
embeddingModel: text('embedding_model').$defaultFn(() =>
getServerDefaultFilesConfig().getEmbeddingModel(),
),

userId: text('user_id').references(() => users.id, { onDelete: 'cascade' }),
...timestamps,
Expand Down
68 changes: 65 additions & 3 deletions src/libs/agent-runtime/bedrock/index.ts
Original file line number Diff line number Diff line change
@@ -1,12 +1,20 @@
import {
BedrockRuntimeClient,
InvokeModelCommand,
InvokeModelWithResponseStreamCommand,
} from '@aws-sdk/client-bedrock-runtime';
import { experimental_buildLlama2Prompt } from 'ai/prompts';

import { LobeRuntimeAI } from '../BaseAI';
import { AgentRuntimeErrorType } from '../error';
import { ChatCompetitionOptions, ChatStreamPayload, ModelProvider } from '../types';
import {
ChatCompetitionOptions,
ChatStreamPayload,
Embeddings,
EmbeddingsOptions,
EmbeddingsPayload,
ModelProvider,
} from '../types';
import { buildAnthropicMessages, buildAnthropicTools } from '../utils/anthropicHelpers';
import { AgentRuntimeError } from '../utils/createError';
import { debugStream } from '../utils/debugStream';
Expand All @@ -32,9 +40,7 @@ export class LobeBedrockAI implements LobeRuntimeAI {
constructor({ region, accessKeyId, accessKeySecret, sessionToken }: LobeBedrockAIParams = {}) {
if (!(accessKeyId && accessKeySecret))
throw AgentRuntimeError.createError(AgentRuntimeErrorType.InvalidBedrockCredentials);

this.region = region ?? 'us-east-1';

this.client = new BedrockRuntimeClient({
credentials: {
accessKeyId: accessKeyId,
Expand All @@ -50,6 +56,62 @@ export class LobeBedrockAI implements LobeRuntimeAI {

return this.invokeClaudeModel(payload, options);
}
/**
* Supports the Amazon Titan Text models series.
* Cohere Embed models are not supported
* because the current text size per request
* exceeds the maximum 2048 characters limit
* for a single request for this series of models.
* [bedrock embed guide] https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-embed.html
*/
async embeddings(payload: EmbeddingsPayload, options?: EmbeddingsOptions): Promise<Embeddings[]> {
const input = Array.isArray(payload.input) ? payload.input : [payload.input];
const promises = input.map((inputText: string, index: number) =>
this.invokeEmbeddingModel(
{
dimensions: payload.dimensions,
index: index,
input: inputText,
model: payload.model,
},
options,
),
);
return Promise.all(promises);
}

private invokeEmbeddingModel = async (
payload: EmbeddingsPayload,
options?: EmbeddingsOptions,
): Promise<Embeddings> => {
const command = new InvokeModelCommand({
accept: 'application/json',
body: JSON.stringify({
dimensions: payload.dimensions,
inputText: payload.input,
normalize: true,
}),
contentType: 'application/json',
modelId: payload.model,
});
try {
const res = await this.client.send(command, { abortSignal: options?.signal });
const responseBody = JSON.parse(new TextDecoder().decode(res.body));
return responseBody.embedding;
} catch (e) {
const err = e as Error & { $metadata: any };
throw AgentRuntimeError.chat({
error: {
body: err.$metadata,
message: err.message,
type: err.name,
},
errorType: AgentRuntimeErrorType.ProviderBizError,
provider: ModelProvider.Bedrock,
region: this.region,
});
}
};

private invokeClaudeModel = async (
payload: ChatStreamPayload,
Expand Down
39 changes: 38 additions & 1 deletion src/libs/agent-runtime/ollama/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,13 @@ import { ChatModelCard } from '@/types/llm';

import { LobeRuntimeAI } from '../BaseAI';
import { AgentRuntimeErrorType } from '../error';
import { ChatCompetitionOptions, ChatStreamPayload, ModelProvider } from '../types';
import {
ChatCompetitionOptions,
ChatStreamPayload,
Embeddings,
EmbeddingsPayload,
ModelProvider,
} from '../types';
import { AgentRuntimeError } from '../utils/createError';
import { debugStream } from '../utils/debugStream';
import { StreamingResponse } from '../utils/response';
Expand Down Expand Up @@ -84,13 +90,44 @@ export class LobeOllamaAI implements LobeRuntimeAI {
}
}

async embeddings(payload: EmbeddingsPayload): Promise<Embeddings[]> {
const input = Array.isArray(payload.input) ? payload.input : [payload.input];
const promises = input.map((inputText: string, index: number) =>
this.invokeEmbeddingModel({
dimensions: payload.dimensions,
index: index,
input: inputText,
model: payload.model,
}),
);
return await Promise.all(promises);
}

async models(): Promise<ChatModelCard[]> {
const list = await this.client.list();
return list.models.map((model) => ({
id: model.name,
}));
}

private invokeEmbeddingModel = async (payload: EmbeddingsPayload): Promise<Embeddings> => {
try {
const responseBody = await this.client.embeddings({
model: payload.model,
prompt: payload.input as string,
});
return responseBody.embedding;
} catch (error) {
const e = error as { message: string; name: string; status_code: number };

throw AgentRuntimeError.chat({
error: { message: e.message, name: e.name, status_code: e.status_code },
errorType: AgentRuntimeErrorType.OllamaBizError,
provider: ModelProvider.Ollama,
});
}
};

private buildOllamaMessages(messages: OpenAIChatMessage[]) {
return messages.map((message) => this.convertContentToOllamaMessage(message));
}
Expand Down
1 change: 1 addition & 0 deletions src/libs/agent-runtime/types/embeddings.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ export interface EmbeddingsPayload {
* supported in `text-embedding-3` and later models.
*/
dimensions?: number;
index?: number;
cookieY marked this conversation as resolved.
Show resolved Hide resolved
/**
* Input text to embed, encoded as a string or array of tokens. To embed multiple
* inputs in a single request, pass an array of strings .
Expand Down
34 changes: 29 additions & 5 deletions src/libs/agent-runtime/utils/openaiCompatibleFactory/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -275,22 +275,46 @@ export const LobeOpenAICompatibleFactory = <T extends Record<string, any> = any>
.filter(Boolean) as ChatModelCard[];
}

/**
* Due to the limitation of a maximum of 8k tokens per request in
cookieY marked this conversation as resolved.
Show resolved Hide resolved
* the openai interface provided by other vendors
* the current chunked input array size exceeds 8k tokens
* Therefore, a single array with multiple concurrent requests is used.
*/
async embeddings(
payload: EmbeddingsPayload,
options?: EmbeddingsOptions,
): Promise<Embeddings[]> {
try {
const res = await this.client.embeddings.create(
{ ...payload, user: options?.user },
{ headers: options?.headers, signal: options?.signal },
const input = Array.isArray(payload.input) ? payload.input : [payload.input];
const promises = input.map((inputText: string) =>
this.invokeEmbeddings(
{
dimensions: payload.dimensions,
input: inputText,
model: payload.model,
},
options,
),
);

return res.data.map((item) => item.embedding);
const results = await Promise.all(promises);
return results.flat();
} catch (error) {
throw this.handleError(error);
}
}

async invokeEmbeddings(
payload: EmbeddingsPayload,
options?: EmbeddingsOptions,
): Promise<Embeddings[]> {
const res = await this.client.embeddings.create(
{ ...payload, user: options?.user },
{ headers: options?.headers, signal: options?.signal },
);
return res.data.map((item) => item.embedding);
}

async textToImage(payload: TextToImagePayload) {
try {
const res = await this.client.images.generate(payload);
Expand Down
9 changes: 7 additions & 2 deletions src/server/globalConfig/index.ts
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
import { appEnv, getAppConfig } from '@/config/app';
import { authEnv } from '@/config/auth';
import { fileEnv } from '@/config/file';
import { knowledgeEnv } from '@/config/knowledge';
import { langfuseEnv } from '@/config/langfuse';
import { enableNextAuth } from '@/const/auth';
import { parseSystemAgent } from '@/server/globalConfig/parseSystemAgent';
import { FilesStore } from '@/server/modules/Files';
import { GlobalServerConfig } from '@/types/serverConfig';

import { genServerLLMConfig } from './genServerLLMConfig';
import { parseAgentConfig } from './parseDefaultAgent';

import { genServerLLMConfig } from './genServerLLMConfig'

export const getServerGlobalConfig = () => {
const { ACCESS_CODES, DEFAULT_AGENT_CONFIG } = getAppConfig();

Expand Down Expand Up @@ -52,3 +53,7 @@ export const getServerDefaultAgentConfig = () => {

return parseAgentConfig(DEFAULT_AGENT_CONFIG) || {};
};

export const getServerDefaultFilesConfig = () => {
return new FilesStore(parseSystemAgent(knowledgeEnv.DEFAULT_FILES_CONFIG));
};
11 changes: 11 additions & 0 deletions src/server/globalConfig/parseSystemAgent.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -92,4 +92,15 @@ describe('parseSystemAgent', () => {

expect(parseSystemAgent(envValue)).toEqual(expected);
});
// 测试embeddings配置是否被正确解析
it('parses embeddings configuration correctly', () => {
const envStr =
'embedding_model=openai/embedding-text-3-large,reranker_model=cohere/rerank-english-v3.0,query_model=full_text';
const expected = {
embedding_model: { provider: 'openai', model: 'embedding-text-3-large' },
reranker_model: { provider: 'cohere', model: 'rerank-english-v3.0' },
query_model: { provider: 'full_text', model: '' },
};
expect(parseSystemAgent(envStr)).toEqual(expected);
});
});
14 changes: 9 additions & 5 deletions src/server/globalConfig/parseSystemAgent.ts
cookieY marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
import { DEFAULT_SYSTEM_AGENT_CONFIG } from '@/const/settings';
import { SYSTEM_EMBEDDING_CONFIG } from '@/const/settings/knowledge';
import { SystemEmbeddingConfig } from '@/types/knowledgeBase';
import { UserSystemAgentConfig } from '@/types/user/settings';

const protectedKeys = Object.keys(DEFAULT_SYSTEM_AGENT_CONFIG);
const protectedKeys = Object.keys({ ...DEFAULT_SYSTEM_AGENT_CONFIG, ...SYSTEM_EMBEDDING_CONFIG });

export const parseSystemAgent = (envString: string = ''): Partial<UserSystemAgentConfig> => {
export interface CommonSystemConfig extends UserSystemAgentConfig, SystemEmbeddingConfig {}

export const parseSystemAgent = (envString: string = ''): Partial<CommonSystemConfig> => {
if (!envString) return {};

const config: Partial<UserSystemAgentConfig> = {};
const config: Partial<CommonSystemConfig> = {};

// 处理全角逗号和多余空格
let envValue = envString.replaceAll(',', ',').trim();
Expand All @@ -20,12 +24,12 @@ export const parseSystemAgent = (envString: string = ''): Partial<UserSystemAgen
const [provider, ...modelParts] = value.split('/');
const model = modelParts.join('/');

if (!provider || !model) {
if (!provider) {
throw new Error('Missing model or provider value');
}

if (protectedKeys.includes(key)) {
config[key as keyof UserSystemAgentConfig] = {
config[key as keyof CommonSystemConfig] = {
enabled: key === 'queryRewrite' ? true : undefined,
model: model.trim(),
provider: provider.trim(),
Expand Down
Loading
Loading