dots.ocr release

This commit is contained in:
zhangwei13
2025-07-30 19:12:56 +08:00
commit be77dff22c
55 changed files with 5187 additions and 0 deletions
+1
View File
@@ -0,0 +1 @@
from .prompts import dict_promptmode_to_prompt
+5
View File
@@ -0,0 +1,5 @@
MIN_PIXELS=3136
MAX_PIXELS=11289600
IMAGE_FACTOR=28
image_extensions = {'.jpg', '.jpeg', '.png'}
+61
View File
@@ -0,0 +1,61 @@
import os
from PIL import Image
def is_valid_image_path(image_path):
"""
Checks if the image path is valid.
Args:
image_path: The path to the image.
Returns:
bool: True if the path is valid, False otherwise.
"""
if not os.path.exists(image_path):
return False
# Check if the file extension is one of the common image formats.
image_extensions = ['.jpg', '.jpeg', '.png', '.gif', '.bmp']
_, extension = os.path.splitext(image_path)
if extension.lower() in image_extensions:
return True
else:
return False
def read_image(image_path, use_native=False):
"""
Reads an image and resizes it while maintaining aspect ratio.
Args:
image_path: The path to the image.
use_native: If True, the max dimension of the original image is used as the max size.
If False, max size is set to 1024.
Returns:
tuple: (resized_image, original_width, original_height)
"""
# Create a default 512x512 blue image as a fallback.
image = Image.new('RGB', (512, 512), color=(0, 0, 255))
if is_valid_image_path(image_path):
image = Image.open(image_path)
else:
raise FileNotFoundError(f"{image_path}: Image path does not exist")
w, h = image.size
if use_native:
max_size = max(w, h)
else:
max_size = 1024
if w > h:
new_w = max_size
new_h = int(h * max_size / w)
else:
new_h = max_size
new_w = int(w * max_size / h)
image = image.resize((new_w, new_h))
return image, w, h
+60
View File
@@ -0,0 +1,60 @@
import fitz
import numpy as np
import enum
from pydantic import BaseModel, Field
from PIL import Image
class SupportedPdfParseMethod(enum.Enum):
OCR = 'ocr'
TXT = 'txt'
class PageInfo(BaseModel):
"""The width and height of page
"""
w: float = Field(description='the width of page')
h: float = Field(description='the height of page')
def fitz_doc_to_image(doc, target_dpi=200, origin_dpi=None) -> dict:
"""Convert fitz.Document to image, Then convert the image to numpy array.
Args:
doc (_type_): pymudoc page
dpi (int, optional): reset the dpi of dpi. Defaults to 200.
Returns:
dict: {'img': numpy array, 'width': width, 'height': height }
"""
from PIL import Image
mat = fitz.Matrix(target_dpi / 72, target_dpi / 72)
pm = doc.get_pixmap(matrix=mat, alpha=False)
if pm.width > 4500 or pm.height > 4500:
mat = fitz.Matrix(72 / 72, 72 / 72) # use fitz default dpi
pm = doc.get_pixmap(matrix=mat, alpha=False)
image = Image.frombytes('RGB', (pm.width, pm.height), pm.samples)
return image
def load_images_from_pdf(pdf_file, dpi=200, start_page_id=0, end_page_id=None) -> list:
images = []
with fitz.open(pdf_file) as doc:
pdf_page_num = doc.page_count
end_page_id = (
end_page_id
if end_page_id is not None and end_page_id >= 0
else pdf_page_num - 1
)
if end_page_id > pdf_page_num - 1:
print('end_page_id is out of range, use images length')
end_page_id = pdf_page_num - 1
for index in range(0, doc.page_count):
if start_page_id <= index <= end_page_id:
page = doc[index]
img = fitz_doc_to_image(page, target_dpi=dpi)
images.append(img)
return images
+205
View File
@@ -0,0 +1,205 @@
import os
import sys
import json
import re
from PIL import Image
from dots_ocr.utils.image_utils import PILimage_to_base64
def has_latex_markdown(text: str) -> bool:
"""
Checks if a string contains LaTeX markdown patterns.
Args:
text (str): The string to check.
Returns:
bool: True if LaTeX markdown is found, otherwise False.
"""
if not isinstance(text, str):
return False
# Define regular expression patterns for LaTeX markdown
latex_patterns = [
r'\$\$.*?\$\$', # Block-level math formula $$...$$
r'\$[^$\n]+?\$', # Inline math formula $...$
r'\\begin\{.*?\}.*?\\end\{.*?\}', # LaTeX environment \begin{...}...\end{...}
r'\\[a-zA-Z]+\{.*?\}', # LaTeX command \command{...}
r'\\[a-zA-Z]+', # Simple LaTeX command \command
r'\\\[.*?\\\]', # Display math formula \[...\]
r'\\\(.*?\\\)', # Inline math formula \(...\)
]
# Check if any of the patterns match
for pattern in latex_patterns:
if re.search(pattern, text, re.DOTALL):
return True
return False
def clean_latex_preamble(latex_text: str) -> str:
"""
Removes LaTeX preamble commands like document class and package imports.
Args:
latex_text (str): The original LaTeX text.
Returns:
str: The cleaned LaTeX text without preamble commands.
"""
# Define patterns to be removed
patterns = [
r'\\documentclass\{[^}]+\}', # \documentclass{...}
r'\\usepackage\{[^}]+\}', # \usepackage{...}
r'\\usepackage\[[^\]]*\]\{[^}]+\}', # \usepackage[options]{...}
r'\\begin\{document\}', # \begin{document}
r'\\end\{document\}', # \end{document}
]
# Apply each pattern to clean the text
cleaned_text = latex_text
for pattern in patterns:
cleaned_text = re.sub(pattern, '', cleaned_text, flags=re.IGNORECASE)
return cleaned_text
def get_formula_in_markdown(text: str) -> str:
"""
Formats a string containing a formula into a standard Markdown block.
Args:
text (str): The input string, potentially containing a formula.
Returns:
str: The formatted string, ready for Markdown rendering.
"""
# Remove leading/trailing whitespace
text = text.strip()
# Check if it's already enclosed in $$
if text.startswith('$$') and text.endswith('$$'):
text_new = text[2:-2].strip()
if not '$' in text_new:
return f"$$\n{text_new}\n$$"
else:
return text
# Handle \[...\] format, convert to $$...$$
if text.startswith('\\[') and text.endswith('\\]'):
inner_content = text[2:-2].strip()
return f"$$\n{inner_content}\n$$"
# Check if it's enclosed in \[ \]
if len(re.findall(r'.*\\\[.*\\\].*', text)) > 0:
return text
# Handle inline formulas ($...$)
pattern = r'\$([^$]+)\$'
matches = re.findall(pattern, text)
if len(matches) > 0:
# It's an inline formula, return it as is
return text
# If no LaTeX markdown syntax is present, return directly
if not has_latex_markdown(text):
return text
# Handle unnecessary LaTeX formatting like \usepackage
if 'usepackage' in text:
text = clean_latex_preamble(text)
if text[0] == '`' and text[-1] == '`':
text = text[1:-1]
# Enclose the final text in a $$ block with newlines
text = f"$$\n{text}\n$$"
return text
def clean_text(text: str) -> str:
"""
Cleans text by removing extra whitespace.
Args:
text: The original text.
Returns:
str: The cleaned text.
"""
if not text:
return ""
# Remove leading and trailing whitespace
text = text.strip()
# Replace multiple consecutive whitespace characters with a single space
text = re.sub(r'\s+', ' ', text)
return text
def layoutjson2md(image: Image.Image, cells: list, text_key: str = 'text', no_page_hf: bool = False) -> str:
"""
Converts a layout JSON format to Markdown.
In the layout JSON, formulas are LaTeX, tables are HTML, and text is Markdown.
Args:
image: A PIL Image object.
cells: A list of dictionaries, each representing a layout cell.
text_key: The key for the text field in the cell dictionary.
no_page_header_footer: If True, skips page headers and footers.
Returns:
str: The text in Markdown format.
"""
text_items = []
for i, cell in enumerate(cells):
x1, y1, x2, y2 = [int(coord) for coord in cell['bbox']]
text = cell.get(text_key, "")
if no_page_hf and cell['category'] in ['Page-header', 'Page-footer']:
continue
if cell['category'] == 'Picture':
image_crop = image.crop((x1, y1, x2, y2))
image_base64 = PILimage_to_base64(image_crop)
text_items.append(f"![]({image_base64})")
elif cell['category'] == 'Formula':
text_items.append(get_formula_in_markdown(text))
else:
text = clean_text(text)
text_items.append(f"{text}")
markdown_text = '\n\n'.join(text_items)
return markdown_text
def fix_streamlit_formulas(md: str) -> str:
"""
Fixes the format of formulas in Markdown to ensure they display correctly in Streamlit.
It adds a newline after the opening $$ and before the closing $$ if they don't already exist.
Args:
md_text (str): The Markdown text to fix.
Returns:
str: The fixed Markdown text.
"""
# This inner function will be used by re.sub to perform the replacement
def replace_formula(match):
content = match.group(1)
# If the content already has surrounding newlines, don't add more.
if content.startswith('\n'):
content = content[1:]
if content.endswith('\n'):
content = content[:-1]
return f'$$\n{content}\n$$'
# Use regex to find all $$....$$ patterns and replace them using the helper function.
return re.sub(r'\$\$(.*?)\$\$', replace_formula, md, flags=re.DOTALL)
+196
View File
@@ -0,0 +1,196 @@
import math
import base64
from PIL import Image
from typing import Tuple
import os
from dots_ocr.utils.consts import IMAGE_FACTOR, MIN_PIXELS, MAX_PIXELS
from dots_ocr.utils.doc_utils import fitz_doc_to_image
from io import BytesIO
import fitz
import requests
import copy
def round_by_factor(number: int, factor: int) -> int:
"""Returns the closest integer to 'number' that is divisible by 'factor'."""
return round(number / factor) * factor
def ceil_by_factor(number: int, factor: int) -> int:
"""Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
return math.ceil(number / factor) * factor
def floor_by_factor(number: int, factor: int) -> int:
"""Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
return math.floor(number / factor) * factor
def smart_resize(
height: int,
width: int,
factor: int = 28,
min_pixels: int = 3136,
max_pixels: int = 11289600,
):
"""Rescales the image so that the following conditions are met:
1. Both dimensions (height and width) are divisible by 'factor'.
2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
3. The aspect ratio of the image is maintained as closely as possible.
"""
if max(height, width) / min(height, width) > 200:
raise ValueError(
f"absolute aspect ratio must be smaller than 200, got {max(height, width) / min(height, width)}"
)
h_bar = max(factor, round_by_factor(height, factor))
w_bar = max(factor, round_by_factor(width, factor))
if h_bar * w_bar > max_pixels:
beta = math.sqrt((height * width) / max_pixels)
h_bar = max(factor, floor_by_factor(height / beta, factor))
w_bar = max(factor, floor_by_factor(width / beta, factor))
elif h_bar * w_bar < min_pixels:
beta = math.sqrt(min_pixels / (height * width))
h_bar = ceil_by_factor(height * beta, factor)
w_bar = ceil_by_factor(width * beta, factor)
if h_bar * w_bar > max_pixels: # max_pixels first to control the token length
beta = math.sqrt((h_bar * w_bar) / max_pixels)
h_bar = max(factor, floor_by_factor(h_bar / beta, factor))
w_bar = max(factor, floor_by_factor(w_bar / beta, factor))
return h_bar, w_bar
def PILimage_to_base64(image, format='PNG'):
buffered = BytesIO()
image.save(buffered, format=format)
base64_str = base64.b64encode(buffered.getvalue()).decode('utf-8')
return f"data:image;base64,{base64_str}"
def to_rgb(pil_image: Image.Image) -> Image.Image:
if pil_image.mode == 'RGBA':
white_background = Image.new("RGB", pil_image.size, (255, 255, 255))
white_background.paste(pil_image, mask=pil_image.split()[3]) # Use alpha channel as mask
return white_background
else:
return pil_image.convert("RGB")
# copy from https://github.com/QwenLM/Qwen2.5-VL/blob/main/qwen-vl-utils/src/qwen_vl_utils/vision_process.py
def fetch_image(
image,
min_pixels=None,
max_pixels=None,
resized_height=None,
resized_width=None,
) -> Image.Image:
assert image is not None, f"image not found, maybe input format error: {image}"
image_obj = None
if isinstance(image, Image.Image):
image_obj = image
elif image.startswith("http://") or image.startswith("https://"):
# fix memory leak issue while using BytesIO
with requests.get(image, stream=True) as response:
response.raise_for_status()
with BytesIO(response.content) as bio:
image_obj = copy.deepcopy(Image.open(bio))
elif image.startswith("file://"):
image_obj = Image.open(image[7:])
elif image.startswith("data:image"):
if "base64," in image:
_, base64_data = image.split("base64,", 1)
data = base64.b64decode(base64_data)
# fix memory leak issue while using BytesIO
with BytesIO(data) as bio:
image_obj = copy.deepcopy(Image.open(bio))
else:
image_obj = Image.open(image)
if image_obj is None:
raise ValueError(f"Unrecognized image input, support local path, http url, base64 and PIL.Image, got {image}")
image = to_rgb(image_obj)
## resize
if resized_height and resized_width:
resized_height, resized_width = smart_resize(
resized_height,
resized_width,
factor=IMAGE_FACTOR,
)
assert resized_height>0 and resized_width>0, f"resized_height: {resized_height}, resized_width: {resized_width}, min_pixels: {min_pixels}, max_pixels:{max_pixels}, width: {width}, height:{height}, "
image = image.resize((resized_width, resized_height))
elif min_pixels or max_pixels:
width, height = image.size
if not min_pixels:
min_pixels = MIN_PIXELS
if not max_pixels:
max_pixels = MAX_PIXELS
resized_height, resized_width = smart_resize(
height,
width,
factor=IMAGE_FACTOR,
min_pixels=min_pixels,
max_pixels=max_pixels,
)
assert resized_height>0 and resized_width>0, f"resized_height: {resized_height}, resized_width: {resized_width}, min_pixels: {min_pixels}, max_pixels:{max_pixels}, width: {width}, height:{height}, "
image = image.resize((resized_width, resized_height))
return image
def get_input_dimensions(
image: Image.Image,
min_pixels: int,
max_pixels: int,
factor: int = 28
) -> Tuple[int, int]:
"""
Gets the resized dimensions of the input image.
Args:
image: The original image.
min_pixels: The minimum number of pixels.
max_pixels: The maximum number of pixels.
factor: The resizing factor.
Returns:
The resized (width, height).
"""
input_height, input_width = smart_resize(
image.height,
image.width,
factor=factor,
min_pixels=min_pixels,
max_pixels=max_pixels
)
return input_width, input_height
def get_image_by_fitz_doc(image, target_dpi=200):
# get image through fitz, to get target dpi image, mainly for higher image
if not isinstance(image, Image.Image):
assert isinstance(image, str)
_, file_ext = os.path.splitext(image)
assert file_ext in {'.jpg', '.jpeg', '.png'}
if image.startswith("http://") or image.startswith("https://"):
with requests.get(image, stream=True) as response:
response.raise_for_status()
data_bytes = response.content
else:
with open(image, 'rb') as f:
data_bytes = f.read()
image = Image.open(BytesIO(data_bytes))
else:
data_bytes = BytesIO()
image.save(data_bytes, format='PNG')
origin_dpi = image.info.get('dpi', None)
pdf_bytes = fitz.open(stream=data_bytes).convert_to_pdf()
doc = fitz.open('pdf', pdf_bytes)
page = doc[0]
image_fitz = fitz_doc_to_image(page, target_dpi=target_dpi, origin_dpi=origin_dpi)
return image_fitz
+228
View File
@@ -0,0 +1,228 @@
from PIL import Image
from typing import Dict, List
import fitz
from io import BytesIO
import json
from dots_ocr.utils.image_utils import smart_resize
from dots_ocr.utils.consts import MIN_PIXELS, MAX_PIXELS
from dots_ocr.utils.output_cleaner import OutputCleaner
# Define a color map (using RGBA format)
dict_layout_type_to_color = {
"Text": (0, 128, 0, 256), # Green, translucent
"Picture": (255, 0, 255, 256), # Magenta, translucent
"Caption": (255, 165, 0, 256), # Orange, translucent
"Section-header": (0, 255, 255, 256), # Cyan, translucent
"Footnote": (0, 128, 0, 256), # Green, translucent
"Formula": (128, 128, 128, 256), # Gray, translucent
"Table": (255, 192, 203, 256), # Pink, translucent
"Title": (255, 0, 0, 256), # Red, translucent
"List-item": (0, 0, 255, 256), # Blue, translucent
"Page-header": (0, 128, 0, 256), # Green, translucent
"Page-footer": (128, 0, 128, 256), # Purple, translucent
"Other": (165, 42, 42, 256), # Brown, translucent
"Unknown": (0, 0, 0, 0),
}
def draw_layout_on_image(image, cells, resized_height=None, resized_width=None, fill_bbox=True, draw_bbox=True):
"""
Draw transparent boxes on an image.
Args:
image: The source PIL Image.
cells: A list of cells containing bounding box information.
resized_height: The resized height.
resized_width: The resized width.
fill_bbox: Whether to fill the bounding box.
draw_bbox: Whether to draw the bounding box.
Returns:
PIL.Image: The image with drawings.
"""
# origin_image = Image.open(image_path)
original_width, original_height = image.size
# Create a new PDF document
doc = fitz.open()
# Get image information
img_bytes = BytesIO()
image.save(img_bytes, format='PNG')
# pix = fitz.Pixmap(image_path)
pix = fitz.Pixmap(img_bytes)
# Create a page
page = doc.new_page(width=pix.width, height=pix.height)
page.insert_image(
fitz.Rect(0, 0, pix.width, pix.height),
# filename=image_path
pixmap=pix
)
for i, cell in enumerate(cells):
bbox = cell['bbox']
layout_type = cell['category']
order = i
top_left = (bbox[0], bbox[1])
down_right = (bbox[2], bbox[3])
if resized_height and resized_width:
scale_x = resized_width / original_width
scale_y = resized_height / original_height
top_left = (int(bbox[0] / scale_x), int(bbox[1] / scale_y))
down_right = (int(bbox[2] / scale_x), int(bbox[3] / scale_y))
color = dict_layout_type_to_color.get(layout_type, (0, 128, 0, 256))
color = [col/255 for col in color[:3]]
x0, y0, x1, y1 = top_left[0], top_left[1], down_right[0], down_right[1]
rect_coords = fitz.Rect(x0, y0, x1, y1)
if draw_bbox:
if fill_bbox:
page.draw_rect(
rect_coords,
color=None,
fill=color,
fill_opacity=0.3,
width=0.5,
overlay=True,
) # Draw the rectangle
else:
page.draw_rect(
rect_coords,
color=color,
fill=None,
fill_opacity=1,
width=0.5,
overlay=True,
) # Draw the rectangle
order_cate = f"{order}_{layout_type}"
page.insert_text(
(x1, y0 + 20), order_cate, fontsize=20, color=color
) # Insert the index in the top left corner of the rectangle
# Convert to a Pixmap (maintaining original dimensions)
mat = fitz.Matrix(1.0, 1.0)
pix = page.get_pixmap(matrix=mat)
return Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
def pre_process_bboxes(
origin_image,
bboxes,
input_width,
input_height,
factor: int = 28,
min_pixels: int = 3136,
max_pixels: int = 11289600
):
assert isinstance(bboxes, list) and len(bboxes) > 0 and isinstance(bboxes[0], list)
min_pixels = min_pixels or MIN_PIXELS
max_pixels = max_pixels or MAX_PIXELS
original_width, original_height = origin_image.size
input_height, input_width = smart_resize(input_height, input_width, min_pixels=min_pixels, max_pixels=max_pixels)
scale_x = original_width / input_width
scale_y = original_height / input_height
bboxes_out = []
for bbox in bboxes:
bbox_resized = [
int(float(bbox[0]) / scale_x),
int(float(bbox[1]) / scale_y),
int(float(bbox[2]) / scale_x),
int(float(bbox[3]) / scale_y)
]
bboxes_out.append(bbox_resized)
return bboxes_out
def post_process_cells(
origin_image: Image.Image,
cells: List[Dict],
input_width, # server input width, also has smart_resize in server
input_height,
factor: int = 28,
min_pixels: int = 3136,
max_pixels: int = 11289600
) -> List[Dict]:
"""
Post-processes cell bounding boxes, converting coordinates from the resized dimensions back to the original dimensions.
Args:
origin_image: The original PIL Image.
cells: A list of cells containing bounding box information.
input_width: The width of the input image sent to the server.
input_height: The height of the input image sent to the server.
factor: Resizing factor.
min_pixels: Minimum number of pixels.
max_pixels: Maximum number of pixels.
Returns:
A list of post-processed cells.
"""
assert isinstance(cells, list) and len(cells) > 0 and isinstance(cells[0], dict)
min_pixels = min_pixels or MIN_PIXELS
max_pixels = max_pixels or MAX_PIXELS
original_width, original_height = origin_image.size
input_height, input_width = smart_resize(input_height, input_width, min_pixels=min_pixels, max_pixels=max_pixels)
scale_x = input_width / original_width
scale_y = input_height / original_height
cells_out = []
for cell in cells:
bbox = cell['bbox']
bbox_resized = [
int(float(bbox[0]) / scale_x),
int(float(bbox[1]) / scale_y),
int(float(bbox[2]) / scale_x),
int(float(bbox[3]) / scale_y)
]
cell_copy = cell.copy()
cell_copy['bbox'] = bbox_resized
cells_out.append(cell_copy)
return cells_out
def is_legal_bbox(cells):
for cell in cells:
bbox = cell['bbox']
if bbox[2] <= bbox[0] or bbox[3] <= bbox[1]:
return False
return True
def post_process_output(response, prompt_mode, origin_image, input_image, min_pixels=None, max_pixels=None):
if prompt_mode in ["prompt_ocr", "prompt_table_html", "prompt_table_latex", "prompt_formula_latex"]:
return response
json_load_failed = False
cells = response
try:
cells = json.loads(cells)
cells = post_process_cells(
origin_image,
cells,
input_image.width,
input_image.height,
min_pixels=min_pixels,
max_pixels=max_pixels
)
return cells, False
except Exception as e:
print(f"cells post process error: {e}, when using {prompt_mode}")
json_load_failed = True
if json_load_failed:
cleaner = OutputCleaner()
response_clean = cleaner.clean_model_output(cells)
if isinstance(response_clean, list):
response_clean = "\n\n".join([cell['text'] for cell in response_clean if 'text' in cell])
return response_clean, True
+623
View File
@@ -0,0 +1,623 @@
#!/usr/bin/env python3
"""
Data Cleaning Script - Cleans all data using a simplified regex method and saves the results
Features:
1. Cleans all cases using a simplified regex method.
2. Saves the cleaned data for each case.
3. Ensures the relative order of dicts remains unchanged.
4. Generates a before-and-after cleaning report.
"""
import json
import re
import os
from typing import Dict, List, Tuple, Optional, Any
from dataclasses import dataclass
from collections import Counter
import traceback
@dataclass
class CleanedData:
"""Data structure for cleaned data"""
case_id: int
original_type: str # 'list' or 'str'
original_length: int
cleaned_data: List[Dict]
cleaning_operations: Dict[str, Any] # Records the cleaning operations performed
success: bool
class OutputCleaner:
"""Data Cleaner - Based on a simplified regex method"""
def __init__(self):
# Simplified regular expression patterns
self.dict_pattern = re.compile(r'\{[^{}]*?"bbox"\s*:\s*\[[^\]]*?\][^{}]*?\}', re.DOTALL)
self.bbox_pattern = re.compile(r'"bbox"\s*:\s*\[([^\]]+)\]')
self.missing_delimiter_pattern = re.compile(r'\}\s*\{(?!")')
self.cleaned_results: List[CleanedData] = []
def clean_list_data(self, data: List[Dict], case_id: int) -> CleanedData:
"""Cleans list-type data"""
print(f"🔧 Cleaning List data - Case {case_id}")
print(f" Original items: {len(data)}")
cleaned_data = []
operations = {
'type': 'list',
'bbox_fixes': 0,
'removed_items': 0,
'original_count': len(data)
}
for i, item in enumerate(data):
if not isinstance(item, dict):
operations['removed_items'] += 1
continue
# Check the bbox field
if 'bbox' in item:
bbox = item['bbox']
# Check bbox length - core logic
if isinstance(bbox, list) and len(bbox) == 3:
print(f" ⚠️ Item {i}: bbox has only 3 coordinates. Removing bbox, keeping category and text.")
# Keep only category and text, ensuring order is preserved
new_item = {}
if 'category' in item:
new_item['category'] = item['category']
if 'text' in item:
new_item['text'] = item['text']
if new_item: # Add only if there is valid content
cleaned_data.append(new_item)
operations['bbox_fixes'] += 1
else:
operations['removed_items'] += 1
continue
elif isinstance(bbox, list) and len(bbox) == 4:
# bbox is normal, add directly, preserving original order
cleaned_data.append(item.copy())
continue
else:
print(f" ❌ Item {i}: Abnormal bbox format, skipping.")
operations['removed_items'] += 1
continue
else:
# No bbox field, keep if category exists
if 'category' in item:
cleaned_data.append(item.copy())
continue
else:
operations['removed_items'] += 1
operations['final_count'] = len(cleaned_data)
print(f" ✅ Cleaning complete: {len(cleaned_data)} items, {operations['bbox_fixes']} bbox fixes, {operations['removed_items']} items removed")
return CleanedData(
case_id=case_id,
original_type='list',
original_length=len(data),
cleaned_data=cleaned_data,
cleaning_operations=operations,
success=True
)
def clean_string_data(self, data_str: str, case_id: int) -> CleanedData:
"""Cleans string-type data"""
print(f"🔧 Cleaning String data - Case {case_id}")
print(f" Original length: {len(data_str):,}")
operations = {
'type': 'str',
'original_length': len(data_str),
'delimiter_fixes': 0,
'tail_truncated': False,
'truncated_length': 0,
'duplicate_dicts_removed': 0,
'final_objects': 0
}
try:
# Step 1: Detect and fix missing delimiters
data_str, delimiter_fixes = self._fix_missing_delimiters(data_str)
operations['delimiter_fixes'] = delimiter_fixes
# Step 2: Truncate the last incomplete element
data_str, tail_truncated = self._truncate_last_incomplete_element(data_str)
operations['tail_truncated'] = tail_truncated
operations['truncated_length'] = len(data_str)
# Step 3: Remove duplicate complete dict objects, preserving order
data_str, duplicate_removes = self._remove_duplicate_complete_dicts_preserve_order(data_str)
operations['duplicate_dicts_removed'] = duplicate_removes
# Step 4: Ensure correct JSON format
data_str = self._ensure_json_format(data_str)
# Step 5: Try to parse the final result
final_data = self._parse_final_json(data_str)
if final_data is not None:
operations['final_objects'] = len(final_data)
print(f" ✅ Cleaning complete: {len(final_data)} objects")
return CleanedData(
case_id=case_id,
original_type='str',
original_length=operations['original_length'],
cleaned_data=final_data,
cleaning_operations=operations,
success=True
)
else:
raise Exception("Could not parse the cleaned data")
except Exception as e:
print(f" ❌ Cleaning failed: {e}")
return CleanedData(
case_id=case_id,
original_type='str',
original_length=operations['original_length'],
cleaned_data=[],
cleaning_operations=operations,
success=False
)
def _fix_missing_delimiters(self, text: str) -> Tuple[str, int]:
"""Fixes missing delimiters"""
fixes = 0
def replace_delimiter(match):
nonlocal fixes
fixes += 1
return '},{'
text = self.missing_delimiter_pattern.sub(replace_delimiter, text)
if fixes > 0:
print(f" ✅ Fixed {fixes} missing delimiters")
return text, fixes
def _truncate_last_incomplete_element(self, text: str) -> Tuple[str, bool]:
"""Truncates the last incomplete element"""
# For very long text (>50k) or text not ending with ']', directly truncate the last '{"bbox":'
needs_truncation = (
len(text) > 50000 or
not text.strip().endswith(']')
)
if needs_truncation:
# Check how many dict objects there are
bbox_count = text.count('{"bbox":')
# If there is only one dict object, do not truncate to avoid deleting the only object
if bbox_count <= 1:
print(f" ⚠️ Only {bbox_count} dict objects found, skipping truncation to avoid deleting all content")
return text, False
# Find the position of the last '{"bbox":'
last_bbox_pos = text.rfind('{"bbox":')
if last_bbox_pos > 0:
# Truncate before this position
truncated_text = text[:last_bbox_pos].rstrip()
# Remove trailing comma
if truncated_text.endswith(','):
truncated_text = truncated_text[:-1]
print(f" ✂️ Truncated the last incomplete element, length reduced from {len(text):,} to {len(truncated_text):,}")
return truncated_text, True
return text, False
def _remove_duplicate_complete_dicts_preserve_order(self, text: str) -> Tuple[str, int]:
"""Removes duplicate complete dict objects, preserving original order"""
# Extract all dict objects, preserving order
dict_matches = list(self.dict_pattern.finditer(text))
if not dict_matches:
return text, 0
print(f" 📊 Found {len(dict_matches)} dict objects")
# Deduplication while preserving order: only keep the first occurrence of a dict
unique_dicts = []
seen_dict_strings = set()
total_duplicates = 0
for match in dict_matches:
dict_str = match.group()
if dict_str not in seen_dict_strings:
unique_dicts.append(dict_str)
seen_dict_strings.add(dict_str)
else:
total_duplicates += 1
if total_duplicates > 0:
# Reconstruct the JSON array, preserving the original order
new_text = '[' + ', '.join(unique_dicts) + ']'
print(f" ✅ Removed {total_duplicates} duplicate dicts, keeping {len(unique_dicts)} unique dicts (order preserved)")
return new_text, total_duplicates
else:
print(f" ✅ No duplicate dict objects found")
return text, 0
def _ensure_json_format(self, text: str) -> str:
"""Ensures correct JSON format"""
text = text.strip()
if not text.startswith('['):
text = '[' + text
if not text.endswith(']'):
# Remove trailing comma
text = text.rstrip(',').rstrip()
text += ']'
return text
def _parse_final_json(self, text: str) -> Optional[List[Dict]]:
"""Tries to parse the final JSON"""
try:
data = json.loads(text)
if isinstance(data, list):
return data
except json.JSONDecodeError as e:
print(f" ❌ JSON parsing failed: {e}")
# fallback1: Extract valid dict objects
valid_dicts = []
for match in self.dict_pattern.finditer(text):
dict_str = match.group()
try:
dict_obj = json.loads(dict_str)
valid_dicts.append(dict_obj)
except:
continue
if valid_dicts:
print(f" ✅ Extracted {len(valid_dicts)} valid dicts")
return valid_dicts
# fallback2: Special handling for a single incomplete dict
return self._handle_single_incomplete_dict(text)
return None
def _handle_single_incomplete_dict(self, text: str) -> Optional[List[Dict]]:
"""Handles the special case of a single incomplete dict"""
# Check if it's a single incomplete dict case
if not text.strip().startswith('[{"bbox":'):
return None
try:
# Try to extract bbox coordinates
bbox_match = re.search(r'"bbox"\s*:\s*\[([^\]]+)\]', text)
if not bbox_match:
return None
bbox_str = bbox_match.group(1)
bbox_coords = [int(x.strip()) for x in bbox_str.split(',')]
if len(bbox_coords) != 4:
return None
# Try to extract category
category_match = re.search(r'"category"\s*:\s*"([^"]+)"', text)
category = category_match.group(1) if category_match else "Text"
# Try to extract the beginning of the text (first 10000 characters)
text_match = re.search(r'"text"\s*:\s*"([^"]{0,10000})', text)
if text_match:
text_content = text_match.group(1)
else:
text_content = ""
# Construct the fixed dict
fixed_dict = {
"bbox": bbox_coords,
"category": category
}
if text_content:
fixed_dict["text"] = text_content
print(f" 🔧 Special fix: single incomplete dict → {fixed_dict}")
return [fixed_dict]
except Exception as e:
print(f" ❌ Special fix failed: {e}")
return None
def remove_duplicate_category_text_pairs_and_bbox(self, data_list: List[dict], case_id: int) -> List[dict]:
"""Removes duplicate category-text pairs and duplicate bboxes"""
if not data_list or len(data_list) <= 1:
print(f" 📊 Data length {len(data_list)} <= 1, skipping deduplication check")
return data_list
print(f" 📊 Original data length: {len(data_list)}")
# 1. Count occurrences and positions of each category-text pair
category_text_pairs = {}
for i, item in enumerate(data_list):
if isinstance(item, dict) and 'category' in item and 'text' in item:
pair_key = (item.get('category', ''), item.get('text', ''))
if pair_key not in category_text_pairs:
category_text_pairs[pair_key] = []
category_text_pairs[pair_key].append(i)
# 2. Count occurrences and positions of each bbox
bbox_pairs = {}
for i, item in enumerate(data_list):
if isinstance(item, dict) and 'bbox' in item:
bbox = item.get('bbox')
if isinstance(bbox, list) and len(bbox) > 0:
bbox_key = tuple(bbox) # Convert to tuple to use as a dictionary key
if bbox_key not in bbox_pairs:
bbox_pairs[bbox_key] = []
bbox_pairs[bbox_key].append(i)
# 3. Identify items to be removed
duplicates_to_remove = set()
# 3a. Process category-text pairs that appear 5 or more times
for pair_key, positions in category_text_pairs.items():
if len(positions) >= 5:
category, text = pair_key
# Keep the first occurrence, remove subsequent duplicates
positions_to_remove = positions[1:]
duplicates_to_remove.update(positions_to_remove)
print(f" 🔍 Found duplicate category-text pair: category='{category}', first 50 chars of text='{text[:50]}...'")
print(f" Count: {len(positions)}, removing at positions: {positions_to_remove}")
# 3b. Process bboxes that appear 2 or more times
for bbox_key, positions in bbox_pairs.items():
if len(positions) >= 2:
# Keep the first occurrence, remove subsequent duplicates
positions_to_remove = positions[1:]
duplicates_to_remove.update(positions_to_remove)
print(f" 🔍 Found duplicate bbox: {list(bbox_key)}")
print(f" Count: {len(positions)}, removing at positions: {positions_to_remove}")
if not duplicates_to_remove:
print(f" ✅ No category-text pairs or bboxes found exceeding the duplication threshold")
return data_list
# 4. Remove duplicate items from the original data (preserving order)
cleaned_data = []
removed_count = 0
for i, item in enumerate(data_list):
if i not in duplicates_to_remove:
cleaned_data.append(item)
else:
removed_count += 1
print(f" ✅ Deduplication complete: Removed {removed_count} duplicate items")
print(f" 📊 Cleaned data length: {len(cleaned_data)}")
return cleaned_data
def clean_model_output(self, model_output: str):
try:
# Select cleaning method based on data type
if isinstance(model_output, list):
result = self.clean_list_data(model_output, case_id=0)
else:
result = self.clean_string_data(str(model_output), case_id=0)
# Add deduplication step: remove duplicate category-text pairs and bboxes
if result and hasattr(result, 'success') and result.success and result.cleaned_data:
original_data = result.cleaned_data
deduplicated_data = self.remove_duplicate_category_text_pairs_and_bbox(original_data, case_id=0)
# Update the cleaned_data in the CleanedData object
result.cleaned_data = deduplicated_data
return result.cleaned_data
except Exception as e:
print(f"❌ Case cleaning failed: {e}")
return model_output
def clean_all_data(self, jsonl_path: str) -> List[CleanedData]:
"""Cleans all data from a JSONL file"""
print(f"🚀 Starting to clean JSONL file: {jsonl_path}")
with open(jsonl_path, 'r', encoding='utf-8') as f:
lines = f.readlines()
datas = []
for i, line in enumerate(lines):
if line.strip():
try:
data = json.loads(line)
predict_field = data.get('predict')
case_id = i + 1
print(f"\n{'='*50}")
print(f"🎯 Cleaning Case {case_id}")
print(f"{'='*50}")
# Select cleaning method based on data type
if isinstance(predict_field, list):
print("📊 Data type: List")
result = self.clean_list_data(predict_field, case_id)
else:
print("📊 Data type: String")
result = self.clean_string_data(str(predict_field), case_id)
# Add deduplication step: remove duplicate category-text pairs and bboxes
if result and hasattr(result, 'success') and result.success and result.cleaned_data:
print("🔄 Checking for and removing duplicate category-text pairs and bboxes...")
original_data = result.cleaned_data
deduplicated_data = self.remove_duplicate_category_text_pairs_and_bbox(original_data, case_id)
# Update the cleaned_data in the CleanedData object
result.cleaned_data = deduplicated_data
data['predict_resized'] = result.cleaned_data
datas.append(data)
self.cleaned_results.append(result)
except Exception as e:
print(f"❌ Case {i+1} cleaning failed: {e}")
traceback.print_exc()
save_path = jsonl_path.replace('.jsonl', '_filtered.jsonl')
with open(save_path, 'w') as w:
for data in datas:
w.write(json.dumps(data, ensure_ascii=False) + '\n')
print(f"✅ Saved cleaned data to: {save_path}")
return self.cleaned_results
def save_cleaned_data(self, output_dir: str):
"""Saves the cleaned data"""
print(f"\n💾 Saving cleaned data to: {output_dir}")
os.makedirs(output_dir, exist_ok=True)
# 1. Save cleaned data for each case
for result in self.cleaned_results:
case_filename = f"cleaned_case_{result.case_id:02d}.json"
case_filepath = os.path.join(output_dir, case_filename)
# Save the cleaned data
with open(case_filepath, 'w', encoding='utf-8') as f:
json.dump(result.cleaned_data, f, ensure_ascii=False, indent=2)
print(f" ✅ Case {result.case_id}: {len(result.cleaned_data)} objects → {case_filename}")
# 2. Save all cleaned data to a single file
all_cleaned_data = []
for result in self.cleaned_results:
all_cleaned_data.append({
'case_id': result.case_id,
'original_type': result.original_type,
'original_length': result.original_length,
'cleaned_objects_count': len(result.cleaned_data),
'success': result.success,
'cleaning_operations': result.cleaning_operations,
'cleaned_data': result.cleaned_data
})
all_data_filepath = os.path.join(output_dir, "all_cleaned_data.json")
with open(all_data_filepath, 'w', encoding='utf-8') as f:
json.dump(all_cleaned_data, f, ensure_ascii=False, indent=2)
print(f" 📁 All data: {len(all_cleaned_data)} cases → all_cleaned_data.json")
# 3. Generate a cleaning report
self._generate_cleaning_report(output_dir)
def _generate_cleaning_report(self, output_dir: str):
"""Generates a cleaning report"""
report = []
report.append("📊 Data Cleaning Report")
report.append("=" * 60)
import datetime
report.append(f"Processing Time: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
report.append("")
# Overall statistics
total_cases = len(self.cleaned_results)
successful_cases = sum(1 for r in self.cleaned_results if r.success)
total_objects = sum(len(r.cleaned_data) for r in self.cleaned_results)
report.append("📈 Overall Statistics:")
report.append(f" Total Cases: {total_cases}")
report.append(f" Successfully Cleaned: {successful_cases}")
report.append(f" Success Rate: {successful_cases/total_cases*100:.1f}%")
report.append(f" Total Recovered Objects: {total_objects}")
report.append("")
# Detailed statistics
list_results = [r for r in self.cleaned_results if r.original_type == 'list']
str_results = [r for r in self.cleaned_results if r.original_type == 'str']
if list_results:
report.append("📋 List Type Cleaning Statistics:")
for r in list_results:
ops = r.cleaning_operations
report.append(f" Case {r.case_id}: {ops['original_count']}{ops['final_count']} objects")
if ops['bbox_fixes'] > 0:
report.append(f" - bbox fixes: {ops['bbox_fixes']}")
if ops['removed_items'] > 0:
report.append(f" - invalid items removed: {ops['removed_items']}")
report.append("")
if str_results:
report.append("📝 String Type Cleaning Statistics:")
for r in str_results:
ops = r.cleaning_operations
status = "" if r.success else ""
report.append(f" Case {r.case_id} {status}: {ops['original_length']:,} chars → {ops['final_objects']} objects")
details = []
if ops['delimiter_fixes'] > 0:
details.append(f"Delimiter fixes: {ops['delimiter_fixes']}")
if ops['tail_truncated']:
reduction = ops['original_length'] - ops['truncated_length']
details.append(f"Tail truncation: -{reduction:,} chars")
if ops['duplicate_dicts_removed'] > 0:
details.append(f"Duplicates removed: {ops['duplicate_dicts_removed']}")
if details:
report.append(f" - {', '.join(details)}")
report.append("")
# Note on data order
report.append("🔄 Data Order Guarantee:")
report.append(" ✅ The relative order of all dict objects is preserved during cleaning.")
report.append(" ✅ When deduplicating, the first occurrence of a dict is kept, and subsequent duplicates are removed.")
report.append(" ✅ The order of items in List-type data is fully preserved.")
# Save the report
report_filepath = os.path.join(output_dir, "cleaning_report.txt")
with open(report_filepath, 'w', encoding='utf-8') as f:
f.write('\n'.join(report))
print(f" 📋 Cleaning report: cleaning_report.txt")
# Also print to console
print(f"\n{chr(10).join(report)}")
def main():
"""Main function"""
# Create a data cleaner instance
cleaner = OutputCleaner()
# Input file
jsonl_path = "output_with_failcase.jsonl"
# Output directory
output_dir = "output_with_failcase_cleaned"
# Clean all data
results = cleaner.clean_all_data(jsonl_path)
# Save the cleaned data
cleaner.save_cleaned_data(output_dir)
print(f"\n🎉 Data cleaning complete!")
print(f"📁 Cleaned data saved in: {output_dir}")
if __name__ == "__main__":
main()
+34
View File
@@ -0,0 +1,34 @@
dict_promptmode_to_prompt = {
# prompt_layout_all_en: parse all layout info in json format.
"prompt_layout_all_en": """Please output the layout information from the PDF image, including each layout element's bbox, its category, and the corresponding text content within the bbox.
1. Bbox format: [x1, y1, x2, y2]
2. Layout Categories: The possible categories are ['Caption', 'Footnote', 'Formula', 'List-item', 'Page-footer', 'Page-header', 'Picture', 'Section-header', 'Table', 'Text', 'Title'].
3. Text Extraction & Formatting Rules:
- Picture: For the 'Picture' category, the text field should be omitted.
- Formula: Format its text as LaTeX.
- Table: Format its text as HTML.
- All Others (Text, Title, etc.): Format their text as Markdown.
4. Constraints:
- The output text must be the original text from the image, with no translation.
- All layout elements must be sorted according to human reading order.
5. Final Output: The entire output must be a single JSON object.
""",
# prompt_layout_only_en: layout detection
"prompt_layout_only_en": """Please output the layout information from this PDF image, including each layout's bbox and its category. The bbox should be in the format [x1, y1, x2, y2]. The layout categories for the PDF document include ['Caption', 'Footnote', 'Formula', 'List-item', 'Page-footer', 'Page-header', 'Picture', 'Section-header', 'Table', 'Text', 'Title']. Do not output the corresponding text. The layout result should be in JSON format.""",
# prompt_layout_only_en: parse ocr text except the Page-header and Page-footer
"prompt_ocr": """Extract the text content from this image.""",
# prompt_grounding_ocr: extract text content in the given bounding box
"prompt_grounding_ocr": """Extract text from the given bounding box on the image (format: [x1, y1, x2, y2]).\nBounding Box:\n""",
# "prompt_table_html": """Convert the table in this image to HTML.""",
# "prompt_table_latex": """Convert the table in this image to LaTeX.""",
# "prompt_formula_latex": """Convert the formula in this image to LaTeX.""",
}