{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from modelscope.pipelines import pipeline\n", "from modelscope.utils.constant import Tasks\n", "table_recognition = pipeline(Tasks.table_recognition, model='cv_dla34_table-structure-recognition_cycle-centernet模型的路径')\n", "result = table_recognition('你需要提取的图片路径')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from paddleocr import PaddleOCR\n", "ocr = PaddleOCR(use_gpu=True, lang='ch')\n", "image_path = '你需要提取的图片路径'\n", "res = ocr.ocr(image_path, cls=True)\n", "print(res)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from PIL import Image, ImageDraw, ImageFont\n", "import textwrap\n", "import numpy as np\n", "def draw_ocr_boxes(image_path, boxes, texts):\n", " \n", " img = Image.open(image_path)\n", " img = Image.new('RGB', img.size, (255, 255, 255))\n", " \n", " draw = ImageDraw.Draw(img)\n", " font = ImageFont.truetype(\"./chinese_cht.ttf\", size=15) \n", " \n", "\n", " # 遍历每个文本框和对应的文本\n", " for box, text in zip(boxes, texts):\n", " draw.rectangle(box, outline='red', width=2)\n", " x, y = box[:2]\n", " draw.text((x,y), text, font=font, fill='black')\n", " \n", " img.save('image_with_boxes_and_text.jpg')\n", "\n", "# 示例文本框坐标和对应的文字\n", "boxes = [(*i[0][0],*i[0][2]) for i in res[0]]\n", "texts = [i[1][0] for i in res[0]]\n", "draw_ocr_boxes('你需要提取的图片路径', boxes, texts)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def is_inside_text(cell, text):\n", " \"\"\"检查文字是否完全在单元格内\"\"\"\n", " cx1, cy1, cx2, cy2 = cell\n", " tx1, ty1, tx2, ty2 = text['coords']\n", " return cx1 <= tx1 and cy1 <= ty1 and cx2 >= tx2 and cy2 >= ty2\n", "def calculate_iou(cell, text):\n", " \"\"\"\n", " 计算两个矩形框的交并比(IoU)。\n", " \n", " :param cell: 单元格的坐标 (x1, y1, x2, y2)\n", " :param text: 文本框的坐标 (x1, y1, x2, y2)\n", " :return: 交并比(IoU)\n", " \"\"\"\n", " # 计算交集的左上角和右下角坐标\n", " intersection_x1 = max(cell[0], text['coords'][0])\n", " intersection_y1 = max(cell[1], text['coords'][1])\n", " intersection_x2 = min(cell[2], text['coords'][2])\n", " intersection_y2 = min(cell[3], text['coords'][3])\n", "\n", " # 如果没有交集,返回 0\n", " if intersection_x1 >= intersection_x2 or intersection_y1 >= intersection_y2:\n", " return 0.0\n", "\n", " # 计算交集的面积\n", " intersection_area = (intersection_x2 - intersection_x1) * (intersection_y2 - intersection_y1)\n", "\n", " # 计算并集的面积\n", " area_box1 = (cell[2] - cell[0]) * (cell[3] - cell[1])\n", " area_box2 = (text['coords'][2] - text['coords'][0]) * (text['coords'][3] - text['coords'][1])\n", " union_area = area_box1 + area_box2 - intersection_area\n", "\n", " # 计算 IoU\n", " iou = intersection_area / union_area\n", "\n", " return iou\n", "def calculate_iot(cell, text):\n", " \"\"\"\n", " 计算两个矩形框的交集面积和文本框面积的比值(IoT)。\n", " \n", " :param cell: 单元格的坐标 (x1, y1, x2, y2)\n", " :param text: 文本框的坐标 (x1, y1, x2, y2)\n", " :return: IoT\n", " \"\"\"\n", " # 计算交集的左上角和右下角坐标\n", " intersection_x1 = max(cell[0], text['coords'][0])\n", " intersection_y1 = max(cell[1], text['coords'][1])\n", " intersection_x2 = min(cell[2], text['coords'][2])\n", " intersection_y2 = min(cell[3], text['coords'][3])\n", "\n", " # 如果没有交集,返回 0\n", " if intersection_x1 >= intersection_x2 or intersection_y1 >= intersection_y2:\n", " return 0.0\n", " # 计算交集的面积\n", " intersection_area = (intersection_x2 - intersection_x1) * (intersection_y2 - intersection_y1)\n", "\n", " text_area = (text['coords'][2] - text['coords'][0]) * (text['coords'][3] - text['coords'][1])\n", " # 计算 IoT\n", " iot = intersection_area / text_area\n", " return iot\n", "\n", "def merge_text_into_cells(cell_coords, ocr_results):\n", " \"\"\"将文字合并到单元格\"\"\"\n", " # 创建一个字典,键是单元格坐标,值是属于该单元格的文字列表\n", " cell_text_dict = {cell: [] for cell in cell_coords}\n", " noncell_text_dict = {}\n", " \n", " # 遍历 OCR 结果,将文字分配给正确的单元格\n", " for cell in cell_coords:\n", " for result in ocr_results:\n", " if calculate_iot(cell, result)>0.5:\n", " cell_text_dict[cell].append(result['text'])\n", " \n", " for result in ocr_results:\n", " if all(calculate_iot(cell, result)<0.1 for cell in cell_coords):\n", " noncell_text_dict[result['coords']] = result['text']\n", "\n", " merged_text = {}\n", " for cell, texts in cell_text_dict.items():\n", " merged_text[cell] = ''.join(texts).strip()\n", " for coords, text in noncell_text_dict.items():\n", " merged_text[coords] = ''.join(text).strip()\n", " \n", " return merged_text\n", "\n", "cell_coords = [tuple([*i[:2],*i[4:6]]) for i in result['polygons']]\n", "ocr_results = [\n", " {'text': i[1][0], 'coords': tuple([*i[0][0],*i[0][2]])} for i in res[0]]\n", "merged_text = merge_text_into_cells(cell_coords, ocr_results)\n", "print(merged_text)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from PIL import Image, ImageDraw, ImageFont\n", "import textwrap\n", "import numpy as np\n", "def draw_text_boxes(image_path, boxes, texts):\n", " # 加载图像\n", " img = Image.open(image_path)\n", " img = Image.new('RGB', img.size, (255, 255, 255))\n", " # 创建一个 ImageDraw 对象\n", " draw = ImageDraw.Draw(img)\n", " \n", " # 设置字体\n", " font = ImageFont.truetype(\"./chinese_cht.ttf\", size=15) # 选择合适的字体和大小\n", " \n", "\n", " # 遍历每个文本框和对应的文本\n", " for box, text in zip(boxes, texts):\n", " # 绘制文本框\n", " draw.rectangle(box, outline='red', width=2)\n", " \n", " \n", " text_len = draw.textbbox(xy=box[:2], text=text, font=font)\n", " \n", " if (text_len[2]-text_len[0]) > (box[2] - box[0]):\n", " # 如果文本长度大于文本框宽度,则将文本换行\n", " text = '\\n'.join(textwrap.wrap(text, width=int(np.ceil((len(text) / np.ceil((text_len[2]-text_len[0]) / (box[2] - box[0])))))))\n", " else:\n", " # 否则直接绘制文本\n", " text = text\n", " x, y = box[:2]\n", " \n", " # 在文本框内居中文本\n", " draw.text((x,y), text, font=font, fill='black')\n", " \n", " # 保存带有文本框和文字的图像\n", " img.save('你保存的图片路径')\n", "\n", "# 示例文本框坐标和对应的文字\n", "boxes = list(merged_text.keys())\n", "texts = list(merged_text.values())\n", "draw_text_boxes('你需要提取的图片路径', boxes, texts)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "\n", "\n", "def adjust_coordinates(merged_text, image_path):\n", " \n", " image = Image.open(image_path)\n", " width, height = image.size\n", " threshold = height / 100\n", " groups = {}\n", " \n", " for coordinates, text in merged_text.items():\n", " # 查找与当前 y 坐标相差不超过 threshold 的分组\n", " found_group = False\n", " for group_y in groups.keys():\n", " if abs(coordinates[1] - group_y) <= threshold:\n", " groups[group_y].append((coordinates,text))\n", " found_group = True\n", " break\n", "\n", " # 如果没有找到合适的分组,则创建一个新的分组\n", " if not found_group:\n", " groups[coordinates[1]] = [(coordinates,text)]\n", " \n", " # 计算每个分组的 y 坐标的平均值,并更新坐标列表\n", " adjusted_coordinates = {}\n", " for group_y, group_coords in groups.items():\n", " avg_y = sum(coord[0][1] for coord in group_coords) / len(group_coords)\n", " for i in group_coords:\n", " adjusted_coordinates[(i[0][0], avg_y, i[0][2], i[0][3])] = i[1]\n", " \n", "\n", " return adjusted_coordinates\n", "\n", "# 调用函数处理坐标\n", "adjusted_merged_text = adjust_coordinates(merged_text, '你需要提取的图片路径')\n", "\n", "# 打印结果\n", "print(\"原始坐标:\", merged_text)\n", "print(\"调整后的坐标:\", adjusted_merged_text)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from PIL import Image, ImageDraw, ImageFont\n", "import textwrap\n", "import numpy as np\n", "def draw_text_boxes(image_path, boxes, texts):\n", " \n", " img = Image.open(image_path)\n", " img = Image.new('RGB', img.size, (255, 255, 255))\n", " draw = ImageDraw.Draw(img)\n", " font = ImageFont.truetype(\"./chinese_cht.ttf\", size=15) # 选择合适的字体和大小\n", " for box, text in zip(boxes, texts):\n", " \n", " draw.rectangle(box, outline='red', width=2)\n", " \n", " \n", " text_len = draw.textbbox(xy=box[:2], text=text, font=font)\n", " \n", " if (text_len[2]-text_len[0]) > (box[2] - box[0]):\n", " # 如果文本长度大于文本框宽度,则将文本换行\n", " text = '\\n'.join(textwrap.wrap(text, width=int(np.ceil(len(text) / np.ceil((text_len[2]-text_len[0]) / (box[2] - box[0]))))))\n", " else:\n", " # 否则直接绘制文本\n", " text = text\n", " x, y = box[:2]\n", " \n", " draw.text((x,y), text, font=font, fill='black')\n", " img.save('你需要保存的图片路径')\n", "\n", "boxes = list(adjusted_merged_text.keys())\n", "texts = list(adjusted_merged_text.values())\n", "draw_text_boxes('你需要提取的图片路径', boxes, texts)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#输出最终的文本\n", "adjusted_merged_text_sorted = sorted(adjusted_merged_text.items(), key=lambda x: (x[0][1], x[0][0]))\n", "adjusted_merged_text_sorted_group = {}\n", "for coordinates, text in adjusted_merged_text_sorted:\n", " if coordinates[1] not in adjusted_merged_text_sorted_group:\n", " adjusted_merged_text_sorted_group[coordinates[1]] = [text]\n", " else:\n", " adjusted_merged_text_sorted_group[coordinates[1]].append(text)\n", "for text_list in adjusted_merged_text_sorted_group.values():\n", " print(' | '.join(text_list))\n" ] } ], "metadata": { "kernelspec": { "display_name": "wyf", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.14" } }, "nbformat": 4, "nbformat_minor": 2 }