Skip to content

Commit

Permalink
Added some more search logic
Browse files Browse the repository at this point in the history
  • Loading branch information
DeevsDeevs committed Oct 23, 2022
1 parent 68eb1fc commit 352f55f
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 59 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,5 @@ streamlit-app/median_prices.feather
streamlit-app/preprocess_all_columns_lem.feather
streamlit-app/rec_data.parquet
.DS_Store
streamlit-app/LABSE-5307-epoch-5-lem.feather
streamlit-app/BertCLS_epoch_5-lem.pth
125 changes: 78 additions & 47 deletions streamlit-app/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
import nltk
import string

nltk.download('punkt')


@st.cache(allow_output_mutation=True)
def load_all():
Expand All @@ -25,16 +27,19 @@ def load_all():

bert_cls = utils.BertCLS(model, 5307)
bert_cls.load_state_dict(torch.load(
"BertCLS_epoch_5.pth", map_location=device))
"BertCLS_epoch_5-lem.pth", map_location=device))
bert_cls = bert_cls.to(device)

final_df = pd.read_feather(
"preprocess_all_columns_lem.feather")
"LABSE-5307-epoch-5-lem.feather")
price_df = pd.read_feather(
"median_prices.feather").rename(columns={'median': 'price'})

data = final_df[final_df.columns[:-768]].reset_index(drop=True)
data = pd.merge(data, price_df.rename(columns={'id': 'ID СТЕ'}), on='ID СТЕ', how='left')

names_counts = data['Название_СТЕ_source'].str.lower().str.strip().value_counts()

rec_dict = dict(pd.read_parquet('rec_data.parquet').values)

embeddings = final_df[final_df.columns[-768:]].values.astype(np.float32)
Expand All @@ -50,62 +55,75 @@ def load_all():
punctuation = set(string.punctuation)
morph = pymorphy2.MorphAnalyzer()

return data, tokenizer, bert_cls, index, kpgz_dict, rec_dict, punctuation, morph

# def get_fig_price(series: pd.Series):
# try:
# data = {"11-20": series[(series.quantile(0.11) <= series) & (series <= series.quantile(0.2))].mean(),
# "21-30": series[(series.quantile(0.21) <= series) & (series <= series.quantile(0.3))].mean(),
# "31-40": series[(series.quantile(0.31) <= series) & (series <= series.quantile(0.4))].mean(),
# "41-50": series[(series.quantile(0.41) <= series) & (series <= series.quantile(0.5))].mean(),
# "51-60": series[(series.quantile(0.51) <= series) & (series <= series.quantile(0.6))].mean(),
# "61-70": series[(series.quantile(0.61) <= series) & (series <= series.quantile(0.7))].mean(),
# "71-80": series[(series.quantile(0.71) <= series) & (series <= series.quantile(0.8))].mean(),
# "81-90": series[(series.quantile(0.81) <= series) & (series <= series.quantile(0.9))].mean(), }
# courses = list(data.keys())
# values = list(data.values())
# plt.style.use('dark_background') # type: ignore
# fig, ax = plt.subplots(figsize=(10, 5))
# ax.bar(courses, values, width=0.4)
# ax.bar_label(ax.containers[0]) # type: ignore
# plt.xlabel("Quantile Group")
# plt.ylabel("Price")
# plt.title("Распределение средней цены по квантилям")
# plt.axhline(y=(data['11-20'] + data['81-90']) / 2,
# linewidth=3, color='red', label="Средняя цена", ls="--")
# plt.axhline(y=series.median(), linewidth=3, color='pink',
# label="Медианная цена", ls="--")
# plt.legend()
# return fig
# except:
# return None

# def get_fig_tops(series: pd.Series):
# try:
# plt.style.use('dark_background') # type: ignore
# fig, ax = plt.subplots(figsize=(7, 3))
# counts = series.value_counts()[:5]
# counts = counts / len(series) * 100
# counts.plot(ax = ax, kind = 'barh', xlabel = 'Процент предложений выставленных этим ИНН от общего числа')
# ax.bar_label(ax.containers[0]) # type: ignore
# return fig
# except:
# return None
return data, names_counts, tokenizer, bert_cls, embeddings, index, kpgz_dict, rec_dict, punctuation, morph

