Skip to content

Commit

Permalink
Fix embedding model migration with existing index_attempts
Browse files Browse the repository at this point in the history
  • Loading branch information
Weves committed Feb 20, 2024
1 parent 4eaf2b1 commit e246ea9
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 50 deletions.
60 changes: 60 additions & 0 deletions backend/alembic/versions/dbaa756c2ccf_embedding_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,13 @@
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy import table, column, String, Integer, Boolean

from danswer.db.embedding_model import (
get_new_default_embedding_model,
get_old_default_embedding_model,
user_has_overridden_embedding_model,
)
from danswer.db.models import IndexModelStatus

# revision identifiers, used by Alembic.
Expand All @@ -34,6 +40,60 @@ def upgrade() -> None:
),
sa.PrimaryKeyConstraint("id"),
)
# since all index attempts must be associated with an embedding model,
# need to put something in here to avoid nulls. On server startup,
# this value will be overriden
EmbeddingModel = table(
"embedding_model",
column("id", Integer),
column("model_name", String),
column("model_dim", Integer),
column("normalize", Boolean),
column("query_prefix", String),
column("passage_prefix", String),
column("index_name", String),
column(
"status", sa.Enum(IndexModelStatus, name="indexmodelstatus", native=False)
),
)
# insert an embedding model row that corresponds to the embedding model
# the user selected via env variables before this change. This is needed since
# all index_attempts must be associated with an embedding model, so without this
# we will run into violations of non-null contraints
old_embedding_model = get_old_default_embedding_model()
op.bulk_insert(
EmbeddingModel,
[
{
"model_name": old_embedding_model.model_name,
"model_dim": old_embedding_model.model_dim,
"normalize": old_embedding_model.normalize,
"query_prefix": old_embedding_model.query_prefix,
"passage_prefix": old_embedding_model.passage_prefix,
"index_name": old_embedding_model.index_name,
"status": old_embedding_model.status,
}
],
)
# if the user has not overridden the default embedding model via env variables,
# insert the new default model into the database to auto-upgrade them
if not user_has_overridden_embedding_model():
new_embedding_model = get_new_default_embedding_model(is_present=False)
op.bulk_insert(
EmbeddingModel,
[
{
"model_name": new_embedding_model.model_name,
"model_dim": new_embedding_model.model_dim,
"normalize": new_embedding_model.normalize,
"query_prefix": new_embedding_model.query_prefix,
"passage_prefix": new_embedding_model.passage_prefix,
"index_name": new_embedding_model.index_name,
"status": IndexModelStatus.FUTURE,
}
],
)

