dots.ocr release
This commit is contained in:
Executable
+1
@@ -0,0 +1 @@
|
||||
from .parser import DotsOCRParser
|
||||
Executable
+50
@@ -0,0 +1,50 @@
|
||||
import json
|
||||
import io
|
||||
import base64
|
||||
import math
|
||||
from PIL import Image
|
||||
import requests
|
||||
from dots_ocr.utils.image_utils import PILimage_to_base64
|
||||
from openai import OpenAI
|
||||
import os
|
||||
|
||||
|
||||
def inference_with_vllm(
|
||||
image,
|
||||
prompt,
|
||||
ip="localhost",
|
||||
port=8000,
|
||||
temperature=0.1,
|
||||
top_p=0.9,
|
||||
max_completion_tokens=32768,
|
||||
model_name='model',
|
||||
):
|
||||
|
||||
addr = f"http://{ip}:{port}/v1"
|
||||
client = OpenAI(api_key="{}".format(os.environ.get("API_KEY", "0")), base_url=addr)
|
||||
messages = []
|
||||
messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": PILimage_to_base64(image)},
|
||||
},
|
||||
{"type": "text", "text": f"<|img|><|imgpad|><|endofimg|>{prompt}"} # if no "<|img|><|imgpad|><|endofimg|>" here,vllm v1 will add "\n" here
|
||||
],
|
||||
}
|
||||
)
|
||||
try:
|
||||
response = client.chat.completions.create(
|
||||
messages=messages,
|
||||
model=model_name,
|
||||
max_completion_tokens=max_completion_tokens,
|
||||
temperature=temperature,
|
||||
top_p=top_p)
|
||||
response = response.choices[0].message.content
|
||||
return response
|
||||
except requests.exceptions.RequestException as e:
|
||||
print(f"request error: {e}")
|
||||
return None
|
||||
|
||||
Executable
+349
@@ -0,0 +1,349 @@
|
||||
import os
|
||||
import json
|
||||
from tqdm import tqdm
|
||||
from multiprocessing.pool import ThreadPool, Pool
|
||||
import argparse
|
||||
|
||||
|
||||
from dots_ocr.model.inference import inference_with_vllm
|
||||
from dots_ocr.utils.consts import image_extensions, MIN_PIXELS, MAX_PIXELS
|
||||
from dots_ocr.utils.image_utils import get_image_by_fitz_doc, fetch_image, smart_resize
|
||||
from dots_ocr.utils.doc_utils import fitz_doc_to_image, load_images_from_pdf
|
||||
from dots_ocr.utils.prompts import dict_promptmode_to_prompt
|
||||
from dots_ocr.utils.layout_utils import post_process_output, draw_layout_on_image, pre_process_bboxes
|
||||
from dots_ocr.utils.format_transformer import layoutjson2md
|
||||
|
||||
|
||||
class DotsOCRParser:
|
||||
"""
|
||||
parse image or pdf file
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
ip='localhost',
|
||||
port=8000,
|
||||
model_name='model',
|
||||
temperature=0.1,
|
||||
top_p=1.0,
|
||||
max_completion_tokens=16384,
|
||||
num_thread=64,
|
||||
dpi = 200,
|
||||
output_dir="./output",
|
||||
min_pixels=None,
|
||||
max_pixels=None,
|
||||
):
|
||||
self.dpi = dpi
|
||||
|
||||
# default args for vllm server
|
||||
self.ip = ip
|
||||
self.port = port
|
||||
self.model_name = model_name
|
||||
# default args for inference
|
||||
self.temperature = temperature
|
||||
self.top_p = top_p
|
||||
self.max_completion_tokens = max_completion_tokens
|
||||
self.num_thread = num_thread
|
||||
self.output_dir = output_dir
|
||||
self.min_pixels = min_pixels
|
||||
self.max_pixels = max_pixels
|
||||
assert self.min_pixels is None or self.min_pixels >= MIN_PIXELS
|
||||
assert self.max_pixels is None or self.max_pixels <= MAX_PIXELS
|
||||
|
||||
|
||||
def _inference_with_vllm(self, image, prompt):
|
||||
response = inference_with_vllm(
|
||||
image,
|
||||
prompt,
|
||||
model_name=self.model_name,
|
||||
ip=self.ip,
|
||||
port=self.port,
|
||||
temperature=self.temperature,
|
||||
top_p=self.top_p,
|
||||
max_completion_tokens=self.max_completion_tokens,
|
||||
)
|
||||
return response
|
||||
|
||||
def get_prompt(self, prompt_mode, bbox=None, origin_image=None, image=None, min_pixels=None, max_pixels=None):
|
||||
prompt = dict_promptmode_to_prompt[prompt_mode]
|
||||
if prompt_mode == 'prompt_grounding_ocr':
|
||||
assert bbox is not None
|
||||
bboxes = [bbox]
|
||||
bbox = pre_process_bboxes(origin_image, bboxes, input_width=image.width, input_height=image.height, min_pixels=min_pixels, max_pixels=max_pixels)[0]
|
||||
prompt = prompt + str(bbox)
|
||||
return prompt
|
||||
|
||||
# def post_process_results(self, response, prompt_mode, save_dir, save_name, origin_image, image, min_pixels, max_pixels)
|
||||
def _parse_single_image(
|
||||
self,
|
||||
origin_image,
|
||||
prompt_mode,
|
||||
save_dir,
|
||||
save_name,
|
||||
source="image",
|
||||
page_idx=0,
|
||||
bbox=None,
|
||||
fitz_preprocess=False,
|
||||
):
|
||||
min_pixels, max_pixels = self.min_pixels, self.max_pixels
|
||||
if prompt_mode == "prompt_grounding_ocr":
|
||||
min_pixels = min_pixels or MIN_PIXELS # preprocess image to the final input
|
||||
max_pixels = max_pixels or MAX_PIXELS
|
||||
if min_pixels is not None: assert min_pixels >= MIN_PIXELS, f"min_pixels should >= {MIN_PIXELS}"
|
||||
if max_pixels is not None: assert max_pixels <= MAX_PIXELS, f"max_pixels should <+ {MAX_PIXELS}"
|
||||
|
||||
if source == 'image' and fitz_preprocess:
|
||||
image = get_image_by_fitz_doc(origin_image, target_dpi=self.dpi)
|
||||
image = fetch_image(image, min_pixels=min_pixels, max_pixels=max_pixels)
|
||||
else:
|
||||
image = fetch_image(origin_image, min_pixels=min_pixels, max_pixels=max_pixels)
|
||||
input_height, input_width = smart_resize(image.height, image.width)
|
||||
prompt = self.get_prompt(prompt_mode, bbox, origin_image, image, min_pixels=min_pixels, max_pixels=max_pixels)
|
||||
response = self._inference_with_vllm(image, prompt)
|
||||
result = {'page_no': page_idx,
|
||||
"input_height": input_height,
|
||||
"input_width": input_width
|
||||
}
|
||||
if source == 'pdf':
|
||||
save_name = f"{save_name}_page_{page_idx}"
|
||||
if prompt_mode in ['prompt_layout_all_en', 'prompt_layout_only_en', 'prompt_grounding_ocr']:
|
||||
cells, filtered = post_process_output(
|
||||
response,
|
||||
prompt_mode,
|
||||
origin_image,
|
||||
image,
|
||||
min_pixels=min_pixels,
|
||||
max_pixels=max_pixels,
|
||||
)
|
||||
if filtered and prompt_mode != 'prompt_layout_only_en': # model output json failed, use filtered process
|
||||
json_file_path = os.path.join(save_dir, f"{save_name}.json")
|
||||
with open(json_file_path, 'w') as w:
|
||||
json.dump(response, w, ensure_ascii=False)
|
||||
|
||||
image_layout_path = os.path.join(save_dir, f"{save_name}.jpg")
|
||||
origin_image.save(image_layout_path)
|
||||
result.update({
|
||||
'layout_info_path': json_file_path,
|
||||
'layout_image_path': image_layout_path,
|
||||
})
|
||||
|
||||
md_file_path = os.path.join(save_dir, f"{save_name}.md")
|
||||
with open(md_file_path, "w", encoding="utf-8") as md_file:
|
||||
md_file.write(cells)
|
||||
result.update({
|
||||
'md_content_path': md_file_path
|
||||
})
|
||||
result.update({
|
||||
'filtered': True
|
||||
})
|
||||
else:
|
||||
try:
|
||||
image_with_layout = draw_layout_on_image(origin_image, cells)
|
||||
except Exception as e:
|
||||
print(f"Error drawing layout on image: {e}")
|
||||
image_with_layout = origin_image
|
||||
|
||||
json_file_path = os.path.join(save_dir, f"{save_name}.json")
|
||||
with open(json_file_path, 'w') as w:
|
||||
json.dump(cells, w, ensure_ascii=False)
|
||||
|
||||
image_layout_path = os.path.join(save_dir, f"{save_name}.jpg")
|
||||
image_with_layout.save(image_layout_path)
|
||||
result.update({
|
||||
'layout_info_path': json_file_path,
|
||||
'layout_image_path': image_layout_path,
|
||||
})
|
||||
if prompt_mode != "prompt_layout_only_en": # no text md when detection only
|
||||
md_content = layoutjson2md(origin_image, cells, text_key='text')
|
||||
md_content_no_hf = layoutjson2md(origin_image, cells, text_key='text', no_page_hf=True) # used for clean output or metric of omnidocbench、olmbench
|
||||
md_file_path = os.path.join(save_dir, f"{save_name}.md")
|
||||
with open(md_file_path, "w", encoding="utf-8") as md_file:
|
||||
md_file.write(md_content)
|
||||
md_nohf_file_path = os.path.join(save_dir, f"{save_name}_nohf.md")
|
||||
with open(md_nohf_file_path, "w", encoding="utf-8") as md_file:
|
||||
md_file.write(md_content_no_hf)
|
||||
result.update({
|
||||
'md_content_path': md_file_path,
|
||||
'md_content_nohf_path': md_nohf_file_path,
|
||||
})
|
||||
else:
|
||||
image_layout_path = os.path.join(save_dir, f"{save_name}.jpg")
|
||||
origin_image.save(image_layout_path)
|
||||
result.update({
|
||||
'layout_image_path': image_layout_path,
|
||||
})
|
||||
|
||||
md_content = response
|
||||
md_file_path = os.path.join(save_dir, f"{save_name}.md")
|
||||
with open(md_file_path, "w", encoding="utf-8") as md_file:
|
||||
md_file.write(md_content)
|
||||
result.update({
|
||||
'md_content_path': md_file_path,
|
||||
})
|
||||
|
||||
return result
|
||||
|
||||
def parse_image(self, input_path, filename, prompt_mode, save_dir, bbox=None, fitz_preprocess=False):
|
||||
origin_image = fetch_image(input_path)
|
||||
result = self._parse_single_image(origin_image, prompt_mode, save_dir, filename, source="image", bbox=bbox, fitz_preprocess=fitz_preprocess)
|
||||
result['file_path'] = input_path
|
||||
return [result]
|
||||
|
||||
def parse_pdf(self, input_path, filename, prompt_mode, save_dir):
|
||||
print(f"loading pdf: {input_path}")
|
||||
images_origin = load_images_from_pdf(input_path)
|
||||
total_pages = len(images_origin)
|
||||
tasks = [
|
||||
{
|
||||
"origin_image": image,
|
||||
"prompt_mode": prompt_mode,
|
||||
"save_dir": save_dir,
|
||||
"save_name": filename,
|
||||
"source":"pdf",
|
||||
"page_idx": i,
|
||||
} for i, image in enumerate(images_origin)
|
||||
]
|
||||
|
||||
def _execute_task(task_args):
|
||||
return self._parse_single_image(**task_args)
|
||||
|
||||
num_thread = min(total_pages, self.num_thread)
|
||||
print(f"Parsing PDF with {total_pages} pages using {num_thread} threads...")
|
||||
|
||||
results = []
|
||||
with ThreadPool(num_thread) as pool:
|
||||
with tqdm(total=total_pages, desc="Processing PDF pages") as pbar:
|
||||
for result in pool.imap_unordered(_execute_task, tasks):
|
||||
results.append(result)
|
||||
pbar.update(1)
|
||||
|
||||
results.sort(key=lambda x: x["page_no"])
|
||||
for i in range(len(results)):
|
||||
results[i]['file_path'] = input_path
|
||||
return results
|
||||
|
||||
def parse_file(self,
|
||||
input_path,
|
||||
output_dir="",
|
||||
prompt_mode="prompt_layout_all_en",
|
||||
bbox=None,
|
||||
fitz_preprocess=False
|
||||
):
|
||||
output_dir = output_dir or self.output_dir
|
||||
output_dir = os.path.abspath(output_dir)
|
||||
filename, file_ext = os.path.splitext(os.path.basename(input_path))
|
||||
save_dir = os.path.join(output_dir, filename)
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
||||
if file_ext == '.pdf':
|
||||
results = self.parse_pdf(input_path, filename, prompt_mode, save_dir)
|
||||
elif file_ext in image_extensions:
|
||||
results = self.parse_image(input_path, filename, prompt_mode, save_dir, bbox=bbox, fitz_preprocess=fitz_preprocess)
|
||||
else:
|
||||
raise ValueError(f"file extension {file_ext} not supported, supported extensions are {image_extensions} and pdf")
|
||||
|
||||
print(f"Parsing finished, results saving to {save_dir}")
|
||||
with open(os.path.join(output_dir, os.path.basename(filename)+'.jsonl'), 'w') as w:
|
||||
for result in results:
|
||||
w.write(json.dumps(result, ensure_ascii=False) + '\n')
|
||||
|
||||
return results
|
||||
|
||||
|
||||
|
||||
def main():
|
||||
prompts = list(dict_promptmode_to_prompt.keys())
|
||||
parser = argparse.ArgumentParser(
|
||||
description="dots.ocr Multilingual Document Layout Parser",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"input_path", type=str,
|
||||
help="Input PDF/image file path"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--output", type=str, default="./output",
|
||||
help="Output directory (default: ./output)"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--prompt", choices=prompts, type=str, default="prompt_layout_all_en",
|
||||
help="prompt to query the model, different prompts for different tasks"
|
||||
)
|
||||
parser.add_argument(
|
||||
'--bbox',
|
||||
type=int,
|
||||
nargs=4,
|
||||
metavar=('x1', 'y1', 'x2', 'y2'),
|
||||
help='should give this argument if you want to prompt_grounding_ocr'
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ip", type=str, default="localhost",
|
||||
help=""
|
||||
)
|
||||
parser.add_argument(
|
||||
"--port", type=int, default=8000,
|
||||
help=""
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_name", type=str, default="model",
|
||||
help=""
|
||||
)
|
||||
parser.add_argument(
|
||||
"--temperature", type=float, default=0.1,
|
||||
help=""
|
||||
)
|
||||
parser.add_argument(
|
||||
"--top_p", type=float, default=1.0,
|
||||
help=""
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dpi", type=int, default=200,
|
||||
help=""
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_completion_tokens", type=int, default=16384,
|
||||
help=""
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_thread", type=int, default=16,
|
||||
help=""
|
||||
)
|
||||
# parser.add_argument(
|
||||
# "--fitz_preprocess", type=bool, default=False,
|
||||
# help="False will use tikz dpi upsample pipeline, good for images which has been render with low dpi, but maybe result in higher computational costs"
|
||||
# )
|
||||
parser.add_argument(
|
||||
"--min_pixels", type=int, default=None,
|
||||
help=""
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_pixels", type=int, default=None,
|
||||
help=""
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
dots_ocr_parser = DotsOCRParser(
|
||||
ip=args.ip,
|
||||
port=args.port,
|
||||
model_name=args.model_name,
|
||||
temperature=args.temperature,
|
||||
top_p=args.top_p,
|
||||
max_completion_tokens=args.max_completion_tokens,
|
||||
num_thread=args.num_thread,
|
||||
dpi=args.dpi,
|
||||
output_dir=args.output,
|
||||
min_pixels=args.min_pixels,
|
||||
max_pixels=args.max_pixels,
|
||||
)
|
||||
|
||||
result = dots_ocr_parser.parse_file(
|
||||
args.input_path,
|
||||
prompt_mode=args.prompt,
|
||||
bbox=args.bbox,
|
||||
)
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Executable
+1
@@ -0,0 +1 @@
|
||||
from .prompts import dict_promptmode_to_prompt
|
||||
Executable
+5
@@ -0,0 +1,5 @@
|
||||
MIN_PIXELS=3136
|
||||
MAX_PIXELS=11289600
|
||||
IMAGE_FACTOR=28
|
||||
|
||||
image_extensions = {'.jpg', '.jpeg', '.png'}
|
||||
Executable
+61
@@ -0,0 +1,61 @@
|
||||
import os
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def is_valid_image_path(image_path):
|
||||
"""
|
||||
Checks if the image path is valid.
|
||||
|
||||
Args:
|
||||
image_path: The path to the image.
|
||||
|
||||
Returns:
|
||||
bool: True if the path is valid, False otherwise.
|
||||
"""
|
||||
if not os.path.exists(image_path):
|
||||
return False
|
||||
|
||||
# Check if the file extension is one of the common image formats.
|
||||
image_extensions = ['.jpg', '.jpeg', '.png', '.gif', '.bmp']
|
||||
_, extension = os.path.splitext(image_path)
|
||||
if extension.lower() in image_extensions:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def read_image(image_path, use_native=False):
|
||||
"""
|
||||
Reads an image and resizes it while maintaining aspect ratio.
|
||||
|
||||
Args:
|
||||
image_path: The path to the image.
|
||||
use_native: If True, the max dimension of the original image is used as the max size.
|
||||
If False, max size is set to 1024.
|
||||
|
||||
Returns:
|
||||
tuple: (resized_image, original_width, original_height)
|
||||
"""
|
||||
# Create a default 512x512 blue image as a fallback.
|
||||
image = Image.new('RGB', (512, 512), color=(0, 0, 255))
|
||||
|
||||
if is_valid_image_path(image_path):
|
||||
image = Image.open(image_path)
|
||||
else:
|
||||
raise FileNotFoundError(f"{image_path}: Image path does not exist")
|
||||
|
||||
w, h = image.size
|
||||
if use_native:
|
||||
max_size = max(w, h)
|
||||
else:
|
||||
max_size = 1024
|
||||
|
||||
if w > h:
|
||||
new_w = max_size
|
||||
new_h = int(h * max_size / w)
|
||||
else:
|
||||
new_h = max_size
|
||||
new_w = int(w * max_size / h)
|
||||
|
||||
image = image.resize((new_w, new_h))
|
||||
return image, w, h
|
||||
Executable
+60
@@ -0,0 +1,60 @@
|
||||
import fitz
|
||||
import numpy as np
|
||||
import enum
|
||||
from pydantic import BaseModel, Field
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class SupportedPdfParseMethod(enum.Enum):
|
||||
OCR = 'ocr'
|
||||
TXT = 'txt'
|
||||
|
||||
|
||||
class PageInfo(BaseModel):
|
||||
"""The width and height of page
|
||||
"""
|
||||
w: float = Field(description='the width of page')
|
||||
h: float = Field(description='the height of page')
|
||||
|
||||
|
||||
def fitz_doc_to_image(doc, target_dpi=200, origin_dpi=None) -> dict:
|
||||
"""Convert fitz.Document to image, Then convert the image to numpy array.
|
||||
|
||||
Args:
|
||||
doc (_type_): pymudoc page
|
||||
dpi (int, optional): reset the dpi of dpi. Defaults to 200.
|
||||
|
||||
Returns:
|
||||
dict: {'img': numpy array, 'width': width, 'height': height }
|
||||
"""
|
||||
from PIL import Image
|
||||
mat = fitz.Matrix(target_dpi / 72, target_dpi / 72)
|
||||
pm = doc.get_pixmap(matrix=mat, alpha=False)
|
||||
|
||||
if pm.width > 4500 or pm.height > 4500:
|
||||
mat = fitz.Matrix(72 / 72, 72 / 72) # use fitz default dpi
|
||||
pm = doc.get_pixmap(matrix=mat, alpha=False)
|
||||
|
||||
image = Image.frombytes('RGB', (pm.width, pm.height), pm.samples)
|
||||
return image
|
||||
|
||||
|
||||
def load_images_from_pdf(pdf_file, dpi=200, start_page_id=0, end_page_id=None) -> list:
|
||||
images = []
|
||||
with fitz.open(pdf_file) as doc:
|
||||
pdf_page_num = doc.page_count
|
||||
end_page_id = (
|
||||
end_page_id
|
||||
if end_page_id is not None and end_page_id >= 0
|
||||
else pdf_page_num - 1
|
||||
)
|
||||
if end_page_id > pdf_page_num - 1:
|
||||
print('end_page_id is out of range, use images length')
|
||||
end_page_id = pdf_page_num - 1
|
||||
|
||||
for index in range(0, doc.page_count):
|
||||
if start_page_id <= index <= end_page_id:
|
||||
page = doc[index]
|
||||
img = fitz_doc_to_image(page, target_dpi=dpi)
|
||||
images.append(img)
|
||||
return images
|
||||
Executable
+205
@@ -0,0 +1,205 @@
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import re
|
||||
|
||||
from PIL import Image
|
||||
from dots_ocr.utils.image_utils import PILimage_to_base64
|
||||
|
||||
|
||||
def has_latex_markdown(text: str) -> bool:
|
||||
"""
|
||||
Checks if a string contains LaTeX markdown patterns.
|
||||
|
||||
Args:
|
||||
text (str): The string to check.
|
||||
|
||||
Returns:
|
||||
bool: True if LaTeX markdown is found, otherwise False.
|
||||
"""
|
||||
if not isinstance(text, str):
|
||||
return False
|
||||
|
||||
# Define regular expression patterns for LaTeX markdown
|
||||
latex_patterns = [
|
||||
r'\$\$.*?\$\$', # Block-level math formula $$...$$
|
||||
r'\$[^$\n]+?\$', # Inline math formula $...$
|
||||
r'\\begin\{.*?\}.*?\\end\{.*?\}', # LaTeX environment \begin{...}...\end{...}
|
||||
r'\\[a-zA-Z]+\{.*?\}', # LaTeX command \command{...}
|
||||
r'\\[a-zA-Z]+', # Simple LaTeX command \command
|
||||
r'\\\[.*?\\\]', # Display math formula \[...\]
|
||||
r'\\\(.*?\\\)', # Inline math formula \(...\)
|
||||
]
|
||||
|
||||
# Check if any of the patterns match
|
||||
for pattern in latex_patterns:
|
||||
if re.search(pattern, text, re.DOTALL):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def clean_latex_preamble(latex_text: str) -> str:
|
||||
"""
|
||||
Removes LaTeX preamble commands like document class and package imports.
|
||||
|
||||
Args:
|
||||
latex_text (str): The original LaTeX text.
|
||||
|
||||
Returns:
|
||||
str: The cleaned LaTeX text without preamble commands.
|
||||
"""
|
||||
# Define patterns to be removed
|
||||
patterns = [
|
||||
r'\\documentclass\{[^}]+\}', # \documentclass{...}
|
||||
r'\\usepackage\{[^}]+\}', # \usepackage{...}
|
||||
r'\\usepackage\[[^\]]*\]\{[^}]+\}', # \usepackage[options]{...}
|
||||
r'\\begin\{document\}', # \begin{document}
|
||||
r'\\end\{document\}', # \end{document}
|
||||
]
|
||||
|
||||
# Apply each pattern to clean the text
|
||||
cleaned_text = latex_text
|
||||
for pattern in patterns:
|
||||
cleaned_text = re.sub(pattern, '', cleaned_text, flags=re.IGNORECASE)
|
||||
|
||||
return cleaned_text
|
||||
|
||||
|
||||
def get_formula_in_markdown(text: str) -> str:
|
||||
"""
|
||||
Formats a string containing a formula into a standard Markdown block.
|
||||
|
||||
Args:
|
||||
text (str): The input string, potentially containing a formula.
|
||||
|
||||
Returns:
|
||||
str: The formatted string, ready for Markdown rendering.
|
||||
"""
|
||||
# Remove leading/trailing whitespace
|
||||
text = text.strip()
|
||||
|
||||
# Check if it's already enclosed in $$
|
||||
if text.startswith('$$') and text.endswith('$$'):
|
||||
text_new = text[2:-2].strip()
|
||||
if not '$' in text_new:
|
||||
return f"$$\n{text_new}\n$$"
|
||||
else:
|
||||
return text
|
||||
|
||||
# Handle \[...\] format, convert to $$...$$
|
||||
if text.startswith('\\[') and text.endswith('\\]'):
|
||||
inner_content = text[2:-2].strip()
|
||||
return f"$$\n{inner_content}\n$$"
|
||||
|
||||
# Check if it's enclosed in \[ \]
|
||||
if len(re.findall(r'.*\\\[.*\\\].*', text)) > 0:
|
||||
return text
|
||||
|
||||
# Handle inline formulas ($...$)
|
||||
pattern = r'\$([^$]+)\$'
|
||||
matches = re.findall(pattern, text)
|
||||
if len(matches) > 0:
|
||||
# It's an inline formula, return it as is
|
||||
return text
|
||||
|
||||
# If no LaTeX markdown syntax is present, return directly
|
||||
if not has_latex_markdown(text):
|
||||
return text
|
||||
|
||||
# Handle unnecessary LaTeX formatting like \usepackage
|
||||
if 'usepackage' in text:
|
||||
text = clean_latex_preamble(text)
|
||||
|
||||
if text[0] == '`' and text[-1] == '`':
|
||||
text = text[1:-1]
|
||||
|
||||
# Enclose the final text in a $$ block with newlines
|
||||
text = f"$$\n{text}\n$$"
|
||||
return text
|
||||
|
||||
|
||||
def clean_text(text: str) -> str:
|
||||
"""
|
||||
Cleans text by removing extra whitespace.
|
||||
|
||||
Args:
|
||||
text: The original text.
|
||||
|
||||
Returns:
|
||||
str: The cleaned text.
|
||||
"""
|
||||
if not text:
|
||||
return ""
|
||||
|
||||
# Remove leading and trailing whitespace
|
||||
text = text.strip()
|
||||
|
||||
# Replace multiple consecutive whitespace characters with a single space
|
||||
text = re.sub(r'\s+', ' ', text)
|
||||
|
||||
return text
|
||||
|
||||
|
||||
def layoutjson2md(image: Image.Image, cells: list, text_key: str = 'text', no_page_hf: bool = False) -> str:
|
||||
"""
|
||||
Converts a layout JSON format to Markdown.
|
||||
|
||||
In the layout JSON, formulas are LaTeX, tables are HTML, and text is Markdown.
|
||||
|
||||
Args:
|
||||
image: A PIL Image object.
|
||||
cells: A list of dictionaries, each representing a layout cell.
|
||||
text_key: The key for the text field in the cell dictionary.
|
||||
no_page_header_footer: If True, skips page headers and footers.
|
||||
|
||||
Returns:
|
||||
str: The text in Markdown format.
|
||||
"""
|
||||
text_items = []
|
||||
|
||||
for i, cell in enumerate(cells):
|
||||
x1, y1, x2, y2 = [int(coord) for coord in cell['bbox']]
|
||||
text = cell.get(text_key, "")
|
||||
|
||||
if no_page_hf and cell['category'] in ['Page-header', 'Page-footer']:
|
||||
continue
|
||||
|
||||
if cell['category'] == 'Picture':
|
||||
image_crop = image.crop((x1, y1, x2, y2))
|
||||
image_base64 = PILimage_to_base64(image_crop)
|
||||
text_items.append(f"")
|
||||
elif cell['category'] == 'Formula':
|
||||
text_items.append(get_formula_in_markdown(text))
|
||||
else:
|
||||
text = clean_text(text)
|
||||
text_items.append(f"{text}")
|
||||
|
||||
markdown_text = '\n\n'.join(text_items)
|
||||
return markdown_text
|
||||
|
||||
|
||||
def fix_streamlit_formulas(md: str) -> str:
|
||||
"""
|
||||
Fixes the format of formulas in Markdown to ensure they display correctly in Streamlit.
|
||||
It adds a newline after the opening $$ and before the closing $$ if they don't already exist.
|
||||
|
||||
Args:
|
||||
md_text (str): The Markdown text to fix.
|
||||
|
||||
Returns:
|
||||
str: The fixed Markdown text.
|
||||
"""
|
||||
|
||||
# This inner function will be used by re.sub to perform the replacement
|
||||
def replace_formula(match):
|
||||
content = match.group(1)
|
||||
# If the content already has surrounding newlines, don't add more.
|
||||
if content.startswith('\n'):
|
||||
content = content[1:]
|
||||
if content.endswith('\n'):
|
||||
content = content[:-1]
|
||||
return f'$$\n{content}\n$$'
|
||||
|
||||
# Use regex to find all $$....$$ patterns and replace them using the helper function.
|
||||
return re.sub(r'\$\$(.*?)\$\$', replace_formula, md, flags=re.DOTALL)
|
||||
Executable
+196
@@ -0,0 +1,196 @@
|
||||
import math
|
||||
import base64
|
||||
from PIL import Image
|
||||
from typing import Tuple
|
||||
import os
|
||||
from dots_ocr.utils.consts import IMAGE_FACTOR, MIN_PIXELS, MAX_PIXELS
|
||||
from dots_ocr.utils.doc_utils import fitz_doc_to_image
|
||||
from io import BytesIO
|
||||
import fitz
|
||||
import requests
|
||||
import copy
|
||||
|
||||
|
||||
def round_by_factor(number: int, factor: int) -> int:
|
||||
"""Returns the closest integer to 'number' that is divisible by 'factor'."""
|
||||
return round(number / factor) * factor
|
||||
|
||||
|
||||
def ceil_by_factor(number: int, factor: int) -> int:
|
||||
"""Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
|
||||
return math.ceil(number / factor) * factor
|
||||
|
||||
|
||||
def floor_by_factor(number: int, factor: int) -> int:
|
||||
"""Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
|
||||
return math.floor(number / factor) * factor
|
||||
|
||||
|
||||
def smart_resize(
|
||||
height: int,
|
||||
width: int,
|
||||
factor: int = 28,
|
||||
min_pixels: int = 3136,
|
||||
max_pixels: int = 11289600,
|
||||
):
|
||||
"""Rescales the image so that the following conditions are met:
|
||||
|
||||
1. Both dimensions (height and width) are divisible by 'factor'.
|
||||
|
||||
2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
|
||||
|
||||
3. The aspect ratio of the image is maintained as closely as possible.
|
||||
|
||||
"""
|
||||
if max(height, width) / min(height, width) > 200:
|
||||
raise ValueError(
|
||||
f"absolute aspect ratio must be smaller than 200, got {max(height, width) / min(height, width)}"
|
||||
)
|
||||
h_bar = max(factor, round_by_factor(height, factor))
|
||||
w_bar = max(factor, round_by_factor(width, factor))
|
||||
if h_bar * w_bar > max_pixels:
|
||||
beta = math.sqrt((height * width) / max_pixels)
|
||||
h_bar = max(factor, floor_by_factor(height / beta, factor))
|
||||
w_bar = max(factor, floor_by_factor(width / beta, factor))
|
||||
elif h_bar * w_bar < min_pixels:
|
||||
beta = math.sqrt(min_pixels / (height * width))
|
||||
h_bar = ceil_by_factor(height * beta, factor)
|
||||
w_bar = ceil_by_factor(width * beta, factor)
|
||||
if h_bar * w_bar > max_pixels: # max_pixels first to control the token length
|
||||
beta = math.sqrt((h_bar * w_bar) / max_pixels)
|
||||
h_bar = max(factor, floor_by_factor(h_bar / beta, factor))
|
||||
w_bar = max(factor, floor_by_factor(w_bar / beta, factor))
|
||||
return h_bar, w_bar
|
||||
|
||||
|
||||
|
||||
def PILimage_to_base64(image, format='PNG'):
|
||||
buffered = BytesIO()
|
||||
image.save(buffered, format=format)
|
||||
base64_str = base64.b64encode(buffered.getvalue()).decode('utf-8')
|
||||
return f"data:image;base64,{base64_str}"
|
||||
|
||||
|
||||
def to_rgb(pil_image: Image.Image) -> Image.Image:
|
||||
if pil_image.mode == 'RGBA':
|
||||
white_background = Image.new("RGB", pil_image.size, (255, 255, 255))
|
||||
white_background.paste(pil_image, mask=pil_image.split()[3]) # Use alpha channel as mask
|
||||
return white_background
|
||||
else:
|
||||
return pil_image.convert("RGB")
|
||||
|
||||
|
||||
# copy from https://github.com/QwenLM/Qwen2.5-VL/blob/main/qwen-vl-utils/src/qwen_vl_utils/vision_process.py
|
||||
def fetch_image(
|
||||
image,
|
||||
min_pixels=None,
|
||||
max_pixels=None,
|
||||
resized_height=None,
|
||||
resized_width=None,
|
||||
) -> Image.Image:
|
||||
assert image is not None, f"image not found, maybe input format error: {image}"
|
||||
image_obj = None
|
||||
if isinstance(image, Image.Image):
|
||||
image_obj = image
|
||||
elif image.startswith("http://") or image.startswith("https://"):
|
||||
# fix memory leak issue while using BytesIO
|
||||
with requests.get(image, stream=True) as response:
|
||||
response.raise_for_status()
|
||||
with BytesIO(response.content) as bio:
|
||||
image_obj = copy.deepcopy(Image.open(bio))
|
||||
elif image.startswith("file://"):
|
||||
image_obj = Image.open(image[7:])
|
||||
elif image.startswith("data:image"):
|
||||
if "base64," in image:
|
||||
_, base64_data = image.split("base64,", 1)
|
||||
data = base64.b64decode(base64_data)
|
||||
# fix memory leak issue while using BytesIO
|
||||
with BytesIO(data) as bio:
|
||||
image_obj = copy.deepcopy(Image.open(bio))
|
||||
else:
|
||||
image_obj = Image.open(image)
|
||||
if image_obj is None:
|
||||
raise ValueError(f"Unrecognized image input, support local path, http url, base64 and PIL.Image, got {image}")
|
||||
image = to_rgb(image_obj)
|
||||
## resize
|
||||
if resized_height and resized_width:
|
||||
resized_height, resized_width = smart_resize(
|
||||
resized_height,
|
||||
resized_width,
|
||||
factor=IMAGE_FACTOR,
|
||||
)
|
||||
assert resized_height>0 and resized_width>0, f"resized_height: {resized_height}, resized_width: {resized_width}, min_pixels: {min_pixels}, max_pixels:{max_pixels}, width: {width}, height:{height}, "
|
||||
image = image.resize((resized_width, resized_height))
|
||||
elif min_pixels or max_pixels:
|
||||
width, height = image.size
|
||||
if not min_pixels:
|
||||
min_pixels = MIN_PIXELS
|
||||
if not max_pixels:
|
||||
max_pixels = MAX_PIXELS
|
||||
resized_height, resized_width = smart_resize(
|
||||
height,
|
||||
width,
|
||||
factor=IMAGE_FACTOR,
|
||||
min_pixels=min_pixels,
|
||||
max_pixels=max_pixels,
|
||||
)
|
||||
assert resized_height>0 and resized_width>0, f"resized_height: {resized_height}, resized_width: {resized_width}, min_pixels: {min_pixels}, max_pixels:{max_pixels}, width: {width}, height:{height}, "
|
||||
image = image.resize((resized_width, resized_height))
|
||||
|
||||
return image
|
||||
|
||||
def get_input_dimensions(
|
||||
image: Image.Image,
|
||||
min_pixels: int,
|
||||
max_pixels: int,
|
||||
factor: int = 28
|
||||
) -> Tuple[int, int]:
|
||||
"""
|
||||
Gets the resized dimensions of the input image.
|
||||
|
||||
Args:
|
||||
image: The original image.
|
||||
min_pixels: The minimum number of pixels.
|
||||
max_pixels: The maximum number of pixels.
|
||||
factor: The resizing factor.
|
||||
|
||||
Returns:
|
||||
The resized (width, height).
|
||||
"""
|
||||
input_height, input_width = smart_resize(
|
||||
image.height,
|
||||
image.width,
|
||||
factor=factor,
|
||||
min_pixels=min_pixels,
|
||||
max_pixels=max_pixels
|
||||
)
|
||||
return input_width, input_height
|
||||
|
||||
|
||||
def get_image_by_fitz_doc(image, target_dpi=200):
|
||||
# get image through fitz, to get target dpi image, mainly for higher image
|
||||
if not isinstance(image, Image.Image):
|
||||
assert isinstance(image, str)
|
||||
_, file_ext = os.path.splitext(image)
|
||||
assert file_ext in {'.jpg', '.jpeg', '.png'}
|
||||
|
||||
if image.startswith("http://") or image.startswith("https://"):
|
||||
with requests.get(image, stream=True) as response:
|
||||
response.raise_for_status()
|
||||
data_bytes = response.content
|
||||
else:
|
||||
with open(image, 'rb') as f:
|
||||
data_bytes = f.read()
|
||||
|
||||
image = Image.open(BytesIO(data_bytes))
|
||||
else:
|
||||
data_bytes = BytesIO()
|
||||
image.save(data_bytes, format='PNG')
|
||||
|
||||
origin_dpi = image.info.get('dpi', None)
|
||||
pdf_bytes = fitz.open(stream=data_bytes).convert_to_pdf()
|
||||
doc = fitz.open('pdf', pdf_bytes)
|
||||
page = doc[0]
|
||||
image_fitz = fitz_doc_to_image(page, target_dpi=target_dpi, origin_dpi=origin_dpi)
|
||||
|
||||
return image_fitz
|
||||
Executable
+228
@@ -0,0 +1,228 @@
|
||||
from PIL import Image
|
||||
from typing import Dict, List
|
||||
|
||||
import fitz
|
||||
from io import BytesIO
|
||||
import json
|
||||
|
||||
from dots_ocr.utils.image_utils import smart_resize
|
||||
from dots_ocr.utils.consts import MIN_PIXELS, MAX_PIXELS
|
||||
from dots_ocr.utils.output_cleaner import OutputCleaner
|
||||
|
||||
|
||||
# Define a color map (using RGBA format)
|
||||
dict_layout_type_to_color = {
|
||||
"Text": (0, 128, 0, 256), # Green, translucent
|
||||
"Picture": (255, 0, 255, 256), # Magenta, translucent
|
||||
"Caption": (255, 165, 0, 256), # Orange, translucent
|
||||
"Section-header": (0, 255, 255, 256), # Cyan, translucent
|
||||
"Footnote": (0, 128, 0, 256), # Green, translucent
|
||||
"Formula": (128, 128, 128, 256), # Gray, translucent
|
||||
"Table": (255, 192, 203, 256), # Pink, translucent
|
||||
"Title": (255, 0, 0, 256), # Red, translucent
|
||||
"List-item": (0, 0, 255, 256), # Blue, translucent
|
||||
"Page-header": (0, 128, 0, 256), # Green, translucent
|
||||
"Page-footer": (128, 0, 128, 256), # Purple, translucent
|
||||
"Other": (165, 42, 42, 256), # Brown, translucent
|
||||
"Unknown": (0, 0, 0, 0),
|
||||
}
|
||||
|
||||
|
||||
def draw_layout_on_image(image, cells, resized_height=None, resized_width=None, fill_bbox=True, draw_bbox=True):
|
||||
"""
|
||||
Draw transparent boxes on an image.
|
||||
|
||||
Args:
|
||||
image: The source PIL Image.
|
||||
cells: A list of cells containing bounding box information.
|
||||
resized_height: The resized height.
|
||||
resized_width: The resized width.
|
||||
fill_bbox: Whether to fill the bounding box.
|
||||
draw_bbox: Whether to draw the bounding box.
|
||||
|
||||
Returns:
|
||||
PIL.Image: The image with drawings.
|
||||
"""
|
||||
# origin_image = Image.open(image_path)
|
||||
original_width, original_height = image.size
|
||||
|
||||
# Create a new PDF document
|
||||
doc = fitz.open()
|
||||
|
||||
# Get image information
|
||||
img_bytes = BytesIO()
|
||||
image.save(img_bytes, format='PNG')
|
||||
# pix = fitz.Pixmap(image_path)
|
||||
pix = fitz.Pixmap(img_bytes)
|
||||
|
||||
# Create a page
|
||||
page = doc.new_page(width=pix.width, height=pix.height)
|
||||
page.insert_image(
|
||||
fitz.Rect(0, 0, pix.width, pix.height),
|
||||
# filename=image_path
|
||||
pixmap=pix
|
||||
)
|
||||
|
||||
for i, cell in enumerate(cells):
|
||||
bbox = cell['bbox']
|
||||
layout_type = cell['category']
|
||||
order = i
|
||||
|
||||
top_left = (bbox[0], bbox[1])
|
||||
down_right = (bbox[2], bbox[3])
|
||||
if resized_height and resized_width:
|
||||
scale_x = resized_width / original_width
|
||||
scale_y = resized_height / original_height
|
||||
top_left = (int(bbox[0] / scale_x), int(bbox[1] / scale_y))
|
||||
down_right = (int(bbox[2] / scale_x), int(bbox[3] / scale_y))
|
||||
|
||||
color = dict_layout_type_to_color.get(layout_type, (0, 128, 0, 256))
|
||||
color = [col/255 for col in color[:3]]
|
||||
|
||||
x0, y0, x1, y1 = top_left[0], top_left[1], down_right[0], down_right[1]
|
||||
rect_coords = fitz.Rect(x0, y0, x1, y1)
|
||||
if draw_bbox:
|
||||
if fill_bbox:
|
||||
page.draw_rect(
|
||||
rect_coords,
|
||||
color=None,
|
||||
fill=color,
|
||||
fill_opacity=0.3,
|
||||
width=0.5,
|
||||
overlay=True,
|
||||
) # Draw the rectangle
|
||||
else:
|
||||
page.draw_rect(
|
||||
rect_coords,
|
||||
color=color,
|
||||
fill=None,
|
||||
fill_opacity=1,
|
||||
width=0.5,
|
||||
overlay=True,
|
||||
) # Draw the rectangle
|
||||
order_cate = f"{order}_{layout_type}"
|
||||
page.insert_text(
|
||||
(x1, y0 + 20), order_cate, fontsize=20, color=color
|
||||
) # Insert the index in the top left corner of the rectangle
|
||||
|
||||
# Convert to a Pixmap (maintaining original dimensions)
|
||||
mat = fitz.Matrix(1.0, 1.0)
|
||||
pix = page.get_pixmap(matrix=mat)
|
||||
|
||||
return Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
|
||||
|
||||
|
||||
def pre_process_bboxes(
|
||||
origin_image,
|
||||
bboxes,
|
||||
input_width,
|
||||
input_height,
|
||||
factor: int = 28,
|
||||
min_pixels: int = 3136,
|
||||
max_pixels: int = 11289600
|
||||
):
|
||||
assert isinstance(bboxes, list) and len(bboxes) > 0 and isinstance(bboxes[0], list)
|
||||
min_pixels = min_pixels or MIN_PIXELS
|
||||
max_pixels = max_pixels or MAX_PIXELS
|
||||
original_width, original_height = origin_image.size
|
||||
|
||||
input_height, input_width = smart_resize(input_height, input_width, min_pixels=min_pixels, max_pixels=max_pixels)
|
||||
|
||||
scale_x = original_width / input_width
|
||||
scale_y = original_height / input_height
|
||||
|
||||
bboxes_out = []
|
||||
for bbox in bboxes:
|
||||
bbox_resized = [
|
||||
int(float(bbox[0]) / scale_x),
|
||||
int(float(bbox[1]) / scale_y),
|
||||
int(float(bbox[2]) / scale_x),
|
||||
int(float(bbox[3]) / scale_y)
|
||||
]
|
||||
bboxes_out.append(bbox_resized)
|
||||
|
||||
return bboxes_out
|
||||
|
||||
def post_process_cells(
|
||||
origin_image: Image.Image,
|
||||
cells: List[Dict],
|
||||
input_width, # server input width, also has smart_resize in server
|
||||
input_height,
|
||||
factor: int = 28,
|
||||
min_pixels: int = 3136,
|
||||
max_pixels: int = 11289600
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
Post-processes cell bounding boxes, converting coordinates from the resized dimensions back to the original dimensions.
|
||||
|
||||
Args:
|
||||
origin_image: The original PIL Image.
|
||||
cells: A list of cells containing bounding box information.
|
||||
input_width: The width of the input image sent to the server.
|
||||
input_height: The height of the input image sent to the server.
|
||||
factor: Resizing factor.
|
||||
min_pixels: Minimum number of pixels.
|
||||
max_pixels: Maximum number of pixels.
|
||||
|
||||
Returns:
|
||||
A list of post-processed cells.
|
||||
"""
|
||||
assert isinstance(cells, list) and len(cells) > 0 and isinstance(cells[0], dict)
|
||||
min_pixels = min_pixels or MIN_PIXELS
|
||||
max_pixels = max_pixels or MAX_PIXELS
|
||||
original_width, original_height = origin_image.size
|
||||
|
||||
input_height, input_width = smart_resize(input_height, input_width, min_pixels=min_pixels, max_pixels=max_pixels)
|
||||
|
||||
scale_x = input_width / original_width
|
||||
scale_y = input_height / original_height
|
||||
|
||||
cells_out = []
|
||||
for cell in cells:
|
||||
bbox = cell['bbox']
|
||||
bbox_resized = [
|
||||
int(float(bbox[0]) / scale_x),
|
||||
int(float(bbox[1]) / scale_y),
|
||||
int(float(bbox[2]) / scale_x),
|
||||
int(float(bbox[3]) / scale_y)
|
||||
]
|
||||
cell_copy = cell.copy()
|
||||
cell_copy['bbox'] = bbox_resized
|
||||
cells_out.append(cell_copy)
|
||||
|
||||
return cells_out
|
||||
|
||||
def is_legal_bbox(cells):
|
||||
for cell in cells:
|
||||
bbox = cell['bbox']
|
||||
if bbox[2] <= bbox[0] or bbox[3] <= bbox[1]:
|
||||
return False
|
||||
return True
|
||||
|
||||
def post_process_output(response, prompt_mode, origin_image, input_image, min_pixels=None, max_pixels=None):
|
||||
if prompt_mode in ["prompt_ocr", "prompt_table_html", "prompt_table_latex", "prompt_formula_latex"]:
|
||||
return response
|
||||
|
||||
json_load_failed = False
|
||||
cells = response
|
||||
try:
|
||||
cells = json.loads(cells)
|
||||
cells = post_process_cells(
|
||||
origin_image,
|
||||
cells,
|
||||
input_image.width,
|
||||
input_image.height,
|
||||
min_pixels=min_pixels,
|
||||
max_pixels=max_pixels
|
||||
)
|
||||
return cells, False
|
||||
except Exception as e:
|
||||
print(f"cells post process error: {e}, when using {prompt_mode}")
|
||||
json_load_failed = True
|
||||
|
||||
if json_load_failed:
|
||||
cleaner = OutputCleaner()
|
||||
response_clean = cleaner.clean_model_output(cells)
|
||||
if isinstance(response_clean, list):
|
||||
response_clean = "\n\n".join([cell['text'] for cell in response_clean if 'text' in cell])
|
||||
return response_clean, True
|
||||
Executable
+623
@@ -0,0 +1,623 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Data Cleaning Script - Cleans all data using a simplified regex method and saves the results
|
||||
|
||||
Features:
|
||||
1. Cleans all cases using a simplified regex method.
|
||||
2. Saves the cleaned data for each case.
|
||||
3. Ensures the relative order of dicts remains unchanged.
|
||||
4. Generates a before-and-after cleaning report.
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
import os
|
||||
from typing import Dict, List, Tuple, Optional, Any
|
||||
from dataclasses import dataclass
|
||||
from collections import Counter
|
||||
import traceback
|
||||
|
||||
|
||||
@dataclass
|
||||
class CleanedData:
|
||||
"""Data structure for cleaned data"""
|
||||
case_id: int
|
||||
original_type: str # 'list' or 'str'
|
||||
original_length: int
|
||||
cleaned_data: List[Dict]
|
||||
cleaning_operations: Dict[str, Any] # Records the cleaning operations performed
|
||||
success: bool
|
||||
|
||||
|
||||
class OutputCleaner:
|
||||
"""Data Cleaner - Based on a simplified regex method"""
|
||||
|
||||
def __init__(self):
|
||||
# Simplified regular expression patterns
|
||||
self.dict_pattern = re.compile(r'\{[^{}]*?"bbox"\s*:\s*\[[^\]]*?\][^{}]*?\}', re.DOTALL)
|
||||
self.bbox_pattern = re.compile(r'"bbox"\s*:\s*\[([^\]]+)\]')
|
||||
self.missing_delimiter_pattern = re.compile(r'\}\s*\{(?!")')
|
||||
|
||||
self.cleaned_results: List[CleanedData] = []
|
||||
|
||||
def clean_list_data(self, data: List[Dict], case_id: int) -> CleanedData:
|
||||
"""Cleans list-type data"""
|
||||
|
||||
print(f"🔧 Cleaning List data - Case {case_id}")
|
||||
print(f" Original items: {len(data)}")
|
||||
|
||||
cleaned_data = []
|
||||
operations = {
|
||||
'type': 'list',
|
||||
'bbox_fixes': 0,
|
||||
'removed_items': 0,
|
||||
'original_count': len(data)
|
||||
}
|
||||
|
||||
for i, item in enumerate(data):
|
||||
if not isinstance(item, dict):
|
||||
operations['removed_items'] += 1
|
||||
continue
|
||||
|
||||
# Check the bbox field
|
||||
if 'bbox' in item:
|
||||
bbox = item['bbox']
|
||||
|
||||
# Check bbox length - core logic
|
||||
if isinstance(bbox, list) and len(bbox) == 3:
|
||||
print(f" ⚠️ Item {i}: bbox has only 3 coordinates. Removing bbox, keeping category and text.")
|
||||
# Keep only category and text, ensuring order is preserved
|
||||
new_item = {}
|
||||
if 'category' in item:
|
||||
new_item['category'] = item['category']
|
||||
if 'text' in item:
|
||||
new_item['text'] = item['text']
|
||||
if new_item: # Add only if there is valid content
|
||||
cleaned_data.append(new_item)
|
||||
operations['bbox_fixes'] += 1
|
||||
else:
|
||||
operations['removed_items'] += 1
|
||||
continue
|
||||
elif isinstance(bbox, list) and len(bbox) == 4:
|
||||
# bbox is normal, add directly, preserving original order
|
||||
cleaned_data.append(item.copy())
|
||||
continue
|
||||
else:
|
||||
print(f" ❌ Item {i}: Abnormal bbox format, skipping.")
|
||||
operations['removed_items'] += 1
|
||||
continue
|
||||
else:
|
||||
# No bbox field, keep if category exists
|
||||
if 'category' in item:
|
||||
cleaned_data.append(item.copy())
|
||||
continue
|
||||
else:
|
||||
operations['removed_items'] += 1
|
||||
|
||||
operations['final_count'] = len(cleaned_data)
|
||||
print(f" ✅ Cleaning complete: {len(cleaned_data)} items, {operations['bbox_fixes']} bbox fixes, {operations['removed_items']} items removed")
|
||||
|
||||
return CleanedData(
|
||||
case_id=case_id,
|
||||
original_type='list',
|
||||
original_length=len(data),
|
||||
cleaned_data=cleaned_data,
|
||||
cleaning_operations=operations,
|
||||
success=True
|
||||
)
|
||||
|
||||
def clean_string_data(self, data_str: str, case_id: int) -> CleanedData:
|
||||
"""Cleans string-type data"""
|
||||
|
||||
print(f"🔧 Cleaning String data - Case {case_id}")
|
||||
print(f" Original length: {len(data_str):,}")
|
||||
|
||||
operations = {
|
||||
'type': 'str',
|
||||
'original_length': len(data_str),
|
||||
'delimiter_fixes': 0,
|
||||
'tail_truncated': False,
|
||||
'truncated_length': 0,
|
||||
'duplicate_dicts_removed': 0,
|
||||
'final_objects': 0
|
||||
}
|
||||
|
||||
try:
|
||||
# Step 1: Detect and fix missing delimiters
|
||||
data_str, delimiter_fixes = self._fix_missing_delimiters(data_str)
|
||||
operations['delimiter_fixes'] = delimiter_fixes
|
||||
|
||||
# Step 2: Truncate the last incomplete element
|
||||
data_str, tail_truncated = self._truncate_last_incomplete_element(data_str)
|
||||
operations['tail_truncated'] = tail_truncated
|
||||
operations['truncated_length'] = len(data_str)
|
||||
|
||||
# Step 3: Remove duplicate complete dict objects, preserving order
|
||||
data_str, duplicate_removes = self._remove_duplicate_complete_dicts_preserve_order(data_str)
|
||||
operations['duplicate_dicts_removed'] = duplicate_removes
|
||||
|
||||
# Step 4: Ensure correct JSON format
|
||||
data_str = self._ensure_json_format(data_str)
|
||||
|
||||
# Step 5: Try to parse the final result
|
||||
final_data = self._parse_final_json(data_str)
|
||||
|
||||
if final_data is not None:
|
||||
operations['final_objects'] = len(final_data)
|
||||
print(f" ✅ Cleaning complete: {len(final_data)} objects")
|
||||
|
||||
return CleanedData(
|
||||
case_id=case_id,
|
||||
original_type='str',
|
||||
original_length=operations['original_length'],
|
||||
cleaned_data=final_data,
|
||||
cleaning_operations=operations,
|
||||
success=True
|
||||
)
|
||||
else:
|
||||
raise Exception("Could not parse the cleaned data")
|
||||
|
||||
except Exception as e:
|
||||
print(f" ❌ Cleaning failed: {e}")
|
||||
return CleanedData(
|
||||
case_id=case_id,
|
||||
original_type='str',
|
||||
original_length=operations['original_length'],
|
||||
cleaned_data=[],
|
||||
cleaning_operations=operations,
|
||||
success=False
|
||||
)
|
||||
|
||||
def _fix_missing_delimiters(self, text: str) -> Tuple[str, int]:
|
||||
"""Fixes missing delimiters"""
|
||||
|
||||
fixes = 0
|
||||
|
||||
def replace_delimiter(match):
|
||||
nonlocal fixes
|
||||
fixes += 1
|
||||
return '},{'
|
||||
|
||||
text = self.missing_delimiter_pattern.sub(replace_delimiter, text)
|
||||
|
||||
if fixes > 0:
|
||||
print(f" ✅ Fixed {fixes} missing delimiters")
|
||||
|
||||
return text, fixes
|
||||
|
||||
def _truncate_last_incomplete_element(self, text: str) -> Tuple[str, bool]:
|
||||
"""Truncates the last incomplete element"""
|
||||
|
||||
# For very long text (>50k) or text not ending with ']', directly truncate the last '{"bbox":'
|
||||
needs_truncation = (
|
||||
len(text) > 50000 or
|
||||
not text.strip().endswith(']')
|
||||
)
|
||||
|
||||
if needs_truncation:
|
||||
# Check how many dict objects there are
|
||||
bbox_count = text.count('{"bbox":')
|
||||
|
||||
# If there is only one dict object, do not truncate to avoid deleting the only object
|
||||
if bbox_count <= 1:
|
||||
print(f" ⚠️ Only {bbox_count} dict objects found, skipping truncation to avoid deleting all content")
|
||||
return text, False
|
||||
|
||||
# Find the position of the last '{"bbox":'
|
||||
last_bbox_pos = text.rfind('{"bbox":')
|
||||
|
||||
if last_bbox_pos > 0:
|
||||
# Truncate before this position
|
||||
truncated_text = text[:last_bbox_pos].rstrip()
|
||||
|
||||
# Remove trailing comma
|
||||
if truncated_text.endswith(','):
|
||||
truncated_text = truncated_text[:-1]
|
||||
|
||||
print(f" ✂️ Truncated the last incomplete element, length reduced from {len(text):,} to {len(truncated_text):,}")
|
||||
return truncated_text, True
|
||||
|
||||
return text, False
|
||||
|
||||
def _remove_duplicate_complete_dicts_preserve_order(self, text: str) -> Tuple[str, int]:
|
||||
"""Removes duplicate complete dict objects, preserving original order"""
|
||||
|
||||
# Extract all dict objects, preserving order
|
||||
dict_matches = list(self.dict_pattern.finditer(text))
|
||||
|
||||
if not dict_matches:
|
||||
return text, 0
|
||||
|
||||
print(f" 📊 Found {len(dict_matches)} dict objects")
|
||||
|
||||
# Deduplication while preserving order: only keep the first occurrence of a dict
|
||||
unique_dicts = []
|
||||
seen_dict_strings = set()
|
||||
total_duplicates = 0
|
||||
|
||||
for match in dict_matches:
|
||||
dict_str = match.group()
|
||||
|
||||
if dict_str not in seen_dict_strings:
|
||||
unique_dicts.append(dict_str)
|
||||
seen_dict_strings.add(dict_str)
|
||||
else:
|
||||
total_duplicates += 1
|
||||
|
||||
if total_duplicates > 0:
|
||||
# Reconstruct the JSON array, preserving the original order
|
||||
new_text = '[' + ', '.join(unique_dicts) + ']'
|
||||
print(f" ✅ Removed {total_duplicates} duplicate dicts, keeping {len(unique_dicts)} unique dicts (order preserved)")
|
||||
return new_text, total_duplicates
|
||||
else:
|
||||
print(f" ✅ No duplicate dict objects found")
|
||||
return text, 0
|
||||
|
||||
def _ensure_json_format(self, text: str) -> str:
|
||||
"""Ensures correct JSON format"""
|
||||
|
||||
text = text.strip()
|
||||
|
||||
if not text.startswith('['):
|
||||
text = '[' + text
|
||||
|
||||
if not text.endswith(']'):
|
||||
# Remove trailing comma
|
||||
text = text.rstrip(',').rstrip()
|
||||
text += ']'
|
||||
|
||||
return text
|
||||
|
||||
def _parse_final_json(self, text: str) -> Optional[List[Dict]]:
|
||||
"""Tries to parse the final JSON"""
|
||||
|
||||
try:
|
||||
data = json.loads(text)
|
||||
if isinstance(data, list):
|
||||
return data
|
||||
except json.JSONDecodeError as e:
|
||||
print(f" ❌ JSON parsing failed: {e}")
|
||||
|
||||
# fallback1: Extract valid dict objects
|
||||
valid_dicts = []
|
||||
|
||||
for match in self.dict_pattern.finditer(text):
|
||||
dict_str = match.group()
|
||||
try:
|
||||
dict_obj = json.loads(dict_str)
|
||||
valid_dicts.append(dict_obj)
|
||||
except:
|
||||
continue
|
||||
|
||||
if valid_dicts:
|
||||
print(f" ✅ Extracted {len(valid_dicts)} valid dicts")
|
||||
return valid_dicts
|
||||
|
||||
# fallback2: Special handling for a single incomplete dict
|
||||
return self._handle_single_incomplete_dict(text)
|
||||
|
||||
return None
|
||||
|
||||
def _handle_single_incomplete_dict(self, text: str) -> Optional[List[Dict]]:
|
||||
"""Handles the special case of a single incomplete dict"""
|
||||
|
||||
# Check if it's a single incomplete dict case
|
||||
if not text.strip().startswith('[{"bbox":'):
|
||||
return None
|
||||
|
||||
try:
|
||||
# Try to extract bbox coordinates
|
||||
bbox_match = re.search(r'"bbox"\s*:\s*\[([^\]]+)\]', text)
|
||||
if not bbox_match:
|
||||
return None
|
||||
|
||||
bbox_str = bbox_match.group(1)
|
||||
bbox_coords = [int(x.strip()) for x in bbox_str.split(',')]
|
||||
|
||||
if len(bbox_coords) != 4:
|
||||
return None
|
||||
|
||||
# Try to extract category
|
||||
category_match = re.search(r'"category"\s*:\s*"([^"]+)"', text)
|
||||
category = category_match.group(1) if category_match else "Text"
|
||||
|
||||
# Try to extract the beginning of the text (first 10000 characters)
|
||||
text_match = re.search(r'"text"\s*:\s*"([^"]{0,10000})', text)
|
||||
if text_match:
|
||||
text_content = text_match.group(1)
|
||||
else:
|
||||
text_content = ""
|
||||
|
||||
# Construct the fixed dict
|
||||
fixed_dict = {
|
||||
"bbox": bbox_coords,
|
||||
"category": category
|
||||
}
|
||||
|
||||
if text_content:
|
||||
fixed_dict["text"] = text_content
|
||||
|
||||
print(f" 🔧 Special fix: single incomplete dict → {fixed_dict}")
|
||||
return [fixed_dict]
|
||||
|
||||
except Exception as e:
|
||||
print(f" ❌ Special fix failed: {e}")
|
||||
return None
|
||||
|
||||
def remove_duplicate_category_text_pairs_and_bbox(self, data_list: List[dict], case_id: int) -> List[dict]:
|
||||
"""Removes duplicate category-text pairs and duplicate bboxes"""
|
||||
|
||||
if not data_list or len(data_list) <= 1:
|
||||
print(f" 📊 Data length {len(data_list)} <= 1, skipping deduplication check")
|
||||
return data_list
|
||||
|
||||
print(f" 📊 Original data length: {len(data_list)}")
|
||||
|
||||
# 1. Count occurrences and positions of each category-text pair
|
||||
category_text_pairs = {}
|
||||
for i, item in enumerate(data_list):
|
||||
if isinstance(item, dict) and 'category' in item and 'text' in item:
|
||||
pair_key = (item.get('category', ''), item.get('text', ''))
|
||||
if pair_key not in category_text_pairs:
|
||||
category_text_pairs[pair_key] = []
|
||||
category_text_pairs[pair_key].append(i)
|
||||
|
||||
# 2. Count occurrences and positions of each bbox
|
||||
bbox_pairs = {}
|
||||
for i, item in enumerate(data_list):
|
||||
if isinstance(item, dict) and 'bbox' in item:
|
||||
bbox = item.get('bbox')
|
||||
if isinstance(bbox, list) and len(bbox) > 0:
|
||||
bbox_key = tuple(bbox) # Convert to tuple to use as a dictionary key
|
||||
if bbox_key not in bbox_pairs:
|
||||
bbox_pairs[bbox_key] = []
|
||||
bbox_pairs[bbox_key].append(i)
|
||||
|
||||
# 3. Identify items to be removed
|
||||
duplicates_to_remove = set()
|
||||
|
||||
# 3a. Process category-text pairs that appear 5 or more times
|
||||
for pair_key, positions in category_text_pairs.items():
|
||||
if len(positions) >= 5:
|
||||
category, text = pair_key
|
||||
# Keep the first occurrence, remove subsequent duplicates
|
||||
positions_to_remove = positions[1:]
|
||||
duplicates_to_remove.update(positions_to_remove)
|
||||
|
||||
print(f" 🔍 Found duplicate category-text pair: category='{category}', first 50 chars of text='{text[:50]}...'")
|
||||
print(f" Count: {len(positions)}, removing at positions: {positions_to_remove}")
|
||||
|
||||
# 3b. Process bboxes that appear 2 or more times
|
||||
for bbox_key, positions in bbox_pairs.items():
|
||||
if len(positions) >= 2:
|
||||
# Keep the first occurrence, remove subsequent duplicates
|
||||
positions_to_remove = positions[1:]
|
||||
duplicates_to_remove.update(positions_to_remove)
|
||||
|
||||
print(f" 🔍 Found duplicate bbox: {list(bbox_key)}")
|
||||
print(f" Count: {len(positions)}, removing at positions: {positions_to_remove}")
|
||||
|
||||
if not duplicates_to_remove:
|
||||
print(f" ✅ No category-text pairs or bboxes found exceeding the duplication threshold")
|
||||
return data_list
|
||||
|
||||
# 4. Remove duplicate items from the original data (preserving order)
|
||||
cleaned_data = []
|
||||
removed_count = 0
|
||||
for i, item in enumerate(data_list):
|
||||
if i not in duplicates_to_remove:
|
||||
cleaned_data.append(item)
|
||||
else:
|
||||
removed_count += 1
|
||||
|
||||
print(f" ✅ Deduplication complete: Removed {removed_count} duplicate items")
|
||||
print(f" 📊 Cleaned data length: {len(cleaned_data)}")
|
||||
|
||||
return cleaned_data
|
||||
|
||||
def clean_model_output(self, model_output: str):
|
||||
try:
|
||||
# Select cleaning method based on data type
|
||||
if isinstance(model_output, list):
|
||||
result = self.clean_list_data(model_output, case_id=0)
|
||||
else:
|
||||
result = self.clean_string_data(str(model_output), case_id=0)
|
||||
|
||||
# Add deduplication step: remove duplicate category-text pairs and bboxes
|
||||
if result and hasattr(result, 'success') and result.success and result.cleaned_data:
|
||||
original_data = result.cleaned_data
|
||||
deduplicated_data = self.remove_duplicate_category_text_pairs_and_bbox(original_data, case_id=0)
|
||||
# Update the cleaned_data in the CleanedData object
|
||||
result.cleaned_data = deduplicated_data
|
||||
return result.cleaned_data
|
||||
except Exception as e:
|
||||
print(f"❌ Case cleaning failed: {e}")
|
||||
return model_output
|
||||
|
||||
def clean_all_data(self, jsonl_path: str) -> List[CleanedData]:
|
||||
"""Cleans all data from a JSONL file"""
|
||||
|
||||
print(f"🚀 Starting to clean JSONL file: {jsonl_path}")
|
||||
|
||||
with open(jsonl_path, 'r', encoding='utf-8') as f:
|
||||
lines = f.readlines()
|
||||
|
||||
datas = []
|
||||
for i, line in enumerate(lines):
|
||||
if line.strip():
|
||||
try:
|
||||
data = json.loads(line)
|
||||
predict_field = data.get('predict')
|
||||
case_id = i + 1
|
||||
|
||||
print(f"\n{'='*50}")
|
||||
print(f"🎯 Cleaning Case {case_id}")
|
||||
print(f"{'='*50}")
|
||||
|
||||
# Select cleaning method based on data type
|
||||
if isinstance(predict_field, list):
|
||||
print("📊 Data type: List")
|
||||
result = self.clean_list_data(predict_field, case_id)
|
||||
else:
|
||||
print("📊 Data type: String")
|
||||
result = self.clean_string_data(str(predict_field), case_id)
|
||||
|
||||
# Add deduplication step: remove duplicate category-text pairs and bboxes
|
||||
if result and hasattr(result, 'success') and result.success and result.cleaned_data:
|
||||
print("🔄 Checking for and removing duplicate category-text pairs and bboxes...")
|
||||
original_data = result.cleaned_data
|
||||
deduplicated_data = self.remove_duplicate_category_text_pairs_and_bbox(original_data, case_id)
|
||||
# Update the cleaned_data in the CleanedData object
|
||||
result.cleaned_data = deduplicated_data
|
||||
data['predict_resized'] = result.cleaned_data
|
||||
|
||||
datas.append(data)
|
||||
self.cleaned_results.append(result)
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Case {i+1} cleaning failed: {e}")
|
||||
traceback.print_exc()
|
||||
|
||||
save_path = jsonl_path.replace('.jsonl', '_filtered.jsonl')
|
||||
with open(save_path, 'w') as w:
|
||||
for data in datas:
|
||||
w.write(json.dumps(data, ensure_ascii=False) + '\n')
|
||||
print(f"✅ Saved cleaned data to: {save_path}")
|
||||
|
||||
return self.cleaned_results
|
||||
|
||||
def save_cleaned_data(self, output_dir: str):
|
||||
"""Saves the cleaned data"""
|
||||
|
||||
print(f"\n💾 Saving cleaned data to: {output_dir}")
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# 1. Save cleaned data for each case
|
||||
for result in self.cleaned_results:
|
||||
case_filename = f"cleaned_case_{result.case_id:02d}.json"
|
||||
case_filepath = os.path.join(output_dir, case_filename)
|
||||
|
||||
# Save the cleaned data
|
||||
with open(case_filepath, 'w', encoding='utf-8') as f:
|
||||
json.dump(result.cleaned_data, f, ensure_ascii=False, indent=2)
|
||||
|
||||
print(f" ✅ Case {result.case_id}: {len(result.cleaned_data)} objects → {case_filename}")
|
||||
|
||||
# 2. Save all cleaned data to a single file
|
||||
all_cleaned_data = []
|
||||
for result in self.cleaned_results:
|
||||
all_cleaned_data.append({
|
||||
'case_id': result.case_id,
|
||||
'original_type': result.original_type,
|
||||
'original_length': result.original_length,
|
||||
'cleaned_objects_count': len(result.cleaned_data),
|
||||
'success': result.success,
|
||||
'cleaning_operations': result.cleaning_operations,
|
||||
'cleaned_data': result.cleaned_data
|
||||
})
|
||||
|
||||
all_data_filepath = os.path.join(output_dir, "all_cleaned_data.json")
|
||||
with open(all_data_filepath, 'w', encoding='utf-8') as f:
|
||||
json.dump(all_cleaned_data, f, ensure_ascii=False, indent=2)
|
||||
|
||||
print(f" 📁 All data: {len(all_cleaned_data)} cases → all_cleaned_data.json")
|
||||
|
||||
# 3. Generate a cleaning report
|
||||
self._generate_cleaning_report(output_dir)
|
||||
|
||||
def _generate_cleaning_report(self, output_dir: str):
|
||||
"""Generates a cleaning report"""
|
||||
|
||||
report = []
|
||||
report.append("📊 Data Cleaning Report")
|
||||
report.append("=" * 60)
|
||||
import datetime
|
||||
report.append(f"Processing Time: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
||||
report.append("")
|
||||
|
||||
# Overall statistics
|
||||
total_cases = len(self.cleaned_results)
|
||||
successful_cases = sum(1 for r in self.cleaned_results if r.success)
|
||||
total_objects = sum(len(r.cleaned_data) for r in self.cleaned_results)
|
||||
|
||||
report.append("📈 Overall Statistics:")
|
||||
report.append(f" Total Cases: {total_cases}")
|
||||
report.append(f" Successfully Cleaned: {successful_cases}")
|
||||
report.append(f" Success Rate: {successful_cases/total_cases*100:.1f}%")
|
||||
report.append(f" Total Recovered Objects: {total_objects}")
|
||||
report.append("")
|
||||
|
||||
# Detailed statistics
|
||||
list_results = [r for r in self.cleaned_results if r.original_type == 'list']
|
||||
str_results = [r for r in self.cleaned_results if r.original_type == 'str']
|
||||
|
||||
if list_results:
|
||||
report.append("📋 List Type Cleaning Statistics:")
|
||||
for r in list_results:
|
||||
ops = r.cleaning_operations
|
||||
report.append(f" Case {r.case_id}: {ops['original_count']} → {ops['final_count']} objects")
|
||||
if ops['bbox_fixes'] > 0:
|
||||
report.append(f" - bbox fixes: {ops['bbox_fixes']}")
|
||||
if ops['removed_items'] > 0:
|
||||
report.append(f" - invalid items removed: {ops['removed_items']}")
|
||||
report.append("")
|
||||
|
||||
if str_results:
|
||||
report.append("📝 String Type Cleaning Statistics:")
|
||||
for r in str_results:
|
||||
ops = r.cleaning_operations
|
||||
status = "✅" if r.success else "❌"
|
||||
report.append(f" Case {r.case_id} {status}: {ops['original_length']:,} chars → {ops['final_objects']} objects")
|
||||
details = []
|
||||
if ops['delimiter_fixes'] > 0:
|
||||
details.append(f"Delimiter fixes: {ops['delimiter_fixes']}")
|
||||
if ops['tail_truncated']:
|
||||
reduction = ops['original_length'] - ops['truncated_length']
|
||||
details.append(f"Tail truncation: -{reduction:,} chars")
|
||||
if ops['duplicate_dicts_removed'] > 0:
|
||||
details.append(f"Duplicates removed: {ops['duplicate_dicts_removed']}")
|
||||
if details:
|
||||
report.append(f" - {', '.join(details)}")
|
||||
report.append("")
|
||||
|
||||
# Note on data order
|
||||
report.append("🔄 Data Order Guarantee:")
|
||||
report.append(" ✅ The relative order of all dict objects is preserved during cleaning.")
|
||||
report.append(" ✅ When deduplicating, the first occurrence of a dict is kept, and subsequent duplicates are removed.")
|
||||
report.append(" ✅ The order of items in List-type data is fully preserved.")
|
||||
|
||||
# Save the report
|
||||
report_filepath = os.path.join(output_dir, "cleaning_report.txt")
|
||||
with open(report_filepath, 'w', encoding='utf-8') as f:
|
||||
f.write('\n'.join(report))
|
||||
|
||||
print(f" 📋 Cleaning report: cleaning_report.txt")
|
||||
|
||||
# Also print to console
|
||||
print(f"\n{chr(10).join(report)}")
|
||||
|
||||
|
||||
def main():
|
||||
"""Main function"""
|
||||
|
||||
# Create a data cleaner instance
|
||||
cleaner = OutputCleaner()
|
||||
|
||||
# Input file
|
||||
jsonl_path = "output_with_failcase.jsonl"
|
||||
|
||||
# Output directory
|
||||
output_dir = "output_with_failcase_cleaned"
|
||||
|
||||
# Clean all data
|
||||
results = cleaner.clean_all_data(jsonl_path)
|
||||
|
||||
# Save the cleaned data
|
||||
cleaner.save_cleaned_data(output_dir)
|
||||
|
||||
print(f"\n🎉 Data cleaning complete!")
|
||||
print(f"📁 Cleaned data saved in: {output_dir}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Executable
+34
@@ -0,0 +1,34 @@
|
||||
dict_promptmode_to_prompt = {
|
||||
# prompt_layout_all_en: parse all layout info in json format.
|
||||
"prompt_layout_all_en": """Please output the layout information from the PDF image, including each layout element's bbox, its category, and the corresponding text content within the bbox.
|
||||
|
||||
1. Bbox format: [x1, y1, x2, y2]
|
||||
|
||||
2. Layout Categories: The possible categories are ['Caption', 'Footnote', 'Formula', 'List-item', 'Page-footer', 'Page-header', 'Picture', 'Section-header', 'Table', 'Text', 'Title'].
|
||||
|
||||
3. Text Extraction & Formatting Rules:
|
||||
- Picture: For the 'Picture' category, the text field should be omitted.
|
||||
- Formula: Format its text as LaTeX.
|
||||
- Table: Format its text as HTML.
|
||||
- All Others (Text, Title, etc.): Format their text as Markdown.
|
||||
|
||||
4. Constraints:
|
||||
- The output text must be the original text from the image, with no translation.
|
||||
- All layout elements must be sorted according to human reading order.
|
||||
|
||||
5. Final Output: The entire output must be a single JSON object.
|
||||
""",
|
||||
|
||||
# prompt_layout_only_en: layout detection
|
||||
"prompt_layout_only_en": """Please output the layout information from this PDF image, including each layout's bbox and its category. The bbox should be in the format [x1, y1, x2, y2]. The layout categories for the PDF document include ['Caption', 'Footnote', 'Formula', 'List-item', 'Page-footer', 'Page-header', 'Picture', 'Section-header', 'Table', 'Text', 'Title']. Do not output the corresponding text. The layout result should be in JSON format.""",
|
||||
|
||||
# prompt_layout_only_en: parse ocr text except the Page-header and Page-footer
|
||||
"prompt_ocr": """Extract the text content from this image.""",
|
||||
|
||||
# prompt_grounding_ocr: extract text content in the given bounding box
|
||||
"prompt_grounding_ocr": """Extract text from the given bounding box on the image (format: [x1, y1, x2, y2]).\nBounding Box:\n""",
|
||||
|
||||
# "prompt_table_html": """Convert the table in this image to HTML.""",
|
||||
# "prompt_table_latex": """Convert the table in this image to LaTeX.""",
|
||||
# "prompt_formula_latex": """Convert the formula in this image to LaTeX.""",
|
||||
}
|
||||
Reference in New Issue
Block a user