Skip to content

Commit

Permalink
Added some shit
Browse files Browse the repository at this point in the history
  • Loading branch information
DeevsDeevs committed Oct 22, 2022
1 parent f1e3e15 commit 68eb1fc
Show file tree
Hide file tree
Showing 6 changed files with 393 additions and 0 deletions.
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
streamlit-app/BertCLS_epoch_5.pth
streamlit-app/median_prices.feather
streamlit-app/preprocess_all_columns_lem.feather
streamlit-app/rec_data.parquet
.DS_Store
155 changes: 155 additions & 0 deletions streamlit-app/app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
from typing import final
import pandas as pd
import numpy as np
from yaml import load
import utils as utils
import torch
import faiss
import onnxruntime as rt
import matplotlib.pyplot as plt
import json

import streamlit as st
import st_aggrid as st_agg

import pymorphy2
import nltk
import string


@st.cache(allow_output_mutation=True)
def load_all():
model, tokenizer = utils.get_model_tokenizer("cointegrated/LaBSE-en-ru")
model.eval()
device = torch.device('cpu')

bert_cls = utils.BertCLS(model, 5307)
bert_cls.load_state_dict(torch.load(
"BertCLS_epoch_5.pth", map_location=device))
bert_cls = bert_cls.to(device)

final_df = pd.read_feather(
"preprocess_all_columns_lem.feather")
price_df = pd.read_feather(
"median_prices.feather").rename(columns={'median': 'price'})
data = final_df[final_df.columns[:-768]].reset_index(drop=True)
data = pd.merge(data, price_df.rename(columns={'id': 'ID СТЕ'}), on='ID СТЕ', how='left')

rec_dict = dict(pd.read_parquet('rec_data.parquet').values)

embeddings = final_df[final_df.columns[-768:]].values.astype(np.float32)
embeddings = np.ascontiguousarray(embeddings)

faiss.normalize_L2(embeddings)
d = 768
index = faiss.IndexFlatIP(d)
index.add(embeddings) # type: ignore

kpgz_dict = dict(final_df[['target', 'Код КПГЗ']].values)

punctuation = set(string.punctuation)
morph = pymorphy2.MorphAnalyzer()

return data, tokenizer, bert_cls, index, kpgz_dict, rec_dict, punctuation, morph

# def get_fig_price(series: pd.Series):
# try:
# data = {"11-20": series[(series.quantile(0.11) <= series) & (series <= series.quantile(0.2))].mean(),
# "21-30": series[(series.quantile(0.21) <= series) & (series <= series.quantile(0.3))].mean(),
# "31-40": series[(series.quantile(0.31) <= series) & (series <= series.quantile(0.4))].mean(),
# "41-50": series[(series.quantile(0.41) <= series) & (series <= series.quantile(0.5))].mean(),
# "51-60": series[(series.quantile(0.51) <= series) & (series <= series.quantile(0.6))].mean(),
# "61-70": series[(series.quantile(0.61) <= series) & (series <= series.quantile(0.7))].mean(),
# "71-80": series[(series.quantile(0.71) <= series) & (series <= series.quantile(0.8))].mean(),
# "81-90": series[(series.quantile(0.81) <= series) & (series <= series.quantile(0.9))].mean(), }
# courses = list(data.keys())
# values = list(data.values())
# plt.style.use('dark_background') # type: ignore
# fig, ax = plt.subplots(figsize=(10, 5))
# ax.bar(courses, values, width=0.4)
# ax.bar_label(ax.containers[0]) # type: ignore
# plt.xlabel("Quantile Group")
# plt.ylabel("Price")
# plt.title("Распределение средней цены по квантилям")
# plt.axhline(y=(data['11-20'] + data['81-90']) / 2,
# linewidth=3, color='red', label="Средняя цена", ls="--")
# plt.axhline(y=series.median(), linewidth=3, color='pink',
# label="Медианная цена", ls="--")
# plt.legend()
# return fig
# except:
# return None

# def get_fig_tops(series: pd.Series):
# try:
# plt.style.use('dark_background') # type: ignore
# fig, ax = plt.subplots(figsize=(7, 3))
# counts = series.value_counts()[:5]
# counts = counts / len(series) * 100
# counts.plot(ax = ax, kind = 'barh', xlabel = 'Процент предложений выставленных этим ИНН от общего числа')
# ax.bar_label(ax.containers[0]) # type: ignore
# return fig
# except:
# return None