op.add_column(
"index_attempt",
sa.Column("embedding_model_id", sa.Integer(), nullable=True),
Expand Down
70 changes: 28 additions & 42 deletions backend/danswer/db/embedding_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from danswer.configs.model_configs import OLD_DEFAULT_DOCUMENT_ENCODER_MODEL
from danswer.configs.model_configs import OLD_DEFAULT_MODEL_DOC_EMBEDDING_DIM
from danswer.configs.model_configs import OLD_DEFAULT_MODEL_NORMALIZE_EMBEDDINGS
from danswer.db.connector_credential_pair import get_connector_credential_pairs
from danswer.db.models import EmbeddingModel
from danswer.db.models import IndexModelStatus
from danswer.indexing.models import EmbeddingModelDetail
Expand Down Expand Up @@ -77,53 +76,40 @@ def update_embedding_model_status(
db_session.commit()


def insert_initial_embedding_models(db_session: Session) -> None:
"""Should be called on startup to ensure that the initial
embedding model is present in the DB."""
existing_embedding_models = db_session.scalars(select(EmbeddingModel)).all()
if existing_embedding_models:
logger.error(
"Called `insert_initial_embedding_models` but models already exist in the DB. Skipping."
)
return

existing_cc_pairs = get_connector_credential_pairs(db_session)

# if the user is overriding the `DOCUMENT_ENCODER_MODEL`, then
# allow them to continue to use that model and do nothing fancy
# in the background OR if the user has no connectors, then we can
# also just use the new model immediately
can_skip_upgrade = (
DOCUMENT_ENCODER_MODEL != DEFAULT_DOCUMENT_ENCODER_MODEL
or not existing_cc_pairs
def user_has_overridden_embedding_model() -> bool:
return DOCUMENT_ENCODER_MODEL != DEFAULT_DOCUMENT_ENCODER_MODEL


def get_old_default_embedding_model() -> EmbeddingModel:
is_overridden = user_has_overridden_embedding_model()
return EmbeddingModel(
model_name=(
DOCUMENT_ENCODER_MODEL
if is_overridden
else OLD_DEFAULT_DOCUMENT_ENCODER_MODEL
),
model_dim=(
DOC_EMBEDDING_DIM if is_overridden else OLD_DEFAULT_MODEL_DOC_EMBEDDING_DIM
),
normalize=(
NORMALIZE_EMBEDDINGS
if is_overridden
else OLD_DEFAULT_MODEL_NORMALIZE_EMBEDDINGS
),
query_prefix=(ASYM_QUERY_PREFIX if is_overridden else ""),
passage_prefix=(ASYM_PASSAGE_PREFIX if is_overridden else ""),
status=IndexModelStatus.PRESENT,
index_name="danswer_chunk",
)

# if we need to automatically upgrade the user, then create
# an entry which will automatically be replaced by the
# below desired model
if not can_skip_upgrade:
embedding_model_to_upgrade = EmbeddingModel(
model_name=OLD_DEFAULT_DOCUMENT_ENCODER_MODEL,
model_dim=OLD_DEFAULT_MODEL_DOC_EMBEDDING_DIM,
normalize=OLD_DEFAULT_MODEL_NORMALIZE_EMBEDDINGS,
query_prefix="",
passage_prefix="",
status=IndexModelStatus.PRESENT,
index_name="danswer_chunk",
)
db_session.add(embedding_model_to_upgrade)

desired_embedding_model = EmbeddingModel(

def get_new_default_embedding_model(is_present: bool) -> EmbeddingModel:
return EmbeddingModel(
model_name=DOCUMENT_ENCODER_MODEL,
model_dim=DOC_EMBEDDING_DIM,
normalize=NORMALIZE_EMBEDDINGS,
query_prefix=ASYM_QUERY_PREFIX,
passage_prefix=ASYM_PASSAGE_PREFIX,
status=IndexModelStatus.PRESENT
if can_skip_upgrade
else IndexModelStatus.FUTURE,
status=IndexModelStatus.PRESENT if is_present else IndexModelStatus.FUTURE,
index_name=f"danswer_chunk_{clean_model_name(DOCUMENT_ENCODER_MODEL)}",
)
db_session.add(desired_embedding_model)

db_session.commit()
9 changes: 1 addition & 8 deletions backend/danswer/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@
from danswer.db.credentials import create_initial_public_credential
from danswer.db.embedding_model import get_current_db_embedding_model
from danswer.db.embedding_model import get_secondary_db_embedding_model
from danswer.db.embedding_model import insert_initial_embedding_models
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.index_attempt import cancel_indexing_attempts_past_model
from danswer.db.index_attempt import expire_index_attempts
Expand Down Expand Up @@ -252,13 +251,7 @@ def startup_event() -> None:
)

with Session(engine) as db_session:
try:
db_embedding_model = get_current_db_embedding_model(db_session)
except RuntimeError:
logger.info("No embedding model's found in DB, creating initial model.")
insert_initial_embedding_models(db_session)
db_embedding_model = get_current_db_embedding_model(db_session)

db_embedding_model = get_current_db_embedding_model(db_session)
secondary_db_embedding_model = get_secondary_db_embedding_model(db_session)

# Break bad state for thrashing indexes
Expand Down

0 comments on commit e246ea9

Please sign in to comment.