Skip to content

Commit

Permalink
ensure tenant context properly passed to ee bg tasks
Browse files Browse the repository at this point in the history
  • Loading branch information
pablonyx committed Oct 8, 2024
1 parent a932df9 commit e3424e1
Showing 1 changed file with 67 additions and 39 deletions.
106 changes: 67 additions & 39 deletions backend/ee/danswer/background/celery/celery_app.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from datetime import timedelta

from sqlalchemy.orm import Session

from danswer.background.celery.celery_app import celery_app
from danswer.background.task_utils import build_celery_task_wrapper
from danswer.background.update import get_all_tenant_ids
from danswer.configs.app_configs import JOB_TIMEOUT
from danswer.configs.app_configs import MULTI_TENANT
from danswer.db.chat import delete_chat_sessions_older_than
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.engine import get_session_with_tenant
from danswer.server.settings.store import load_settings
from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import global_version
Expand All @@ -32,6 +32,7 @@
run_external_group_permission_sync,
)
from ee.danswer.server.reporting.usage_export_generation import create_new_usage_report
from shared_configs.configs import current_tenant_id

logger = setup_logger()

Expand All @@ -41,22 +42,26 @@

@build_celery_task_wrapper(name_sync_external_doc_permissions_task)
@celery_app.task(soft_time_limit=JOB_TIMEOUT)
def sync_external_doc_permissions_task(cc_pair_id: int) -> None:
with Session(get_sqlalchemy_engine()) as db_session:
def sync_external_doc_permissions_task(cc_pair_id: int, tenant_id: str | None) -> None:
with get_session_with_tenant(tenant_id) as db_session:
run_external_doc_permission_sync(db_session=db_session, cc_pair_id=cc_pair_id)


@build_celery_task_wrapper(name_sync_external_group_permissions_task)
@celery_app.task(soft_time_limit=JOB_TIMEOUT)
def sync_external_group_permissions_task(cc_pair_id: int) -> None:
with Session(get_sqlalchemy_engine()) as db_session:
def sync_external_group_permissions_task(
cc_pair_id: int, tenant_id: str | None
) -> None:
with get_session_with_tenant(tenant_id) as db_session:
run_external_group_permission_sync(db_session=db_session, cc_pair_id=cc_pair_id)


@build_celery_task_wrapper(name_chat_ttl_task)
@celery_app.task(soft_time_limit=JOB_TIMEOUT)
def perform_ttl_management_task(retention_limit_days: int) -> None:
with Session(get_sqlalchemy_engine()) as db_session:
def perform_ttl_management_task(
retention_limit_days: int, tenant_id: str | None
) -> None:
with get_session_with_tenant(tenant_id) as db_session:
delete_chat_sessions_older_than(retention_limit_days, db_session)


Expand All @@ -67,59 +72,67 @@ def perform_ttl_management_task(retention_limit_days: int) -> None:
name="check_sync_external_doc_permissions_task",
soft_time_limit=JOB_TIMEOUT,
)
def check_sync_external_doc_permissions_task() -> None:
def check_sync_external_doc_permissions_task(tenant_id: str | None) -> None:
"""Runs periodically to sync external permissions"""
with Session(get_sqlalchemy_engine()) as db_session:
with get_session_with_tenant(tenant_id) as db_session:
cc_pairs = get_all_auto_sync_cc_pairs(db_session)
for cc_pair in cc_pairs:
if should_perform_external_doc_permissions_check(
cc_pair=cc_pair, db_session=db_session
):
sync_external_doc_permissions_task.apply_async(
kwargs=dict(cc_pair_id=cc_pair.id),
kwargs=dict(cc_pair_id=cc_pair.id, tenant_id=tenant_id),
)