def get_fig_price(series: pd.Series):
try:
series = series.dropna()
data = {"11-20": series[(series.quantile(0.11) <= series) & (series <= series.quantile(0.2))].mean(),
"21-30": series[(series.quantile(0.21) <= series) & (series <= series.quantile(0.3))].mean(),
"31-40": series[(series.quantile(0.31) <= series) & (series <= series.quantile(0.4))].mean(),
"41-50": series[(series.quantile(0.41) <= series) & (series <= series.quantile(0.5))].mean(),
"51-60": series[(series.quantile(0.51) <= series) & (series <= series.quantile(0.6))].mean(),
"61-70": series[(series.quantile(0.61) <= series) & (series <= series.quantile(0.7))].mean(),
"71-80": series[(series.quantile(0.71) <= series) & (series <= series.quantile(0.8))].mean(),
"81-90": series[(series.quantile(0.81) <= series) & (series <= series.quantile(0.9))].mean(), }
courses = list(data.keys())
values = list(data.values())
plt.style.use('dark_background') # type: ignore
fig, ax = plt.subplots(figsize=(10, 5))
ax.bar(courses, values, width=0.4)
ax.bar_label(ax.containers[0]) # type: ignore
plt.xlabel("Quantile Group")
plt.ylabel("Price")
plt.title("Распределение средней цены по квантилям")
plt.axhline(y=(data['11-20'] + data['81-90']) / 2,
linewidth=3, color='red', label="Средняя цена", ls="--")
plt.axhline(y=series.median(), linewidth=3, color='pink',
label="Медианная цена", ls="--")
plt.legend()
return fig
except:
return None



def main():
st.set_page_config(page_title='Tender Search Engine')
st.markdown("""
# Tender Search Engine
""")
data, tokenizer, bert_cls, index, kpgz_dict, rec_dict, punctuation, morph = load_all()
data, names_counts, tokenizer, bert_cls, embeddings, index, kpgz_dict, rec_dict, punctuation, morph = load_all()
search_request = st.text_input('Введите слова для поиска:').lower().strip()
search_expander = st.expander('Дополнительные настройки')
additional_info = search_expander.text_input('Дополнительные параметры').lower().strip()
additional_info = search_expander.text_input('Ключевые характеристики').lower().strip()
min_price = search_expander.number_input('Введите минимальную стоимость')
max_price = search_expander.number_input('Введите максимальную стоимость')
kpgz_code = search_expander.text_input('Введите КПГЗ код')
if max_price < min_price:
min_price, max_price = max_price, min_price
if search_request:
cnt = 0
pos_pop_requests = []
for val in (names_counts.index):
if val.startswith(search_request):
if val != search_request:
pos_pop_requests.append(val)
cnt += 1
if cnt == 3:
break
st.markdown(f"""
Автодополнение: <br/> {r"<br/>".join(pos_pop_requests)}
""", unsafe_allow_html=True)
search_request = utils.clear_text(search_request, punctuation, morph)
additional_info = utils.clear_text(additional_info, punctuation, morph)
search_results = utils.get_search_results(search_request=search_request, additional_info=additional_info, data=data,
bert_cls = bert_cls, index=index, tokenizer=tokenizer, kpgz_dict = kpgz_dict)
bert_cls = bert_cls, embeddings=embeddings, index=index, tokenizer=tokenizer, kpgz_dict = kpgz_dict,
min_price=min_price, max_price=max_price, kpgz_code=kpgz_code)
figs_expander = st.expander("Анализ рынка")
fig = get_fig_price(search_results['price'])
if fig is not None:
figs_expander.pyplot(fig)
gb_main = st_agg.GridOptionsBuilder.from_dataframe(search_results)
gb_main.configure_default_column(
groupable=True, value=True, enableRowGroup=True, editable=False)
Expand All @@ -126,8 +144,21 @@ def main():
)
selected_main_row = grid_main_response['selected_rows']
if len(selected_main_row) != 0:
near_id = -1
row_index = selected_main_row[0]['_selectedRowNodeInfo']['nodeRowIndex']

if selected_main_row[0]['ID СТЕ'] in rec_dict:
near_id = selected_main_row[0]['ID СТЕ']
elif (row_index - 1) >= 0:
if search_results.iloc[row_index - 1]['ID СТЕ'] in rec_dict: # type: ignore
near_id = search_results.iloc[row_index - 1]['ID СТЕ'] # type: ignore
elif (row_index + 1) < len(search_results):
if search_results.iloc[row_index + 1]['ID СТЕ'] in rec_dict: # type: ignore
near_id = search_results.iloc[row_index + 1]['ID СТЕ'] # type: ignore