def main():
st.set_page_config(page_title='Tender Search Engine')
st.markdown("""
# Tender Search Engine
""")
data, tokenizer, bert_cls, index, kpgz_dict, rec_dict, punctuation, morph = load_all()
search_request = st.text_input('Введите слова для поиска:').lower().strip()
search_expander = st.expander('Дополнительные настройки')
additional_info = search_expander.text_input('Дополнительные параметры').lower().strip()
if search_request:
search_request = utils.clear_text(search_request, punctuation, morph)
search_results = utils.get_search_results(search_request=search_request, additional_info=additional_info, data=data,
bert_cls = bert_cls, index=index, tokenizer=tokenizer, kpgz_dict = kpgz_dict)
gb_main = st_agg.GridOptionsBuilder.from_dataframe(search_results)
gb_main.configure_default_column(
groupable=True, value=True, enableRowGroup=True, editable=False)
gb_main.configure_side_bar()
gb_main.configure_selection('single', use_checkbox=True, )
gb_main.configure_pagination(
paginationPageSize=10, paginationAutoPageSize=False)
gb_main.configure_grid_options(domLayout='normal')
grid_main_options = gb_main.build()
grid_main_response = st_agg.AgGrid(
search_results,
gridOptions=grid_main_options,
width='100%',
update_mode=st_agg.GridUpdateMode.MODEL_CHANGED,
data_return_mode=st_agg.DataReturnMode.AS_INPUT,
key='main',
reload_data=True,
)
selected_main_row = grid_main_response['selected_rows']
if len(selected_main_row) != 0:
recommend_results = utils.get_search_results(search_request=selected_main_row[0]['Название СТЕ'].strip().lower(), additional_info=additional_info, data=data,
bert_cls = bert_cls, index=index, tokenizer=tokenizer, kpgz_dict = kpgz_dict, rec = True, rec_dict=rec_dict, item_index=selected_main_row[0]['ID СТЕ'])
gb_rec = st_agg.GridOptionsBuilder.from_dataframe(search_results)
gb_rec.configure_default_column(
groupable=True, value=True, enableRowGroup=True, editable=False)
gb_rec.configure_side_bar()
gb_rec.configure_selection('single', use_checkbox=False, )
gb_rec.configure_pagination(
paginationPageSize=10, paginationAutoPageSize=False)
gb_rec.configure_grid_options(domLayout='normal')
grid_rec_options = gb_rec.build()
st.markdown("""
### Сопутствующие товары
""")
grid_rec_response = st_agg.AgGrid(
recommend_results,
gridOptions=grid_rec_options,
width='100%',
update_mode=st_agg.GridUpdateMode.MODEL_CHANGED,
data_return_mode=st_agg.DataReturnMode.AS_INPUT,
key='rec',
reload_data=True,
)


if __name__ == '__main__':
main()
30 changes: 30 additions & 0 deletions streamlit-app/bert_config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
{
"_name_or_path": "cointegrated/LaBSE-en-ru",
"architectures": [
"BertForPreTraining"
],
"attention_probs_dropout_prob": 0.1,
"directionality": "bidi",
"gradient_checkpointing": false,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"hidden_size": 768,
"initializer_range": 0.02,
"intermediate_size": 3072,
"layer_norm_eps": 1e-12,
"max_position_embeddings": 512,
"model_type": "bert",
"num_attention_heads": 12,
"num_hidden_layers": 12,
"pad_token_id": 0,
"pooler_fc_size": 768,
"pooler_num_attention_heads": 12,
"pooler_num_fc_layers": 3,
"pooler_size_per_head": 128,
"pooler_type": "first_token_transform",
"position_embedding_type": "absolute",
"transformers_version": "4.5.1",
"type_vocab_size": 2,
"use_cache": true,
"vocab_size": 55083
}
86 changes: 86 additions & 0 deletions streamlit-app/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
altair==4.2.0
attrs==22.1.0
autopep8==1.7.0
blinker==1.5
cachetools==5.2.0
certifi==2022.9.24
charset-normalizer==2.1.1
click==8.1.3
coloredlogs==15.0.1
commonmark==0.9.1
contourpy==1.0.5
cycler==0.11.0
DAWG-Python==0.7.2
decorator==5.1.1
docopt==0.6.2
entrypoints==0.4
faiss-cpu==1.7.2
filelock==3.8.0
flatbuffers==22.9.24
fonttools==4.37.4
gitdb==4.0.9
GitPython==3.1.27
graphviz==0.20.1
huggingface-hub==0.10.0
humanfriendly==10.0
idna==3.4
importlib-metadata==4.13.0
importlib-resources==5.9.0
Jinja2==3.1.2
joblib==1.2.0
jsonschema==4.16.0
kiwisolver==1.4.4
MarkupSafe==2.1.1
matplotlib==3.6.0
mpmath==1.2.1
nltk==3.7
numpy==1.23.3
onnxruntime==1.12.1
packaging==21.3
pandas==1.5.0
Pillow==9.2.0
pkgutil_resolve_name==1.3.10
plotly==5.10.0
protobuf==3.20.3
pyarrow==9.0.0
pybind11==2.10.0
pycodestyle==2.9.1
pydeck==0.8.0b3
Pygments==2.13.0
pyjarowinkler==1.8
pymorphy2==0.9.1
pymorphy2-dicts-ru==2.4.417127.4579844
Pympler==1.0.1
pyngrok==5.1.0
pyparsing==3.0.9
pyrsistent==0.18.1
python-dateutil==2.8.2
python-decouple==3.6
pytz==2022.2.1
pytz-deprecation-shim==0.1.0.post0
PyYAML==6.0
regex==2022.9.13
requests==2.28.1
rich==12.5.1
scipy==1.9.1
semver==2.13.0
six==1.16.0
smart-open==6.2.0
smmap==5.0.0
streamlit==1.13.0
streamlit-aggrid==0.3.3
sympy==1.11.1
tenacity==8.1.0
tokenizers==0.12.1
toml==0.10.2
toolz==0.12.0
torch==1.12.1
tornado==6.2
tqdm==4.64.1
transformers==4.23.1
typing_extensions==4.3.0
tzdata==2022.4
tzlocal==4.2
urllib3==1.26.12
validators==0.20.0
zipp==3.8.1
6 changes: 6 additions & 0 deletions streamlit-app/server_ngrok.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from pyngrok import ngrok
import os
os.system('ngrok authtoken 2ARsKtGKj47h7y4uXMQPrIeOinS_47Mkh6jkzNjFEJWuZYNEX')
url = ngrok.connect(port = 8501)
print(url)
input()
Loading

0 comments on commit 68eb1fc

Please sign in to comment.