Skip to content

Commit

Permalink
deleted: README.md
Browse files Browse the repository at this point in the history
	new file:   table_extract/README.md
	new file:   table_extract/chinese_cht.ttf
	new file:   table_extract/imgs/2.jpg
	new file:   table_extract/imgs/extract_2.png
	new file:   table_extract/table2txt.ipynb
  • Loading branch information
wyf3 committed Aug 4, 2024
1 parent ed1b78a commit 9ef69f3
Show file tree
Hide file tree
Showing 6 changed files with 359 additions and 2 deletions.
2 changes: 0 additions & 2 deletions README.md

This file was deleted.

30 changes: 30 additions & 0 deletions table_extract/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# table2txt

## 介绍
支持图片或者pdf中的普通文本提取和表格中文本的提取(pdf需要先转成图片),并保持其结构化排版布局(尽量保持其结构,不完美)

可参考如下示例:

需要提取的图片:

![需要提取的图片](./imgs/2.jpg "表格")

提取之后:

![提取之后的图片](./imgs/extract_2.png "表格")

## 使用方法

1、下载模型

modelscope下载表格提取模型,并修改代码中相关路径

https://modelscope.cn/models/iic/cv_dla34_table-structure-recognition_cycle-centernet

2、修改代码中需要提取的图片路径

## 注意

有时会出现调整完坐标之后的效果图片无法绘制的情况,可忽略,文字可正常提取