recommend_results = utils.get_search_results(search_request=selected_main_row[0]['Название СТЕ'].strip().lower(), additional_info=additional_info, data=data,
bert_cls = bert_cls, index=index, tokenizer=tokenizer, kpgz_dict = kpgz_dict, rec = True, rec_dict=rec_dict, item_index=selected_main_row[0]['ID СТЕ'])
bert_cls = bert_cls, embeddings=embeddings, index=index, tokenizer=tokenizer,
kpgz_dict = kpgz_dict, rec = True, rec_dict=rec_dict, item_id=near_id)
gb_rec = st_agg.GridOptionsBuilder.from_dataframe(search_results)
gb_rec.configure_default_column(
groupable=True, value=True, enableRowGroup=True, editable=False)
Expand Down
37 changes: 25 additions & 12 deletions streamlit-app/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import time
import json
import nltk
import scipy


def clear_text(text, punctuation, morph):
Expand Down Expand Up @@ -67,18 +68,21 @@ def string_dist(str1, str2):
scaling=0.2)


def prepare_data(data: pd.DataFrame) -> pd.DataFrame:

def prepare_data(data: pd.DataFrame, min_price=0.0, max_price=float('inf'), kpgz_code = "") -> pd.DataFrame:
if max_price != 0.0:
data = data.loc[(data['price'] >= min_price) &
(data['price'] <= max_price)]
if kpgz_code != "":
data = data.loc[data['Код КПГЗ'].str.startswith(kpgz_code)]
return data

def filter_for_rec(kpgz, kpgz_table):
return kpgz != kpgz_table

def get_search_results(search_request: str, additional_info: str, data: pd.DataFrame, bert_cls: BertCLS, index: faiss.IndexFlatIP, tokenizer, kpgz_dict: dict, rec = False, rec_dict = dict(), item_index = None) -> pd.DataFrame:
def get_search_results(search_request: str, additional_info: str, data: pd.DataFrame, bert_cls: BertCLS, embeddings, index: faiss.IndexFlatIP, tokenizer, kpgz_dict: dict, rec = False, rec_dict = dict(), item_id = None, min_price=0.0, max_price=float('inf'), kpgz_code="") -> pd.DataFrame:
new_data = deepcopy(data)
new_data = prepare_data(data=new_data)

search_request = search_request
search_request += " [SEP] " + additional_info
new_data = prepare_data(data=new_data, min_price=min_price, max_price=max_price, kpgz_code=kpgz_code)
kpgz_table = get_kpgz(bert_cls, tokenizer, search_request, kpgz_dict)
if rec:
new_data['sub_kpgz'] = new_data['Код КПГЗ'].apply(lambda x: filter_for_rec(x, kpgz_table))
Expand All @@ -93,19 +97,28 @@ def get_search_results(search_request: str, additional_info: str, data: pd.DataF

faiss_results = new_data.loc[selected].reset_index(
drop=True) # type: ignore


faiss_results['additional_dist'] = faiss_results['Характеристики'].apply(lambda x: sum((tok in x) for tok in nltk.word_tokenize(additional_info, language="ru")))
faiss_results['cos_sim'] = list(torch.nn.functional.cosine_similarity(torch.from_numpy(embeddings[selected]), torch.from_numpy(xq)).cpu().detach().numpy())
faiss_results['string_dist'] = faiss_results['Название СТЕ'].apply(
lambda x: string_dist(x, search_request))

a = faiss_results['cos_sim'].to_list()
conf_int = scipy.stats.norm.interval(0.95, loc=np.mean(a), scale=scipy.stats.sem(a))[0]
conf_int = max(conf_int, 0.65)
faiss_results = faiss_results.loc[faiss_results['cos_sim'] >= conf_int]

if rec:
faiss_results = faiss_results[faiss_results['sub_kpgz'] == True]

if item_index in rec_dict:
history_rec_ids = rec_dict[item_index]
if item_id in rec_dict:
history_rec_ids = rec_dict[item_id]
history_rec_ids = sorted(history_rec_ids, key=lambda x: x[1], reverse=True)[:10]
history_rec_ids = set(i[0] for i in history_rec_ids)
history_rec_df = new_data[new_data['ID СТЕ'].isin(history_rec_ids)]
return pd.concat([history_rec_df, faiss_results.sort_values(by=['kpgz_sim', 'string_dist'], ascending=[False, False]).head(5)])

return pd.concat([history_rec_df, faiss_results.sort_values(by=['kpgz_sim', 'string_dist'], ascending=[False, False]).head(10)])

return faiss_results.sort_values(by=['kpgz_sim', 'string_dist'], ascending=[False, False]).head(5)
return faiss_results.sort_values(by=['kpgz_sim', 'string_dist'], ascending=[False, False]).head(10)

return faiss_results.sort_values(by=['kpgz_sim', 'string_dist'], ascending=[False, False]).head(5)
return faiss_results.sort_values(by=['additional_dist', 'kpgz_sim', 'string_dist'], ascending=[False, False, False]).head(10)

0 comments on commit 352f55f

Please sign in to comment.