forked from wyf3/llm_related
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
new file: table_extract/README.md new file: table_extract/chinese_cht.ttf new file: table_extract/imgs/2.jpg new file: table_extract/imgs/extract_2.png new file: table_extract/table2txt.ipynb
- Loading branch information
Showing
6 changed files
with
359 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
# table2txt | ||
|
||
## 介绍 | ||
支持图片或者pdf中的普通文本提取和表格中文本的提取(pdf需要先转成图片),并保持其结构化排版布局(尽量保持其结构,不完美) | ||
|
||
可参考如下示例: | ||
|
||
需要提取的图片: | ||
|
||
![需要提取的图片](./imgs/2.jpg "表格") | ||
|
||
提取之后: | ||
|
||
![提取之后的图片](./imgs/extract_2.png "表格") | ||
|
||
## 使用方法 | ||
|
||
1、下载模型 | ||
|
||
modelscope下载表格提取模型,并修改代码中相关路径 | ||
|
||
https://modelscope.cn/models/iic/cv_dla34_table-structure-recognition_cycle-centernet | ||
|
||
2、修改代码中需要提取的图片路径 | ||
|
||
## 注意 | ||
|
||
有时会出现调整完坐标之后的效果图片无法绘制的情况,可忽略,文字可正常提取 | ||
|
||
代码中有不完善的地方,可根据需要自行修改 |
Binary file not shown.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,329 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from modelscope.pipelines import pipeline\n", | ||
"from modelscope.utils.constant import Tasks\n", | ||
"table_recognition = pipeline(Tasks.table_recognition, model='cv_dla34_table-structure-recognition_cycle-centernet模型的路径')\n", | ||
"result = table_recognition('你需要提取的图片路径')" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from paddleocr import PaddleOCR\n", | ||
"ocr = PaddleOCR(use_gpu=True, lang='ch')\n", | ||
"image_path = '你需要提取的图片路径'\n", | ||
"res = ocr.ocr(image_path, cls=True)\n", | ||
"print(res)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from PIL import Image, ImageDraw, ImageFont\n", | ||
"import textwrap\n", | ||
"import numpy as np\n", | ||
"def draw_ocr_boxes(image_path, boxes, texts):\n", | ||
" \n", | ||
" img = Image.open(image_path)\n", | ||
" img = Image.new('RGB', img.size, (255, 255, 255))\n", | ||
" \n", | ||
" draw = ImageDraw.Draw(img)\n", | ||
" font = ImageFont.truetype(\"./chinese_cht.ttf\", size=15) \n", | ||
" \n", | ||
"\n", | ||
" # 遍历每个文本框和对应的文本\n", | ||
" for box, text in zip(boxes, texts):\n", | ||
" draw.rectangle(box, outline='red', width=2)\n", | ||
" x, y = box[:2]\n", | ||
" draw.text((x,y), text, font=font, fill='black')\n", | ||
" \n", | ||
" img.save('image_with_boxes_and_text.jpg')\n", | ||
"\n", | ||
"# 示例文本框坐标和对应的文字\n", | ||
"boxes = [(*i[0][0],*i[0][2]) for i in res[0]]\n", | ||
"texts = [i[1][0] for i in res[0]]\n", | ||
"draw_ocr_boxes('你需要提取的图片路径', boxes, texts)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"def is_inside_text(cell, text):\n", | ||
" \"\"\"检查文字是否完全在单元格内\"\"\"\n", | ||
" cx1, cy1, cx2, cy2 = cell\n", | ||
" tx1, ty1, tx2, ty2 = text['coords']\n", | ||
" return cx1 <= tx1 and cy1 <= ty1 and cx2 >= tx2 and cy2 >= ty2\n", | ||
"def calculate_iou(cell, text):\n", | ||
" \"\"\"\n", | ||
" 计算两个矩形框的交并比(IoU)。\n", | ||
" \n", | ||
" :param cell: 单元格的坐标 (x1, y1, x2, y2)\n", | ||
" :param text: 文本框的坐标 (x1, y1, x2, y2)\n", | ||
" :return: 交并比(IoU)\n", | ||
" \"\"\"\n", | ||
" # 计算交集的左上角和右下角坐标\n", | ||
" intersection_x1 = max(cell[0], text['coords'][0])\n", | ||
" intersection_y1 = max(cell[1], text['coords'][1])\n", | ||
" intersection_x2 = min(cell[2], text['coords'][2])\n", | ||
" intersection_y2 = min(cell[3], text['coords'][3])\n", | ||
"\n", | ||
" # 如果没有交集,返回 0\n", | ||
" if intersection_x1 >= intersection_x2 or intersection_y1 >= intersection_y2:\n", | ||
" return 0.0\n", | ||
"\n", | ||
" # 计算交集的面积\n", | ||
" intersection_area = (intersection_x2 - intersection_x1) * (intersection_y2 - intersection_y1)\n", | ||
"\n", | ||
" # 计算并集的面积\n", | ||
" area_box1 = (cell[2] - cell[0]) * (cell[3] - cell[1])\n", | ||
" area_box2 = (text['coords'][2] - text['coords'][0]) * (text['coords'][3] - text['coords'][1])\n", | ||
" union_area = area_box1 + area_box2 - intersection_area\n", | ||
"\n", | ||
" # 计算 IoU\n", | ||
" iou = intersection_area / union_area\n", | ||
"\n", | ||
" return iou\n", | ||
"def calculate_iot(cell, text):\n", | ||
" \"\"\"\n", | ||
" 计算两个矩形框的交集面积和文本框面积的比值(IoT)。\n", | ||
" \n", | ||
" :param cell: 单元格的坐标 (x1, y1, x2, y2)\n", | ||
" :param text: 文本框的坐标 (x1, y1, x2, y2)\n", | ||
" :return: IoT\n", | ||
" \"\"\"\n", | ||
" # 计算交集的左上角和右下角坐标\n", | ||
" intersection_x1 = max(cell[0], text['coords'][0])\n", | ||
" intersection_y1 = max(cell[1], text['coords'][1])\n", | ||
" intersection_x2 = min(cell[2], text['coords'][2])\n", | ||
" intersection_y2 = min(cell[3], text['coords'][3])\n", | ||
"\n", | ||
" # 如果没有交集,返回 0\n", | ||
" if intersection_x1 >= intersection_x2 or intersection_y1 >= intersection_y2:\n", | ||
" return 0.0\n", | ||
" # 计算交集的面积\n", | ||
" intersection_area = (intersection_x2 - intersection_x1) * (intersection_y2 - intersection_y1)\n", | ||
"\n", | ||
" text_area = (text['coords'][2] - text['coords'][0]) * (text['coords'][3] - text['coords'][1])\n", | ||
" # 计算 IoT\n", | ||
" iot = intersection_area / text_area\n", | ||
" return iot\n", | ||
"\n", | ||
"def merge_text_into_cells(cell_coords, ocr_results):\n", | ||
" \"\"\"将文字合并到单元格\"\"\"\n", | ||
" # 创建一个字典,键是单元格坐标,值是属于该单元格的文字列表\n", | ||
" cell_text_dict = {cell: [] for cell in cell_coords}\n", | ||
" noncell_text_dict = {}\n", | ||
" \n", | ||
" # 遍历 OCR 结果,将文字分配给正确的单元格\n", | ||
" for cell in cell_coords:\n", | ||
" for result in ocr_results:\n", | ||
" if calculate_iot(cell, result)>0.5:\n", | ||
" cell_text_dict[cell].append(result['text'])\n", | ||
" \n", | ||
" for result in ocr_results:\n", | ||
" if all(calculate_iot(cell, result)<0.1 for cell in cell_coords):\n", | ||
" noncell_text_dict[result['coords']] = result['text']\n", | ||
"\n", | ||
" merged_text = {}\n", | ||
" for cell, texts in cell_text_dict.items():\n", | ||
" merged_text[cell] = ''.join(texts).strip()\n", | ||
" for coords, text in noncell_text_dict.items():\n", | ||
" merged_text[coords] = ''.join(text).strip()\n", | ||
" \n", | ||
" return merged_text\n", | ||
"\n", | ||
"cell_coords = [tuple([*i[:2],*i[4:6]]) for i in result['polygons']]\n", | ||
"ocr_results = [\n", | ||
" {'text': i[1][0], 'coords': tuple([*i[0][0],*i[0][2]])} for i in res[0]]\n", | ||
"merged_text = merge_text_into_cells(cell_coords, ocr_results)\n", | ||
"print(merged_text)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from PIL import Image, ImageDraw, ImageFont\n", | ||
"import textwrap\n", | ||
"import numpy as np\n", | ||
"def draw_text_boxes(image_path, boxes, texts):\n", | ||
" # 加载图像\n", | ||
" img = Image.open(image_path)\n", | ||
" img = Image.new('RGB', img.size, (255, 255, 255))\n", | ||
" # 创建一个 ImageDraw 对象\n", | ||
" draw = ImageDraw.Draw(img)\n", | ||
" \n", | ||
" # 设置字体\n", | ||
" font = ImageFont.truetype(\"./chinese_cht.ttf\", size=15) # 选择合适的字体和大小\n", | ||
" \n", | ||
"\n", | ||
" # 遍历每个文本框和对应的文本\n", | ||
" for box, text in zip(boxes, texts):\n", | ||
" # 绘制文本框\n", | ||
" draw.rectangle(box, outline='red', width=2)\n", | ||
" \n", | ||
" \n", | ||
" text_len = draw.textbbox(xy=box[:2], text=text, font=font)\n", | ||
" \n", | ||
" if (text_len[2]-text_len[0]) > (box[2] - box[0]):\n", | ||
" # 如果文本长度大于文本框宽度,则将文本换行\n", | ||
" text = '\\n'.join(textwrap.wrap(text, width=int(np.ceil((len(text) / np.ceil((text_len[2]-text_len[0]) / (box[2] - box[0])))))))\n", | ||
" else:\n", | ||
" # 否则直接绘制文本\n", | ||
" text = text\n", | ||
" x, y = box[:2]\n", | ||
" \n", | ||
" # 在文本框内居中文本\n", | ||
" draw.text((x,y), text, font=font, fill='black')\n", | ||
" \n", | ||
" # 保存带有文本框和文字的图像\n", | ||
" img.save('你保存的图片路径')\n", | ||
"\n", | ||
"# 示例文本框坐标和对应的文字\n", | ||
"boxes = list(merged_text.keys())\n", | ||
"texts = list(merged_text.values())\n", | ||
"draw_text_boxes('你需要提取的图片路径', boxes, texts)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"\n", | ||
"\n", | ||
"def adjust_coordinates(merged_text, image_path):\n", | ||
" \n", | ||
" image = Image.open(image_path)\n", | ||
" width, height = image.size\n", | ||
" threshold = height / 100\n", | ||
" groups = {}\n", | ||
" \n", | ||
" for coordinates, text in merged_text.items():\n", | ||
" # 查找与当前 y 坐标相差不超过 threshold 的分组\n", | ||
" found_group = False\n", | ||
" for group_y in groups.keys():\n", | ||
" if abs(coordinates[1] - group_y) <= threshold:\n", | ||
" groups[group_y].append((coordinates,text))\n", | ||
" found_group = True\n", | ||
" break\n", | ||
"\n", | ||
" # 如果没有找到合适的分组,则创建一个新的分组\n", | ||
" if not found_group:\n", | ||
" groups[coordinates[1]] = [(coordinates,text)]\n", | ||
" \n", | ||
" # 计算每个分组的 y 坐标的平均值,并更新坐标列表\n", | ||
" adjusted_coordinates = {}\n", | ||
" for group_y, group_coords in groups.items():\n", | ||
" avg_y = sum(coord[0][1] for coord in group_coords) / len(group_coords)\n", | ||
" for i in group_coords:\n", | ||
" adjusted_coordinates[(i[0][0], avg_y, i[0][2], i[0][3])] = i[1]\n", | ||
" \n", | ||
"\n", | ||
" return adjusted_coordinates\n", | ||
"\n", | ||
"# 调用函数处理坐标\n", | ||
"adjusted_merged_text = adjust_coordinates(merged_text, '你需要提取的图片路径')\n", | ||
"\n", | ||
"# 打印结果\n", | ||
"print(\"原始坐标:\", merged_text)\n", | ||
"print(\"调整后的坐标:\", adjusted_merged_text)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from PIL import Image, ImageDraw, ImageFont\n", | ||
"import textwrap\n", | ||
"import numpy as np\n", | ||
"def draw_text_boxes(image_path, boxes, texts):\n", | ||
" \n", | ||
" img = Image.open(image_path)\n", | ||
" img = Image.new('RGB', img.size, (255, 255, 255))\n", | ||
" draw = ImageDraw.Draw(img)\n", | ||
" font = ImageFont.truetype(\"./chinese_cht.ttf\", size=15) # 选择合适的字体和大小\n", | ||
" for box, text in zip(boxes, texts):\n", | ||
" \n", | ||
" draw.rectangle(box, outline='red', width=2)\n", | ||
" \n", | ||
" \n", | ||
" text_len = draw.textbbox(xy=box[:2], text=text, font=font)\n", | ||
" \n", | ||
" if (text_len[2]-text_len[0]) > (box[2] - box[0]):\n", | ||
" # 如果文本长度大于文本框宽度,则将文本换行\n", | ||
" text = '\\n'.join(textwrap.wrap(text, width=int(np.ceil(len(text) / np.ceil((text_len[2]-text_len[0]) / (box[2] - box[0]))))))\n", | ||
" else:\n", | ||
" # 否则直接绘制文本\n", | ||
" text = text\n", | ||
" x, y = box[:2]\n", | ||
" \n", | ||
" draw.text((x,y), text, font=font, fill='black')\n", | ||
" img.save('你需要保存的图片路径')\n", | ||
"\n", | ||
"boxes = list(adjusted_merged_text.keys())\n", | ||
"texts = list(adjusted_merged_text.values())\n", | ||
"draw_text_boxes('你需要提取的图片路径', boxes, texts)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"#输出最终的文本\n", | ||
"adjusted_merged_text_sorted = sorted(adjusted_merged_text.items(), key=lambda x: (x[0][1], x[0][0]))\n", | ||
"adjusted_merged_text_sorted_group = {}\n", | ||
"for coordinates, text in adjusted_merged_text_sorted:\n", | ||
" if coordinates[1] not in adjusted_merged_text_sorted_group:\n", | ||
" adjusted_merged_text_sorted_group[coordinates[1]] = [text]\n", | ||
" else:\n", | ||
" adjusted_merged_text_sorted_group[coordinates[1]].append(text)\n", | ||
"for text_list in adjusted_merged_text_sorted_group.values():\n", | ||
" print(' | '.join(text_list))\n" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "wyf", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.10.14" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |