From e246ea9d3b2da2b1905e070c88e76059c346fda0 Mon Sep 17 00:00:00 2001 From: Weves Date: Mon, 19 Feb 2024 17:54:39 -0800 Subject: [PATCH] Fix embedding model migration with existing index_attempts --- .../versions/dbaa756c2ccf_embedding_models.py | 60 ++++++++++++++++ backend/danswer/db/embedding_model.py | 70 ++++++++----------- backend/danswer/main.py | 9 +-- 3 files changed, 89 insertions(+), 50 deletions(-) diff --git a/backend/alembic/versions/dbaa756c2ccf_embedding_models.py b/backend/alembic/versions/dbaa756c2ccf_embedding_models.py index 7b104e1217f..daa6294af26 100644 --- a/backend/alembic/versions/dbaa756c2ccf_embedding_models.py +++ b/backend/alembic/versions/dbaa756c2ccf_embedding_models.py @@ -7,7 +7,13 @@ """ from alembic import op import sqlalchemy as sa +from sqlalchemy import table, column, String, Integer, Boolean +from danswer.db.embedding_model import ( + get_new_default_embedding_model, + get_old_default_embedding_model, + user_has_overridden_embedding_model, +) from danswer.db.models import IndexModelStatus # revision identifiers, used by Alembic. @@ -34,6 +40,60 @@ def upgrade() -> None: ), sa.PrimaryKeyConstraint("id"), ) + # since all index attempts must be associated with an embedding model, + # need to put something in here to avoid nulls. On server startup, + # this value will be overriden + EmbeddingModel = table( + "embedding_model", + column("id", Integer), + column("model_name", String), + column("model_dim", Integer), + column("normalize", Boolean), + column("query_prefix", String), + column("passage_prefix", String), + column("index_name", String), + column( + "status", sa.Enum(IndexModelStatus, name="indexmodelstatus", native=False) + ), + ) + # insert an embedding model row that corresponds to the embedding model + # the user selected via env variables before this change. This is needed since + # all index_attempts must be associated with an embedding model, so without this + # we will run into violations of non-null contraints + old_embedding_model = get_old_default_embedding_model() + op.bulk_insert( + EmbeddingModel, + [ + { + "model_name": old_embedding_model.model_name, + "model_dim": old_embedding_model.model_dim, + "normalize": old_embedding_model.normalize, + "query_prefix": old_embedding_model.query_prefix, + "passage_prefix": old_embedding_model.passage_prefix, + "index_name": old_embedding_model.index_name, + "status": old_embedding_model.status, + } + ], + ) + # if the user has not overridden the default embedding model via env variables, + # insert the new default model into the database to auto-upgrade them + if not user_has_overridden_embedding_model(): + new_embedding_model = get_new_default_embedding_model(is_present=False) + op.bulk_insert( + EmbeddingModel, + [ + { + "model_name": new_embedding_model.model_name, + "model_dim": new_embedding_model.model_dim, + "normalize": new_embedding_model.normalize, + "query_prefix": new_embedding_model.query_prefix, + "passage_prefix": new_embedding_model.passage_prefix, + "index_name": new_embedding_model.index_name, + "status": IndexModelStatus.FUTURE, + } + ], + ) + op.add_column( "index_attempt", sa.Column("embedding_model_id", sa.Integer(), nullable=True), diff --git a/backend/danswer/db/embedding_model.py b/backend/danswer/db/embedding_model.py index e6b77df5c39..ae2b98d514f 100644 --- a/backend/danswer/db/embedding_model.py +++ b/backend/danswer/db/embedding_model.py @@ -10,7 +10,6 @@ from danswer.configs.model_configs import OLD_DEFAULT_DOCUMENT_ENCODER_MODEL from danswer.configs.model_configs import OLD_DEFAULT_MODEL_DOC_EMBEDDING_DIM from danswer.configs.model_configs import OLD_DEFAULT_MODEL_NORMALIZE_EMBEDDINGS -from danswer.db.connector_credential_pair import get_connector_credential_pairs from danswer.db.models import EmbeddingModel from danswer.db.models import IndexModelStatus from danswer.indexing.models import EmbeddingModelDetail @@ -77,53 +76,40 @@ def update_embedding_model_status( db_session.commit() -def insert_initial_embedding_models(db_session: Session) -> None: - """Should be called on startup to ensure that the initial - embedding model is present in the DB.""" - existing_embedding_models = db_session.scalars(select(EmbeddingModel)).all() - if existing_embedding_models: - logger.error( - "Called `insert_initial_embedding_models` but models already exist in the DB. Skipping." - ) - return - - existing_cc_pairs = get_connector_credential_pairs(db_session) - - # if the user is overriding the `DOCUMENT_ENCODER_MODEL`, then - # allow them to continue to use that model and do nothing fancy - # in the background OR if the user has no connectors, then we can - # also just use the new model immediately - can_skip_upgrade = ( - DOCUMENT_ENCODER_MODEL != DEFAULT_DOCUMENT_ENCODER_MODEL - or not existing_cc_pairs +def user_has_overridden_embedding_model() -> bool: + return DOCUMENT_ENCODER_MODEL != DEFAULT_DOCUMENT_ENCODER_MODEL + + +def get_old_default_embedding_model() -> EmbeddingModel: + is_overridden = user_has_overridden_embedding_model() + return EmbeddingModel( + model_name=( + DOCUMENT_ENCODER_MODEL + if is_overridden + else OLD_DEFAULT_DOCUMENT_ENCODER_MODEL + ), + model_dim=( + DOC_EMBEDDING_DIM if is_overridden else OLD_DEFAULT_MODEL_DOC_EMBEDDING_DIM + ), + normalize=( + NORMALIZE_EMBEDDINGS + if is_overridden + else OLD_DEFAULT_MODEL_NORMALIZE_EMBEDDINGS + ), + query_prefix=(ASYM_QUERY_PREFIX if is_overridden else ""), + passage_prefix=(ASYM_PASSAGE_PREFIX if is_overridden else ""), + status=IndexModelStatus.PRESENT, + index_name="danswer_chunk", ) - # if we need to automatically upgrade the user, then create - # an entry which will automatically be replaced by the - # below desired model - if not can_skip_upgrade: - embedding_model_to_upgrade = EmbeddingModel( - model_name=OLD_DEFAULT_DOCUMENT_ENCODER_MODEL, - model_dim=OLD_DEFAULT_MODEL_DOC_EMBEDDING_DIM, - normalize=OLD_DEFAULT_MODEL_NORMALIZE_EMBEDDINGS, - query_prefix="", - passage_prefix="", - status=IndexModelStatus.PRESENT, - index_name="danswer_chunk", - ) - db_session.add(embedding_model_to_upgrade) - - desired_embedding_model = EmbeddingModel( + +def get_new_default_embedding_model(is_present: bool) -> EmbeddingModel: + return EmbeddingModel( model_name=DOCUMENT_ENCODER_MODEL, model_dim=DOC_EMBEDDING_DIM, normalize=NORMALIZE_EMBEDDINGS, query_prefix=ASYM_QUERY_PREFIX, passage_prefix=ASYM_PASSAGE_PREFIX, - status=IndexModelStatus.PRESENT - if can_skip_upgrade - else IndexModelStatus.FUTURE, + status=IndexModelStatus.PRESENT if is_present else IndexModelStatus.FUTURE, index_name=f"danswer_chunk_{clean_model_name(DOCUMENT_ENCODER_MODEL)}", ) - db_session.add(desired_embedding_model) - - db_session.commit() diff --git a/backend/danswer/main.py b/backend/danswer/main.py index e7b6e849108..5977f2a7cd5 100644 --- a/backend/danswer/main.py +++ b/backend/danswer/main.py @@ -47,7 +47,6 @@ from danswer.db.credentials import create_initial_public_credential from danswer.db.embedding_model import get_current_db_embedding_model from danswer.db.embedding_model import get_secondary_db_embedding_model -from danswer.db.embedding_model import insert_initial_embedding_models from danswer.db.engine import get_sqlalchemy_engine from danswer.db.index_attempt import cancel_indexing_attempts_past_model from danswer.db.index_attempt import expire_index_attempts @@ -252,13 +251,7 @@ def startup_event() -> None: ) with Session(engine) as db_session: - try: - db_embedding_model = get_current_db_embedding_model(db_session) - except RuntimeError: - logger.info("No embedding model's found in DB, creating initial model.") - insert_initial_embedding_models(db_session) - db_embedding_model = get_current_db_embedding_model(db_session) - + db_embedding_model = get_current_db_embedding_model(db_session) secondary_db_embedding_model = get_secondary_db_embedding_model(db_session) # Break bad state for thrashing indexes