Skip to content

Commit

Permalink
Merge pull request #23 from zRzRzRzRzRzRzR/main
Browse files Browse the repository at this point in the history
fix dict bug
  • Loading branch information
CosmosShadow authored Jul 9, 2024
2 parents 62f2623 + d3b26af commit ceacaf5
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 585 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ gp/*
dist/*
.idea
venv
test_output
35 changes: 23 additions & 12 deletions gptpdf/parse.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import re
from typing import List, Tuple, Optional, Dict
import logging

Expand All @@ -17,8 +18,7 @@
3. 内容不要包含在```markdown ```中、段落公式使用 $$ $$ 的形式、行内公式使用 $ $ 的形式、忽略掉长直线、忽略掉页码。
再次强调,不要解释和输出无关的文字,直接输出图片中的内容。
"""
DEFAULT_RECT_PROMPT = """图片中用红色框和名称(%s)标注出了一些区域。
如果区域是表格或者图片,使用 ![]() 的形式插入到输出内容中,否则直接输出文字内容。
DEFAULT_RECT_PROMPT = """图片中用红色框和名称(%s)标注出了一些区域。如果区域是表格或者图片,使用 ![]() 的形式插入到输出内容中,否则直接输出文字内容。
"""
DEFAULT_ROLE_PROMPT = """你是一个PDF文档解析器,使用markdown和latex语法输出图片的内容。
"""
Expand Down Expand Up @@ -114,14 +114,14 @@ def _parse_rects(page: fitz.Page) -> List[Tuple[float, float, float, float]]:

merged_rects = _merge_rects(rect_list, distance=10, horizontal_distance=100)
merged_rects = [rect for rect in merged_rects if explain_validity(rect) == 'Valid Geometry']

# 将大文本区域和小文本区域分开处理: 大文本相小合并,小文本靠近合并
is_large_content = lambda x: (len(x[4]) / max(1, len(x[4].split('\n')))) > 5
small_text_area_rects = [sg.box(*x[:4]) for x in page.get_text('blocks') if not is_large_content(x)]
large_text_area_rects = [sg.box(*x[:4]) for x in page.get_text('blocks') if is_large_content(x)]
_, merged_rects = _adsorb_rects_to_rects(large_text_area_rects, merged_rects, distance=0.1) # 完全相交
_, merged_rects = _adsorb_rects_to_rects(small_text_area_rects, merged_rects, distance=5) # 靠近

# 再次自身合并
merged_rects = _merge_rects(merged_rects, distance=10)

Expand Down Expand Up @@ -175,7 +175,7 @@ def _parse_pdf_to_images(pdf_path: str, output_dir: str = './') -> List[Tuple[st

def _gpt_parse_images(
image_infos: List[Tuple[str, List[str]]],
prompt: Optional[Dict] = None,
prompt_dict: Optional[Dict] = None,
output_dir: str = './',
api_key: Optional[str] = None,
base_url: Optional[str] = None,
Expand All @@ -188,17 +188,20 @@ def _gpt_parse_images(
"""
from GeneralAgent import Agent

if prompt is None:
if isinstance(prompt_dict, dict) and 'prompt' in prompt_dict:
prompt = prompt_dict['prompt']
logging.info("prompt is provided, using user prompt.")
else:
prompt = DEFAULT_PROMPT
logging.info("prompt is not provided, using default prompt.")
if isinstance(prompt, dict) and 'rect_prompt' in prompt:
rect_prompt = prompt['rect_prompt']
if isinstance(prompt_dict, dict) and 'rect_prompt' in prompt_dict:
rect_prompt = prompt_dict['rect_prompt']
logging.info("rect_prompt is provided, using user prompt.")
else:
rect_prompt = DEFAULT_RECT_PROMPT
logging.info("rect_prompt is not provided, using default prompt.")
if isinstance(prompt, dict) and 'role_prompt' in prompt:
role_prompt = prompt['role_prompt']
if isinstance(prompt_dict, dict) and 'role_prompt' in prompt_dict:
role_prompt = prompt_dict['role_prompt']
logging.info("role_prompt is provided, using user prompt.")
else:
role_prompt = DEFAULT_ROLE_PROMPT
Expand All @@ -210,7 +213,7 @@ def _process_page(index: int, image_info: Tuple[str, List[str]]) -> Tuple[int, s
page_image, rect_images = image_info
local_prompt = prompt
if rect_images:
local_prompt += rect_prompt % ', '.join(rect_images)
local_prompt += rect_prompt + ', '.join(rect_images)
content = agent.run([local_prompt, {'image': page_image}], show_stream=verbose)
return index, content

Expand All @@ -219,6 +222,14 @@ def _process_page(index: int, image_info: Tuple[str, List[str]]) -> Tuple[int, s
futures = [executor.submit(_process_page, index, image_info) for index, image_info in enumerate(image_infos)]
for future in concurrent.futures.as_completed(futures):
index, content = future.result()

# 在某些情况下大模型还是会输出 ```markdown ```字符串
if '```markdown' in content:
content = content.replace('```markdown\n', '')
last_backticks_pos = content.rfind('```')
if last_backticks_pos != -1:
content = content[:last_backticks_pos] + content[last_backticks_pos + 3:]

contents[index] = content

output_path = os.path.join(output_dir, 'output.md')
Expand Down Expand Up @@ -248,7 +259,7 @@ def parse_pdf(
content = _gpt_parse_images(
image_infos=image_infos,
output_dir=output_dir,
prompt=prompt,
prompt_dict=prompt,
api_key=api_key,
base_url=base_url,
model=model,
Expand Down
Loading

0 comments on commit ceacaf5

Please sign in to comment.