-
Notifications
You must be signed in to change notification settings - Fork 54
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
1 changed file
with
325 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,325 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"vscode": { | ||
"languageId": "plaintext" | ||
} | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"from docx import Document\n", | ||
"# 读取word中表格,处理成表头加行的形式\n", | ||
"def read_table_from_word(file_path):\n", | ||
" all_rows = []\n", | ||
" doc = Document(file_path)\n", | ||
" for i, table in enumerate(doc.tables):\n", | ||
" for row in table.rows:\n", | ||
" cells_text = [cell.text.replace('\\n', '') for cell in row.cells]\n", | ||
" all_rows.append(cells_text)\n", | ||
" new_all_rows=[]\n", | ||
" i = 0\n", | ||
" # 合并跨页的表格行\n", | ||
" while i<len(all_rows):\n", | ||
" if i==len(all_rows)-1:\n", | ||
" new_all_rows.append(all_rows[i])\n", | ||
" break\n", | ||
" if all_rows[i+1][0]!='':\n", | ||
" new_all_rows.append(all_rows[i])\n", | ||
" i+=1\n", | ||
" else:\n", | ||
" new_all_rows.append([a+b for a, b in zip(all_rows[i],all_rows[i+1])])\n", | ||
" i+=2\n", | ||
" \n", | ||
" final_rows=[]\n", | ||
" flag = []\n", | ||
" # 合并表头\n", | ||
" for row in new_all_rows:\n", | ||
" \n", | ||
" if row[0]=='序号':\n", | ||
" flag = row\n", | ||
" merge_now = '|'.join(flag)+'\\n' + '|'.join(row)\n", | ||
" final_rows.append(merge_now)\n", | ||
"\n", | ||
" return final_rows" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"vscode": { | ||
"languageId": "plaintext" | ||
} | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"from langchain_core.documents import Document as doc\n", | ||
"import os\n", | ||
"file_path = '/home/user/wyf/rag_table.docx'\n", | ||
"\n", | ||
"def get_docs(file_path):\n", | ||
" text_chunks = read_table_from_word(file_path)\n", | ||
" file_name = os.path.basename(file_path)\n", | ||
" # 加入元数据\n", | ||
" docs = [doc(page_content=chunk, metadata={'source':file_name}) for chunk in text_chunks]\n", | ||
" return docs" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"vscode": { | ||
"languageId": "plaintext" | ||
} | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"splits = get_docs(file_path)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"vscode": { | ||
"languageId": "plaintext" | ||
} | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"print(splits[1].page_content)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"vscode": { | ||
"languageId": "plaintext" | ||
} | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"from langchain_community.vectorstores import FAISS\n", | ||
"from langchain_huggingface import HuggingFaceEmbeddings\n", | ||
"from langchain_openai import ChatOpenAI\n", | ||
"embeddings = HuggingFaceEmbeddings(model_name='/home/user/wyf/fastchat/bge-large-zh-v1.5', model_kwargs = {'device': 'cuda:1'})\n", | ||
"llm = ChatOpenAI(temperature=0, model='/home/user/Downloads/Qwen2-7B-Instruct/', base_url='http://10.250.2.23:8600/v1', api_key='ww')\n", | ||
"vectorstore = FAISS.from_documents(splits, embeddings)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"vscode": { | ||
"languageId": "plaintext" | ||
} | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"retriever = vectorstore.as_retriever(search_kwargs={'k':3})" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"vscode": { | ||
"languageId": "plaintext" | ||
} | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"recall_text = retriever.get_relevant_documents('斑小将美白隔离防晒乳的检验结果')\n", | ||
"text = ''\n", | ||
"for i in recall_text:\n", | ||
" text += i.page_content\n", | ||
" text += '\\n\\n'" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"vscode": { | ||
"languageId": "plaintext" | ||
} | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"print(text)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"vscode": { | ||
"languageId": "plaintext" | ||
} | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"prompt = '''任务目标:根据给定的知识,回答用户提出的问题。\n", | ||
"任务要求:\n", | ||
"1、不得脱离给定的知识进行回答。\n", | ||
"2、如果给定的知识不包含问题的答案,请回答我不知道。\n", | ||
"\n", | ||
"知识:\n", | ||
"{}\n", | ||
"\n", | ||
"用户问题:\n", | ||
"{}\n", | ||
"'''" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"vscode": { | ||
"languageId": "plaintext" | ||
} | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"query = '斑小将美白隔离防晒乳的检验结果'\n", | ||
"print(llm.invoke(prompt.format(text,query)).content)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"vscode": { | ||
"languageId": "plaintext" | ||
} | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"import mammoth\n", | ||
"with open(\"/home/user/wyf/rag_test.docx\", \"rb\") as docx_file:\n", | ||
" result = mammoth.convert_to_html(docx_file)\n", | ||
" html = result.value \n", | ||
" print(html)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"vscode": { | ||
"languageId": "plaintext" | ||
} | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"with open('tmp.html', 'w', encoding='utf-8') as f:\n", | ||
" f.write(html)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"vscode": { | ||
"languageId": "plaintext" | ||
} | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"from langchain_text_splitters import HTMLSectionSplitter\n", | ||
"from langchain_text_splitters import RecursiveCharacterTextSplitter\n", | ||
"headers_to_split_on = [\n", | ||
" (\"h1\", \"Header 1\"),\n", | ||
" (\"h2\", \"Header 2\"),\n", | ||
" (\"h3\", \"Header 3\"),\n", | ||
"]\n", | ||
"\n", | ||
"html_splitter = HTMLSectionSplitter(headers_to_split_on)\n", | ||
"\n", | ||
"html_header_splits = html_splitter.split_text(html)\n", | ||
"for i in html_header_splits:\n", | ||
" print(i)\n", | ||
" print(\"=======================================================\")\n", | ||
"# chunk_size = 500\n", | ||
"# chunk_overlap = 50\n", | ||
"# text_splitter = RecursiveCharacterTextSplitter(\n", | ||
"# chunk_size=chunk_size, chunk_overlap=chunk_overlap\n", | ||
"# )\n", | ||
"\n", | ||
"# # Split\n", | ||
"# splits = text_splitter.split_documents(html_header_splits)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"vscode": { | ||
"languageId": "plaintext" | ||
} | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"vectorstore = FAISS.from_documents(html_header_splits, embeddings)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"vscode": { | ||
"languageId": "plaintext" | ||
} | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"retriever = vectorstore.as_retriever(search_kwargs={'k':1})" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"vscode": { | ||
"languageId": "plaintext" | ||
} | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"recall_text = retriever.get_relevant_documents('国家药监局什么时候发布了《关于更新化妆品禁用原料目录的公告》')\n", | ||
"text = ''\n", | ||
"for i in recall_text:\n", | ||
" text += i.page_content\n", | ||
" text += '\\n\\n'" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"vscode": { | ||
"languageId": "plaintext" | ||
} | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"query = '国家药监局什么时候发布了《关于更新化妆品禁用原料目录的公告》'\n", | ||
"print(llm.invoke(prompt.format(text,query)).content)" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"language_info": { | ||
"name": "python" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |