dots.ocr release

This commit is contained in:
zhangwei13
2025-07-30 19:12:56 +08:00
commit be77dff22c
55 changed files with 5187 additions and 0 deletions
Executable
+123
View File
@@ -0,0 +1,123 @@
# Byte-compiled / optimized / DLL files
weights/
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
# PyInstaller
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# pyenv
.python-version
# celery beat schedule file
celerybeat-schedule
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# IDEs
.vscode/
.idea/
*.swp
*.swo
*~
# MacOS
.DS_Store
# OCR related
#*.jpg
# *.jpeg
#*.png
#*.pdf
temp/
output/
# playground/
+21
View File
@@ -0,0 +1,21 @@
MIT License
Copyright (c) 2025 rednote-hilab
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
Executable
+1214
View File
File diff suppressed because it is too large Load Diff
BIN
View File
Binary file not shown.

After

Width:  |  Height:  |  Size: 63 KiB

BIN
View File
Binary file not shown.

After

Width:  |  Height:  |  Size: 66 KiB

BIN
View File
Binary file not shown.

After

Width:  |  Height:  |  Size: 2.8 MiB

BIN
View File
Binary file not shown.

After

Width:  |  Height:  |  Size: 1.2 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.7 MiB

BIN
View File
Binary file not shown.

After

Width:  |  Height:  |  Size: 1.0 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1013 KiB

BIN
View File
Binary file not shown.

After

Width:  |  Height:  |  Size: 1.8 MiB

BIN
View File
Binary file not shown.

After

Width:  |  Height:  |  Size: 3.7 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.8 MiB

BIN
View File
Binary file not shown.

After

Width:  |  Height:  |  Size: 2.9 MiB

BIN
View File
Binary file not shown.

After

Width:  |  Height:  |  Size: 1.4 MiB

BIN
View File
Binary file not shown.

After

Width:  |  Height:  |  Size: 1.7 MiB

BIN
View File
Binary file not shown.

After

Width:  |  Height:  |  Size: 1.4 MiB

BIN
View File
Binary file not shown.

After

Width:  |  Height:  |  Size: 1.8 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 943 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 662 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 292 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 263 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 445 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.1 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 673 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.7 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 755 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 920 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.0 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 937 KiB

BIN
View File
Binary file not shown.

After

Width:  |  Height:  |  Size: 203 KiB

+948
View File
@@ -0,0 +1,948 @@
"""
Layout Inference Web Application with Gradio
A Gradio-based layout inference tool that supports image uploads and multiple backend inference engines.
It adopts a reference-style interface design while preserving the original inference logic.
"""
import gradio as gr
import json
import os
import io
import tempfile
import base64
import zipfile
import uuid
import re
from pathlib import Path
from PIL import Image
import requests
# Local tool imports
from dots_ocr.utils import dict_promptmode_to_prompt
from dots_ocr.utils.consts import MIN_PIXELS, MAX_PIXELS
from dots_ocr.utils.demo_utils.display import read_image
from dots_ocr.utils.doc_utils import load_images_from_pdf
# Add DotsOCRParser import
from dots_ocr.parser import DotsOCRParser
# ==================== Configuration ====================
DEFAULT_CONFIG = {
'ip': "127.0.0.1",
'port_vllm': 8000,
'min_pixels': MIN_PIXELS,
'max_pixels': MAX_PIXELS,
'test_images_dir': "./assets/showcase_origin",
}
# ==================== Global Variables ====================
# Store current configuration
current_config = DEFAULT_CONFIG.copy()
# Create DotsOCRParser instance
dots_parser = DotsOCRParser(
ip=DEFAULT_CONFIG['ip'],
port=DEFAULT_CONFIG['port_vllm'],
dpi=200,
min_pixels=DEFAULT_CONFIG['min_pixels'],
max_pixels=DEFAULT_CONFIG['max_pixels']
)
# Store processing results
processing_results = {
'original_image': None,
'processed_image': None,
'layout_result': None,
'markdown_content': None,
'cells_data': None,
'temp_dir': None,
'session_id': None,
'result_paths': None,
'pdf_results': None # Store multi-page PDF results
}
# PDF caching mechanism
pdf_cache = {
"images": [],
"current_page": 0,
"total_pages": 0,
"file_type": None, # 'image' or 'pdf'
"is_parsed": False, # Whether it has been parsed
"results": [] # Store parsing results for each page
}
def read_image_v2(img):
"""Reads an image, supports URLs and local paths"""
if isinstance(img, str) and img.startswith(("http://", "https://")):
with requests.get(img, stream=True) as response:
response.raise_for_status()
img = Image.open(io.BytesIO(response.content))
elif isinstance(img, str):
img, _, _ = read_image(img, use_native=True)
elif isinstance(img, Image.Image):
pass
else:
raise ValueError(f"Invalid image type: {type(img)}")
return img
def load_file_for_preview(file_path):
"""Loads a file for preview, supports PDF and image files"""
global pdf_cache
if not file_path or not os.path.exists(file_path):
return None, "<div id='page_info_box'>0 / 0</div>"
file_ext = os.path.splitext(file_path)[1].lower()
if file_ext == '.pdf':
try:
# Read PDF and convert to images (one image per page)
pages = load_images_from_pdf(file_path)
pdf_cache["file_type"] = "pdf"
except Exception as e:
return None, f"<div id='page_info_box'>PDF loading failed: {str(e)}</div>"
elif file_ext in ['.jpg', '.jpeg', '.png']:
# For image files, read directly as a single-page image
try:
image = Image.open(file_path)
pages = [image]
pdf_cache["file_type"] = "image"
except Exception as e:
return None, f"<div id='page_info_box'>Image loading failed: {str(e)}</div>"
else:
return None, "<div id='page_info_box'>Unsupported file format</div>"
pdf_cache["images"] = pages
pdf_cache["current_page"] = 0
pdf_cache["total_pages"] = len(pages)
pdf_cache["is_parsed"] = False
pdf_cache["results"] = []
return pages[0], f"<div id='page_info_box'>1 / {len(pages)}</div>"
def turn_page(direction):
"""Page turning function"""
global pdf_cache
if not pdf_cache["images"]:
return None, "<div id='page_info_box'>0 / 0</div>", "", ""
if direction == "prev":
pdf_cache["current_page"] = max(0, pdf_cache["current_page"] - 1)
elif direction == "next":
pdf_cache["current_page"] = min(pdf_cache["total_pages"] - 1, pdf_cache["current_page"] + 1)
index = pdf_cache["current_page"]
current_image = pdf_cache["images"][index] # Use the original image by default
page_info = f"<div id='page_info_box'>{index + 1} / {pdf_cache['total_pages']}</div>"
# If parsed, display the results for the current page
current_md = ""
current_md_raw = ""
current_json = ""
if pdf_cache["is_parsed"] and index < len(pdf_cache["results"]):
result = pdf_cache["results"][index]
if 'md_content' in result:
# Get the raw markdown content
current_md_raw = result['md_content']
# Process the content after LaTeX rendering
current_md = result['md_content'] if result['md_content'] else ""
if 'cells_data' in result:
try:
current_json = json.dumps(result['cells_data'], ensure_ascii=False, indent=2)
except:
current_json = str(result.get('cells_data', ''))
# Use the image with layout boxes (if available)
if 'layout_image' in result and result['layout_image']:
current_image = result['layout_image']
return current_image, page_info, current_json
def get_test_images():
"""Gets the list of test images"""
test_images = []
test_dir = current_config['test_images_dir']
if os.path.exists(test_dir):
test_images = [os.path.join(test_dir, name) for name in os.listdir(test_dir)
if name.lower().endswith(('.png', '.jpg', '.jpeg', '.pdf'))]
return test_images
def convert_image_to_base64(image):
"""Converts a PIL image to base64 encoding"""
buffered = io.BytesIO()
image.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode()
return f"data:image/png;base64,{img_str}"
def create_temp_session_dir():
"""Creates a unique temporary directory for each processing request"""
session_id = uuid.uuid4().hex[:8]
temp_dir = os.path.join(tempfile.gettempdir(), f"dots_ocr_demo_{session_id}")
os.makedirs(temp_dir, exist_ok=True)
return temp_dir, session_id
def parse_image_with_high_level_api(parser, image, prompt_mode, fitz_preprocess=False):
"""
Processes using the high-level API parse_image from DotsOCRParser
"""
# Create a temporary session directory
temp_dir, session_id = create_temp_session_dir()
try:
# Save the PIL Image as a temporary file
temp_image_path = os.path.join(temp_dir, f"input_{session_id}.png")
image.save(temp_image_path, "PNG")
# Use the high-level API parse_image
filename = f"demo_{session_id}"
results = parser.parse_image(
# input_path=temp_image_path,
input_path=image,
filename=filename,
prompt_mode=prompt_mode,
save_dir=temp_dir,
fitz_preprocess=fitz_preprocess
)
# Parse the results
if not results:
raise ValueError("No results returned from parser")
result = results[0] # parse_image returns a list with a single result
# Read the result files
layout_image = None
cells_data = None
md_content = None
raw_response = None
filtered = False
# Read the layout image
if 'layout_image_path' in result and os.path.exists(result['layout_image_path']):
layout_image = Image.open(result['layout_image_path'])
# Read the JSON data
if 'layout_info_path' in result and os.path.exists(result['layout_info_path']):
with open(result['layout_info_path'], 'r', encoding='utf-8') as f:
cells_data = json.load(f)
# Read the Markdown content
if 'md_content_path' in result and os.path.exists(result['md_content_path']):
with open(result['md_content_path'], 'r', encoding='utf-8') as f:
md_content = f.read()
# Check for the raw response file (when JSON parsing fails)
if 'filtered' in result:
filtered = result['filtered']
return {
'layout_image': layout_image,
'cells_data': cells_data,
'md_content': md_content,
'filtered': filtered,
'temp_dir': temp_dir,
'session_id': session_id,
'result_paths': result,
'input_width': result['input_width'],
'input_height': result['input_height'],
}
except Exception as e:
# Clean up the temporary directory on error
import shutil
if os.path.exists(temp_dir):
shutil.rmtree(temp_dir, ignore_errors=True)
raise e
def parse_pdf_with_high_level_api(parser, pdf_path, prompt_mode):
"""
Processes using the high-level API parse_pdf from DotsOCRParser
"""
# Create a temporary session directory
temp_dir, session_id = create_temp_session_dir()
try:
# Use the high-level API parse_pdf
filename = f"demo_{session_id}"
results = parser.parse_pdf(
input_path=pdf_path,
filename=filename,
prompt_mode=prompt_mode,
save_dir=temp_dir
)
# Parse the results
if not results:
raise ValueError("No results returned from parser")
# Handle multi-page results
parsed_results = []
all_md_content = []
all_cells_data = []
for i, result in enumerate(results):
page_result = {
'page_no': result.get('page_no', i),
'layout_image': None,
'cells_data': None,
'md_content': None,
'filtered': False
}
# Read the layout image
if 'layout_image_path' in result and os.path.exists(result['layout_image_path']):
page_result['layout_image'] = Image.open(result['layout_image_path'])
# Read the JSON data
if 'layout_info_path' in result and os.path.exists(result['layout_info_path']):
with open(result['layout_info_path'], 'r', encoding='utf-8') as f:
page_result['cells_data'] = json.load(f)
all_cells_data.extend(page_result['cells_data'])
# Read the Markdown content
if 'md_content_path' in result and os.path.exists(result['md_content_path']):
with open(result['md_content_path'], 'r', encoding='utf-8') as f:
page_content = f.read()
page_result['md_content'] = page_content
all_md_content.append(page_content)
# Check for the raw response file (when JSON parsing fails)
page_result['filtered'] = False
if 'filtered' in page_result:
page_result['filtered'] = page_result['filtered']
parsed_results.append(page_result)
# Merge the content of all pages
combined_md = "\n\n---\n\n".join(all_md_content) if all_md_content else ""
return {
'parsed_results': parsed_results,
'combined_md_content': combined_md,
'combined_cells_data': all_cells_data,
'temp_dir': temp_dir,
'session_id': session_id,
'total_pages': len(results)
}
except Exception as e:
# Clean up the temporary directory on error
import shutil
if os.path.exists(temp_dir):
shutil.rmtree(temp_dir, ignore_errors=True)
raise e
# ==================== Core Processing Function ====================
def process_image_inference(test_image_input, file_input,
prompt_mode, server_ip, server_port, min_pixels, max_pixels,
fitz_preprocess=False
):
"""Core function to handle image/PDF inference"""
global current_config, processing_results, dots_parser, pdf_cache
# First, clean up previous processing results to avoid confusion with the download button
if processing_results.get('temp_dir') and os.path.exists(processing_results['temp_dir']):
import shutil
try:
shutil.rmtree(processing_results['temp_dir'], ignore_errors=True)
except Exception as e:
print(f"Failed to clean up previous temporary directory: {e}")
# Reset processing results
processing_results = {
'original_image': None,
'processed_image': None,
'layout_result': None,
'markdown_content': None,
'cells_data': None,
'temp_dir': None,
'session_id': None,
'result_paths': None,
'pdf_results': None
}
# Update configuration
current_config.update({
'ip': server_ip,
'port_vllm': server_port,
'min_pixels': min_pixels,
'max_pixels': max_pixels
})
# Update parser configuration
dots_parser.ip = server_ip
dots_parser.port = server_port
dots_parser.min_pixels = min_pixels
dots_parser.max_pixels = max_pixels
# Determine the input source
input_file_path = None
image = None
# Prioritize file input (supports PDF)
if file_input is not None:
input_file_path = file_input
file_ext = os.path.splitext(input_file_path)[1].lower()
if file_ext == '.pdf':
# PDF file processing
try:
return process_pdf_file(input_file_path, prompt_mode)
except Exception as e:
return None, f"PDF processing failed: {e}", "", "", gr.update(value=None), None, ""
elif file_ext in ['.jpg', '.jpeg', '.png']:
# Image file processing
try:
image = Image.open(input_file_path)
except Exception as e:
return None, f"Failed to read image file: {e}", "", "", gr.update(value=None), None, ""
# If no file input, check the test image input
if image is None:
if test_image_input and test_image_input != "":
file_ext = os.path.splitext(test_image_input)[1].lower()
if file_ext == '.pdf':
return process_pdf_file(test_image_input, prompt_mode)
else:
try:
image = read_image_v2(test_image_input)
except Exception as e:
return None, f"Failed to read test image: {e}", "", "", gr.update(value=None), gr.update(value=None), None, ""
if image is None:
return None, "Please upload image/PDF file or select test image", "", "", gr.update(value=None), None, ""
try:
# Clear PDF cache (for image processing)
pdf_cache["images"] = []
pdf_cache["current_page"] = 0
pdf_cache["total_pages"] = 0
pdf_cache["is_parsed"] = False
pdf_cache["results"] = []
# Process using the high-level API of DotsOCRParser
original_image = image
parse_result = parse_image_with_high_level_api(dots_parser, image, prompt_mode, fitz_preprocess)
# Extract parsing results
layout_image = parse_result['layout_image']
cells_data = parse_result['cells_data']
md_content = parse_result['md_content']
filtered = parse_result['filtered']
# Handle parsing failure case
if filtered:
# JSON parsing failed, only text content is available
info_text = f"""
**Image Information:**
- Original Size: {original_image.width} x {original_image.height}
- Processing: JSON parsing failed, using cleaned text output
- Server: {current_config['ip']}:{current_config['port_vllm']}
- Session ID: {parse_result['session_id']}
"""
# Store results
processing_results.update({
'original_image': original_image,
'processed_image': None,
'layout_result': None,
'markdown_content': md_content,
'cells_data': None,
'temp_dir': parse_result['temp_dir'],
'session_id': parse_result['session_id'],
'result_paths': parse_result['result_paths']
})
return (
original_image, # No layout image
info_text,
md_content,
md_content, # Display raw markdown text
gr.update(visible=False), # Hide download button
None, # Page info
"" # Current page JSON output
)
# JSON parsing successful case
# Save the raw markdown content (before LaTeX processing)
md_content_raw = md_content or "No markdown content generated"
# Store results
processing_results.update({
'original_image': original_image,
'processed_image': None, # High-level API does not return processed_image
'layout_result': layout_image,
'markdown_content': md_content,
'cells_data': cells_data,
'temp_dir': parse_result['temp_dir'],
'session_id': parse_result['session_id'],
'result_paths': parse_result['result_paths']
})
# Prepare display information
num_elements = len(cells_data) if cells_data else 0
info_text = f"""
**Image Information:**
- Original Size: {original_image.width} x {original_image.height}
- Model Input Size: {parse_result['input_width']} x {parse_result['input_height']}
- Server: {current_config['ip']}:{current_config['port_vllm']}
- Detected {num_elements} layout elements
- Session ID: {parse_result['session_id']}
"""
# Current page JSON output
current_json = ""
if cells_data:
try:
current_json = json.dumps(cells_data, ensure_ascii=False, indent=2)
except:
current_json = str(cells_data)
# Create the download ZIP file
download_zip_path = None
if parse_result['temp_dir']:
download_zip_path = os.path.join(parse_result['temp_dir'], f"layout_results_{parse_result['session_id']}.zip")
try:
with zipfile.ZipFile(download_zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
for root, dirs, files in os.walk(parse_result['temp_dir']):
for file in files:
if file.endswith('.zip'):
continue
file_path = os.path.join(root, file)
arcname = os.path.relpath(file_path, parse_result['temp_dir'])
zipf.write(file_path, arcname)
except Exception as e:
print(f"Failed to create download ZIP: {e}")
download_zip_path = None
return (
layout_image,
info_text,
md_content or "No markdown content generated",
md_content_raw, # Raw markdown text
gr.update(value=download_zip_path, visible=True) if download_zip_path else gr.update(visible=False), # Set the download file
None, # Page info (not displayed for image processing)
current_json # Current page JSON
)
except Exception as e:
return None, f"Error during processing: {e}", "", "", gr.update(value=None), None, ""
def process_pdf_file(pdf_path, prompt_mode):
"""Dedicated function for processing PDF files"""
global pdf_cache, processing_results, dots_parser
try:
# First, load the PDF for preview
preview_image, page_info = load_file_for_preview(pdf_path)
# Parse the PDF using DotsOCRParser
pdf_result = parse_pdf_with_high_level_api(dots_parser, pdf_path, prompt_mode)
# Update the PDF cache
pdf_cache["is_parsed"] = True
pdf_cache["results"] = pdf_result['parsed_results']
# Handle LaTeX table rendering
combined_md = pdf_result['combined_md_content']
combined_md_raw = combined_md or "No markdown content generated" # Save the raw content
# Store results
processing_results.update({
'original_image': None,
'processed_image': None,
'layout_result': None,
'markdown_content': combined_md,
'cells_data': pdf_result['combined_cells_data'],
'temp_dir': pdf_result['temp_dir'],
'session_id': pdf_result['session_id'],
'result_paths': None,
'pdf_results': pdf_result['parsed_results']
})
# Prepare display information
total_elements = len(pdf_result['combined_cells_data'])
info_text = f"""
**PDF Information:**
- Total Pages: {pdf_result['total_pages']}
- Server: {current_config['ip']}:{current_config['port_vllm']}
- Total Detected Elements: {total_elements}
- Session ID: {pdf_result['session_id']}
"""
# Content of the current page (first page)
current_page_md = ""
current_page_md_raw = ""
current_page_json = ""
current_page_layout_image = preview_image # Use the original preview image by default
if pdf_cache["results"] and len(pdf_cache["results"]) > 0:
current_result = pdf_cache["results"][0]
if current_result['md_content']:
# Raw markdown content
current_page_md_raw = current_result['md_content']
# Process the content after LaTeX rendering
current_page_md = current_result['md_content']
if current_result['cells_data']:
try:
current_page_json = json.dumps(current_result['cells_data'], ensure_ascii=False, indent=2)
except:
current_page_json = str(current_result['cells_data'])
# Use the image with layout boxes (if available)
if 'layout_image' in current_result and current_result['layout_image']:
current_page_layout_image = current_result['layout_image']
# Create the download ZIP file
download_zip_path = None
if pdf_result['temp_dir']:
download_zip_path = os.path.join(pdf_result['temp_dir'], f"layout_results_{pdf_result['session_id']}.zip")
try:
with zipfile.ZipFile(download_zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
for root, dirs, files in os.walk(pdf_result['temp_dir']):
for file in files:
if file.endswith('.zip'):
continue
file_path = os.path.join(root, file)
arcname = os.path.relpath(file_path, pdf_result['temp_dir'])
zipf.write(file_path, arcname)
except Exception as e:
print(f"Failed to create download ZIP: {e}")
download_zip_path = None
return (
current_page_layout_image, # Use the image with layout boxes
info_text,
combined_md or "No markdown content generated", # Display the markdown for the entire PDF
combined_md_raw or "No markdown content generated", # Display the raw markdown for the entire PDF
gr.update(value=download_zip_path, visible=True) if download_zip_path else gr.update(visible=False), # Set the download file
page_info,
current_page_json
)
except Exception as e:
# Reset the PDF cache
pdf_cache["images"] = []
pdf_cache["current_page"] = 0
pdf_cache["total_pages"] = 0
pdf_cache["is_parsed"] = False
pdf_cache["results"] = []
raise e
def clear_all_data():
"""Clears all data"""
global processing_results, pdf_cache
# Clean up the temporary directory
if processing_results.get('temp_dir') and os.path.exists(processing_results['temp_dir']):
import shutil
try:
shutil.rmtree(processing_results['temp_dir'], ignore_errors=True)
except Exception as e:
print(f"Failed to clean up temporary directory: {e}")
# Reset processing results
processing_results = {
'original_image': None,
'processed_image': None,
'layout_result': None,
'markdown_content': None,
'cells_data': None,
'temp_dir': None,
'session_id': None,
'result_paths': None,
'pdf_results': None
}
# Reset the PDF cache
pdf_cache = {
"images": [],
"current_page": 0,
"total_pages": 0,
"file_type": None,
"is_parsed": False,
"results": []
}
return (
None, # Clear file input
"", # Clear test image selection
None, # Clear result image
"Waiting for processing results...", # Reset info display
"## Waiting for processing results...", # Reset Markdown display
"🕐 Waiting for parsing result...", # Clear raw Markdown text
gr.update(visible=False), # Hide download button
"<div id='page_info_box'>0 / 0</div>", # Reset page info
"🕐 Waiting for parsing result..." # Clear current page JSON
)
def update_prompt_display(prompt_mode):
"""Updates the prompt display content"""
return dict_promptmode_to_prompt[prompt_mode]
# ==================== Gradio Interface ====================
def create_gradio_interface():
"""Creates the Gradio interface"""
# CSS styles, matching the reference style
css = """
#parse_button {
background: #FF576D !important; /* !important 确保覆盖主题默认样式 */
border-color: #FF576D !important;
}
/* 鼠标悬停时的颜色 */
#parse_button:hover {
background: #F72C49 !important;
border-color: #F72C49 !important;
}
#page_info_html {
display: flex;
align-items: center;
justify-content: center;
height: 100%;
margin: 0 12px;
}
#page_info_box {
padding: 8px 20px;
font-size: 16px;
border: 1px solid #bbb;
border-radius: 8px;
background-color: #f8f8f8;
text-align: center;
min-width: 80px;
box-shadow: 0 1px 3px rgba(0,0,0,0.1);
}
#markdown_output {
min-height: 800px;
overflow: auto;
}
footer {
visibility: hidden;
}
#info_box {
padding: 10px;
background-color: #f8f9fa;
border-radius: 8px;
border: 1px solid #dee2e6;
margin: 10px 0;
font-size: 14px;
}
#result_image {
border-radius: 8px;
}
#markdown_tabs {
height: 100%;
}
"""
with gr.Blocks(theme="ocean", css=css, title='dots.ocr') as demo:
# Title
gr.HTML("""
<div style="display: flex; align-items: center; justify-content: center; margin-bottom: 20px;">
<h1 style="margin: 0; font-size: 2em;">🔍 dots.ocr</h1>
</div>
<div style="text-align: center; margin-bottom: 10px;">
<em>Supports image/PDF layout analysis and structured output</em>
</div>
""")
with gr.Row():
# Left side: Input and Configuration
with gr.Column(scale=1, elem_id="left-panel"):
gr.Markdown("### 📥 Upload & Select")
file_input = gr.File(
label="Upload PDF/Image",
type="filepath",
file_types=[".pdf", ".jpg", ".jpeg", ".png"],
)
test_images = get_test_images()
test_image_input = gr.Dropdown(
label="Or Select an Example",
choices=[""] + test_images,
value="",
)
gr.Markdown("### ⚙️ Prompt & Actions")
prompt_mode = gr.Dropdown(
label="Select Prompt",
choices=["prompt_layout_all_en", "prompt_layout_only_en", "prompt_ocr"],
value="prompt_layout_all_en",
show_label=True
)
# Display current prompt content
prompt_display = gr.Textbox(
label="Current Prompt Content",
value=dict_promptmode_to_prompt[list(dict_promptmode_to_prompt.keys())[0]],
lines=4,
max_lines=8,
interactive=False,
show_copy_button=True
)
with gr.Row():
process_btn = gr.Button("🔍 Parse", variant="primary", scale=2, elem_id="parse_button")
clear_btn = gr.Button("🗑️ Clear", variant="secondary", scale=1)
with gr.Accordion("🛠️ Advanced Configuration", open=False):
fitz_preprocess = gr.Checkbox(
label="Enable fitz_preprocess for images",
value=True,
info="Processes image via a PDF-like pipeline (image->pdf->200dpi image). Recommended if your image DPI is low."
)
with gr.Row():
server_ip = gr.Textbox(label="Server IP", value=DEFAULT_CONFIG['ip'])
server_port = gr.Number(label="Port", value=DEFAULT_CONFIG['port_vllm'], precision=0)
with gr.Row():
min_pixels = gr.Number(label="Min Pixels", value=DEFAULT_CONFIG['min_pixels'], precision=0)
max_pixels = gr.Number(label="Max Pixels", value=DEFAULT_CONFIG['max_pixels'], precision=0)
# Right side: Result Display
with gr.Column(scale=6, variant="compact"):
with gr.Row():
# Result Image
with gr.Column(scale=3):
gr.Markdown("### 👁️ File Preview")
result_image = gr.Image(
label="Layout Preview",
visible=True,
height=800,
show_label=False
)
# Page navigation (shown during PDF preview)
with gr.Row():
prev_btn = gr.Button("⬅ Previous", size="sm")
page_info = gr.HTML(
value="<div id='page_info_box'>0 / 0</div>",
elem_id="page_info_html"
)
next_btn = gr.Button("Next ➡", size="sm")
# Info Display
info_display = gr.Markdown(
"Waiting for processing results...",
elem_id="info_box"
)
# Markdown Result
with gr.Column(scale=3):
gr.Markdown("### ✔️ Result Display")
with gr.Tabs(elem_id="markdown_tabs"):
with gr.TabItem("Markdown Render Preview"):
md_output = gr.Markdown(
"## Please click the parse button to parse or select for single-task recognition...",
label="Markdown Preview",
max_height=600,
latex_delimiters=[
{"left": "$$", "right": "$$", "display": True},
{"left": "$", "right": "$", "display": False},
],
show_copy_button=False,
elem_id="markdown_output"
)
with gr.TabItem("Markdown Raw Text"):
md_raw_output = gr.Textbox(
value="🕐 Waiting for parsing result...",
label="Markdown Raw Text",
max_lines=100,
lines=38,
show_copy_button=True,
elem_id="markdown_output",
show_label=False
)
with gr.TabItem("Current Page JSON"):
current_page_json = gr.Textbox(
value="🕐 Waiting for parsing result...",
label="Current Page JSON",
max_lines=100,
lines=38,
show_copy_button=True,
elem_id="markdown_output",
show_label=False
)
# Download Button
with gr.Row():
download_btn = gr.DownloadButton(
"⬇️ Download Results",
visible=False
)
# When the prompt mode changes, update the display content
prompt_mode.change(
fn=update_prompt_display,
inputs=prompt_mode,
outputs=prompt_display,
show_progress=False
)
# Show preview on file upload
file_input.upload(
fn=load_file_for_preview,
inputs=file_input,
outputs=[result_image, page_info],
show_progress=False
)
# Page navigation
prev_btn.click(
fn=lambda: turn_page("prev"),
outputs=[result_image, page_info, current_page_json],
show_progress=False
)
next_btn.click(
fn=lambda: turn_page("next"),
outputs=[result_image, page_info, current_page_json],
show_progress=False
)
process_btn.click(
fn=process_image_inference,
inputs=[
test_image_input, file_input,
prompt_mode, server_ip, server_port, min_pixels, max_pixels,
fitz_preprocess
],
outputs=[
result_image, info_display, md_output, md_raw_output,
download_btn, page_info, current_page_json
],
show_progress=True
)
clear_btn.click(
fn=clear_all_data,
outputs=[
file_input, test_image_input,
result_image, info_display, md_output, md_raw_output,
download_btn, page_info, current_page_json
],
show_progress=False
)
return demo
# ==================== Main Program ====================
if __name__ == "__main__":
demo = create_gradio_interface()
demo.queue().launch(
server_name="0.0.0.0",
server_port=7860,
debug=True
)
+666
View File
@@ -0,0 +1,666 @@
"""
Layout Inference Web Application with Gradio - Annotation Version
A Gradio-based layout inference tool that supports image uploads and multiple backend inference engines.
This version adds an image annotation feature, allowing users to draw bounding boxes on an image and send both the image and the boxes to the model.
"""
import gradio as gr
import json
import os
import io
import tempfile
import base64
import zipfile
import uuid
import re
from pathlib import Path
from PIL import Image
import requests
from gradio_image_annotation import image_annotator
# Local utility imports
from dots_ocr.utils import dict_promptmode_to_prompt
from dots_ocr.utils.consts import MIN_PIXELS, MAX_PIXELS
from dots_ocr.utils.demo_utils.display import read_image
from dots_ocr.utils.doc_utils import load_images_from_pdf
# Add DotsOCRParser import
from dots_ocr.parser import DotsOCRParser
# ==================== Configuration ====================
DEFAULT_CONFIG = {
'ip': "127.0.0.1",
'port_vllm': 8000,
'min_pixels': MIN_PIXELS,
'max_pixels': MAX_PIXELS,
'test_images_dir': "./assets/showcase_origin",
}
# ==================== Global Variables ====================
# Store the current configuration
current_config = DEFAULT_CONFIG.copy()
# Create a DotsOCRParser instance
dots_parser = DotsOCRParser(
ip=DEFAULT_CONFIG['ip'],
port=DEFAULT_CONFIG['port_vllm'],
dpi=200,
min_pixels=DEFAULT_CONFIG['min_pixels'],
max_pixels=DEFAULT_CONFIG['max_pixels']
)
# Store processing results
processing_results = {
'original_image': None,
'processed_image': None,
'layout_result': None,
'markdown_content': None,
'cells_data': None,
'temp_dir': None,
'session_id': None,
'result_paths': None,
'annotation_data': None # Store annotation data
}
# ==================== Utility Functions ====================
def read_image_v2(img):
"""Reads an image, supporting URLs and local paths."""
if isinstance(img, str) and img.startswith(("http://", "https://")):
with requests.get(img, stream=True) as response:
response.raise_for_status()
img = Image.open(io.BytesIO(response.content))
elif isinstance(img, str):
img, _, _ = read_image(img, use_native=True)
elif isinstance(img, Image.Image):
pass
else:
raise ValueError(f"Invalid image type: {type(img)}")
return img
def get_test_images():
"""Gets the list of test images."""
test_images = []
test_dir = current_config['test_images_dir']
if os.path.exists(test_dir):
test_images = [os.path.join(test_dir, name) for name in os.listdir(test_dir)
if name.lower().endswith(('.png', '.jpg', '.jpeg'))]
return test_images
def create_temp_session_dir():
"""Creates a unique temporary directory for each processing request."""
session_id = uuid.uuid4().hex[:8]
temp_dir = os.path.join(tempfile.gettempdir(), f"dots_ocr_demo_{session_id}")
os.makedirs(temp_dir, exist_ok=True)
return temp_dir, session_id
def parse_image_with_bbox(parser, image, prompt_mode, bbox=None, fitz_preprocess=False):
"""
Processes an image using DotsOCRParser, with support for the bbox parameter.
"""
# Create a temporary session directory
temp_dir, session_id = create_temp_session_dir()
try:
# Save the PIL Image to a temporary file
temp_image_path = os.path.join(temp_dir, f"input_{session_id}.png")
image.save(temp_image_path, "PNG")
# Use the high-level parse_image interface, passing the bbox parameter
filename = f"demo_{session_id}"
results = parser.parse_image(
input_path=temp_image_path,
filename=filename,
prompt_mode=prompt_mode,
save_dir=temp_dir,
bbox=bbox,
fitz_preprocess=fitz_preprocess
)
# Parse the results
if not results:
raise ValueError("No results returned from parser")
result = results[0] # parse_image returns a list with a single result
# Read the result files
layout_image = None
cells_data = None
md_content = None
filtered = False
# Read the layout image
if 'layout_image_path' in result and os.path.exists(result['layout_image_path']):
layout_image = Image.open(result['layout_image_path'])
# Read the JSON data
if 'layout_info_path' in result and os.path.exists(result['layout_info_path']):
with open(result['layout_info_path'], 'r', encoding='utf-8') as f:
cells_data = json.load(f)
# Read the Markdown content
if 'md_content_path' in result and os.path.exists(result['md_content_path']):
with open(result['md_content_path'], 'r', encoding='utf-8') as f:
md_content = f.read()
# Check for the original response file (if JSON parsing fails)
if 'filtered' in result:
filtered = result['filtered']
return {
'layout_image': layout_image,
'cells_data': cells_data,
'md_content': md_content,
'filtered': filtered,
'temp_dir': temp_dir,
'session_id': session_id,
'result_paths': result
}
except Exception as e:
# Clean up the temporary directory on error
import shutil
if os.path.exists(temp_dir):
shutil.rmtree(temp_dir, ignore_errors=True)
raise e
def process_annotation_data(annotation_data):
"""Processes annotation data, converting it to the format required by the model."""
if not annotation_data or not annotation_data.get('boxes'):
return None, None
# Get image and box data
image = annotation_data.get('image')
boxes = annotation_data.get('boxes', [])
if not boxes:
return image, None
# Ensure the image is in PIL Image format
if image is not None:
import numpy as np
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
elif not isinstance(image, Image.Image):
# If it's another format, try to convert it
try:
image = Image.open(image) if isinstance(image, str) else Image.fromarray(image)
except Exception as e:
print(f"Image format conversion failed: {e}")
return None, None
# Get the coordinate information of the box (only one box)
box = boxes[0]
bbox = [box['xmin'], box['ymin'], box['xmax'], box['ymax']]
return image, bbox
# ==================== Core Processing Function ====================
def process_image_inference_with_annotation(annotation_data, test_image_input,
prompt_mode, server_ip, server_port, min_pixels, max_pixels,
fitz_preprocess=False
):
"""Core function for image inference, supporting annotation data."""
global current_config, processing_results, dots_parser
# First, clean up previous processing results
if processing_results.get('temp_dir') and os.path.exists(processing_results['temp_dir']):
import shutil
try:
shutil.rmtree(processing_results['temp_dir'], ignore_errors=True)
except Exception as e:
print(f"Failed to clean up previous temporary directory: {e}")
# Reset processing results
processing_results = {
'original_image': None,
'processed_image': None,
'layout_result': None,
'markdown_content': None,
'cells_data': None,
'temp_dir': None,
'session_id': None,
'result_paths': None,
'annotation_data': annotation_data
}
# Update configuration
current_config.update({
'ip': server_ip,
'port_vllm': server_port,
'min_pixels': min_pixels,
'max_pixels': max_pixels
})
# Update parser configuration
dots_parser.ip = server_ip
dots_parser.port = server_port
dots_parser.min_pixels = min_pixels
dots_parser.max_pixels = max_pixels
# Determine the input source and process annotation data
image = None
bbox = None
# Prioritize processing annotation data
if annotation_data and annotation_data.get('image') is not None:
image, bbox = process_annotation_data(annotation_data)
if image is not None:
# If there's a bbox, force the use of 'prompt_grounding_ocr' mode
assert bbox is not None
prompt_mode = "prompt_grounding_ocr"
# If there's no annotation data, check the test image input
if image is None and test_image_input and test_image_input != "":
try:
image = read_image_v2(test_image_input)
except Exception as e:
return None, f"Failed to read test image: {e}", "", "", gr.update(value=None), ""
if image is None:
return None, "Please select a test image or add an image in the annotation component", "", "", gr.update(value=None), ""
if bbox is None:
return "Please select a bounding box by mouse", "Please select a bounding box by mouse", "", "", gr.update(value=None)
try:
# Process using DotsOCRParser, passing the bbox parameter
original_image = image
parse_result = parse_image_with_bbox(dots_parser, image, prompt_mode, bbox, fitz_preprocess)
# Extract parsing results
layout_image = parse_result['layout_image']
cells_data = parse_result['cells_data']
md_content = parse_result['md_content']
filtered = parse_result['filtered']
# Store the results
processing_results.update({
'original_image': original_image,
'processed_image': None,
'layout_result': layout_image,
'markdown_content': md_content,
'cells_data': cells_data,
'temp_dir': parse_result['temp_dir'],
'session_id': parse_result['session_id'],
'result_paths': parse_result['result_paths'],
'annotation_data': annotation_data
})
# Handle the case where parsing fails
if filtered:
info_text = f"""
**Image Information:**
- Original Dimensions: {original_image.width} x {original_image.height}
- Processing Mode: {'Region OCR' if bbox else 'Full Image OCR'}
- Processing Status: JSON parsing failed, using cleaned text output
- Server: {current_config['ip']}:{current_config['port_vllm']}
- Session ID: {parse_result['session_id']}
- Box Coordinates: {bbox if bbox else 'None'}
"""
return (
md_content or "No markdown content generated",
info_text,
md_content or "No markdown content generated",
md_content or "No markdown content generated",
gr.update(visible=False),
""
)
# Handle the case where JSON parsing succeeds
num_elements = len(cells_data) if cells_data else 0
info_text = f"""
**Image Information:**
- Original Dimensions: {original_image.width} x {original_image.height}
- Processing Mode: {'Region OCR' if bbox else 'Full Image OCR'}
- Server: {current_config['ip']}:{current_config['port_vllm']}
- Detected {num_elements} layout elements
- Session ID: {parse_result['session_id']}
- Box Coordinates: {bbox if bbox else 'None'}
"""
# Current page JSON output
current_json = ""
if cells_data:
try:
current_json = json.dumps(cells_data, ensure_ascii=False, indent=2)
except:
current_json = str(cells_data)
# Create a downloadable ZIP file
download_zip_path = None
if parse_result['temp_dir']:
download_zip_path = os.path.join(parse_result['temp_dir'], f"layout_results_{parse_result['session_id']}.zip")
try:
with zipfile.ZipFile(download_zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
for root, dirs, files in os.walk(parse_result['temp_dir']):
for file in files:
if file.endswith('.zip'):
continue
file_path = os.path.join(root, file)
arcname = os.path.relpath(file_path, parse_result['temp_dir'])
zipf.write(file_path, arcname)
except Exception as e:
print(f"Failed to create download ZIP: {e}")
download_zip_path = None
return (
md_content or "No markdown content generated",
info_text,
md_content or "No markdown content generated",
md_content or "No markdown content generated",
gr.update(value=download_zip_path, visible=True) if download_zip_path else gr.update(visible=False),
current_json
)
except Exception as e:
return f"An error occurred during processing: {e}", f"An error occurred during processing: {e}", "", "", gr.update(value=None), ""
def load_image_to_annotator(test_image_input):
"""Loads an image into the annotation component."""
image = None
# Check the test image input
if test_image_input and test_image_input != "":
try:
image = read_image_v2(test_image_input)
except Exception as e:
return None
if image is None:
return None
# Return the format required by the annotation component
return {
"image": image,
"boxes": []
}
def clear_all_data():
"""Clears all data."""
global processing_results
# Clean up the temporary directory
if processing_results.get('temp_dir') and os.path.exists(processing_results['temp_dir']):
import shutil
try:
shutil.rmtree(processing_results['temp_dir'], ignore_errors=True)
except Exception as e:
print(f"Failed to clean up temporary directory: {e}")
# Reset processing results
processing_results = {
'original_image': None,
'processed_image': None,
'layout_result': None,
'markdown_content': None,
'cells_data': None,
'temp_dir': None,
'session_id': None,
'result_paths': None,
'annotation_data': None
}
return (
"", # Clear test image selection
None, # Clear annotation component
"Waiting for processing results...", # Reset info display
"## Waiting for processing results...", # Reset Markdown display
"🕐 Waiting for parsing results...", # Clear raw Markdown text
gr.update(visible=False), # Hide download button
"🕐 Waiting for parsing results..." # Clear JSON
)
def update_prompt_display(prompt_mode):
"""Updates the displayed prompt content."""
return dict_promptmode_to_prompt[prompt_mode]
# ==================== Gradio Interface ====================
def create_gradio_interface():
"""Creates the Gradio interface."""
# CSS styling to match the reference style
css = """
footer {
visibility: hidden;
}
#info_box {
padding: 10px;
background-color: #f8f9fa;
border-radius: 8px;
border: 1px solid #dee2e6;
margin: 10px 0;
font-size: 14px;
}
#markdown_tabs {
height: 100%;
}
#annotation_component {
border-radius: 8px;
}
"""
with gr.Blocks(theme="ocean", css=css, title='dots.ocr - Annotation') as demo:
# Title
gr.HTML("""
<div style="display: flex; align-items: center; justify-content: center; margin-bottom: 20px;">
<h1 style="margin: 0; font-size: 2em;">🔍 dots.ocr - Annotation Version</h1>
</div>
<div style="text-align: center; margin-bottom: 10px;">
<em>Supports image annotation, drawing boxes, and sending box information to the model for OCR.</em>
</div>
""")
with gr.Row():
# Left side: Input and Configuration
with gr.Column(scale=1, variant="compact"):
gr.Markdown("### 📁 Select Example")
test_images = get_test_images()
test_image_input = gr.Dropdown(
label="Select Example",
choices=[""] + test_images,
value="",
show_label=True
)
# Button to load image into the annotation component
load_btn = gr.Button("📷 Load Image to Annotation Area", variant="secondary")
prompt_mode = gr.Dropdown(
label="Select Prompt",
# choices=["prompt_layout_all_en", "prompt_layout_only_en", "prompt_ocr", "prompt_grounding_ocr"],
choices=["prompt_grounding_ocr"],
value="prompt_grounding_ocr",
show_label=True,
info="If a box is drawn, 'prompt_grounding_ocr' mode will be used automatically."
)
# Display the current prompt content
prompt_display = gr.Textbox(
label="Current Prompt Content",
# value=dict_promptmode_to_prompt[list(dict_promptmode_to_prompt.keys())[0]],
value=dict_promptmode_to_prompt["prompt_grounding_ocr"],
lines=4,
max_lines=8,
interactive=False,
show_copy_button=True
)
gr.Markdown("### ⚙️ Actions")
process_btn = gr.Button("🔍 Parse", variant="primary")
clear_btn = gr.Button("🗑️ Clear", variant="secondary")
gr.Markdown("### 🛠️ Configuration")
fitz_preprocess = gr.Checkbox(
label="Enable fitz_preprocess",
value=False,
info="Performs fitz preprocessing on the image input, converting the image to a PDF and then to a 200dpi image."
)
with gr.Row():
server_ip = gr.Textbox(
label="Server IP",
value=DEFAULT_CONFIG['ip']
)
server_port = gr.Number(
label="Port",
value=DEFAULT_CONFIG['port_vllm'],
precision=0
)
with gr.Row():
min_pixels = gr.Number(
label="Min Pixels",
value=DEFAULT_CONFIG['min_pixels'],
precision=0
)
max_pixels = gr.Number(
label="Max Pixels",
value=DEFAULT_CONFIG['max_pixels'],
precision=0
)
# Right side: Result Display
with gr.Column(scale=6, variant="compact"):
with gr.Row():
# Image Annotation Area
with gr.Column(scale=3):
gr.Markdown("### 🎯 Image Annotation Area")
gr.Markdown("""
**Instructions:**
- Method 1: Select an example image on the left and click "Load Image to Annotation Area".
- Method 2: Upload an image directly in the annotation area below (drag and drop or click to upload).
- Use the mouse to draw a box on the image to select the region for recognition.
- Only one box can be drawn. To draw a new one, please delete the old one first.
- **Hotkey: Press the Delete key to remove the selected box.**
- After drawing a box, clicking Parse will automatically use the Region OCR mode.
""")
annotator = image_annotator(
value=None,
label="Image Annotation",
height=600,
show_label=False,
elem_id="annotation_component",
single_box=True, # Only allow one box; a new box will replace the old one
box_min_size=10,
interactive=True,
disable_edit_boxes=True, # Disable the edit dialog
label_list=["OCR Region"], # Set the default label
label_colors=[(255, 0, 0)], # Set color to red
use_default_label=True, # Use the default label
image_type="pil" # Ensure it returns a PIL Image format
)
# Information Display
info_display = gr.Markdown(
"Waiting for processing results...",
elem_id="info_box"
)
# Result Display Area
with gr.Column(scale=3):
gr.Markdown("### ✅ Results")
with gr.Tabs(elem_id="markdown_tabs"):
with gr.TabItem("Markdown Rendered View"):
md_output = gr.Markdown(
"## Please upload an image and click the Parse button for recognition...",
label="Markdown Preview",
max_height=1000,
latex_delimiters=[
{"left": "$$", "right": "$$", "display": True},
{"left": "$", "right": "$", "display": False},
],
show_copy_button=False,
elem_id="markdown_output"
)
with gr.TabItem("Markdown Raw Text"):
md_raw_output = gr.Textbox(
value="🕐 Waiting for parsing results...",
label="Markdown Raw Text",
max_lines=100,
lines=38,
show_copy_button=True,
elem_id="markdown_output",
show_label=False
)
with gr.TabItem("JSON Result"):
json_output = gr.Textbox(
value="🕐 Waiting for parsing results...",
label="JSON Result",
max_lines=100,
lines=38,
show_copy_button=True,
elem_id="markdown_output",
show_label=False
)
# Download Button
with gr.Row():
download_btn = gr.DownloadButton(
"⬇️ Download Results",
visible=False
)
# Event Binding
# When the prompt mode changes, update the displayed content
prompt_mode.change(
fn=update_prompt_display,
inputs=prompt_mode,
outputs=prompt_display,
show_progress=False
)
# Load image into the annotation component
load_btn.click(
fn=load_image_to_annotator,
inputs=[test_image_input],
outputs=annotator,
show_progress=False
)
# Process Inference
process_btn.click(
fn=process_image_inference_with_annotation,
inputs=[
annotator, test_image_input,
prompt_mode, server_ip, server_port, min_pixels, max_pixels,
fitz_preprocess
],
outputs=[
md_output, info_display, md_raw_output, md_raw_output,
download_btn, json_output
],
show_progress=True
)
# Clear Data
clear_btn.click(
fn=clear_all_data,
outputs=[
test_image_input, annotator,
info_display, md_output, md_raw_output,
download_btn, json_output
],
show_progress=False
)
return demo
# ==================== Main Program ====================
if __name__ == "__main__":
demo = create_gradio_interface()
demo.queue().launch(
server_name="0.0.0.0",
server_port=7861, # Use a different port to avoid conflicts
debug=True
)
+71
View File
@@ -0,0 +1,71 @@
import os
if "LOCAL_RANK" not in os.environ:
os.environ["LOCAL_RANK"] = "0"
import torch
from transformers import AutoModelForCausalLM, AutoProcessor, AutoTokenizer
from qwen_vl_utils import process_vision_info
from dots_ocr.utils import dict_promptmode_to_prompt
def inference(image_path, prompt, model, processor):
# image_path = "demo/demo_image1.jpg"
messages = [
{
"role": "user",
"content": [
{
"type": "image",
"image": image_path
},
{"type": "text", "text": prompt}
]
}
]
# Preparation for inference
text = processor.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
)
inputs = inputs.to("cuda")
# Inference: Generation of the output
generated_ids = model.generate(**inputs, max_new_tokens=24000)
generated_ids_trimmed = [
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
print(output_text)
if __name__ == "__main__":
# We recommend enabling flash_attention_2 for better acceleration and memory saving, especially in multi-image and video scenarios.
model_path = "./weights/DotsOCR"
model = AutoModelForCausalLM.from_pretrained(
model_path,
attn_implementation="flash_attention_2",
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True
)
processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
image_path = "demo/demo_image1.jpg"
for prompt_mode, prompt in dict_promptmode_to_prompt.items():
print(f"prompt: {prompt}")
inference(image_path, prompt, model, processor)
BIN
View File
Binary file not shown.

After

Width:  |  Height:  |  Size: 755 KiB

BIN
View File
Binary file not shown.
+222
View File
@@ -0,0 +1,222 @@
"""
Layout Inference Web Application
A Streamlit-based layout inference tool that supports image uploads and multiple backend inference engines.
"""
import streamlit as st
import json
import os
import io
import tempfile
from PIL import Image
import requests
# Local utility imports
# from utils import infer
from dots_ocr.utils import dict_promptmode_to_prompt
from dots_ocr.utils.format_transformer import layoutjson2md
from dots_ocr.utils.layout_utils import draw_layout_on_image, post_process_cells
from dots_ocr.utils.image_utils import get_input_dimensions, get_image_by_fitz_doc
from dots_ocr.model.inference import inference_with_vllm
from dots_ocr.utils.consts import MIN_PIXELS, MAX_PIXELS
import os
from PIL import Image
from dots_ocr.utils.demo_utils.display import read_image
# ==================== Configuration ====================
DEFAULT_CONFIG = {
'ip': "127.0.0.1",
'port_vllm': 8000,
'min_pixels': MIN_PIXELS,
'max_pixels': MAX_PIXELS,
'test_images_dir': "./assets/showcase_origin",
}
# ==================== Utility Functions ====================
@st.cache_resource
def read_image_v2(img: str):
if img.startswith(("http://", "https://")):
with requests.get(img, stream=True) as response:
response.raise_for_status()
img = Image.open(io.BytesIO(response.content))
if isinstance(img, str):
# img = transform_image_path(img)
img, _, _ = read_image(img, use_native=True)
elif isinstance(img, Image.Image):
pass
else:
raise ValueError(f"Invalid image type: {type(img)}")
return img
# ==================== UI Components ====================
def create_config_sidebar():
"""Create configuration sidebar"""
st.sidebar.header("Configuration Parameters")
config = {}
config['prompt_key'] = st.sidebar.selectbox("Prompt Mode", list(dict_promptmode_to_prompt.keys()))
config['ip'] = st.sidebar.text_input("Server IP", DEFAULT_CONFIG['ip'])
config['port'] = st.sidebar.number_input("Port", min_value=1000, max_value=9999, value=DEFAULT_CONFIG['port_vllm'])
# config['eos_word'] = st.sidebar.text_input("EOS Word", DEFAULT_CONFIG['eos_word'])
# Image configuration
st.sidebar.subheader("Image Configuration")
config['min_pixels'] = st.sidebar.number_input("Min Pixels", value=DEFAULT_CONFIG['min_pixels'])
config['max_pixels'] = st.sidebar.number_input("Max Pixels", value=DEFAULT_CONFIG['max_pixels'])
return config
def get_image_input():
"""Get image input"""
st.markdown("#### Image Input")
input_mode = st.pills(label="Select input method", options=["Upload Image", "Enter Image URL/Path", "Select Test Image"], key="input_mode", label_visibility="collapsed")
if input_mode == "Upload Image":
# File uploader
uploaded_file = st.file_uploader("Upload Image", type=["png", "jpg", "jpeg"])
if uploaded_file is not None:
with tempfile.NamedTemporaryFile(delete=False, suffix='.png') as tmp_file:
tmp_file.write(uploaded_file.getvalue())
return tmp_file.name
elif input_mode == 'Enter Image URL/Path':
# URL input
img_url_input = st.text_input("Enter Image URL/Path")
return img_url_input
elif input_mode == 'Select Test Image':
# Test image selection
test_images = []
test_dir = DEFAULT_CONFIG['test_images_dir']
if os.path.exists(test_dir):
test_images = [os.path.join(test_dir, name) for name in os.listdir(test_dir)]
img_url_test = st.selectbox("Select Test Image", [""] + test_images)
return img_url_test
else:
raise ValueError(f"Invalid input mode: {input_mode}")
return None
def process_and_display_results(output: str, image: Image.Image, config: dict):
"""Process and display inference results"""
prompt, response = output['prompt'], output['response']
try:
col1, col2 = st.columns(2)
# st.markdown('---')
cells = json.loads(response)
# image = Image.open(img_url)
# Post-processing
cells = post_process_cells(
image, cells,
image.width, image.height,
min_pixels=config['min_pixels'],
max_pixels=config['max_pixels']
)
# Calculate input dimensions
input_width, input_height = get_input_dimensions(
image,
min_pixels=config['min_pixels'],
max_pixels=config['max_pixels']
)
st.markdown('---')
st.write(f'Input Dimensions: {input_width} x {input_height}')
# st.write(f'Prompt: {prompt}')
# st.markdown(f'模型原始输出: <span style="color:blue">{result}</span>', unsafe_allow_html=True)
# st.write('模型原始输出:')
# st.write(response)
# st.write('后处理结果:', str(cells))
st.text_area('Original Model Output', response, height=200)
st.text_area('Post-processed Result', str(cells), height=200)
# 显示结果
# st.title("Layout推理结果")
with col1:
# st.markdown("##### 可视化结果")
new_image = draw_layout_on_image(
image, cells,
resized_height=None, resized_width=None,
# text_key='text',
fill_bbox=True, draw_bbox=True
)
st.markdown('##### Visualization Result')
st.image(new_image, width=new_image.width)
# st.write(f"尺寸: {new_image.width} x {new_image.height}")
with col2:
# st.markdown("##### Markdown格式")
md_code = layoutjson2md(image, cells, text_key='text')
# md_code = fix_streamlit_formula(md_code)
st.markdown('##### Markdown Format')
st.markdown(md_code, unsafe_allow_html=True)
except json.JSONDecodeError:
st.error("Model output is not a valid JSON format")
except Exception as e:
st.error(f"Error processing results: {e}")
# ==================== Main Application ====================
def main():
"""Main application function"""
st.set_page_config(page_title="Layout Inference Tool", layout="wide")
st.title("🔍 Layout Inference Tool")
# Configuration
config = create_config_sidebar()
prompt = dict_promptmode_to_prompt[config['prompt_key']]
st.sidebar.info(f"Current Prompt: {prompt}")
# Image input
img_url = get_image_input()
start_button = st.button('🚀 Start Inference', type="primary")
if img_url is not None and img_url.strip() != "":
try:
# processed_image = read_image_v2(img_url)
origin_image = read_image_v2(img_url)
st.write(f"Original Dimensions: {origin_image.width} x {origin_image.height}")
# processed_image = get_image_by_fitz_doc(origin_image, target_dpi=200)
processed_image = origin_image
except Exception as e:
st.error(f"Failed to read image: {e}")
return
else:
st.info("Please enter an image URL/path or upload an image")
return
output = None
# Inference button
if start_button:
with st.spinner(f"Inferring... Server: {config['ip']}:{config['port']}"):
response = inference_with_vllm(
processed_image, prompt, config['ip'], config['port'],
# config['min_pixels'], config['max_pixels']
)
output = {
'prompt': prompt,
'response': response,
}
else:
st.image(processed_image, width=500)
# Process results
if output:
process_and_display_results(output, processed_image, config)
if __name__ == "__main__":
main()
+41
View File
@@ -0,0 +1,41 @@
import argparse
import os
from openai import OpenAI
from transformers.utils.versions import require_version
from PIL import Image
import io
import base64
from dots_ocr.utils import dict_promptmode_to_prompt
from dots_ocr.model.inference import inference_with_vllm
parser = argparse.ArgumentParser()
parser.add_argument("--ip", type=str, default="localhost")
parser.add_argument("--port", type=str, default="8000")
parser.add_argument("--model_name", type=str, default="model")
parser.add_argument("--prompt_mode", type=str, default="prompt_layout_all_en")
args = parser.parse_args()
require_version("openai>=1.5.0", "To fix: pip install openai>=1.5.0")
def main():
addr = f"http://{args.ip}:{args.port}/v1"
image_path = "demo/demo_image1.jpg"
prompt = dict_promptmode_to_prompt[args.prompt_mode]
image = Image.open(image_path)
response = inference_with_vllm(
image,
prompt,
ip="localhost",
port=8000,
temperature=0.1,
top_p=0.9,
)
print(f"response: {response}")
if __name__ == "__main__":
main()
+17
View File
@@ -0,0 +1,17 @@
# download model to /path/to/model
if [ -z "$NODOWNLOAD" ]; then
python3 tools/download_model.py
fi
# register model to vllm
hf_model_path=./weights/DotsOCR # Path to your downloaded model weights
export PYTHONPATH=$(dirname "$hf_model_path"):$PYTHONPATH
sed -i '/^from vllm\.entrypoints\.cli\.main import main$/a\
from DotsOCR import modeling_dots_ocr_vllm' `which vllm`
# launch vllm server
model_name=model
CUDA_VISIBLE_DEVICES=0 vllm serve ${hf_model_path} --tensor-parallel-size 1 --gpu-memory-utilization 0.95 --chat-template-content-format string --served-model-name ${model_name} --trust-remote-code
# # run python demo after launch vllm server
# python demo/demo_vllm.py
+4
View File
@@ -0,0 +1,4 @@
from vllm/vllm-openai:v0.9.1
RUN pip3 install flash_attn==2.8.0.post2
RUN pip3 install transformers==4.51.3
+1
View File
@@ -0,0 +1 @@
from .parser import DotsOCRParser
+50
View File
@@ -0,0 +1,50 @@
import json
import io
import base64
import math
from PIL import Image
import requests
from dots_ocr.utils.image_utils import PILimage_to_base64
from openai import OpenAI
import os
def inference_with_vllm(
image,
prompt,
ip="localhost",
port=8000,
temperature=0.1,
top_p=0.9,
max_completion_tokens=32768,
model_name='model',
):
addr = f"http://{ip}:{port}/v1"
client = OpenAI(api_key="{}".format(os.environ.get("API_KEY", "0")), base_url=addr)
messages = []
messages.append(
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {"url": PILimage_to_base64(image)},
},
{"type": "text", "text": f"<|img|><|imgpad|><|endofimg|>{prompt}"} # if no "<|img|><|imgpad|><|endofimg|>" here,vllm v1 will add "\n" here
],
}
)
try:
response = client.chat.completions.create(
messages=messages,
model=model_name,
max_completion_tokens=max_completion_tokens,
temperature=temperature,
top_p=top_p)
response = response.choices[0].message.content
return response
except requests.exceptions.RequestException as e:
print(f"request error: {e}")
return None
+349
View File
@@ -0,0 +1,349 @@
import os
import json
from tqdm import tqdm
from multiprocessing.pool import ThreadPool, Pool
import argparse
from dots_ocr.model.inference import inference_with_vllm
from dots_ocr.utils.consts import image_extensions, MIN_PIXELS, MAX_PIXELS
from dots_ocr.utils.image_utils import get_image_by_fitz_doc, fetch_image, smart_resize
from dots_ocr.utils.doc_utils import fitz_doc_to_image, load_images_from_pdf
from dots_ocr.utils.prompts import dict_promptmode_to_prompt
from dots_ocr.utils.layout_utils import post_process_output, draw_layout_on_image, pre_process_bboxes
from dots_ocr.utils.format_transformer import layoutjson2md
class DotsOCRParser:
"""
parse image or pdf file
"""
def __init__(self,
ip='localhost',
port=8000,
model_name='model',
temperature=0.1,
top_p=1.0,
max_completion_tokens=16384,
num_thread=64,
dpi = 200,
output_dir="./output",
min_pixels=None,
max_pixels=None,
):
self.dpi = dpi
# default args for vllm server
self.ip = ip
self.port = port
self.model_name = model_name
# default args for inference
self.temperature = temperature
self.top_p = top_p
self.max_completion_tokens = max_completion_tokens
self.num_thread = num_thread
self.output_dir = output_dir
self.min_pixels = min_pixels
self.max_pixels = max_pixels
assert self.min_pixels is None or self.min_pixels >= MIN_PIXELS
assert self.max_pixels is None or self.max_pixels <= MAX_PIXELS
def _inference_with_vllm(self, image, prompt):
response = inference_with_vllm(
image,
prompt,
model_name=self.model_name,
ip=self.ip,
port=self.port,
temperature=self.temperature,
top_p=self.top_p,
max_completion_tokens=self.max_completion_tokens,
)
return response
def get_prompt(self, prompt_mode, bbox=None, origin_image=None, image=None, min_pixels=None, max_pixels=None):
prompt = dict_promptmode_to_prompt[prompt_mode]
if prompt_mode == 'prompt_grounding_ocr':
assert bbox is not None
bboxes = [bbox]
bbox = pre_process_bboxes(origin_image, bboxes, input_width=image.width, input_height=image.height, min_pixels=min_pixels, max_pixels=max_pixels)[0]
prompt = prompt + str(bbox)
return prompt
# def post_process_results(self, response, prompt_mode, save_dir, save_name, origin_image, image, min_pixels, max_pixels)
def _parse_single_image(
self,
origin_image,
prompt_mode,
save_dir,
save_name,
source="image",
page_idx=0,
bbox=None,
fitz_preprocess=False,
):
min_pixels, max_pixels = self.min_pixels, self.max_pixels
if prompt_mode == "prompt_grounding_ocr":
min_pixels = min_pixels or MIN_PIXELS # preprocess image to the final input
max_pixels = max_pixels or MAX_PIXELS
if min_pixels is not None: assert min_pixels >= MIN_PIXELS, f"min_pixels should >= {MIN_PIXELS}"
if max_pixels is not None: assert max_pixels <= MAX_PIXELS, f"max_pixels should <+ {MAX_PIXELS}"
if source == 'image' and fitz_preprocess:
image = get_image_by_fitz_doc(origin_image, target_dpi=self.dpi)
image = fetch_image(image, min_pixels=min_pixels, max_pixels=max_pixels)
else:
image = fetch_image(origin_image, min_pixels=min_pixels, max_pixels=max_pixels)
input_height, input_width = smart_resize(image.height, image.width)
prompt = self.get_prompt(prompt_mode, bbox, origin_image, image, min_pixels=min_pixels, max_pixels=max_pixels)
response = self._inference_with_vllm(image, prompt)
result = {'page_no': page_idx,
"input_height": input_height,
"input_width": input_width
}
if source == 'pdf':
save_name = f"{save_name}_page_{page_idx}"
if prompt_mode in ['prompt_layout_all_en', 'prompt_layout_only_en', 'prompt_grounding_ocr']:
cells, filtered = post_process_output(
response,
prompt_mode,
origin_image,
image,
min_pixels=min_pixels,
max_pixels=max_pixels,
)
if filtered and prompt_mode != 'prompt_layout_only_en': # model output json failed, use filtered process
json_file_path = os.path.join(save_dir, f"{save_name}.json")
with open(json_file_path, 'w') as w:
json.dump(response, w, ensure_ascii=False)
image_layout_path = os.path.join(save_dir, f"{save_name}.jpg")
origin_image.save(image_layout_path)
result.update({
'layout_info_path': json_file_path,
'layout_image_path': image_layout_path,
})
md_file_path = os.path.join(save_dir, f"{save_name}.md")
with open(md_file_path, "w", encoding="utf-8") as md_file:
md_file.write(cells)
result.update({
'md_content_path': md_file_path
})
result.update({
'filtered': True
})
else:
try:
image_with_layout = draw_layout_on_image(origin_image, cells)
except Exception as e:
print(f"Error drawing layout on image: {e}")
image_with_layout = origin_image
json_file_path = os.path.join(save_dir, f"{save_name}.json")
with open(json_file_path, 'w') as w:
json.dump(cells, w, ensure_ascii=False)
image_layout_path = os.path.join(save_dir, f"{save_name}.jpg")
image_with_layout.save(image_layout_path)
result.update({
'layout_info_path': json_file_path,
'layout_image_path': image_layout_path,
})
if prompt_mode != "prompt_layout_only_en": # no text md when detection only
md_content = layoutjson2md(origin_image, cells, text_key='text')
md_content_no_hf = layoutjson2md(origin_image, cells, text_key='text', no_page_hf=True) # used for clean output or metric of omnidocbench、olmbench
md_file_path = os.path.join(save_dir, f"{save_name}.md")
with open(md_file_path, "w", encoding="utf-8") as md_file:
md_file.write(md_content)
md_nohf_file_path = os.path.join(save_dir, f"{save_name}_nohf.md")
with open(md_nohf_file_path, "w", encoding="utf-8") as md_file:
md_file.write(md_content_no_hf)
result.update({
'md_content_path': md_file_path,
'md_content_nohf_path': md_nohf_file_path,
})
else:
image_layout_path = os.path.join(save_dir, f"{save_name}.jpg")
origin_image.save(image_layout_path)
result.update({
'layout_image_path': image_layout_path,
})
md_content = response
md_file_path = os.path.join(save_dir, f"{save_name}.md")
with open(md_file_path, "w", encoding="utf-8") as md_file:
md_file.write(md_content)
result.update({
'md_content_path': md_file_path,
})
return result
def parse_image(self, input_path, filename, prompt_mode, save_dir, bbox=None, fitz_preprocess=False):
origin_image = fetch_image(input_path)
result = self._parse_single_image(origin_image, prompt_mode, save_dir, filename, source="image", bbox=bbox, fitz_preprocess=fitz_preprocess)
result['file_path'] = input_path
return [result]
def parse_pdf(self, input_path, filename, prompt_mode, save_dir):
print(f"loading pdf: {input_path}")
images_origin = load_images_from_pdf(input_path)
total_pages = len(images_origin)
tasks = [
{
"origin_image": image,
"prompt_mode": prompt_mode,
"save_dir": save_dir,
"save_name": filename,
"source":"pdf",
"page_idx": i,
} for i, image in enumerate(images_origin)
]
def _execute_task(task_args):
return self._parse_single_image(**task_args)
num_thread = min(total_pages, self.num_thread)
print(f"Parsing PDF with {total_pages} pages using {num_thread} threads...")
results = []
with ThreadPool(num_thread) as pool:
with tqdm(total=total_pages, desc="Processing PDF pages") as pbar:
for result in pool.imap_unordered(_execute_task, tasks):
results.append(result)
pbar.update(1)
results.sort(key=lambda x: x["page_no"])
for i in range(len(results)):
results[i]['file_path'] = input_path
return results
def parse_file(self,
input_path,
output_dir="",
prompt_mode="prompt_layout_all_en",
bbox=None,
fitz_preprocess=False
):
output_dir = output_dir or self.output_dir
output_dir = os.path.abspath(output_dir)
filename, file_ext = os.path.splitext(os.path.basename(input_path))
save_dir = os.path.join(output_dir, filename)
os.makedirs(save_dir, exist_ok=True)
if file_ext == '.pdf':
results = self.parse_pdf(input_path, filename, prompt_mode, save_dir)
elif file_ext in image_extensions:
results = self.parse_image(input_path, filename, prompt_mode, save_dir, bbox=bbox, fitz_preprocess=fitz_preprocess)
else:
raise ValueError(f"file extension {file_ext} not supported, supported extensions are {image_extensions} and pdf")
print(f"Parsing finished, results saving to {save_dir}")
with open(os.path.join(output_dir, os.path.basename(filename)+'.jsonl'), 'w') as w:
for result in results:
w.write(json.dumps(result, ensure_ascii=False) + '\n')
return results
def main():
prompts = list(dict_promptmode_to_prompt.keys())
parser = argparse.ArgumentParser(
description="dots.ocr Multilingual Document Layout Parser",
)
parser.add_argument(
"input_path", type=str,
help="Input PDF/image file path"
)
parser.add_argument(
"--output", type=str, default="./output",
help="Output directory (default: ./output)"
)
parser.add_argument(
"--prompt", choices=prompts, type=str, default="prompt_layout_all_en",
help="prompt to query the model, different prompts for different tasks"
)
parser.add_argument(
'--bbox',
type=int,
nargs=4,
metavar=('x1', 'y1', 'x2', 'y2'),
help='should give this argument if you want to prompt_grounding_ocr'
)
parser.add_argument(
"--ip", type=str, default="localhost",
help=""
)
parser.add_argument(
"--port", type=int, default=8000,
help=""
)
parser.add_argument(
"--model_name", type=str, default="model",
help=""
)
parser.add_argument(
"--temperature", type=float, default=0.1,
help=""
)
parser.add_argument(
"--top_p", type=float, default=1.0,
help=""
)
parser.add_argument(
"--dpi", type=int, default=200,
help=""
)
parser.add_argument(
"--max_completion_tokens", type=int, default=16384,
help=""
)
parser.add_argument(
"--num_thread", type=int, default=16,
help=""
)
# parser.add_argument(
# "--fitz_preprocess", type=bool, default=False,
# help="False will use tikz dpi upsample pipeline, good for images which has been render with low dpi, but maybe result in higher computational costs"
# )
parser.add_argument(
"--min_pixels", type=int, default=None,
help=""
)
parser.add_argument(
"--max_pixels", type=int, default=None,
help=""
)
args = parser.parse_args()
dots_ocr_parser = DotsOCRParser(
ip=args.ip,
port=args.port,
model_name=args.model_name,
temperature=args.temperature,
top_p=args.top_p,
max_completion_tokens=args.max_completion_tokens,
num_thread=args.num_thread,
dpi=args.dpi,
output_dir=args.output,
min_pixels=args.min_pixels,
max_pixels=args.max_pixels,
)
result = dots_ocr_parser.parse_file(
args.input_path,
prompt_mode=args.prompt,
bbox=args.bbox,
)
if __name__ == "__main__":
main()
+1
View File
@@ -0,0 +1 @@
from .prompts import dict_promptmode_to_prompt
+5
View File
@@ -0,0 +1,5 @@
MIN_PIXELS=3136
MAX_PIXELS=11289600
IMAGE_FACTOR=28
image_extensions = {'.jpg', '.jpeg', '.png'}
+61
View File
@@ -0,0 +1,61 @@
import os
from PIL import Image
def is_valid_image_path(image_path):
"""
Checks if the image path is valid.
Args:
image_path: The path to the image.
Returns:
bool: True if the path is valid, False otherwise.
"""
if not os.path.exists(image_path):
return False
# Check if the file extension is one of the common image formats.
image_extensions = ['.jpg', '.jpeg', '.png', '.gif', '.bmp']
_, extension = os.path.splitext(image_path)
if extension.lower() in image_extensions:
return True
else:
return False
def read_image(image_path, use_native=False):
"""
Reads an image and resizes it while maintaining aspect ratio.
Args:
image_path: The path to the image.
use_native: If True, the max dimension of the original image is used as the max size.
If False, max size is set to 1024.
Returns:
tuple: (resized_image, original_width, original_height)
"""
# Create a default 512x512 blue image as a fallback.
image = Image.new('RGB', (512, 512), color=(0, 0, 255))
if is_valid_image_path(image_path):
image = Image.open(image_path)
else:
raise FileNotFoundError(f"{image_path}: Image path does not exist")
w, h = image.size
if use_native:
max_size = max(w, h)
else:
max_size = 1024
if w > h:
new_w = max_size
new_h = int(h * max_size / w)
else:
new_h = max_size
new_w = int(w * max_size / h)
image = image.resize((new_w, new_h))
return image, w, h
+60
View File
@@ -0,0 +1,60 @@
import fitz
import numpy as np
import enum
from pydantic import BaseModel, Field
from PIL import Image
class SupportedPdfParseMethod(enum.Enum):
OCR = 'ocr'
TXT = 'txt'
class PageInfo(BaseModel):
"""The width and height of page
"""
w: float = Field(description='the width of page')
h: float = Field(description='the height of page')
def fitz_doc_to_image(doc, target_dpi=200, origin_dpi=None) -> dict:
"""Convert fitz.Document to image, Then convert the image to numpy array.
Args:
doc (_type_): pymudoc page
dpi (int, optional): reset the dpi of dpi. Defaults to 200.
Returns:
dict: {'img': numpy array, 'width': width, 'height': height }
"""
from PIL import Image
mat = fitz.Matrix(target_dpi / 72, target_dpi / 72)
pm = doc.get_pixmap(matrix=mat, alpha=False)
if pm.width > 4500 or pm.height > 4500:
mat = fitz.Matrix(72 / 72, 72 / 72) # use fitz default dpi
pm = doc.get_pixmap(matrix=mat, alpha=False)
image = Image.frombytes('RGB', (pm.width, pm.height), pm.samples)
return image
def load_images_from_pdf(pdf_file, dpi=200, start_page_id=0, end_page_id=None) -> list:
images = []
with fitz.open(pdf_file) as doc:
pdf_page_num = doc.page_count
end_page_id = (
end_page_id
if end_page_id is not None and end_page_id >= 0
else pdf_page_num - 1
)
if end_page_id > pdf_page_num - 1:
print('end_page_id is out of range, use images length')
end_page_id = pdf_page_num - 1
for index in range(0, doc.page_count):
if start_page_id <= index <= end_page_id:
page = doc[index]
img = fitz_doc_to_image(page, target_dpi=dpi)
images.append(img)
return images
+205
View File
@@ -0,0 +1,205 @@
import os
import sys
import json
import re
from PIL import Image
from dots_ocr.utils.image_utils import PILimage_to_base64
def has_latex_markdown(text: str) -> bool:
"""
Checks if a string contains LaTeX markdown patterns.
Args:
text (str): The string to check.
Returns:
bool: True if LaTeX markdown is found, otherwise False.
"""
if not isinstance(text, str):
return False
# Define regular expression patterns for LaTeX markdown
latex_patterns = [
r'\$\$.*?\$\$', # Block-level math formula $$...$$
r'\$[^$\n]+?\$', # Inline math formula $...$
r'\\begin\{.*?\}.*?\\end\{.*?\}', # LaTeX environment \begin{...}...\end{...}
r'\\[a-zA-Z]+\{.*?\}', # LaTeX command \command{...}
r'\\[a-zA-Z]+', # Simple LaTeX command \command
r'\\\[.*?\\\]', # Display math formula \[...\]
r'\\\(.*?\\\)', # Inline math formula \(...\)
]
# Check if any of the patterns match
for pattern in latex_patterns:
if re.search(pattern, text, re.DOTALL):
return True
return False
def clean_latex_preamble(latex_text: str) -> str:
"""
Removes LaTeX preamble commands like document class and package imports.
Args:
latex_text (str): The original LaTeX text.
Returns:
str: The cleaned LaTeX text without preamble commands.
"""
# Define patterns to be removed
patterns = [
r'\\documentclass\{[^}]+\}', # \documentclass{...}
r'\\usepackage\{[^}]+\}', # \usepackage{...}
r'\\usepackage\[[^\]]*\]\{[^}]+\}', # \usepackage[options]{...}
r'\\begin\{document\}', # \begin{document}
r'\\end\{document\}', # \end{document}
]
# Apply each pattern to clean the text
cleaned_text = latex_text
for pattern in patterns:
cleaned_text = re.sub(pattern, '', cleaned_text, flags=re.IGNORECASE)
return cleaned_text
def get_formula_in_markdown(text: str) -> str:
"""
Formats a string containing a formula into a standard Markdown block.
Args:
text (str): The input string, potentially containing a formula.
Returns:
str: The formatted string, ready for Markdown rendering.
"""
# Remove leading/trailing whitespace
text = text.strip()
# Check if it's already enclosed in $$
if text.startswith('$$') and text.endswith('$$'):
text_new = text[2:-2].strip()
if not '$' in text_new:
return f"$$\n{text_new}\n$$"
else:
return text
# Handle \[...\] format, convert to $$...$$
if text.startswith('\\[') and text.endswith('\\]'):
inner_content = text[2:-2].strip()
return f"$$\n{inner_content}\n$$"
# Check if it's enclosed in \[ \]
if len(re.findall(r'.*\\\[.*\\\].*', text)) > 0:
return text
# Handle inline formulas ($...$)
pattern = r'\$([^$]+)\$'
matches = re.findall(pattern, text)
if len(matches) > 0:
# It's an inline formula, return it as is
return text
# If no LaTeX markdown syntax is present, return directly
if not has_latex_markdown(text):
return text
# Handle unnecessary LaTeX formatting like \usepackage
if 'usepackage' in text:
text = clean_latex_preamble(text)
if text[0] == '`' and text[-1] == '`':
text = text[1:-1]
# Enclose the final text in a $$ block with newlines
text = f"$$\n{text}\n$$"
return text
def clean_text(text: str) -> str:
"""
Cleans text by removing extra whitespace.
Args:
text: The original text.
Returns:
str: The cleaned text.
"""
if not text:
return ""
# Remove leading and trailing whitespace
text = text.strip()
# Replace multiple consecutive whitespace characters with a single space
text = re.sub(r'\s+', ' ', text)
return text
def layoutjson2md(image: Image.Image, cells: list, text_key: str = 'text', no_page_hf: bool = False) -> str:
"""
Converts a layout JSON format to Markdown.
In the layout JSON, formulas are LaTeX, tables are HTML, and text is Markdown.
Args:
image: A PIL Image object.
cells: A list of dictionaries, each representing a layout cell.
text_key: The key for the text field in the cell dictionary.
no_page_header_footer: If True, skips page headers and footers.
Returns:
str: The text in Markdown format.
"""
text_items = []
for i, cell in enumerate(cells):
x1, y1, x2, y2 = [int(coord) for coord in cell['bbox']]
text = cell.get(text_key, "")
if no_page_hf and cell['category'] in ['Page-header', 'Page-footer']:
continue
if cell['category'] == 'Picture':
image_crop = image.crop((x1, y1, x2, y2))
image_base64 = PILimage_to_base64(image_crop)
text_items.append(f"![]({image_base64})")
elif cell['category'] == 'Formula':
text_items.append(get_formula_in_markdown(text))
else:
text = clean_text(text)
text_items.append(f"{text}")
markdown_text = '\n\n'.join(text_items)
return markdown_text
def fix_streamlit_formulas(md: str) -> str:
"""
Fixes the format of formulas in Markdown to ensure they display correctly in Streamlit.
It adds a newline after the opening $$ and before the closing $$ if they don't already exist.
Args:
md_text (str): The Markdown text to fix.
Returns:
str: The fixed Markdown text.
"""
# This inner function will be used by re.sub to perform the replacement
def replace_formula(match):
content = match.group(1)
# If the content already has surrounding newlines, don't add more.
if content.startswith('\n'):
content = content[1:]
if content.endswith('\n'):
content = content[:-1]
return f'$$\n{content}\n$$'
# Use regex to find all $$....$$ patterns and replace them using the helper function.
return re.sub(r'\$\$(.*?)\$\$', replace_formula, md, flags=re.DOTALL)
+196
View File
@@ -0,0 +1,196 @@
import math
import base64
from PIL import Image
from typing import Tuple
import os
from dots_ocr.utils.consts import IMAGE_FACTOR, MIN_PIXELS, MAX_PIXELS
from dots_ocr.utils.doc_utils import fitz_doc_to_image
from io import BytesIO
import fitz
import requests
import copy
def round_by_factor(number: int, factor: int) -> int:
"""Returns the closest integer to 'number' that is divisible by 'factor'."""
return round(number / factor) * factor
def ceil_by_factor(number: int, factor: int) -> int:
"""Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
return math.ceil(number / factor) * factor
def floor_by_factor(number: int, factor: int) -> int:
"""Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
return math.floor(number / factor) * factor
def smart_resize(
height: int,
width: int,
factor: int = 28,
min_pixels: int = 3136,
max_pixels: int = 11289600,
):
"""Rescales the image so that the following conditions are met:
1. Both dimensions (height and width) are divisible by 'factor'.
2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
3. The aspect ratio of the image is maintained as closely as possible.
"""
if max(height, width) / min(height, width) > 200:
raise ValueError(
f"absolute aspect ratio must be smaller than 200, got {max(height, width) / min(height, width)}"
)
h_bar = max(factor, round_by_factor(height, factor))
w_bar = max(factor, round_by_factor(width, factor))
if h_bar * w_bar > max_pixels:
beta = math.sqrt((height * width) / max_pixels)
h_bar = max(factor, floor_by_factor(height / beta, factor))
w_bar = max(factor, floor_by_factor(width / beta, factor))
elif h_bar * w_bar < min_pixels:
beta = math.sqrt(min_pixels / (height * width))
h_bar = ceil_by_factor(height * beta, factor)
w_bar = ceil_by_factor(width * beta, factor)
if h_bar * w_bar > max_pixels: # max_pixels first to control the token length
beta = math.sqrt((h_bar * w_bar) / max_pixels)
h_bar = max(factor, floor_by_factor(h_bar / beta, factor))
w_bar = max(factor, floor_by_factor(w_bar / beta, factor))
return h_bar, w_bar
def PILimage_to_base64(image, format='PNG'):
buffered = BytesIO()
image.save(buffered, format=format)
base64_str = base64.b64encode(buffered.getvalue()).decode('utf-8')
return f"data:image;base64,{base64_str}"
def to_rgb(pil_image: Image.Image) -> Image.Image:
if pil_image.mode == 'RGBA':
white_background = Image.new("RGB", pil_image.size, (255, 255, 255))
white_background.paste(pil_image, mask=pil_image.split()[3]) # Use alpha channel as mask
return white_background
else:
return pil_image.convert("RGB")
# copy from https://github.com/QwenLM/Qwen2.5-VL/blob/main/qwen-vl-utils/src/qwen_vl_utils/vision_process.py
def fetch_image(
image,
min_pixels=None,
max_pixels=None,
resized_height=None,
resized_width=None,
) -> Image.Image:
assert image is not None, f"image not found, maybe input format error: {image}"
image_obj = None
if isinstance(image, Image.Image):
image_obj = image
elif image.startswith("http://") or image.startswith("https://"):
# fix memory leak issue while using BytesIO
with requests.get(image, stream=True) as response:
response.raise_for_status()
with BytesIO(response.content) as bio:
image_obj = copy.deepcopy(Image.open(bio))
elif image.startswith("file://"):
image_obj = Image.open(image[7:])
elif image.startswith("data:image"):
if "base64," in image:
_, base64_data = image.split("base64,", 1)
data = base64.b64decode(base64_data)
# fix memory leak issue while using BytesIO
with BytesIO(data) as bio:
image_obj = copy.deepcopy(Image.open(bio))
else:
image_obj = Image.open(image)
if image_obj is None:
raise ValueError(f"Unrecognized image input, support local path, http url, base64 and PIL.Image, got {image}")
image = to_rgb(image_obj)
## resize
if resized_height and resized_width:
resized_height, resized_width = smart_resize(
resized_height,
resized_width,
factor=IMAGE_FACTOR,
)
assert resized_height>0 and resized_width>0, f"resized_height: {resized_height}, resized_width: {resized_width}, min_pixels: {min_pixels}, max_pixels:{max_pixels}, width: {width}, height:{height}, "
image = image.resize((resized_width, resized_height))
elif min_pixels or max_pixels:
width, height = image.size
if not min_pixels:
min_pixels = MIN_PIXELS
if not max_pixels:
max_pixels = MAX_PIXELS
resized_height, resized_width = smart_resize(
height,
width,
factor=IMAGE_FACTOR,
min_pixels=min_pixels,
max_pixels=max_pixels,
)
assert resized_height>0 and resized_width>0, f"resized_height: {resized_height}, resized_width: {resized_width}, min_pixels: {min_pixels}, max_pixels:{max_pixels}, width: {width}, height:{height}, "
image = image.resize((resized_width, resized_height))
return image
def get_input_dimensions(
image: Image.Image,
min_pixels: int,
max_pixels: int,
factor: int = 28
) -> Tuple[int, int]:
"""
Gets the resized dimensions of the input image.
Args:
image: The original image.
min_pixels: The minimum number of pixels.
max_pixels: The maximum number of pixels.
factor: The resizing factor.
Returns:
The resized (width, height).
"""
input_height, input_width = smart_resize(
image.height,
image.width,
factor=factor,
min_pixels=min_pixels,
max_pixels=max_pixels
)
return input_width, input_height
def get_image_by_fitz_doc(image, target_dpi=200):
# get image through fitz, to get target dpi image, mainly for higher image
if not isinstance(image, Image.Image):
assert isinstance(image, str)
_, file_ext = os.path.splitext(image)
assert file_ext in {'.jpg', '.jpeg', '.png'}
if image.startswith("http://") or image.startswith("https://"):
with requests.get(image, stream=True) as response:
response.raise_for_status()
data_bytes = response.content
else:
with open(image, 'rb') as f:
data_bytes = f.read()
image = Image.open(BytesIO(data_bytes))
else:
data_bytes = BytesIO()
image.save(data_bytes, format='PNG')
origin_dpi = image.info.get('dpi', None)
pdf_bytes = fitz.open(stream=data_bytes).convert_to_pdf()
doc = fitz.open('pdf', pdf_bytes)
page = doc[0]
image_fitz = fitz_doc_to_image(page, target_dpi=target_dpi, origin_dpi=origin_dpi)
return image_fitz
+228
View File
@@ -0,0 +1,228 @@
from PIL import Image
from typing import Dict, List
import fitz
from io import BytesIO
import json
from dots_ocr.utils.image_utils import smart_resize
from dots_ocr.utils.consts import MIN_PIXELS, MAX_PIXELS
from dots_ocr.utils.output_cleaner import OutputCleaner
# Define a color map (using RGBA format)
dict_layout_type_to_color = {
"Text": (0, 128, 0, 256), # Green, translucent
"Picture": (255, 0, 255, 256), # Magenta, translucent
"Caption": (255, 165, 0, 256), # Orange, translucent
"Section-header": (0, 255, 255, 256), # Cyan, translucent
"Footnote": (0, 128, 0, 256), # Green, translucent
"Formula": (128, 128, 128, 256), # Gray, translucent
"Table": (255, 192, 203, 256), # Pink, translucent
"Title": (255, 0, 0, 256), # Red, translucent
"List-item": (0, 0, 255, 256), # Blue, translucent
"Page-header": (0, 128, 0, 256), # Green, translucent
"Page-footer": (128, 0, 128, 256), # Purple, translucent
"Other": (165, 42, 42, 256), # Brown, translucent
"Unknown": (0, 0, 0, 0),
}
def draw_layout_on_image(image, cells, resized_height=None, resized_width=None, fill_bbox=True, draw_bbox=True):
"""
Draw transparent boxes on an image.
Args:
image: The source PIL Image.
cells: A list of cells containing bounding box information.
resized_height: The resized height.
resized_width: The resized width.
fill_bbox: Whether to fill the bounding box.
draw_bbox: Whether to draw the bounding box.
Returns:
PIL.Image: The image with drawings.
"""
# origin_image = Image.open(image_path)
original_width, original_height = image.size
# Create a new PDF document
doc = fitz.open()
# Get image information
img_bytes = BytesIO()
image.save(img_bytes, format='PNG')
# pix = fitz.Pixmap(image_path)
pix = fitz.Pixmap(img_bytes)
# Create a page
page = doc.new_page(width=pix.width, height=pix.height)
page.insert_image(
fitz.Rect(0, 0, pix.width, pix.height),
# filename=image_path
pixmap=pix
)
for i, cell in enumerate(cells):
bbox = cell['bbox']
layout_type = cell['category']
order = i
top_left = (bbox[0], bbox[1])
down_right = (bbox[2], bbox[3])
if resized_height and resized_width:
scale_x = resized_width / original_width
scale_y = resized_height / original_height
top_left = (int(bbox[0] / scale_x), int(bbox[1] / scale_y))
down_right = (int(bbox[2] / scale_x), int(bbox[3] / scale_y))
color = dict_layout_type_to_color.get(layout_type, (0, 128, 0, 256))
color = [col/255 for col in color[:3]]
x0, y0, x1, y1 = top_left[0], top_left[1], down_right[0], down_right[1]
rect_coords = fitz.Rect(x0, y0, x1, y1)
if draw_bbox:
if fill_bbox:
page.draw_rect(
rect_coords,
color=None,
fill=color,
fill_opacity=0.3,
width=0.5,
overlay=True,
) # Draw the rectangle
else:
page.draw_rect(
rect_coords,
color=color,
fill=None,
fill_opacity=1,
width=0.5,
overlay=True,
) # Draw the rectangle
order_cate = f"{order}_{layout_type}"
page.insert_text(
(x1, y0 + 20), order_cate, fontsize=20, color=color
) # Insert the index in the top left corner of the rectangle
# Convert to a Pixmap (maintaining original dimensions)
mat = fitz.Matrix(1.0, 1.0)
pix = page.get_pixmap(matrix=mat)
return Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
def pre_process_bboxes(
origin_image,
bboxes,
input_width,
input_height,
factor: int = 28,
min_pixels: int = 3136,
max_pixels: int = 11289600
):
assert isinstance(bboxes, list) and len(bboxes) > 0 and isinstance(bboxes[0], list)
min_pixels = min_pixels or MIN_PIXELS
max_pixels = max_pixels or MAX_PIXELS
original_width, original_height = origin_image.size
input_height, input_width = smart_resize(input_height, input_width, min_pixels=min_pixels, max_pixels=max_pixels)
scale_x = original_width / input_width
scale_y = original_height / input_height
bboxes_out = []
for bbox in bboxes:
bbox_resized = [
int(float(bbox[0]) / scale_x),
int(float(bbox[1]) / scale_y),
int(float(bbox[2]) / scale_x),
int(float(bbox[3]) / scale_y)
]
bboxes_out.append(bbox_resized)
return bboxes_out
def post_process_cells(
origin_image: Image.Image,
cells: List[Dict],
input_width, # server input width, also has smart_resize in server
input_height,
factor: int = 28,
min_pixels: int = 3136,
max_pixels: int = 11289600
) -> List[Dict]:
"""
Post-processes cell bounding boxes, converting coordinates from the resized dimensions back to the original dimensions.
Args:
origin_image: The original PIL Image.
cells: A list of cells containing bounding box information.
input_width: The width of the input image sent to the server.
input_height: The height of the input image sent to the server.
factor: Resizing factor.
min_pixels: Minimum number of pixels.
max_pixels: Maximum number of pixels.
Returns:
A list of post-processed cells.
"""
assert isinstance(cells, list) and len(cells) > 0 and isinstance(cells[0], dict)
min_pixels = min_pixels or MIN_PIXELS
max_pixels = max_pixels or MAX_PIXELS
original_width, original_height = origin_image.size
input_height, input_width = smart_resize(input_height, input_width, min_pixels=min_pixels, max_pixels=max_pixels)
scale_x = input_width / original_width
scale_y = input_height / original_height
cells_out = []
for cell in cells:
bbox = cell['bbox']
bbox_resized = [
int(float(bbox[0]) / scale_x),
int(float(bbox[1]) / scale_y),
int(float(bbox[2]) / scale_x),
int(float(bbox[3]) / scale_y)
]
cell_copy = cell.copy()
cell_copy['bbox'] = bbox_resized
cells_out.append(cell_copy)
return cells_out
def is_legal_bbox(cells):
for cell in cells:
bbox = cell['bbox']
if bbox[2] <= bbox[0] or bbox[3] <= bbox[1]:
return False
return True
def post_process_output(response, prompt_mode, origin_image, input_image, min_pixels=None, max_pixels=None):
if prompt_mode in ["prompt_ocr", "prompt_table_html", "prompt_table_latex", "prompt_formula_latex"]:
return response
json_load_failed = False
cells = response
try:
cells = json.loads(cells)
cells = post_process_cells(
origin_image,
cells,
input_image.width,
input_image.height,
min_pixels=min_pixels,
max_pixels=max_pixels
)
return cells, False
except Exception as e:
print(f"cells post process error: {e}, when using {prompt_mode}")
json_load_failed = True
if json_load_failed:
cleaner = OutputCleaner()
response_clean = cleaner.clean_model_output(cells)
if isinstance(response_clean, list):
response_clean = "\n\n".join([cell['text'] for cell in response_clean if 'text' in cell])
return response_clean, True
+623
View File
@@ -0,0 +1,623 @@
#!/usr/bin/env python3
"""
Data Cleaning Script - Cleans all data using a simplified regex method and saves the results
Features:
1. Cleans all cases using a simplified regex method.
2. Saves the cleaned data for each case.
3. Ensures the relative order of dicts remains unchanged.
4. Generates a before-and-after cleaning report.
"""
import json
import re
import os
from typing import Dict, List, Tuple, Optional, Any
from dataclasses import dataclass
from collections import Counter
import traceback
@dataclass
class CleanedData:
"""Data structure for cleaned data"""
case_id: int
original_type: str # 'list' or 'str'
original_length: int
cleaned_data: List[Dict]
cleaning_operations: Dict[str, Any] # Records the cleaning operations performed
success: bool
class OutputCleaner:
"""Data Cleaner - Based on a simplified regex method"""
def __init__(self):
# Simplified regular expression patterns
self.dict_pattern = re.compile(r'\{[^{}]*?"bbox"\s*:\s*\[[^\]]*?\][^{}]*?\}', re.DOTALL)
self.bbox_pattern = re.compile(r'"bbox"\s*:\s*\[([^\]]+)\]')
self.missing_delimiter_pattern = re.compile(r'\}\s*\{(?!")')
self.cleaned_results: List[CleanedData] = []
def clean_list_data(self, data: List[Dict], case_id: int) -> CleanedData:
"""Cleans list-type data"""
print(f"🔧 Cleaning List data - Case {case_id}")
print(f" Original items: {len(data)}")
cleaned_data = []
operations = {
'type': 'list',
'bbox_fixes': 0,
'removed_items': 0,
'original_count': len(data)
}
for i, item in enumerate(data):
if not isinstance(item, dict):
operations['removed_items'] += 1
continue
# Check the bbox field
if 'bbox' in item:
bbox = item['bbox']
# Check bbox length - core logic
if isinstance(bbox, list) and len(bbox) == 3:
print(f" ⚠️ Item {i}: bbox has only 3 coordinates. Removing bbox, keeping category and text.")
# Keep only category and text, ensuring order is preserved
new_item = {}
if 'category' in item:
new_item['category'] = item['category']
if 'text' in item:
new_item['text'] = item['text']
if new_item: # Add only if there is valid content
cleaned_data.append(new_item)
operations['bbox_fixes'] += 1
else:
operations['removed_items'] += 1
continue
elif isinstance(bbox, list) and len(bbox) == 4:
# bbox is normal, add directly, preserving original order
cleaned_data.append(item.copy())
continue
else:
print(f" ❌ Item {i}: Abnormal bbox format, skipping.")
operations['removed_items'] += 1
continue
else:
# No bbox field, keep if category exists
if 'category' in item:
cleaned_data.append(item.copy())
continue
else:
operations['removed_items'] += 1
operations['final_count'] = len(cleaned_data)
print(f" ✅ Cleaning complete: {len(cleaned_data)} items, {operations['bbox_fixes']} bbox fixes, {operations['removed_items']} items removed")
return CleanedData(
case_id=case_id,
original_type='list',
original_length=len(data),
cleaned_data=cleaned_data,
cleaning_operations=operations,
success=True
)
def clean_string_data(self, data_str: str, case_id: int) -> CleanedData:
"""Cleans string-type data"""
print(f"🔧 Cleaning String data - Case {case_id}")
print(f" Original length: {len(data_str):,}")
operations = {
'type': 'str',
'original_length': len(data_str),
'delimiter_fixes': 0,
'tail_truncated': False,
'truncated_length': 0,
'duplicate_dicts_removed': 0,
'final_objects': 0
}
try:
# Step 1: Detect and fix missing delimiters
data_str, delimiter_fixes = self._fix_missing_delimiters(data_str)
operations['delimiter_fixes'] = delimiter_fixes
# Step 2: Truncate the last incomplete element
data_str, tail_truncated = self._truncate_last_incomplete_element(data_str)
operations['tail_truncated'] = tail_truncated
operations['truncated_length'] = len(data_str)
# Step 3: Remove duplicate complete dict objects, preserving order
data_str, duplicate_removes = self._remove_duplicate_complete_dicts_preserve_order(data_str)
operations['duplicate_dicts_removed'] = duplicate_removes
# Step 4: Ensure correct JSON format
data_str = self._ensure_json_format(data_str)
# Step 5: Try to parse the final result
final_data = self._parse_final_json(data_str)
if final_data is not None:
operations['final_objects'] = len(final_data)
print(f" ✅ Cleaning complete: {len(final_data)} objects")
return CleanedData(
case_id=case_id,
original_type='str',
original_length=operations['original_length'],
cleaned_data=final_data,
cleaning_operations=operations,
success=True
)
else:
raise Exception("Could not parse the cleaned data")
except Exception as e:
print(f" ❌ Cleaning failed: {e}")
return CleanedData(
case_id=case_id,
original_type='str',
original_length=operations['original_length'],
cleaned_data=[],
cleaning_operations=operations,
success=False
)
def _fix_missing_delimiters(self, text: str) -> Tuple[str, int]:
"""Fixes missing delimiters"""
fixes = 0
def replace_delimiter(match):
nonlocal fixes
fixes += 1
return '},{'
text = self.missing_delimiter_pattern.sub(replace_delimiter, text)
if fixes > 0:
print(f" ✅ Fixed {fixes} missing delimiters")
return text, fixes
def _truncate_last_incomplete_element(self, text: str) -> Tuple[str, bool]:
"""Truncates the last incomplete element"""
# For very long text (>50k) or text not ending with ']', directly truncate the last '{"bbox":'
needs_truncation = (
len(text) > 50000 or
not text.strip().endswith(']')
)
if needs_truncation:
# Check how many dict objects there are
bbox_count = text.count('{"bbox":')
# If there is only one dict object, do not truncate to avoid deleting the only object
if bbox_count <= 1:
print(f" ⚠️ Only {bbox_count} dict objects found, skipping truncation to avoid deleting all content")
return text, False
# Find the position of the last '{"bbox":'
last_bbox_pos = text.rfind('{"bbox":')
if last_bbox_pos > 0:
# Truncate before this position
truncated_text = text[:last_bbox_pos].rstrip()
# Remove trailing comma
if truncated_text.endswith(','):
truncated_text = truncated_text[:-1]
print(f" ✂️ Truncated the last incomplete element, length reduced from {len(text):,} to {len(truncated_text):,}")
return truncated_text, True
return text, False
def _remove_duplicate_complete_dicts_preserve_order(self, text: str) -> Tuple[str, int]:
"""Removes duplicate complete dict objects, preserving original order"""
# Extract all dict objects, preserving order
dict_matches = list(self.dict_pattern.finditer(text))
if not dict_matches:
return text, 0
print(f" 📊 Found {len(dict_matches)} dict objects")
# Deduplication while preserving order: only keep the first occurrence of a dict
unique_dicts = []
seen_dict_strings = set()
total_duplicates = 0
for match in dict_matches:
dict_str = match.group()
if dict_str not in seen_dict_strings:
unique_dicts.append(dict_str)
seen_dict_strings.add(dict_str)
else:
total_duplicates += 1
if total_duplicates > 0:
# Reconstruct the JSON array, preserving the original order
new_text = '[' + ', '.join(unique_dicts) + ']'
print(f" ✅ Removed {total_duplicates} duplicate dicts, keeping {len(unique_dicts)} unique dicts (order preserved)")
return new_text, total_duplicates
else:
print(f" ✅ No duplicate dict objects found")
return text, 0
def _ensure_json_format(self, text: str) -> str:
"""Ensures correct JSON format"""
text = text.strip()
if not text.startswith('['):
text = '[' + text
if not text.endswith(']'):
# Remove trailing comma
text = text.rstrip(',').rstrip()
text += ']'
return text
def _parse_final_json(self, text: str) -> Optional[List[Dict]]:
"""Tries to parse the final JSON"""
try:
data = json.loads(text)
if isinstance(data, list):
return data
except json.JSONDecodeError as e:
print(f" ❌ JSON parsing failed: {e}")
# fallback1: Extract valid dict objects
valid_dicts = []
for match in self.dict_pattern.finditer(text):
dict_str = match.group()
try:
dict_obj = json.loads(dict_str)
valid_dicts.append(dict_obj)
except:
continue
if valid_dicts:
print(f" ✅ Extracted {len(valid_dicts)} valid dicts")
return valid_dicts
# fallback2: Special handling for a single incomplete dict
return self._handle_single_incomplete_dict(text)
return None
def _handle_single_incomplete_dict(self, text: str) -> Optional[List[Dict]]:
"""Handles the special case of a single incomplete dict"""
# Check if it's a single incomplete dict case
if not text.strip().startswith('[{"bbox":'):
return None
try:
# Try to extract bbox coordinates
bbox_match = re.search(r'"bbox"\s*:\s*\[([^\]]+)\]', text)
if not bbox_match:
return None
bbox_str = bbox_match.group(1)
bbox_coords = [int(x.strip()) for x in bbox_str.split(',')]
if len(bbox_coords) != 4:
return None
# Try to extract category
category_match = re.search(r'"category"\s*:\s*"([^"]+)"', text)
category = category_match.group(1) if category_match else "Text"
# Try to extract the beginning of the text (first 10000 characters)
text_match = re.search(r'"text"\s*:\s*"([^"]{0,10000})', text)
if text_match:
text_content = text_match.group(1)
else:
text_content = ""
# Construct the fixed dict
fixed_dict = {
"bbox": bbox_coords,
"category": category
}
if text_content:
fixed_dict["text"] = text_content
print(f" 🔧 Special fix: single incomplete dict → {fixed_dict}")
return [fixed_dict]
except Exception as e:
print(f" ❌ Special fix failed: {e}")
return None
def remove_duplicate_category_text_pairs_and_bbox(self, data_list: List[dict], case_id: int) -> List[dict]:
"""Removes duplicate category-text pairs and duplicate bboxes"""
if not data_list or len(data_list) <= 1:
print(f" 📊 Data length {len(data_list)} <= 1, skipping deduplication check")
return data_list
print(f" 📊 Original data length: {len(data_list)}")
# 1. Count occurrences and positions of each category-text pair
category_text_pairs = {}
for i, item in enumerate(data_list):
if isinstance(item, dict) and 'category' in item and 'text' in item:
pair_key = (item.get('category', ''), item.get('text', ''))
if pair_key not in category_text_pairs:
category_text_pairs[pair_key] = []
category_text_pairs[pair_key].append(i)
# 2. Count occurrences and positions of each bbox
bbox_pairs = {}
for i, item in enumerate(data_list):
if isinstance(item, dict) and 'bbox' in item:
bbox = item.get('bbox')
if isinstance(bbox, list) and len(bbox) > 0:
bbox_key = tuple(bbox) # Convert to tuple to use as a dictionary key
if bbox_key not in bbox_pairs:
bbox_pairs[bbox_key] = []
bbox_pairs[bbox_key].append(i)
# 3. Identify items to be removed
duplicates_to_remove = set()
# 3a. Process category-text pairs that appear 5 or more times
for pair_key, positions in category_text_pairs.items():
if len(positions) >= 5:
category, text = pair_key
# Keep the first occurrence, remove subsequent duplicates
positions_to_remove = positions[1:]
duplicates_to_remove.update(positions_to_remove)
print(f" 🔍 Found duplicate category-text pair: category='{category}', first 50 chars of text='{text[:50]}...'")
print(f" Count: {len(positions)}, removing at positions: {positions_to_remove}")
# 3b. Process bboxes that appear 2 or more times
for bbox_key, positions in bbox_pairs.items():
if len(positions) >= 2:
# Keep the first occurrence, remove subsequent duplicates
positions_to_remove = positions[1:]
duplicates_to_remove.update(positions_to_remove)
print(f" 🔍 Found duplicate bbox: {list(bbox_key)}")
print(f" Count: {len(positions)}, removing at positions: {positions_to_remove}")
if not duplicates_to_remove:
print(f" ✅ No category-text pairs or bboxes found exceeding the duplication threshold")
return data_list
# 4. Remove duplicate items from the original data (preserving order)
cleaned_data = []
removed_count = 0
for i, item in enumerate(data_list):
if i not in duplicates_to_remove:
cleaned_data.append(item)
else:
removed_count += 1
print(f" ✅ Deduplication complete: Removed {removed_count} duplicate items")
print(f" 📊 Cleaned data length: {len(cleaned_data)}")
return cleaned_data
def clean_model_output(self, model_output: str):
try:
# Select cleaning method based on data type
if isinstance(model_output, list):
result = self.clean_list_data(model_output, case_id=0)
else:
result = self.clean_string_data(str(model_output), case_id=0)
# Add deduplication step: remove duplicate category-text pairs and bboxes
if result and hasattr(result, 'success') and result.success and result.cleaned_data:
original_data = result.cleaned_data
deduplicated_data = self.remove_duplicate_category_text_pairs_and_bbox(original_data, case_id=0)
# Update the cleaned_data in the CleanedData object
result.cleaned_data = deduplicated_data
return result.cleaned_data
except Exception as e:
print(f"❌ Case cleaning failed: {e}")
return model_output
def clean_all_data(self, jsonl_path: str) -> List[CleanedData]:
"""Cleans all data from a JSONL file"""
print(f"🚀 Starting to clean JSONL file: {jsonl_path}")
with open(jsonl_path, 'r', encoding='utf-8') as f:
lines = f.readlines()
datas = []
for i, line in enumerate(lines):
if line.strip():
try:
data = json.loads(line)
predict_field = data.get('predict')
case_id = i + 1
print(f"\n{'='*50}")
print(f"🎯 Cleaning Case {case_id}")
print(f"{'='*50}")
# Select cleaning method based on data type
if isinstance(predict_field, list):
print("📊 Data type: List")
result = self.clean_list_data(predict_field, case_id)
else:
print("📊 Data type: String")
result = self.clean_string_data(str(predict_field), case_id)
# Add deduplication step: remove duplicate category-text pairs and bboxes
if result and hasattr(result, 'success') and result.success and result.cleaned_data:
print("🔄 Checking for and removing duplicate category-text pairs and bboxes...")
original_data = result.cleaned_data
deduplicated_data = self.remove_duplicate_category_text_pairs_and_bbox(original_data, case_id)
# Update the cleaned_data in the CleanedData object
result.cleaned_data = deduplicated_data
data['predict_resized'] = result.cleaned_data
datas.append(data)
self.cleaned_results.append(result)
except Exception as e:
print(f"❌ Case {i+1} cleaning failed: {e}")
traceback.print_exc()
save_path = jsonl_path.replace('.jsonl', '_filtered.jsonl')
with open(save_path, 'w') as w:
for data in datas:
w.write(json.dumps(data, ensure_ascii=False) + '\n')
print(f"✅ Saved cleaned data to: {save_path}")
return self.cleaned_results
def save_cleaned_data(self, output_dir: str):
"""Saves the cleaned data"""
print(f"\n💾 Saving cleaned data to: {output_dir}")
os.makedirs(output_dir, exist_ok=True)
# 1. Save cleaned data for each case
for result in self.cleaned_results:
case_filename = f"cleaned_case_{result.case_id:02d}.json"
case_filepath = os.path.join(output_dir, case_filename)
# Save the cleaned data
with open(case_filepath, 'w', encoding='utf-8') as f:
json.dump(result.cleaned_data, f, ensure_ascii=False, indent=2)
print(f" ✅ Case {result.case_id}: {len(result.cleaned_data)} objects → {case_filename}")
# 2. Save all cleaned data to a single file
all_cleaned_data = []
for result in self.cleaned_results:
all_cleaned_data.append({
'case_id': result.case_id,
'original_type': result.original_type,
'original_length': result.original_length,
'cleaned_objects_count': len(result.cleaned_data),
'success': result.success,
'cleaning_operations': result.cleaning_operations,
'cleaned_data': result.cleaned_data
})
all_data_filepath = os.path.join(output_dir, "all_cleaned_data.json")
with open(all_data_filepath, 'w', encoding='utf-8') as f:
json.dump(all_cleaned_data, f, ensure_ascii=False, indent=2)
print(f" 📁 All data: {len(all_cleaned_data)} cases → all_cleaned_data.json")
# 3. Generate a cleaning report
self._generate_cleaning_report(output_dir)
def _generate_cleaning_report(self, output_dir: str):
"""Generates a cleaning report"""
report = []
report.append("📊 Data Cleaning Report")
report.append("=" * 60)
import datetime
report.append(f"Processing Time: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
report.append("")
# Overall statistics
total_cases = len(self.cleaned_results)
successful_cases = sum(1 for r in self.cleaned_results if r.success)
total_objects = sum(len(r.cleaned_data) for r in self.cleaned_results)
report.append("📈 Overall Statistics:")
report.append(f" Total Cases: {total_cases}")
report.append(f" Successfully Cleaned: {successful_cases}")
report.append(f" Success Rate: {successful_cases/total_cases*100:.1f}%")
report.append(f" Total Recovered Objects: {total_objects}")
report.append("")
# Detailed statistics
list_results = [r for r in self.cleaned_results if r.original_type == 'list']
str_results = [r for r in self.cleaned_results if r.original_type == 'str']
if list_results:
report.append("📋 List Type Cleaning Statistics:")
for r in list_results:
ops = r.cleaning_operations
report.append(f" Case {r.case_id}: {ops['original_count']}{ops['final_count']} objects")
if ops['bbox_fixes'] > 0:
report.append(f" - bbox fixes: {ops['bbox_fixes']}")
if ops['removed_items'] > 0:
report.append(f" - invalid items removed: {ops['removed_items']}")
report.append("")
if str_results:
report.append("📝 String Type Cleaning Statistics:")
for r in str_results:
ops = r.cleaning_operations
status = "" if r.success else ""
report.append(f" Case {r.case_id} {status}: {ops['original_length']:,} chars → {ops['final_objects']} objects")
details = []
if ops['delimiter_fixes'] > 0:
details.append(f"Delimiter fixes: {ops['delimiter_fixes']}")
if ops['tail_truncated']:
reduction = ops['original_length'] - ops['truncated_length']
details.append(f"Tail truncation: -{reduction:,} chars")
if ops['duplicate_dicts_removed'] > 0:
details.append(f"Duplicates removed: {ops['duplicate_dicts_removed']}")
if details:
report.append(f" - {', '.join(details)}")
report.append("")
# Note on data order
report.append("🔄 Data Order Guarantee:")
report.append(" ✅ The relative order of all dict objects is preserved during cleaning.")
report.append(" ✅ When deduplicating, the first occurrence of a dict is kept, and subsequent duplicates are removed.")
report.append(" ✅ The order of items in List-type data is fully preserved.")
# Save the report
report_filepath = os.path.join(output_dir, "cleaning_report.txt")
with open(report_filepath, 'w', encoding='utf-8') as f:
f.write('\n'.join(report))
print(f" 📋 Cleaning report: cleaning_report.txt")
# Also print to console
print(f"\n{chr(10).join(report)}")
def main():
"""Main function"""
# Create a data cleaner instance
cleaner = OutputCleaner()
# Input file
jsonl_path = "output_with_failcase.jsonl"
# Output directory
output_dir = "output_with_failcase_cleaned"
# Clean all data
results = cleaner.clean_all_data(jsonl_path)
# Save the cleaned data
cleaner.save_cleaned_data(output_dir)
print(f"\n🎉 Data cleaning complete!")
print(f"📁 Cleaned data saved in: {output_dir}")
if __name__ == "__main__":
main()
+34
View File
@@ -0,0 +1,34 @@
dict_promptmode_to_prompt = {
# prompt_layout_all_en: parse all layout info in json format.
"prompt_layout_all_en": """Please output the layout information from the PDF image, including each layout element's bbox, its category, and the corresponding text content within the bbox.
1. Bbox format: [x1, y1, x2, y2]
2. Layout Categories: The possible categories are ['Caption', 'Footnote', 'Formula', 'List-item', 'Page-footer', 'Page-header', 'Picture', 'Section-header', 'Table', 'Text', 'Title'].
3. Text Extraction & Formatting Rules:
- Picture: For the 'Picture' category, the text field should be omitted.
- Formula: Format its text as LaTeX.
- Table: Format its text as HTML.
- All Others (Text, Title, etc.): Format their text as Markdown.
4. Constraints:
- The output text must be the original text from the image, with no translation.
- All layout elements must be sorted according to human reading order.
5. Final Output: The entire output must be a single JSON object.
""",
# prompt_layout_only_en: layout detection
"prompt_layout_only_en": """Please output the layout information from this PDF image, including each layout's bbox and its category. The bbox should be in the format [x1, y1, x2, y2]. The layout categories for the PDF document include ['Caption', 'Footnote', 'Formula', 'List-item', 'Page-footer', 'Page-header', 'Picture', 'Section-header', 'Table', 'Text', 'Title']. Do not output the corresponding text. The layout result should be in JSON format.""",
# prompt_layout_only_en: parse ocr text except the Page-header and Page-footer
"prompt_ocr": """Extract the text content from this image.""",
# prompt_grounding_ocr: extract text content in the given bounding box
"prompt_grounding_ocr": """Extract text from the given bounding box on the image (format: [x1, y1, x2, y2]).\nBounding Box:\n""",
# "prompt_table_html": """Convert the table in this image to HTML.""",
# "prompt_table_latex": """Convert the table in this image to LaTeX.""",
# "prompt_formula_latex": """Convert the formula in this image to LaTeX.""",
}
+11
View File
@@ -0,0 +1,11 @@
# 生产环境依赖
# streamlit
gradio
gradio_image_annotation
PyMuPDF
openai
qwen_vl_utils
transformers==4.51.3
huggingface_hub
flash-attn==2.8.0.post2
accelerate
Executable
+17
View File
@@ -0,0 +1,17 @@
from setuptools import setup, find_packages
# 从requirements.txt文件读取依赖
def parse_requirements(filename):
with open(filename, 'r', encoding='utf-8') as f:
return f.read().splitlines()
setup(
name='dots_ocr',
version='1.0',
packages=find_packages(),
include_package_data=True,
install_requires=parse_requirements('requirements.txt'),
description='dots.ocr: Multilingual Document Layout Parsing in one Vision-Language Model',
url="https://github.com/rednote-hilab/dots.ocr",
python_requires=">=3.10",
)
+19
View File
@@ -0,0 +1,19 @@
from argparse import ArgumentParser
import os
if __name__ == '__main__':
parser = ArgumentParser()
parser.add_argument('--type', '-t', type=str, default="huggingface")
parser.add_argument('--name', '-n', type=str, default="rednote-hilab/dots.ocr")
args = parser.parse_args()
script_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
print(f"Attention: The model save dir dots.ocr should be replace by a name without `.` like DotsOCR, util we merge our code to transformers.")
model_dir = os.path.join(script_dir, "weights/DotsOCR")
if not os.path.exists(model_dir):
os.makedirs(model_dir)
if args.type == "huggingface":
from huggingface_hub import snapshot_download
snapshot_download(repo_id=args.name, local_dir=model_dir, local_dir_use_symlinks=False, resume_download=True)
print(f"model downloaded to {model_dir}")