Skip to content

Commit

Permalink
new file: table_rag.ipynb
Browse files Browse the repository at this point in the history
  • Loading branch information
wyf3 committed Oct 28, 2024
1 parent 840bdb7 commit 5891f9d
Showing 1 changed file with 325 additions and 0 deletions.
325 changes: 325 additions & 0 deletions table_rag.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,325 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"vscode": {
"languageId": "plaintext"
}
},
"outputs": [],
"source": [
"from docx import Document\n",
"# 读取word中表格,处理成表头加行的形式\n",
"def read_table_from_word(file_path):\n",
" all_rows = []\n",
" doc = Document(file_path)\n",
" for i, table in enumerate(doc.tables):\n",
" for row in table.rows:\n",
" cells_text = [cell.text.replace('\\n', '') for cell in row.cells]\n",
" all_rows.append(cells_text)\n",
" new_all_rows=[]\n",
" i = 0\n",
" # 合并跨页的表格行\n",
" while i<len(all_rows):\n",
" if i==len(all_rows)-1:\n",
" new_all_rows.append(all_rows[i])\n",
" break\n",
" if all_rows[i+1][0]!='':\n",
" new_all_rows.append(all_rows[i])\n",
" i+=1\n",
" else:\n",
" new_all_rows.append([a+b for a, b in zip(all_rows[i],all_rows[i+1])])\n",
" i+=2\n",
" \n",
" final_rows=[]\n",
" flag = []\n",
" # 合并表头\n",
" for row in new_all_rows:\n",
" \n",
" if row[0]=='序号':\n",
" flag = row\n",
" merge_now = '|'.join(flag)+'\\n' + '|'.join(row)\n",
" final_rows.append(merge_now)\n",
"\n",
" return final_rows"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"vscode": {
"languageId": "plaintext"
}
},
"outputs": [],
"source": [
"from langchain_core.documents import Document as doc\n",
"import os\n",
"file_path = '/home/user/wyf/rag_table.docx'\n",
"\n",
"def get_docs(file_path):\n",
" text_chunks = read_table_from_word(file_path)\n",
" file_name = os.path.basename(file_path)\n",
" # 加入元数据\n",
" docs = [doc(page_content=chunk, metadata={'source':file_name}) for chunk in text_chunks]\n",
" return docs"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"vscode": {
"languageId": "plaintext"
}
},
"outputs": [],
"source": [
"splits = get_docs(file_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"vscode": {
"languageId": "plaintext"
}
},
"outputs": [],
"source": [
"print(splits[1].page_content)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"vscode": {
"languageId": "plaintext"
}
},
"outputs": [],
"source": [
"from langchain_community.vectorstores import FAISS\n",
"from langchain_huggingface import HuggingFaceEmbeddings\n",
"from langchain_openai import ChatOpenAI\n",
"embeddings = HuggingFaceEmbeddings(model_name='/home/user/wyf/fastchat/bge-large-zh-v1.5', model_kwargs = {'device': 'cuda:1'})\n",
"llm = ChatOpenAI(temperature=0, model='/home/user/Downloads/Qwen2-7B-Instruct/', base_url='http://10.250.2.23:8600/v1', api_key='ww')\n",
"vectorstore = FAISS.from_documents(splits, embeddings)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"vscode": {
"languageId": "plaintext"
}
},
"outputs": [],
"source": [
"retriever = vectorstore.as_retriever(search_kwargs={'k':3})"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"vscode": {
"languageId": "plaintext"
}
},
"outputs": [],
"source": [
"recall_text = retriever.get_relevant_documents('斑小将美白隔离防晒乳的检验结果')\n",
"text = ''\n",
"for i in recall_text:\n",
" text += i.page_content\n",
" text += '\\n\\n'"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"vscode": {
"languageId": "plaintext"
}
},
"outputs": [],
"source": [
"print(text)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"vscode": {
"languageId": "plaintext"
}
},
"outputs": [],
"source": [
"prompt = '''任务目标:根据给定的知识,回答用户提出的问题。\n",
"任务要求:\n",
"1、不得脱离给定的知识进行回答。\n",
"2、如果给定的知识不包含问题的答案,请回答我不知道。\n",
"\n",
"知识:\n",
"{}\n",
"\n",
"用户问题:\n",
"{}\n",
"'''"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"vscode": {
"languageId": "plaintext"
}
},
"outputs": [],
"source": [
"query = '斑小将美白隔离防晒乳的检验结果'\n",
"print(llm.invoke(prompt.format(text,query)).content)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"vscode": {
"languageId": "plaintext"
}
},
"outputs": [],
"source": [
"import mammoth\n",
"with open(\"/home/user/wyf/rag_test.docx\", \"rb\") as docx_file:\n",
" result = mammoth.convert_to_html(docx_file)\n",
" html = result.value \n",
" print(html)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"vscode": {
"languageId": "plaintext"
}
},
"outputs": [],
"source": [
"with open('tmp.html', 'w', encoding='utf-8') as f:\n",
" f.write(html)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"vscode": {
"languageId": "plaintext"
}
},
"outputs": [],
"source": [
"from langchain_text_splitters import HTMLSectionSplitter\n",
"from langchain_text_splitters import RecursiveCharacterTextSplitter\n",
"headers_to_split_on = [\n",
" (\"h1\", \"Header 1\"),\n",
" (\"h2\", \"Header 2\"),\n",
" (\"h3\", \"Header 3\"),\n",
"]\n",
"\n",
"html_splitter = HTMLSectionSplitter(headers_to_split_on)\n",
"\n",
"html_header_splits = html_splitter.split_text(html)\n",
"for i in html_header_splits:\n",
" print(i)\n",
" print(\"=======================================================\")\n",
"# chunk_size = 500\n",
"# chunk_overlap = 50\n",
"# text_splitter = RecursiveCharacterTextSplitter(\n",
"# chunk_size=chunk_size, chunk_overlap=chunk_overlap\n",
"# )\n",
"\n",
"# # Split\n",
"# splits = text_splitter.split_documents(html_header_splits)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"vscode": {
"languageId": "plaintext"
}
},
"outputs": [],
"source": [
"vectorstore = FAISS.from_documents(html_header_splits, embeddings)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"vscode": {
"languageId": "plaintext"
}
},
"outputs": [],
"source": [
"retriever = vectorstore.as_retriever(search_kwargs={'k':1})"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"vscode": {
"languageId": "plaintext"
}
},
"outputs": [],
"source": [
"recall_text = retriever.get_relevant_documents('国家药监局什么时候发布了《关于更新化妆品禁用原料目录的公告》')\n",
"text = ''\n",
"for i in recall_text:\n",
" text += i.page_content\n",
" text += '\\n\\n'"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"vscode": {
"languageId": "plaintext"
}
},
"outputs": [],
"source": [
"query = '国家药监局什么时候发布了《关于更新化妆品禁用原料目录的公告》'\n",
"print(llm.invoke(prompt.format(text,query)).content)"
]
}
],
"metadata": {
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

0 comments on commit 5891f9d

Please sign in to comment.