Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

✨ feat: support to customize Embedding model with env #5177

Merged
merged 11 commits into from
Jan 15, 2025
Prev Previous commit
Next Next commit
♻ Code Refactoring
- Update the file configuration and standardize the model naming to camel case.
  • Loading branch information
cookieY committed Jan 7, 2025
commit 91ed1dc9b7b858f5897772b9db1869a776536485
6 changes: 3 additions & 3 deletions src/const/settings/knowledge.ts
cookieY marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ export const DEFAULT_FILE_RERANK_MODEL_ITEM: FilesConfigItem = {
};

export const DEFAULT_FILES_CONFIG: FilesConfig = {
embedding_model: DEFAULT_FILE_EMBEDDING_MODEL_ITEM,
query_model: DEFAULT_RERANK_QUERY_MODE,
reranker_model: DEFAULT_FILE_RERANK_MODEL_ITEM,
embeddingModel: DEFAULT_FILE_EMBEDDING_MODEL_ITEM,
queryModel: DEFAULT_RERANK_QUERY_MODE,
rerankerModel: DEFAULT_FILE_RERANK_MODEL_ITEM,
};
6 changes: 3 additions & 3 deletions src/server/globalConfig/parseFilesConfig.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ describe('parseFilesConfig', () => {
const envStr =
'embedding_model=openai/embedding-text-3-large,reranker_model=cohere/rerank-english-v3.0,query_model=full_text';
const expected = {
embedding_model: { provider: 'openai', model: 'embedding-text-3-large' },
reranker_model: { provider: 'cohere', model: 'rerank-english-v3.0' },
query_model: 'full_text',
embeddingModel: { provider: 'openai', model: 'embedding-text-3-large' },
rerankerModel: { provider: 'cohere', model: 'rerank-english-v3.0' },
queryModel: 'full_text',
};
expect(parseFilesConfig(envStr)).toEqual(expected);
});
Expand Down
26 changes: 18 additions & 8 deletions src/server/globalConfig/parseFilesConfig.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@ import { DEFAULT_FILES_CONFIG } from '@/const/settings/knowledge';
import { SystemEmbeddingConfig } from '@/types/knowledgeBase';
import { FilesConfig } from '@/types/user/settings/filesConfig';

const protectedKeys = Object.keys(DEFAULT_FILES_CONFIG);
const protectedKeys = Object.keys({
embedding_model: null,
query_model: null,
reranker_model: null,
});

export const parseFilesConfig = (envString: string = ''): SystemEmbeddingConfig => {
if (!envString) return DEFAULT_FILES_CONFIG;
Expand All @@ -29,13 +33,19 @@ export const parseFilesConfig = (envString: string = ''): SystemEmbeddingConfig
}

if (protectedKeys.includes(key)) {
if (key === 'query_model') {
config.query_model = value;
} else {
config[key as keyof FilesConfig] = {
model: model.trim(),
provider: provider.trim(),
} as any;
switch (key) {
case 'embedding_model': {
config.embeddingModel = { model: model.trim(), provider: provider.trim() };
break;
}
case 'reranker_model': {
config.rerankerModel = { model: model.trim(), provider: provider.trim() };
break;
}
case 'query_model': {
config.queryModel = value;
break;
}
}
}
} else {
Expand Down
11 changes: 0 additions & 11 deletions src/server/globalConfig/parseSystemAgent.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -92,15 +92,4 @@ describe('parseSystemAgent', () => {

expect(parseSystemAgent(envValue)).toEqual(expected);
});
// 测试embeddings配置是否被正确解析
it('parses embeddings configuration correctly', () => {
const envStr =
'embedding_model=openai/embedding-text-3-large,reranker_model=cohere/rerank-english-v3.0,query_model=full_text';
const expected = {
embedding_model: { provider: 'openai', model: 'embedding-text-3-large' },
reranker_model: { provider: 'cohere', model: 'rerank-english-v3.0' },
query_model: { provider: 'full_text', model: '' },
};
expect(parseSystemAgent(envStr)).toEqual(expected);
});
});
9 changes: 5 additions & 4 deletions src/server/routers/async/file.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import { z } from 'zod';