代码中有不完善的地方,可根据需要自行修改
Binary file added table_extract/chinese_cht.ttf
Binary file not shown.
Binary file added table_extract/imgs/2.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added table_extract/imgs/extract_2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
329 changes: 329 additions & 0 deletions table_extract/table2txt.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,329 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from modelscope.pipelines import pipeline\n",
"from modelscope.utils.constant import Tasks\n",
"table_recognition = pipeline(Tasks.table_recognition, model='cv_dla34_table-structure-recognition_cycle-centernet模型的路径')\n",
"result = table_recognition('你需要提取的图片路径')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from paddleocr import PaddleOCR\n",
"ocr = PaddleOCR(use_gpu=True, lang='ch')\n",
"image_path = '你需要提取的图片路径'\n",
"res = ocr.ocr(image_path, cls=True)\n",
"print(res)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from PIL import Image, ImageDraw, ImageFont\n",
"import textwrap\n",
"import numpy as np\n",
"def draw_ocr_boxes(image_path, boxes, texts):\n",
" \n",
" img = Image.open(image_path)\n",
" img = Image.new('RGB', img.size, (255, 255, 255))\n",
" \n",
" draw = ImageDraw.Draw(img)\n",
" font = ImageFont.truetype(\"./chinese_cht.ttf\", size=15) \n",
" \n",
"\n",
" # 遍历每个文本框和对应的文本\n",
" for box, text in zip(boxes, texts):\n",
" draw.rectangle(box, outline='red', width=2)\n",
" x, y = box[:2]\n",
" draw.text((x,y), text, font=font, fill='black')\n",
" \n",
" img.save('image_with_boxes_and_text.jpg')\n",
"\n",
"# 示例文本框坐标和对应的文字\n",
"boxes = [(*i[0][0],*i[0][2]) for i in res[0]]\n",
"texts = [i[1][0] for i in res[0]]\n",
"draw_ocr_boxes('你需要提取的图片路径', boxes, texts)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def is_inside_text(cell, text):\n",
" \"\"\"检查文字是否完全在单元格内\"\"\"\n",
" cx1, cy1, cx2, cy2 = cell\n",
" tx1, ty1, tx2, ty2 = text['coords']\n",
" return cx1 <= tx1 and cy1 <= ty1 and cx2 >= tx2 and cy2 >= ty2\n",
"def calculate_iou(cell, text):\n",
" \"\"\"\n",
" 计算两个矩形框的交并比(IoU)。\n",
" \n",
" :param cell: 单元格的坐标 (x1, y1, x2, y2)\n",
" :param text: 文本框的坐标 (x1, y1, x2, y2)\n",
" :return: 交并比(IoU)\n",
" \"\"\"\n",
" # 计算交集的左上角和右下角坐标\n",
" intersection_x1 = max(cell[0], text['coords'][0])\n",
" intersection_y1 = max(cell[1], text['coords'][1])\n",
" intersection_x2 = min(cell[2], text['coords'][2])\n",
" intersection_y2 = min(cell[3], text['coords'][3])\n",
"\n",
" # 如果没有交集,返回 0\n",
" if intersection_x1 >= intersection_x2 or intersection_y1 >= intersection_y2:\n",
" return 0.0\n",
"\n",
" # 计算交集的面积\n",
" intersection_area = (intersection_x2 - intersection_x1) * (intersection_y2 - intersection_y1)\n",
"\n",
" # 计算并集的面积\n",
" area_box1 = (cell[2] - cell[0]) * (cell[3] - cell[1])\n",
" area_box2 = (text['coords'][2] - text['coords'][0]) * (text['coords'][3] - text['coords'][1])\n",
" union_area = area_box1 + area_box2 - intersection_area\n",
"\n",
" # 计算 IoU\n",
" iou = intersection_area / union_area\n",
"\n",
" return iou\n",
"def calculate_iot(cell, text):\n",
" \"\"\"\n",
" 计算两个矩形框的交集面积和文本框面积的比值(IoT)。\n",
" \n",
" :param cell: 单元格的坐标 (x1, y1, x2, y2)\n",
" :param text: 文本框的坐标 (x1, y1, x2, y2)\n",
" :return: IoT\n",
" \"\"\"\n",
" # 计算交集的左上角和右下角坐标\n",
" intersection_x1 = max(cell[0], text['coords'][0])\n",
" intersection_y1 = max(cell[1], text['coords'][1])\n",
" intersection_x2 = min(cell[2], text['coords'][2])\n",
" intersection_y2 = min(cell[3], text['coords'][3])\n",
"\n",
" # 如果没有交集,返回 0\n",
" if intersection_x1 >= intersection_x2 or intersection_y1 >= intersection_y2:\n",
" return 0.0\n",
" # 计算交集的面积\n",
" intersection_area = (intersection_x2 - intersection_x1) * (intersection_y2 - intersection_y1)\n",
"\n",
" text_area = (text['coords'][2] - text['coords'][0]) * (text['coords'][3] - text['coords'][1])\n",
" # 计算 IoT\n",
" iot = intersection_area / text_area\n",
" return iot\n",
"\n",
"def merge_text_into_cells(cell_coords, ocr_results):\n",
" \"\"\"将文字合并到单元格\"\"\"\n",
" # 创建一个字典,键是单元格坐标,值是属于该单元格的文字列表\n",
" cell_text_dict = {cell: [] for cell in cell_coords}\n",
" noncell_text_dict = {}\n",
" \n",
" # 遍历 OCR 结果,将文字分配给正确的单元格\n",
" for cell in cell_coords:\n",
" for result in ocr_results:\n",
" if calculate_iot(cell, result)>0.5:\n",
" cell_text_dict[cell].append(result['text'])\n",
" \n",
" for result in ocr_results:\n",
" if all(calculate_iot(cell, result)<0.1 for cell in cell_coords):\n",
" noncell_text_dict[result['coords']] = result['text']\n",
"\n",
" merged_text = {}\n",
" for cell, texts in cell_text_dict.items():\n",
" merged_text[cell] = ''.join(texts).strip()\n",
" for coords, text in noncell_text_dict.items():\n",
" merged_text[coords] = ''.join(text).strip()\n",
" \n",
" return merged_text\n",
"\n",
"cell_coords = [tuple([*i[:2],*i[4:6]]) for i in result['polygons']]\n",
"ocr_results = [\n",
" {'text': i[1][0], 'coords': tuple([*i[0][0],*i[0][2]])} for i in res[0]]\n",
"merged_text = merge_text_into_cells(cell_coords, ocr_results)\n",
"print(merged_text)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from PIL import Image, ImageDraw, ImageFont\n",
"import textwrap\n",
"import numpy as np\n",
"def draw_text_boxes(image_path, boxes, texts):\n",
" # 加载图像\n",
" img = Image.open(image_path)\n",
" img = Image.new('RGB', img.size, (255, 255, 255))\n",
" # 创建一个 ImageDraw 对象\n",
" draw = ImageDraw.Draw(img)\n",
" \n",
" # 设置字体\n",
" font = ImageFont.truetype(\"./chinese_cht.ttf\", size=15) # 选择合适的字体和大小\n",
" \n",
"\n",
" # 遍历每个文本框和对应的文本\n",
" for box, text in zip(boxes, texts):\n",
" # 绘制文本框\n",
" draw.rectangle(box, outline='red', width=2)\n",
" \n",
" \n",
" text_len = draw.textbbox(xy=box[:2], text=text, font=font)\n",
" \n",
" if (text_len[2]-text_len[0]) > (box[2] - box[0]):\n",
" # 如果文本长度大于文本框宽度,则将文本换行\n",
" text = '\\n'.join(textwrap.wrap(text, width=int(np.ceil((len(text) / np.ceil((text_len[2]-text_len[0]) / (box[2] - box[0])))))))\n",
" else:\n",
" # 否则直接绘制文本\n",
" text = text\n",
" x, y = box[:2]\n",
" \n",
" # 在文本框内居中文本\n",
" draw.text((x,y), text, font=font, fill='black')\n",
" \n",
" # 保存带有文本框和文字的图像\n",
" img.save('你保存的图片路径')\n",
"\n",
"# 示例文本框坐标和对应的文字\n",
"boxes = list(merged_text.keys())\n",
"texts = list(merged_text.values())\n",
"draw_text_boxes('你需要提取的图片路径', boxes, texts)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"\n",
"def adjust_coordinates(merged_text, image_path):\n",
" \n",
" image = Image.open(image_path)\n",
" width, height = image.size\n",
" threshold = height / 100\n",
" groups = {}\n",
" \n",
" for coordinates, text in merged_text.items():\n",
" # 查找与当前 y 坐标相差不超过 threshold 的分组\n",
" found_group = False\n",
" for group_y in groups.keys():\n",
" if abs(coordinates[1] - group_y) <= threshold:\n",
" groups[group_y].append((coordinates,text))\n",
" found_group = True\n",
" break\n",
"\n",
" # 如果没有找到合适的分组,则创建一个新的分组\n",
" if not found_group:\n",
" groups[coordinates[1]] = [(coordinates,text)]\n",
" \n",
" # 计算每个分组的 y 坐标的平均值,并更新坐标列表\n",
" adjusted_coordinates = {}\n",
" for group_y, group_coords in groups.items():\n",
" avg_y = sum(coord[0][1] for coord in group_coords) / len(group_coords)\n",
" for i in group_coords:\n",
" adjusted_coordinates[(i[0][0], avg_y, i[0][2], i[0][3])] = i[1]\n",
" \n",
"\n",
" return adjusted_coordinates\n",
"\n",
"# 调用函数处理坐标\n",
"adjusted_merged_text = adjust_coordinates(merged_text, '你需要提取的图片路径')\n",
"\n",
"# 打印结果\n",
"print(\"原始坐标:\", merged_text)\n",
"print(\"调整后的坐标:\", adjusted_merged_text)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from PIL import Image, ImageDraw, ImageFont\n",
"import textwrap\n",
"import numpy as np\n",
"def draw_text_boxes(image_path, boxes, texts):\n",
" \n",
" img = Image.open(image_path)\n",
" img = Image.new('RGB', img.size, (255, 255, 255))\n",
" draw = ImageDraw.Draw(img)\n",
" font = ImageFont.truetype(\"./chinese_cht.ttf\", size=15) # 选择合适的字体和大小\n",
" for box, text in zip(boxes, texts):\n",
" \n",
" draw.rectangle(box, outline='red', width=2)\n",
" \n",
" \n",
" text_len = draw.textbbox(xy=box[:2], text=text, font=font)\n",
" \n",
" if (text_len[2]-text_len[0]) > (box[2] - box[0]):\n",
" # 如果文本长度大于文本框宽度,则将文本换行\n",
" text = '\\n'.join(textwrap.wrap(text, width=int(np.ceil(len(text) / np.ceil((text_len[2]-text_len[0]) / (box[2] - box[0]))))))\n",
" else:\n",
" # 否则直接绘制文本\n",
" text = text\n",
" x, y = box[:2]\n",
" \n",
" draw.text((x,y), text, font=font, fill='black')\n",
" img.save('你需要保存的图片路径')\n",
"\n",
"boxes = list(adjusted_merged_text.keys())\n",
"texts = list(adjusted_merged_text.values())\n",
"draw_text_boxes('你需要提取的图片路径', boxes, texts)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#输出最终的文本\n",
"adjusted_merged_text_sorted = sorted(adjusted_merged_text.items(), key=lambda x: (x[0][1], x[0][0]))\n",
"adjusted_merged_text_sorted_group = {}\n",
"for coordinates, text in adjusted_merged_text_sorted:\n",
" if coordinates[1] not in adjusted_merged_text_sorted_group:\n",
" adjusted_merged_text_sorted_group[coordinates[1]] = [text]\n",
" else:\n",
" adjusted_merged_text_sorted_group[coordinates[1]].append(text)\n",
"for text_list in adjusted_merged_text_sorted_group.values():\n",
" print(' | '.join(text_list))\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "wyf",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.14"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

0 comments on commit 9ef69f3

Please sign in to comment.