Skip to content

Commit

Permalink
fix: 🐛 filters bug
Browse files Browse the repository at this point in the history
  • Loading branch information
gmpetrov committed Oct 2, 2023
1 parent 011eb8d commit 4c5df81
Show file tree
Hide file tree
Showing 6 changed files with 39 additions and 48 deletions.
7 changes: 7 additions & 0 deletions pages/api/agents/[id]/query.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,13 @@ export const chatAgentRequest = async (
throw new ApiError(ApiErrorType.UNAUTHORIZED);
}

// Make sure the Agent has access to datastores passed as filters
for (const datastoreId of data.filters?.datastore_ids || []) {
if (!agent?.tools?.find((one) => one?.datastoreId === datastoreId)) {
throw new ApiError(ApiErrorType.UNAUTHORIZED);
}
}

const orgSession =
session?.organization || formatOrganizationSession(agent?.organization!);
const usage = orgSession?.usage as Usage;
Expand Down
5 changes: 5 additions & 0 deletions pages/api/datastores/[id]/query.ts
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ export const queryURL = async (
throw new ApiError(ApiErrorType.INVALID_REQUEST);
}

if (data.filters?.datastore_ids) {
throw new ApiError(ApiErrorType.UNAUTHORIZED);
}

const datastore = await prisma.datastore.findUnique({
where: {
id: datastoreId,
Expand Down Expand Up @@ -64,6 +68,7 @@ export const queryURL = async (
source: each.metadata.source_url,
datasource_name: each.metadata.datasource_name,
datasource_id: each.metadata.datasource_id,
custom_id: each.metadata.custom_id,
}));
};

Expand Down
5 changes: 0 additions & 5 deletions types/dtos.ts
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,6 @@ export type TaskRemoveDatastoreSchema = z.infer<
>;

export const FiltersSchema = z.object({
custom_id: z.string().optional(),
datasource_id: z
.string()
.cuid()
.optional(),
datastore_ids: z.array(z.string().cuid()).optional(),
datasource_ids: z.array(z.string().cuid()).optional(),
custom_ids: z.array(z.string()).optional(),
Expand Down
39 changes: 14 additions & 25 deletions utils/agent.ts
Original file line number Diff line number Diff line change
@@ -1,22 +1,5 @@
import {
Agent,
Datastore,
Message,
MessageFrom,
PromptType,
Tool,
ToolType,
} from '@prisma/client';
import { AgentExecutor, ZeroShotAgent } from 'langchain/agents';
import { LLMChain } from 'langchain/chains';
import { ChatOpenAI } from 'langchain/chat_models/openai';
import {
ChatPromptTemplate,
HumanMessagePromptTemplate,
SystemMessagePromptTemplate,
} from 'langchain/prompts';
import { AIMessage, HumanMessage, SystemMessage } from 'langchain/schema';
import { Tool as LangchainTool } from 'langchain/tools';
import { Agent, Datastore, Message, PromptType, Tool } from '@prisma/client';
import { AIMessage, HumanMessage } from 'langchain/schema';

import { ChatRequest } from '@app/types/dtos';

Expand Down Expand Up @@ -92,14 +75,20 @@ export default class AgentManager {

const SIMILARITY_THRESHOLD = 0.7;

const filterDatastoreIds = filters?.datastore_ids
? filters?.datastore_ids
: this.agent?.tools
?.filter((each) => !!each?.datastoreId)
?.map((each) => each?.datastoreId);

// Only allow datasource filtering if datastore are present
const filterDatasourceIds =
filterDatastoreIds?.length > 0 ? filters?.datasource_ids : [];

const _filters = {
...filters,
datastore_ids: [
...this.agent?.tools
?.filter((each) => !!each?.datastoreId)
?.map((each) => each?.datastoreId),
...(filters?.datastore_ids || [])!,
],
datastore_ids: filterDatastoreIds,
datasource_ids: filterDatasourceIds,
} as AgentManagerProps['filters'];

return chatRetrieval({
Expand Down
1 change: 1 addition & 0 deletions utils/chains/chat-retrieval.ts
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ const chatRetrieval = async ({
.sort((a, b) => b.metadata.score! - a.metadata.score!)
.map((each) => ({
chunk_id: each.metadata.chunk_id,
custom_id: each.metadata.custom_id,
datasource_id: each.metadata.datasource_id!,
datasource_name: each.metadata.datasource_name!,
datasource_type: each.metadata.datasource_type!,
Expand Down
30 changes: 12 additions & 18 deletions utils/datastores/qdrant.ts
Original file line number Diff line number Diff line change
Expand Up @@ -256,36 +256,30 @@ export class QdrantManager extends ClientManager<DatastoreType> {
},
]
: []),
...(props.filters?.custom_id
...((props.filters?.custom_ids || [])?.length > 0
? [
{
key: MetadataFields.custom_id,
match: { value: props.filters.custom_id },
match: { any: props.filters?.custom_ids },
},
]
: []),
...(props.filters?.datasource_id
...((props.filters?.datasource_ids || [])?.length > 0
? [
{
key: MetadataFields.datasource_id,
match: { value: props.filters.datasource_id },
match: { any: props.filters?.datasource_ids },
},
]
: []),
...((props.filters?.datastore_ids || [])?.length > 0
? [
{
key: MetadataFields.datastore_id,
match: { any: props.filters?.datastore_ids },
},
]
: []),
],
should: [
...(props.filters?.custom_ids || [])?.map((each) => ({
key: MetadataFields.custom_id,
match: { value: each },
})),
...(props.filters?.datasource_ids || [])?.map((each) => ({
key: MetadataFields.datasource_id,
match: { value: each },
})),
...(props.filters?.datastore_ids || [])?.map((each) => ({
key: MetadataFields.datastore_id,
match: { value: each },
})),
],
},
}
Expand Down

0 comments on commit 4c5df81

Please sign in to comment.