import { serverDBEnv } from '@/config/db';
import { fileEnv } from '@/config/file';
import { DEFAULT_FILE_EMBEDDING_MODEL_ITEM } from '@/const/settings/knowledge';
import { NewChunkItem, NewEmbeddingsItem } from '@/database/schemas';
import { serverDB } from '@/database/server';
import { ASYNC_TASK_TIMEOUT, AsyncTaskModel } from '@/database/server/models/asyncTask';
Expand Down Expand Up @@ -55,8 +56,8 @@ export const fileRouter = router({

const asyncTask = await ctx.asyncTaskModel.findById(input.taskId);

const model = getServerDefaultFilesConfig().embedding_model.model;
const provider = getServerDefaultFilesConfig().embedding_model.provider;
const { model, provider } =
getServerDefaultFilesConfig().embeddingModel || DEFAULT_FILE_EMBEDDING_MODEL_ITEM;

if (!asyncTask) throw new TRPCError({ code: 'BAD_REQUEST', message: 'Async Task not found' });

Expand Down Expand Up @@ -101,7 +102,7 @@ export const fileRouter = router({
const embeddings = await agentRuntime.embeddings({
dimensions: 1024,
input: chunks.map((c) => c.text),
model: model,
model,
});
console.timeEnd(`任务[${number}]: embeddings`);

Expand All @@ -110,7 +111,7 @@ export const fileRouter = router({
chunkId: chunks[idx].id,
embeddings: e,
fileId: input.fileId,
model: model,
model,
})) || [];

console.time(`任务[${number}]: insert db`);
Expand Down
18 changes: 8 additions & 10 deletions src/server/routers/lambda/chunk.ts
cookieY marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import { inArray } from 'drizzle-orm/expressions';
import { z } from 'zod';

import { DEFAULT_FILE_EMBEDDING_MODEL_ITEM } from '@/const/settings/knowledge';
import { knowledgeBaseFiles } from '@/database/schemas';
import { serverDB } from '@/database/server';
import { AsyncTaskModel } from '@/database/server/models/asyncTask';
Expand Down Expand Up @@ -104,15 +105,14 @@ export const chunkRouter = router({
}),
)
.mutation(async ({ ctx, input }) => {
console.time('embedding');
const model = getServerDefaultFilesConfig().embedding_model.model;
const provider = getServerDefaultFilesConfig().embedding_model.provider;
const { model, provider } =
getServerDefaultFilesConfig().embeddingModel || DEFAULT_FILE_EMBEDDING_MODEL_ITEM;
const agentRuntime = await initAgentRuntimeWithUserPayload(provider, ctx.jwtPayload);

const embeddings = await agentRuntime.embeddings({
dimensions: 1024,
input: input.query,
model: model,
model,
});
console.timeEnd('embedding');

Expand All @@ -127,12 +127,10 @@ export const chunkRouter = router({
.input(SemanticSearchSchema)
.mutation(async ({ ctx, input }) => {
const item = await ctx.messageModel.findMessageQueriesById(input.messageId);
const model = getServerDefaultFilesConfig().embedding_model.model;
const provider = getServerDefaultFilesConfig().embedding_model.provider;
const { model, provider } =
getServerDefaultFilesConfig().embeddingModel || DEFAULT_FILE_EMBEDDING_MODEL_ITEM;
let embedding: number[];
let ragQueryId: string;
console.log('embeddingProvider:', provider);
console.log('embeddingModel:', model);
// if there is no message rag or it's embeddings, then we need to create one
if (!item || !item.embeddings) {
// TODO: need to support customize
Expand All @@ -141,13 +139,13 @@ export const chunkRouter = router({
const embeddings = await agentRuntime.embeddings({
dimensions: 1024,
input: input.rewriteQuery,
model: model,
model,
});

embedding = embeddings![0];
const embeddingsId = await ctx.embeddingModel.create({
embeddings: embedding,
model: model,
model,
});

const result = await ctx.messageModel.createMessageQuery({
Expand Down
6 changes: 3 additions & 3 deletions src/types/knowledgeBase/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ export interface KnowledgeItem {
}

export interface SystemEmbeddingConfig {
embedding_model: FilesConfigItem;
query_model: string;
reranker_model: FilesConfigItem;
embeddingModel: FilesConfigItem;
queryModel: string;
rerankerModel: FilesConfigItem;
}
6 changes: 3 additions & 3 deletions src/types/user/settings/filesConfig.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ export interface FilesConfigItem {
provider: string;
}
export interface FilesConfig {
embedding_model: FilesConfigItem;
query_model: string;
reranker_model: FilesConfigItem;
embeddingModel: FilesConfigItem;
queryModel: string;
rerankerModel: FilesConfigItem;
}