@celery_app.task(
name="check_sync_external_group_permissions_task",
soft_time_limit=JOB_TIMEOUT,
)
def check_sync_external_group_permissions_task() -> None:
def check_sync_external_group_permissions_task(tenant_id: str | None) -> None:
"""Runs periodically to sync external group permissions"""
with Session(get_sqlalchemy_engine()) as db_session:
with get_session_with_tenant(tenant_id) as db_session:
cc_pairs = get_all_auto_sync_cc_pairs(db_session)
for cc_pair in cc_pairs:
if should_perform_external_group_permissions_check(
cc_pair=cc_pair, db_session=db_session
):
sync_external_group_permissions_task.apply_async(
kwargs=dict(cc_pair_id=cc_pair.id),
kwargs=dict(cc_pair_id=cc_pair.id, tenant_id=tenant_id),
)


@celery_app.task(
name="check_ttl_management_task",
soft_time_limit=JOB_TIMEOUT,
)
def check_ttl_management_task() -> None:
def check_ttl_management_task(tenant_id: str | None) -> None:
"""Runs periodically to check if any ttl tasks should be run and adds them
to the queue"""
token = None
if MULTI_TENANT and tenant_id is not None:
token = current_tenant_id.set(tenant_id)

settings = load_settings()
retention_limit_days = settings.maximum_chat_retention_days
with Session(get_sqlalchemy_engine()) as db_session:
with get_session_with_tenant(tenant_id) as db_session:
if should_perform_chat_ttl_check(retention_limit_days, db_session):
perform_ttl_management_task.apply_async(
kwargs=dict(retention_limit_days=retention_limit_days),
kwargs=dict(
retention_limit_days=retention_limit_days, tenant_id=tenant_id
),
)
if token is not None:
current_tenant_id.reset(token)


@celery_app.task(
name="autogenerate_usage_report_task",
soft_time_limit=JOB_TIMEOUT,
)
def autogenerate_usage_report_task() -> None:
def autogenerate_usage_report_task(tenant_id: str | None) -> None:
"""This generates usage report under the /admin/generate-usage/report endpoint"""
with Session(get_sqlalchemy_engine()) as db_session:
with get_session_with_tenant(tenant_id) as db_session:
create_new_usage_report(
db_session=db_session,
user_id=None,
Expand All @@ -130,22 +143,37 @@ def autogenerate_usage_report_task() -> None:
#####
# Celery Beat (Periodic Tasks) Settings
#####
celery_app.conf.beat_schedule = {
"sync-external-doc-permissions": {
"task": "check_sync_external_doc_permissions_task",
"schedule": timedelta(seconds=5), # TODO: optimize this
},
"sync-external-group-permissions": {
"task": "check_sync_external_group_permissions_task",
"schedule": timedelta(seconds=5), # TODO: optimize this
},
"autogenerate_usage_report": {
"task": "autogenerate_usage_report_task",
"schedule": timedelta(days=30), # TODO: change this to config flag
},
"check-ttl-management": {
"task": "check_ttl_management_task",
"schedule": timedelta(hours=1),
},
**(celery_app.conf.beat_schedule or {}),
}


def schedule_tenant_tasks() -> None:
tenants = get_all_tenant_ids()
celery_app.conf.beat_schedule = {}

for tenant_id in tenants:
celery_app.conf.beat_schedule.update(
{
f"sync-external-doc-permissions-{tenant_id}": {
"task": "check_sync_external_doc_permissions_task",
"schedule": timedelta(seconds=5), # TODO: optimize this
"args": (tenant_id,),
},
f"sync-external-group-permissions-{tenant_id}": {
"task": "check_sync_external_group_permissions_task",
"schedule": timedelta(seconds=5), # TODO: optimize this
"args": (tenant_id,),
},
f"autogenerate-usage-report-{tenant_id}": {
"task": "autogenerate_usage_report_task",
"schedule": timedelta(days=30), # TODO: change this to config flag
"args": (tenant_id,),
},
f"check-ttl-management-{tenant_id}": {
"task": "check_ttl_management_task",
"schedule": timedelta(hours=1),
"args": (tenant_id,),
},
}
)


schedule_tenant_tasks()

0 comments on commit e3424e1

Please sign in to comment.