Skip to content

Commit

Permalink
Fixed bm25
Browse files Browse the repository at this point in the history
  • Loading branch information
DeevsDeevs committed Oct 23, 2022
1 parent 7f29bab commit 941227d
Showing 1 changed file with 8 additions and 9 deletions.
17 changes: 8 additions & 9 deletions streamlit-app/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,9 +151,7 @@ def get_search_results(search_request: str, additional_info: str, data: pd.DataF
bm25_top10 = pd.DataFrame()
if bm25 != None:
ids, scores = bm25.predict(search_request)
bm25_top10 = goods_df.iloc[ids].head(10) # type: ignore
# bm25_top10['kpgz_sim'] = bm25_top10['Код КПГЗ'].apply(lambda x: calc_kpgz(x, kpgz_table))
bm25_top10['kpgz_sim'] = 0
bm25_top10 = goods_df.iloc[ids].head(10)['ID СТЕ'].values # type: ignore

faiss.normalize_L2(xq)
k = 100
Expand All @@ -165,8 +163,14 @@ def get_search_results(search_request: str, additional_info: str, data: pd.DataF
faiss_results = new_data.loc[selected].reset_index(
drop=True) # type: ignore

faiss_results['additional_dist'] = faiss_results['Характеристики'].apply(lambda x: sum((tok in x) for tok in nltk.word_tokenize(additional_info, language="ru")))

faiss_results['cos_sim'] = list(torch.nn.functional.cosine_similarity(torch.from_numpy(embeddings[selected]), torch.from_numpy(xq)).cpu().detach().numpy())
if not rec:
bm_data = new_data[((new_data['ID СТЕ'].isin(bm25_top10)))].reset_index(drop=True)
bm_data['cos_sim'] = 1
faiss_results = pd.concat([faiss_results, bm_data])

faiss_results['additional_dist'] = faiss_results['Характеристики'].apply(lambda x: sum((tok in x) for tok in nltk.word_tokenize(additional_info, language="ru")))
faiss_results['string_dist'] = faiss_results['Название СТЕ'].apply(
lambda x: string_dist(x, search_request))

Expand All @@ -175,11 +179,6 @@ def get_search_results(search_request: str, additional_info: str, data: pd.DataF
conf_int = max(conf_int, 0.65)
faiss_results = faiss_results.loc[faiss_results['cos_sim'] >= conf_int]

if not rec:
bm25_top10['cos_sim'] = 1
faiss_results = pd.concat([faiss_results, bm25_top10]) # type: ignore
faiss_results = faiss_results.drop_duplicates(subset=['ID СТЕ'])

if rec:
faiss_results = faiss_results[faiss_results['sub_kpgz'] == True]

Expand Down

0 comments on commit 941227d

Please sign in to comment.