dots.ocr release
@@ -0,0 +1,123 @@
|
|||||||
|
# Byte-compiled / optimized / DLL files
|
||||||
|
weights/
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
|
||||||
|
# C extensions
|
||||||
|
*.so
|
||||||
|
|
||||||
|
# Distribution / packaging
|
||||||
|
.Python
|
||||||
|
build/
|
||||||
|
develop-eggs/
|
||||||
|
dist/
|
||||||
|
downloads/
|
||||||
|
eggs/
|
||||||
|
.eggs/
|
||||||
|
lib/
|
||||||
|
lib64/
|
||||||
|
parts/
|
||||||
|
sdist/
|
||||||
|
var/
|
||||||
|
wheels/
|
||||||
|
*.egg-info/
|
||||||
|
.installed.cfg
|
||||||
|
*.egg
|
||||||
|
|
||||||
|
# PyInstaller
|
||||||
|
*.manifest
|
||||||
|
*.spec
|
||||||
|
|
||||||
|
# Installer logs
|
||||||
|
pip-log.txt
|
||||||
|
pip-delete-this-directory.txt
|
||||||
|
|
||||||
|
# Unit test / coverage reports
|
||||||
|
htmlcov/
|
||||||
|
.tox/
|
||||||
|
.coverage
|
||||||
|
.coverage.*
|
||||||
|
.cache
|
||||||
|
nosetests.xml
|
||||||
|
coverage.xml
|
||||||
|
*.cover
|
||||||
|
.hypothesis/
|
||||||
|
.pytest_cache/
|
||||||
|
|
||||||
|
# Translations
|
||||||
|
*.mo
|
||||||
|
*.pot
|
||||||
|
|
||||||
|
# Django stuff:
|
||||||
|
*.log
|
||||||
|
local_settings.py
|
||||||
|
db.sqlite3
|
||||||
|
|
||||||
|
# Flask stuff:
|
||||||
|
instance/
|
||||||
|
.webassets-cache
|
||||||
|
|
||||||
|
# Scrapy stuff:
|
||||||
|
.scrapy
|
||||||
|
|
||||||
|
# Sphinx documentation
|
||||||
|
docs/_build/
|
||||||
|
|
||||||
|
# PyBuilder
|
||||||
|
target/
|
||||||
|
|
||||||
|
# Jupyter Notebook
|
||||||
|
.ipynb_checkpoints
|
||||||
|
|
||||||
|
# pyenv
|
||||||
|
.python-version
|
||||||
|
|
||||||
|
# celery beat schedule file
|
||||||
|
celerybeat-schedule
|
||||||
|
|
||||||
|
# SageMath parsed files
|
||||||
|
*.sage.py
|
||||||
|
|
||||||
|
# Environments
|
||||||
|
.env
|
||||||
|
.venv
|
||||||
|
env/
|
||||||
|
venv/
|
||||||
|
ENV/
|
||||||
|
env.bak/
|
||||||
|
venv.bak/
|
||||||
|
|
||||||
|
# Spyder project settings
|
||||||
|
.spyderproject
|
||||||
|
.spyproject
|
||||||
|
|
||||||
|
# Rope project settings
|
||||||
|
.ropeproject
|
||||||
|
|
||||||
|
# mkdocs documentation
|
||||||
|
/site
|
||||||
|
|
||||||
|
# mypy
|
||||||
|
.mypy_cache/
|
||||||
|
.dmypy.json
|
||||||
|
dmypy.json
|
||||||
|
|
||||||
|
# IDEs
|
||||||
|
.vscode/
|
||||||
|
.idea/
|
||||||
|
*.swp
|
||||||
|
*.swo
|
||||||
|
*~
|
||||||
|
|
||||||
|
# MacOS
|
||||||
|
.DS_Store
|
||||||
|
|
||||||
|
# OCR related
|
||||||
|
#*.jpg
|
||||||
|
# *.jpeg
|
||||||
|
#*.png
|
||||||
|
#*.pdf
|
||||||
|
temp/
|
||||||
|
output/
|
||||||
|
# playground/
|
||||||
@@ -0,0 +1,21 @@
|
|||||||
|
MIT License
|
||||||
|
|
||||||
|
Copyright (c) 2025 rednote-hilab
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all
|
||||||
|
copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
SOFTWARE.
|
||||||
|
After Width: | Height: | Size: 63 KiB |
|
After Width: | Height: | Size: 66 KiB |
|
After Width: | Height: | Size: 2.8 MiB |
|
After Width: | Height: | Size: 1.2 MiB |
|
After Width: | Height: | Size: 1.7 MiB |
|
After Width: | Height: | Size: 1.0 MiB |
|
After Width: | Height: | Size: 1013 KiB |
|
After Width: | Height: | Size: 1.8 MiB |
|
After Width: | Height: | Size: 3.7 MiB |
|
After Width: | Height: | Size: 2.8 MiB |
|
After Width: | Height: | Size: 2.9 MiB |
|
After Width: | Height: | Size: 1.4 MiB |
|
After Width: | Height: | Size: 1.7 MiB |
|
After Width: | Height: | Size: 1.4 MiB |
|
After Width: | Height: | Size: 1.8 MiB |
|
After Width: | Height: | Size: 943 KiB |
|
After Width: | Height: | Size: 662 KiB |
|
After Width: | Height: | Size: 292 KiB |
|
After Width: | Height: | Size: 263 KiB |
|
After Width: | Height: | Size: 445 KiB |
|
After Width: | Height: | Size: 1.1 MiB |
|
After Width: | Height: | Size: 673 KiB |
|
After Width: | Height: | Size: 1.7 MiB |
|
After Width: | Height: | Size: 755 KiB |
|
After Width: | Height: | Size: 920 KiB |
|
After Width: | Height: | Size: 2.0 MiB |
|
After Width: | Height: | Size: 937 KiB |
|
After Width: | Height: | Size: 203 KiB |
@@ -0,0 +1,948 @@
|
|||||||
|
"""
|
||||||
|
Layout Inference Web Application with Gradio
|
||||||
|
|
||||||
|
A Gradio-based layout inference tool that supports image uploads and multiple backend inference engines.
|
||||||
|
It adopts a reference-style interface design while preserving the original inference logic.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import gradio as gr
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import io
|
||||||
|
import tempfile
|
||||||
|
import base64
|
||||||
|
import zipfile
|
||||||
|
import uuid
|
||||||
|
import re
|
||||||
|
from pathlib import Path
|
||||||
|
from PIL import Image
|
||||||
|
import requests
|
||||||
|
|
||||||
|
# Local tool imports
|
||||||
|
from dots_ocr.utils import dict_promptmode_to_prompt
|
||||||
|
from dots_ocr.utils.consts import MIN_PIXELS, MAX_PIXELS
|
||||||
|
from dots_ocr.utils.demo_utils.display import read_image
|
||||||
|
from dots_ocr.utils.doc_utils import load_images_from_pdf
|
||||||
|
|
||||||
|
# Add DotsOCRParser import
|
||||||
|
from dots_ocr.parser import DotsOCRParser
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== Configuration ====================
|
||||||
|
DEFAULT_CONFIG = {
|
||||||
|
'ip': "127.0.0.1",
|
||||||
|
'port_vllm': 8000,
|
||||||
|
'min_pixels': MIN_PIXELS,
|
||||||
|
'max_pixels': MAX_PIXELS,
|
||||||
|
'test_images_dir': "./assets/showcase_origin",
|
||||||
|
}
|
||||||
|
|
||||||
|
# ==================== Global Variables ====================
|
||||||
|
# Store current configuration
|
||||||
|
current_config = DEFAULT_CONFIG.copy()
|
||||||
|
|
||||||
|
# Create DotsOCRParser instance
|
||||||
|
dots_parser = DotsOCRParser(
|
||||||
|
ip=DEFAULT_CONFIG['ip'],
|
||||||
|
port=DEFAULT_CONFIG['port_vllm'],
|
||||||
|
dpi=200,
|
||||||
|
min_pixels=DEFAULT_CONFIG['min_pixels'],
|
||||||
|
max_pixels=DEFAULT_CONFIG['max_pixels']
|
||||||
|
)
|
||||||
|
|
||||||
|
# Store processing results
|
||||||
|
processing_results = {
|
||||||
|
'original_image': None,
|
||||||
|
'processed_image': None,
|
||||||
|
'layout_result': None,
|
||||||
|
'markdown_content': None,
|
||||||
|
'cells_data': None,
|
||||||
|
'temp_dir': None,
|
||||||
|
'session_id': None,
|
||||||
|
'result_paths': None,
|
||||||
|
'pdf_results': None # Store multi-page PDF results
|
||||||
|
}
|
||||||
|
|
||||||
|
# PDF caching mechanism
|
||||||
|
pdf_cache = {
|
||||||
|
"images": [],
|
||||||
|
"current_page": 0,
|
||||||
|
"total_pages": 0,
|
||||||
|
"file_type": None, # 'image' or 'pdf'
|
||||||
|
"is_parsed": False, # Whether it has been parsed
|
||||||
|
"results": [] # Store parsing results for each page
|
||||||
|
}
|
||||||
|
|
||||||
|
def read_image_v2(img):
|
||||||
|
"""Reads an image, supports URLs and local paths"""
|
||||||
|
if isinstance(img, str) and img.startswith(("http://", "https://")):
|
||||||
|
with requests.get(img, stream=True) as response:
|
||||||
|
response.raise_for_status()
|
||||||
|
img = Image.open(io.BytesIO(response.content))
|
||||||
|
elif isinstance(img, str):
|
||||||
|
img, _, _ = read_image(img, use_native=True)
|
||||||
|
elif isinstance(img, Image.Image):
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid image type: {type(img)}")
|
||||||
|
return img
|
||||||
|
|
||||||
|
def load_file_for_preview(file_path):
|
||||||
|
"""Loads a file for preview, supports PDF and image files"""
|
||||||
|
global pdf_cache
|
||||||
|
|
||||||
|
if not file_path or not os.path.exists(file_path):
|
||||||
|
return None, "<div id='page_info_box'>0 / 0</div>"
|
||||||
|
|
||||||
|
file_ext = os.path.splitext(file_path)[1].lower()
|
||||||
|
|
||||||
|
if file_ext == '.pdf':
|
||||||
|
try:
|
||||||
|
# Read PDF and convert to images (one image per page)
|
||||||
|
pages = load_images_from_pdf(file_path)
|
||||||
|
pdf_cache["file_type"] = "pdf"
|
||||||
|
except Exception as e:
|
||||||
|
return None, f"<div id='page_info_box'>PDF loading failed: {str(e)}</div>"
|
||||||
|
elif file_ext in ['.jpg', '.jpeg', '.png']:
|
||||||
|
# For image files, read directly as a single-page image
|
||||||
|
try:
|
||||||
|
image = Image.open(file_path)
|
||||||
|
pages = [image]
|
||||||
|
pdf_cache["file_type"] = "image"
|
||||||
|
except Exception as e:
|
||||||
|
return None, f"<div id='page_info_box'>Image loading failed: {str(e)}</div>"
|
||||||
|
else:
|
||||||
|
return None, "<div id='page_info_box'>Unsupported file format</div>"
|
||||||
|
|
||||||
|
pdf_cache["images"] = pages
|
||||||
|
pdf_cache["current_page"] = 0
|
||||||
|
pdf_cache["total_pages"] = len(pages)
|
||||||
|
pdf_cache["is_parsed"] = False
|
||||||
|
pdf_cache["results"] = []
|
||||||
|
|
||||||
|
return pages[0], f"<div id='page_info_box'>1 / {len(pages)}</div>"
|
||||||
|
|
||||||
|
def turn_page(direction):
|
||||||
|
"""Page turning function"""
|
||||||
|
global pdf_cache
|
||||||
|
|
||||||
|
if not pdf_cache["images"]:
|
||||||
|
return None, "<div id='page_info_box'>0 / 0</div>", "", ""
|
||||||
|
|
||||||
|
if direction == "prev":
|
||||||
|
pdf_cache["current_page"] = max(0, pdf_cache["current_page"] - 1)
|
||||||
|
elif direction == "next":
|
||||||
|
pdf_cache["current_page"] = min(pdf_cache["total_pages"] - 1, pdf_cache["current_page"] + 1)
|
||||||
|
|
||||||
|
index = pdf_cache["current_page"]
|
||||||
|
current_image = pdf_cache["images"][index] # Use the original image by default
|
||||||
|
page_info = f"<div id='page_info_box'>{index + 1} / {pdf_cache['total_pages']}</div>"
|
||||||
|
|
||||||
|
# If parsed, display the results for the current page
|
||||||
|
current_md = ""
|
||||||
|
current_md_raw = ""
|
||||||
|
current_json = ""
|
||||||
|
if pdf_cache["is_parsed"] and index < len(pdf_cache["results"]):
|
||||||
|
result = pdf_cache["results"][index]
|
||||||
|
if 'md_content' in result:
|
||||||
|
# Get the raw markdown content
|
||||||
|
current_md_raw = result['md_content']
|
||||||
|
# Process the content after LaTeX rendering
|
||||||
|
current_md = result['md_content'] if result['md_content'] else ""
|
||||||
|
if 'cells_data' in result:
|
||||||
|
try:
|
||||||
|
current_json = json.dumps(result['cells_data'], ensure_ascii=False, indent=2)
|
||||||
|
except:
|
||||||
|
current_json = str(result.get('cells_data', ''))
|
||||||
|
# Use the image with layout boxes (if available)
|
||||||
|
if 'layout_image' in result and result['layout_image']:
|
||||||
|
current_image = result['layout_image']
|
||||||
|
|
||||||
|
return current_image, page_info, current_json
|
||||||
|
|
||||||
|
def get_test_images():
|
||||||
|
"""Gets the list of test images"""
|
||||||
|
test_images = []
|
||||||
|
test_dir = current_config['test_images_dir']
|
||||||
|
if os.path.exists(test_dir):
|
||||||
|
test_images = [os.path.join(test_dir, name) for name in os.listdir(test_dir)
|
||||||
|
if name.lower().endswith(('.png', '.jpg', '.jpeg', '.pdf'))]
|
||||||
|
return test_images
|
||||||
|
|
||||||
|
def convert_image_to_base64(image):
|
||||||
|
"""Converts a PIL image to base64 encoding"""
|
||||||
|
buffered = io.BytesIO()
|
||||||
|
image.save(buffered, format="PNG")
|
||||||
|
img_str = base64.b64encode(buffered.getvalue()).decode()
|
||||||
|
return f"data:image/png;base64,{img_str}"
|
||||||
|
|
||||||
|
def create_temp_session_dir():
|
||||||
|
"""Creates a unique temporary directory for each processing request"""
|
||||||
|
session_id = uuid.uuid4().hex[:8]
|
||||||
|
temp_dir = os.path.join(tempfile.gettempdir(), f"dots_ocr_demo_{session_id}")
|
||||||
|
os.makedirs(temp_dir, exist_ok=True)
|
||||||
|
return temp_dir, session_id
|
||||||
|
|
||||||
|
def parse_image_with_high_level_api(parser, image, prompt_mode, fitz_preprocess=False):
|
||||||
|
"""
|
||||||
|
Processes using the high-level API parse_image from DotsOCRParser
|
||||||
|
"""
|
||||||
|
# Create a temporary session directory
|
||||||
|
temp_dir, session_id = create_temp_session_dir()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Save the PIL Image as a temporary file
|
||||||
|
temp_image_path = os.path.join(temp_dir, f"input_{session_id}.png")
|
||||||
|
image.save(temp_image_path, "PNG")
|
||||||
|
|
||||||
|
# Use the high-level API parse_image
|
||||||
|
filename = f"demo_{session_id}"
|
||||||
|
results = parser.parse_image(
|
||||||
|
# input_path=temp_image_path,
|
||||||
|
input_path=image,
|
||||||
|
filename=filename,
|
||||||
|
prompt_mode=prompt_mode,
|
||||||
|
save_dir=temp_dir,
|
||||||
|
fitz_preprocess=fitz_preprocess
|
||||||
|
)
|
||||||
|
|
||||||
|
# Parse the results
|
||||||
|
if not results:
|
||||||
|
raise ValueError("No results returned from parser")
|
||||||
|
|
||||||
|
result = results[0] # parse_image returns a list with a single result
|
||||||
|
|
||||||
|
# Read the result files
|
||||||
|
layout_image = None
|
||||||
|
cells_data = None
|
||||||
|
md_content = None
|
||||||
|
raw_response = None
|
||||||
|
filtered = False
|
||||||
|
|
||||||
|
# Read the layout image
|
||||||
|
if 'layout_image_path' in result and os.path.exists(result['layout_image_path']):
|
||||||
|
layout_image = Image.open(result['layout_image_path'])
|
||||||
|
|
||||||
|
# Read the JSON data
|
||||||
|
if 'layout_info_path' in result and os.path.exists(result['layout_info_path']):
|
||||||
|
with open(result['layout_info_path'], 'r', encoding='utf-8') as f:
|
||||||
|
cells_data = json.load(f)
|
||||||
|
|
||||||
|
# Read the Markdown content
|
||||||
|
if 'md_content_path' in result and os.path.exists(result['md_content_path']):
|
||||||
|
with open(result['md_content_path'], 'r', encoding='utf-8') as f:
|
||||||
|
md_content = f.read()
|
||||||
|
|
||||||
|
# Check for the raw response file (when JSON parsing fails)
|
||||||
|
if 'filtered' in result:
|
||||||
|
filtered = result['filtered']
|
||||||
|
|
||||||
|
return {
|
||||||
|
'layout_image': layout_image,
|
||||||
|
'cells_data': cells_data,
|
||||||
|
'md_content': md_content,
|
||||||
|
'filtered': filtered,
|
||||||
|
'temp_dir': temp_dir,
|
||||||
|
'session_id': session_id,
|
||||||
|
'result_paths': result,
|
||||||
|
'input_width': result['input_width'],
|
||||||
|
'input_height': result['input_height'],
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# Clean up the temporary directory on error
|
||||||
|
import shutil
|
||||||
|
if os.path.exists(temp_dir):
|
||||||
|
shutil.rmtree(temp_dir, ignore_errors=True)
|
||||||
|
raise e
|
||||||
|
|
||||||
|
def parse_pdf_with_high_level_api(parser, pdf_path, prompt_mode):
|
||||||
|
"""
|
||||||
|
Processes using the high-level API parse_pdf from DotsOCRParser
|
||||||
|
"""
|
||||||
|
# Create a temporary session directory
|
||||||
|
temp_dir, session_id = create_temp_session_dir()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Use the high-level API parse_pdf
|
||||||
|
filename = f"demo_{session_id}"
|
||||||
|
results = parser.parse_pdf(
|
||||||
|
input_path=pdf_path,
|
||||||
|
filename=filename,
|
||||||
|
prompt_mode=prompt_mode,
|
||||||
|
save_dir=temp_dir
|
||||||
|
)
|
||||||
|
|
||||||
|
# Parse the results
|
||||||
|
if not results:
|
||||||
|
raise ValueError("No results returned from parser")
|
||||||
|
|
||||||
|
# Handle multi-page results
|
||||||
|
parsed_results = []
|
||||||
|
all_md_content = []
|
||||||
|
all_cells_data = []
|
||||||
|
|
||||||
|
for i, result in enumerate(results):
|
||||||
|
page_result = {
|
||||||
|
'page_no': result.get('page_no', i),
|
||||||
|
'layout_image': None,
|
||||||
|
'cells_data': None,
|
||||||
|
'md_content': None,
|
||||||
|
'filtered': False
|
||||||
|
}
|
||||||
|
|
||||||
|
# Read the layout image
|
||||||
|
if 'layout_image_path' in result and os.path.exists(result['layout_image_path']):
|
||||||
|
page_result['layout_image'] = Image.open(result['layout_image_path'])
|
||||||
|
|
||||||
|
# Read the JSON data
|
||||||
|
if 'layout_info_path' in result and os.path.exists(result['layout_info_path']):
|
||||||
|
with open(result['layout_info_path'], 'r', encoding='utf-8') as f:
|
||||||
|
page_result['cells_data'] = json.load(f)
|
||||||
|
all_cells_data.extend(page_result['cells_data'])
|
||||||
|
|
||||||
|
# Read the Markdown content
|
||||||
|
if 'md_content_path' in result and os.path.exists(result['md_content_path']):
|
||||||
|
with open(result['md_content_path'], 'r', encoding='utf-8') as f:
|
||||||
|
page_content = f.read()
|
||||||
|
page_result['md_content'] = page_content
|
||||||
|
all_md_content.append(page_content)
|
||||||
|
|
||||||
|
# Check for the raw response file (when JSON parsing fails)
|
||||||
|
page_result['filtered'] = False
|
||||||
|
if 'filtered' in page_result:
|
||||||
|
page_result['filtered'] = page_result['filtered']
|
||||||
|
|
||||||
|
parsed_results.append(page_result)
|
||||||
|
|
||||||
|
# Merge the content of all pages
|
||||||
|
combined_md = "\n\n---\n\n".join(all_md_content) if all_md_content else ""
|
||||||
|
|
||||||
|
return {
|
||||||
|
'parsed_results': parsed_results,
|
||||||
|
'combined_md_content': combined_md,
|
||||||
|
'combined_cells_data': all_cells_data,
|
||||||
|
'temp_dir': temp_dir,
|
||||||
|
'session_id': session_id,
|
||||||
|
'total_pages': len(results)
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# Clean up the temporary directory on error
|
||||||
|
import shutil
|
||||||
|
if os.path.exists(temp_dir):
|
||||||
|
shutil.rmtree(temp_dir, ignore_errors=True)
|
||||||
|
raise e
|
||||||
|
|
||||||
|
# ==================== Core Processing Function ====================
|
||||||
|
def process_image_inference(test_image_input, file_input,
|
||||||
|
prompt_mode, server_ip, server_port, min_pixels, max_pixels,
|
||||||
|
fitz_preprocess=False
|
||||||
|
):
|
||||||
|
"""Core function to handle image/PDF inference"""
|
||||||
|
global current_config, processing_results, dots_parser, pdf_cache
|
||||||
|
|
||||||
|
# First, clean up previous processing results to avoid confusion with the download button
|
||||||
|
if processing_results.get('temp_dir') and os.path.exists(processing_results['temp_dir']):
|
||||||
|
import shutil
|
||||||
|
try:
|
||||||
|
shutil.rmtree(processing_results['temp_dir'], ignore_errors=True)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Failed to clean up previous temporary directory: {e}")
|
||||||
|
|
||||||
|
# Reset processing results
|
||||||
|
processing_results = {
|
||||||
|
'original_image': None,
|
||||||
|
'processed_image': None,
|
||||||
|
'layout_result': None,
|
||||||
|
'markdown_content': None,
|
||||||
|
'cells_data': None,
|
||||||
|
'temp_dir': None,
|
||||||
|
'session_id': None,
|
||||||
|
'result_paths': None,
|
||||||
|
'pdf_results': None
|
||||||
|
}
|
||||||
|
|
||||||
|
# Update configuration
|
||||||
|
current_config.update({
|
||||||
|
'ip': server_ip,
|
||||||
|
'port_vllm': server_port,
|
||||||
|
'min_pixels': min_pixels,
|
||||||
|
'max_pixels': max_pixels
|
||||||
|
})
|
||||||
|
|
||||||
|
# Update parser configuration
|
||||||
|
dots_parser.ip = server_ip
|
||||||
|
dots_parser.port = server_port
|
||||||
|
dots_parser.min_pixels = min_pixels
|
||||||
|
dots_parser.max_pixels = max_pixels
|
||||||
|
|
||||||
|
# Determine the input source
|
||||||
|
input_file_path = None
|
||||||
|
image = None
|
||||||
|
|
||||||
|
# Prioritize file input (supports PDF)
|
||||||
|
if file_input is not None:
|
||||||
|
input_file_path = file_input
|
||||||
|
file_ext = os.path.splitext(input_file_path)[1].lower()
|
||||||
|
|
||||||
|
if file_ext == '.pdf':
|
||||||
|
# PDF file processing
|
||||||
|
try:
|
||||||
|
return process_pdf_file(input_file_path, prompt_mode)
|
||||||
|
except Exception as e:
|
||||||
|
return None, f"PDF processing failed: {e}", "", "", gr.update(value=None), None, ""
|
||||||
|
elif file_ext in ['.jpg', '.jpeg', '.png']:
|
||||||
|
# Image file processing
|
||||||
|
try:
|
||||||
|
image = Image.open(input_file_path)
|
||||||
|
except Exception as e:
|
||||||
|
return None, f"Failed to read image file: {e}", "", "", gr.update(value=None), None, ""
|
||||||
|
|
||||||
|
# If no file input, check the test image input
|
||||||
|
if image is None:
|
||||||
|
if test_image_input and test_image_input != "":
|
||||||
|
file_ext = os.path.splitext(test_image_input)[1].lower()
|
||||||
|
if file_ext == '.pdf':
|
||||||
|
return process_pdf_file(test_image_input, prompt_mode)
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
image = read_image_v2(test_image_input)
|
||||||
|
except Exception as e:
|
||||||
|
return None, f"Failed to read test image: {e}", "", "", gr.update(value=None), gr.update(value=None), None, ""
|
||||||
|
|
||||||
|
if image is None:
|
||||||
|
return None, "Please upload image/PDF file or select test image", "", "", gr.update(value=None), None, ""
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Clear PDF cache (for image processing)
|
||||||
|
pdf_cache["images"] = []
|
||||||
|
pdf_cache["current_page"] = 0
|
||||||
|
pdf_cache["total_pages"] = 0
|
||||||
|
pdf_cache["is_parsed"] = False
|
||||||
|
pdf_cache["results"] = []
|
||||||
|
|
||||||
|
# Process using the high-level API of DotsOCRParser
|
||||||
|
original_image = image
|
||||||
|
parse_result = parse_image_with_high_level_api(dots_parser, image, prompt_mode, fitz_preprocess)
|
||||||
|
|
||||||
|
# Extract parsing results
|
||||||
|
layout_image = parse_result['layout_image']
|
||||||
|
cells_data = parse_result['cells_data']
|
||||||
|
md_content = parse_result['md_content']
|
||||||
|
filtered = parse_result['filtered']
|
||||||
|
|
||||||
|
# Handle parsing failure case
|
||||||
|
if filtered:
|
||||||
|
# JSON parsing failed, only text content is available
|
||||||
|
info_text = f"""
|
||||||
|
**Image Information:**
|
||||||
|
- Original Size: {original_image.width} x {original_image.height}
|
||||||
|
- Processing: JSON parsing failed, using cleaned text output
|
||||||
|
- Server: {current_config['ip']}:{current_config['port_vllm']}
|
||||||
|
- Session ID: {parse_result['session_id']}
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Store results
|
||||||
|
processing_results.update({
|
||||||
|
'original_image': original_image,
|
||||||
|
'processed_image': None,
|
||||||
|
'layout_result': None,
|
||||||
|
'markdown_content': md_content,
|
||||||
|
'cells_data': None,
|
||||||
|
'temp_dir': parse_result['temp_dir'],
|
||||||
|
'session_id': parse_result['session_id'],
|
||||||
|
'result_paths': parse_result['result_paths']
|
||||||
|
})
|
||||||
|
|
||||||
|
return (
|
||||||
|
original_image, # No layout image
|
||||||
|
info_text,
|
||||||
|
md_content,
|
||||||
|
md_content, # Display raw markdown text
|
||||||
|
gr.update(visible=False), # Hide download button
|
||||||
|
None, # Page info
|
||||||
|
"" # Current page JSON output
|
||||||
|
)
|
||||||
|
|
||||||
|
# JSON parsing successful case
|
||||||
|
# Save the raw markdown content (before LaTeX processing)
|
||||||
|
md_content_raw = md_content or "No markdown content generated"
|
||||||
|
|
||||||
|
# Store results
|
||||||
|
processing_results.update({
|
||||||
|
'original_image': original_image,
|
||||||
|
'processed_image': None, # High-level API does not return processed_image
|
||||||
|
'layout_result': layout_image,
|
||||||
|
'markdown_content': md_content,
|
||||||
|
'cells_data': cells_data,
|
||||||
|
'temp_dir': parse_result['temp_dir'],
|
||||||
|
'session_id': parse_result['session_id'],
|
||||||
|
'result_paths': parse_result['result_paths']
|
||||||
|
})
|
||||||
|
|
||||||
|
# Prepare display information
|
||||||
|
num_elements = len(cells_data) if cells_data else 0
|
||||||
|
info_text = f"""
|
||||||
|
**Image Information:**
|
||||||
|
- Original Size: {original_image.width} x {original_image.height}
|
||||||
|
- Model Input Size: {parse_result['input_width']} x {parse_result['input_height']}
|
||||||
|
- Server: {current_config['ip']}:{current_config['port_vllm']}
|
||||||
|
- Detected {num_elements} layout elements
|
||||||
|
- Session ID: {parse_result['session_id']}
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Current page JSON output
|
||||||
|
current_json = ""
|
||||||
|
if cells_data:
|
||||||
|
try:
|
||||||
|
current_json = json.dumps(cells_data, ensure_ascii=False, indent=2)
|
||||||
|
except:
|
||||||
|
current_json = str(cells_data)
|
||||||
|
|
||||||
|
# Create the download ZIP file
|
||||||
|
download_zip_path = None
|
||||||
|
if parse_result['temp_dir']:
|
||||||
|
download_zip_path = os.path.join(parse_result['temp_dir'], f"layout_results_{parse_result['session_id']}.zip")
|
||||||
|
try:
|
||||||
|
with zipfile.ZipFile(download_zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
|
||||||
|
for root, dirs, files in os.walk(parse_result['temp_dir']):
|
||||||
|
for file in files:
|
||||||
|
if file.endswith('.zip'):
|
||||||
|
continue
|
||||||
|
file_path = os.path.join(root, file)
|
||||||
|
arcname = os.path.relpath(file_path, parse_result['temp_dir'])
|
||||||
|
zipf.write(file_path, arcname)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Failed to create download ZIP: {e}")
|
||||||
|
download_zip_path = None
|
||||||
|
|
||||||
|
return (
|
||||||
|
layout_image,
|
||||||
|
info_text,
|
||||||
|
md_content or "No markdown content generated",
|
||||||
|
md_content_raw, # Raw markdown text
|
||||||
|
gr.update(value=download_zip_path, visible=True) if download_zip_path else gr.update(visible=False), # Set the download file
|
||||||
|
None, # Page info (not displayed for image processing)
|
||||||
|
current_json # Current page JSON
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return None, f"Error during processing: {e}", "", "", gr.update(value=None), None, ""
|
||||||
|
|
||||||
|
def process_pdf_file(pdf_path, prompt_mode):
|
||||||
|
"""Dedicated function for processing PDF files"""
|
||||||
|
global pdf_cache, processing_results, dots_parser
|
||||||
|
|
||||||
|
try:
|
||||||
|
# First, load the PDF for preview
|
||||||
|
preview_image, page_info = load_file_for_preview(pdf_path)
|
||||||
|
|
||||||
|
# Parse the PDF using DotsOCRParser
|
||||||
|
pdf_result = parse_pdf_with_high_level_api(dots_parser, pdf_path, prompt_mode)
|
||||||
|
|
||||||
|
# Update the PDF cache
|
||||||
|
pdf_cache["is_parsed"] = True
|
||||||
|
pdf_cache["results"] = pdf_result['parsed_results']
|
||||||
|
|
||||||
|
# Handle LaTeX table rendering
|
||||||
|
combined_md = pdf_result['combined_md_content']
|
||||||
|
combined_md_raw = combined_md or "No markdown content generated" # Save the raw content
|
||||||
|
|
||||||
|
# Store results
|
||||||
|
processing_results.update({
|
||||||
|
'original_image': None,
|
||||||
|
'processed_image': None,
|
||||||
|
'layout_result': None,
|
||||||
|
'markdown_content': combined_md,
|
||||||
|
'cells_data': pdf_result['combined_cells_data'],
|
||||||
|
'temp_dir': pdf_result['temp_dir'],
|
||||||
|
'session_id': pdf_result['session_id'],
|
||||||
|
'result_paths': None,
|
||||||
|
'pdf_results': pdf_result['parsed_results']
|
||||||
|
})
|
||||||
|
|
||||||
|
# Prepare display information
|
||||||
|
total_elements = len(pdf_result['combined_cells_data'])
|
||||||
|
info_text = f"""
|
||||||
|
**PDF Information:**
|
||||||
|
- Total Pages: {pdf_result['total_pages']}
|
||||||
|
- Server: {current_config['ip']}:{current_config['port_vllm']}
|
||||||
|
- Total Detected Elements: {total_elements}
|
||||||
|
- Session ID: {pdf_result['session_id']}
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Content of the current page (first page)
|
||||||
|
current_page_md = ""
|
||||||
|
current_page_md_raw = ""
|
||||||
|
current_page_json = ""
|
||||||
|
current_page_layout_image = preview_image # Use the original preview image by default
|
||||||
|
|
||||||
|
if pdf_cache["results"] and len(pdf_cache["results"]) > 0:
|
||||||
|
current_result = pdf_cache["results"][0]
|
||||||
|
if current_result['md_content']:
|
||||||
|
# Raw markdown content
|
||||||
|
current_page_md_raw = current_result['md_content']
|
||||||
|
# Process the content after LaTeX rendering
|
||||||
|
|
||||||
|
current_page_md = current_result['md_content']
|
||||||
|
if current_result['cells_data']:
|
||||||
|
try:
|
||||||
|
current_page_json = json.dumps(current_result['cells_data'], ensure_ascii=False, indent=2)
|
||||||
|
except:
|
||||||
|
current_page_json = str(current_result['cells_data'])
|
||||||
|
# Use the image with layout boxes (if available)
|
||||||
|
if 'layout_image' in current_result and current_result['layout_image']:
|
||||||
|
current_page_layout_image = current_result['layout_image']
|
||||||
|
|
||||||
|
# Create the download ZIP file
|
||||||
|
download_zip_path = None
|
||||||
|
if pdf_result['temp_dir']:
|
||||||
|
download_zip_path = os.path.join(pdf_result['temp_dir'], f"layout_results_{pdf_result['session_id']}.zip")
|
||||||
|
try:
|
||||||
|
with zipfile.ZipFile(download_zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
|
||||||
|
for root, dirs, files in os.walk(pdf_result['temp_dir']):
|
||||||
|
for file in files:
|
||||||
|
if file.endswith('.zip'):
|
||||||
|
continue
|
||||||
|
file_path = os.path.join(root, file)
|
||||||
|
arcname = os.path.relpath(file_path, pdf_result['temp_dir'])
|
||||||
|
zipf.write(file_path, arcname)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Failed to create download ZIP: {e}")
|
||||||
|
download_zip_path = None
|
||||||
|
|
||||||
|
return (
|
||||||
|
current_page_layout_image, # Use the image with layout boxes
|
||||||
|
info_text,
|
||||||
|
combined_md or "No markdown content generated", # Display the markdown for the entire PDF
|
||||||
|
combined_md_raw or "No markdown content generated", # Display the raw markdown for the entire PDF
|
||||||
|
gr.update(value=download_zip_path, visible=True) if download_zip_path else gr.update(visible=False), # Set the download file
|
||||||
|
page_info,
|
||||||
|
current_page_json
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# Reset the PDF cache
|
||||||
|
pdf_cache["images"] = []
|
||||||
|
pdf_cache["current_page"] = 0
|
||||||
|
pdf_cache["total_pages"] = 0
|
||||||
|
pdf_cache["is_parsed"] = False
|
||||||
|
pdf_cache["results"] = []
|
||||||
|
raise e
|
||||||
|
|
||||||
|
def clear_all_data():
|
||||||
|
"""Clears all data"""
|
||||||
|
global processing_results, pdf_cache
|
||||||
|
|
||||||
|
# Clean up the temporary directory
|
||||||
|
if processing_results.get('temp_dir') and os.path.exists(processing_results['temp_dir']):
|
||||||
|
import shutil
|
||||||
|
try:
|
||||||
|
shutil.rmtree(processing_results['temp_dir'], ignore_errors=True)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Failed to clean up temporary directory: {e}")
|
||||||
|
|
||||||
|
# Reset processing results
|
||||||
|
processing_results = {
|
||||||
|
'original_image': None,
|
||||||
|
'processed_image': None,
|
||||||
|
'layout_result': None,
|
||||||
|
'markdown_content': None,
|
||||||
|
'cells_data': None,
|
||||||
|
'temp_dir': None,
|
||||||
|
'session_id': None,
|
||||||
|
'result_paths': None,
|
||||||
|
'pdf_results': None
|
||||||
|
}
|
||||||
|
|
||||||
|
# Reset the PDF cache
|
||||||
|
pdf_cache = {
|
||||||
|
"images": [],
|
||||||
|
"current_page": 0,
|
||||||
|
"total_pages": 0,
|
||||||
|
"file_type": None,
|
||||||
|
"is_parsed": False,
|
||||||
|
"results": []
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
None, # Clear file input
|
||||||
|
"", # Clear test image selection
|
||||||
|
None, # Clear result image
|
||||||
|
"Waiting for processing results...", # Reset info display
|
||||||
|
"## Waiting for processing results...", # Reset Markdown display
|
||||||
|
"🕐 Waiting for parsing result...", # Clear raw Markdown text
|
||||||
|
gr.update(visible=False), # Hide download button
|
||||||
|
"<div id='page_info_box'>0 / 0</div>", # Reset page info
|
||||||
|
"🕐 Waiting for parsing result..." # Clear current page JSON
|
||||||
|
)
|
||||||
|
|
||||||
|
def update_prompt_display(prompt_mode):
|
||||||
|
"""Updates the prompt display content"""
|
||||||
|
return dict_promptmode_to_prompt[prompt_mode]
|
||||||
|
|
||||||
|
# ==================== Gradio Interface ====================
|
||||||
|
def create_gradio_interface():
|
||||||
|
"""Creates the Gradio interface"""
|
||||||
|
|
||||||
|
# CSS styles, matching the reference style
|
||||||
|
css = """
|
||||||
|
|
||||||
|
#parse_button {
|
||||||
|
background: #FF576D !important; /* !important 确保覆盖主题默认样式 */
|
||||||
|
border-color: #FF576D !important;
|
||||||
|
}
|
||||||
|
/* 鼠标悬停时的颜色 */
|
||||||
|
#parse_button:hover {
|
||||||
|
background: #F72C49 !important;
|
||||||
|
border-color: #F72C49 !important;
|
||||||
|
}
|
||||||
|
|
||||||
|
#page_info_html {
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
justify-content: center;
|
||||||
|
height: 100%;
|
||||||
|
margin: 0 12px;
|
||||||
|
}
|
||||||
|
|
||||||
|
#page_info_box {
|
||||||
|
padding: 8px 20px;
|
||||||
|
font-size: 16px;
|
||||||
|
border: 1px solid #bbb;
|
||||||
|
border-radius: 8px;
|
||||||
|
background-color: #f8f8f8;
|
||||||
|
text-align: center;
|
||||||
|
min-width: 80px;
|
||||||
|
box-shadow: 0 1px 3px rgba(0,0,0,0.1);
|
||||||
|
}
|
||||||
|
|
||||||
|
#markdown_output {
|
||||||
|
min-height: 800px;
|
||||||
|
overflow: auto;
|
||||||
|
}
|
||||||
|
|
||||||
|
footer {
|
||||||
|
visibility: hidden;
|
||||||
|
}
|
||||||
|
|
||||||
|
#info_box {
|
||||||
|
padding: 10px;
|
||||||
|
background-color: #f8f9fa;
|
||||||
|
border-radius: 8px;
|
||||||
|
border: 1px solid #dee2e6;
|
||||||
|
margin: 10px 0;
|
||||||
|
font-size: 14px;
|
||||||
|
}
|
||||||
|
|
||||||
|
#result_image {
|
||||||
|
border-radius: 8px;
|
||||||
|
}
|
||||||
|
|
||||||
|
#markdown_tabs {
|
||||||
|
height: 100%;
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
with gr.Blocks(theme="ocean", css=css, title='dots.ocr') as demo:
|
||||||
|
|
||||||
|
# Title
|
||||||
|
gr.HTML("""
|
||||||
|
<div style="display: flex; align-items: center; justify-content: center; margin-bottom: 20px;">
|
||||||
|
<h1 style="margin: 0; font-size: 2em;">🔍 dots.ocr</h1>
|
||||||
|
</div>
|
||||||
|
<div style="text-align: center; margin-bottom: 10px;">
|
||||||
|
<em>Supports image/PDF layout analysis and structured output</em>
|
||||||
|
</div>
|
||||||
|
""")
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
# Left side: Input and Configuration
|
||||||
|
with gr.Column(scale=1, elem_id="left-panel"):
|
||||||
|
gr.Markdown("### 📥 Upload & Select")
|
||||||
|
file_input = gr.File(
|
||||||
|
label="Upload PDF/Image",
|
||||||
|
type="filepath",
|
||||||
|
file_types=[".pdf", ".jpg", ".jpeg", ".png"],
|
||||||
|
)
|
||||||
|
|
||||||
|
test_images = get_test_images()
|
||||||
|
test_image_input = gr.Dropdown(
|
||||||
|
label="Or Select an Example",
|
||||||
|
choices=[""] + test_images,
|
||||||
|
value="",
|
||||||
|
)
|
||||||
|
|
||||||
|
gr.Markdown("### ⚙️ Prompt & Actions")
|
||||||
|
prompt_mode = gr.Dropdown(
|
||||||
|
label="Select Prompt",
|
||||||
|
choices=["prompt_layout_all_en", "prompt_layout_only_en", "prompt_ocr"],
|
||||||
|
value="prompt_layout_all_en",
|
||||||
|
show_label=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Display current prompt content
|
||||||
|
prompt_display = gr.Textbox(
|
||||||
|
label="Current Prompt Content",
|
||||||
|
value=dict_promptmode_to_prompt[list(dict_promptmode_to_prompt.keys())[0]],
|
||||||
|
lines=4,
|
||||||
|
max_lines=8,
|
||||||
|
interactive=False,
|
||||||
|
show_copy_button=True
|
||||||
|
)
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
process_btn = gr.Button("🔍 Parse", variant="primary", scale=2, elem_id="parse_button")
|
||||||
|
clear_btn = gr.Button("🗑️ Clear", variant="secondary", scale=1)
|
||||||
|
|
||||||
|
with gr.Accordion("🛠️ Advanced Configuration", open=False):
|
||||||
|
fitz_preprocess = gr.Checkbox(
|
||||||
|
label="Enable fitz_preprocess for images",
|
||||||
|
value=True,
|
||||||
|
info="Processes image via a PDF-like pipeline (image->pdf->200dpi image). Recommended if your image DPI is low."
|
||||||
|
)
|
||||||
|
with gr.Row():
|
||||||
|
server_ip = gr.Textbox(label="Server IP", value=DEFAULT_CONFIG['ip'])
|
||||||
|
server_port = gr.Number(label="Port", value=DEFAULT_CONFIG['port_vllm'], precision=0)
|
||||||
|
with gr.Row():
|
||||||
|
min_pixels = gr.Number(label="Min Pixels", value=DEFAULT_CONFIG['min_pixels'], precision=0)
|
||||||
|
max_pixels = gr.Number(label="Max Pixels", value=DEFAULT_CONFIG['max_pixels'], precision=0)
|
||||||
|
# Right side: Result Display
|
||||||
|
with gr.Column(scale=6, variant="compact"):
|
||||||
|
with gr.Row():
|
||||||
|
# Result Image
|
||||||
|
with gr.Column(scale=3):
|
||||||
|
gr.Markdown("### 👁️ File Preview")
|
||||||
|
result_image = gr.Image(
|
||||||
|
label="Layout Preview",
|
||||||
|
visible=True,
|
||||||
|
height=800,
|
||||||
|
show_label=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# Page navigation (shown during PDF preview)
|
||||||
|
with gr.Row():
|
||||||
|
prev_btn = gr.Button("⬅ Previous", size="sm")
|
||||||
|
page_info = gr.HTML(
|
||||||
|
value="<div id='page_info_box'>0 / 0</div>",
|
||||||
|
elem_id="page_info_html"
|
||||||
|
)
|
||||||
|
next_btn = gr.Button("Next ➡", size="sm")
|
||||||
|
|
||||||
|
# Info Display
|
||||||
|
info_display = gr.Markdown(
|
||||||
|
"Waiting for processing results...",
|
||||||
|
elem_id="info_box"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Markdown Result
|
||||||
|
with gr.Column(scale=3):
|
||||||
|
gr.Markdown("### ✔️ Result Display")
|
||||||
|
|
||||||
|
with gr.Tabs(elem_id="markdown_tabs"):
|
||||||
|
with gr.TabItem("Markdown Render Preview"):
|
||||||
|
md_output = gr.Markdown(
|
||||||
|
"## Please click the parse button to parse or select for single-task recognition...",
|
||||||
|
label="Markdown Preview",
|
||||||
|
max_height=600,
|
||||||
|
latex_delimiters=[
|
||||||
|
{"left": "$$", "right": "$$", "display": True},
|
||||||
|
{"left": "$", "right": "$", "display": False},
|
||||||
|
],
|
||||||
|
show_copy_button=False,
|
||||||
|
elem_id="markdown_output"
|
||||||
|
)
|
||||||
|
|
||||||
|
with gr.TabItem("Markdown Raw Text"):
|
||||||
|
md_raw_output = gr.Textbox(
|
||||||
|
value="🕐 Waiting for parsing result...",
|
||||||
|
label="Markdown Raw Text",
|
||||||
|
max_lines=100,
|
||||||
|
lines=38,
|
||||||
|
show_copy_button=True,
|
||||||
|
elem_id="markdown_output",
|
||||||
|
show_label=False
|
||||||
|
)
|
||||||
|
|
||||||
|
with gr.TabItem("Current Page JSON"):
|
||||||
|
current_page_json = gr.Textbox(
|
||||||
|
value="🕐 Waiting for parsing result...",
|
||||||
|
label="Current Page JSON",
|
||||||
|
max_lines=100,
|
||||||
|
lines=38,
|
||||||
|
show_copy_button=True,
|
||||||
|
elem_id="markdown_output",
|
||||||
|
show_label=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# Download Button
|
||||||
|
with gr.Row():
|
||||||
|
download_btn = gr.DownloadButton(
|
||||||
|
"⬇️ Download Results",
|
||||||
|
visible=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# When the prompt mode changes, update the display content
|
||||||
|
prompt_mode.change(
|
||||||
|
fn=update_prompt_display,
|
||||||
|
inputs=prompt_mode,
|
||||||
|
outputs=prompt_display,
|
||||||
|
show_progress=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# Show preview on file upload
|
||||||
|
file_input.upload(
|
||||||
|
fn=load_file_for_preview,
|
||||||
|
inputs=file_input,
|
||||||
|
outputs=[result_image, page_info],
|
||||||
|
show_progress=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# Page navigation
|
||||||
|
prev_btn.click(
|
||||||
|
fn=lambda: turn_page("prev"),
|
||||||
|
outputs=[result_image, page_info, current_page_json],
|
||||||
|
show_progress=False
|
||||||
|
)
|
||||||
|
|
||||||
|
next_btn.click(
|
||||||
|
fn=lambda: turn_page("next"),
|
||||||
|
outputs=[result_image, page_info, current_page_json],
|
||||||
|
show_progress=False
|
||||||
|
)
|
||||||
|
|
||||||
|
process_btn.click(
|
||||||
|
fn=process_image_inference,
|
||||||
|
inputs=[
|
||||||
|
test_image_input, file_input,
|
||||||
|
prompt_mode, server_ip, server_port, min_pixels, max_pixels,
|
||||||
|
fitz_preprocess
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
result_image, info_display, md_output, md_raw_output,
|
||||||
|
download_btn, page_info, current_page_json
|
||||||
|
],
|
||||||
|
show_progress=True
|
||||||
|
)
|
||||||
|
|
||||||
|
clear_btn.click(
|
||||||
|
fn=clear_all_data,
|
||||||
|
outputs=[
|
||||||
|
file_input, test_image_input,
|
||||||
|
result_image, info_display, md_output, md_raw_output,
|
||||||
|
download_btn, page_info, current_page_json
|
||||||
|
],
|
||||||
|
show_progress=False
|
||||||
|
)
|
||||||
|
|
||||||
|
return demo
|
||||||
|
|
||||||
|
# ==================== Main Program ====================
|
||||||
|
if __name__ == "__main__":
|
||||||
|
demo = create_gradio_interface()
|
||||||
|
demo.queue().launch(
|
||||||
|
server_name="0.0.0.0",
|
||||||
|
server_port=7860,
|
||||||
|
debug=True
|
||||||
|
)
|
||||||
@@ -0,0 +1,666 @@
|
|||||||
|
"""
|
||||||
|
Layout Inference Web Application with Gradio - Annotation Version
|
||||||
|
|
||||||
|
A Gradio-based layout inference tool that supports image uploads and multiple backend inference engines.
|
||||||
|
This version adds an image annotation feature, allowing users to draw bounding boxes on an image and send both the image and the boxes to the model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import gradio as gr
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import io
|
||||||
|
import tempfile
|
||||||
|
import base64
|
||||||
|
import zipfile
|
||||||
|
import uuid
|
||||||
|
import re
|
||||||
|
from pathlib import Path
|
||||||
|
from PIL import Image
|
||||||
|
import requests
|
||||||
|
from gradio_image_annotation import image_annotator
|
||||||
|
|
||||||
|
# Local utility imports
|
||||||
|
from dots_ocr.utils import dict_promptmode_to_prompt
|
||||||
|
from dots_ocr.utils.consts import MIN_PIXELS, MAX_PIXELS
|
||||||
|
from dots_ocr.utils.demo_utils.display import read_image
|
||||||
|
from dots_ocr.utils.doc_utils import load_images_from_pdf
|
||||||
|
|
||||||
|
# Add DotsOCRParser import
|
||||||
|
from dots_ocr.parser import DotsOCRParser
|
||||||
|
|
||||||
|
# ==================== Configuration ====================
|
||||||
|
DEFAULT_CONFIG = {
|
||||||
|
'ip': "127.0.0.1",
|
||||||
|
'port_vllm': 8000,
|
||||||
|
'min_pixels': MIN_PIXELS,
|
||||||
|
'max_pixels': MAX_PIXELS,
|
||||||
|
'test_images_dir': "./assets/showcase_origin",
|
||||||
|
}
|
||||||
|
|
||||||
|
# ==================== Global Variables ====================
|
||||||
|
# Store the current configuration
|
||||||
|
current_config = DEFAULT_CONFIG.copy()
|
||||||
|
|
||||||
|
# Create a DotsOCRParser instance
|
||||||
|
dots_parser = DotsOCRParser(
|
||||||
|
ip=DEFAULT_CONFIG['ip'],
|
||||||
|
port=DEFAULT_CONFIG['port_vllm'],
|
||||||
|
dpi=200,
|
||||||
|
min_pixels=DEFAULT_CONFIG['min_pixels'],
|
||||||
|
max_pixels=DEFAULT_CONFIG['max_pixels']
|
||||||
|
)
|
||||||
|
|
||||||
|
# Store processing results
|
||||||
|
processing_results = {
|
||||||
|
'original_image': None,
|
||||||
|
'processed_image': None,
|
||||||
|
'layout_result': None,
|
||||||
|
'markdown_content': None,
|
||||||
|
'cells_data': None,
|
||||||
|
'temp_dir': None,
|
||||||
|
'session_id': None,
|
||||||
|
'result_paths': None,
|
||||||
|
'annotation_data': None # Store annotation data
|
||||||
|
}
|
||||||
|
|
||||||
|
# ==================== Utility Functions ====================
|
||||||
|
def read_image_v2(img):
|
||||||
|
"""Reads an image, supporting URLs and local paths."""
|
||||||
|
if isinstance(img, str) and img.startswith(("http://", "https://")):
|
||||||
|
with requests.get(img, stream=True) as response:
|
||||||
|
response.raise_for_status()
|
||||||
|
img = Image.open(io.BytesIO(response.content))
|
||||||
|
elif isinstance(img, str):
|
||||||
|
img, _, _ = read_image(img, use_native=True)
|
||||||
|
elif isinstance(img, Image.Image):
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid image type: {type(img)}")
|
||||||
|
return img
|
||||||
|
|
||||||
|
def get_test_images():
|
||||||
|
"""Gets the list of test images."""
|
||||||
|
test_images = []
|
||||||
|
test_dir = current_config['test_images_dir']
|
||||||
|
if os.path.exists(test_dir):
|
||||||
|
test_images = [os.path.join(test_dir, name) for name in os.listdir(test_dir)
|
||||||
|
if name.lower().endswith(('.png', '.jpg', '.jpeg'))]
|
||||||
|
return test_images
|
||||||
|
|
||||||
|
def create_temp_session_dir():
|
||||||
|
"""Creates a unique temporary directory for each processing request."""
|
||||||
|
session_id = uuid.uuid4().hex[:8]
|
||||||
|
temp_dir = os.path.join(tempfile.gettempdir(), f"dots_ocr_demo_{session_id}")
|
||||||
|
os.makedirs(temp_dir, exist_ok=True)
|
||||||
|
return temp_dir, session_id
|
||||||
|
|
||||||
|
def parse_image_with_bbox(parser, image, prompt_mode, bbox=None, fitz_preprocess=False):
|
||||||
|
"""
|
||||||
|
Processes an image using DotsOCRParser, with support for the bbox parameter.
|
||||||
|
"""
|
||||||
|
# Create a temporary session directory
|
||||||
|
temp_dir, session_id = create_temp_session_dir()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Save the PIL Image to a temporary file
|
||||||
|
temp_image_path = os.path.join(temp_dir, f"input_{session_id}.png")
|
||||||
|
image.save(temp_image_path, "PNG")
|
||||||
|
|
||||||
|
# Use the high-level parse_image interface, passing the bbox parameter
|
||||||
|
filename = f"demo_{session_id}"
|
||||||
|
results = parser.parse_image(
|
||||||
|
input_path=temp_image_path,
|
||||||
|
filename=filename,
|
||||||
|
prompt_mode=prompt_mode,
|
||||||
|
save_dir=temp_dir,
|
||||||
|
bbox=bbox,
|
||||||
|
fitz_preprocess=fitz_preprocess
|
||||||
|
)
|
||||||
|
|
||||||
|
# Parse the results
|
||||||
|
if not results:
|
||||||
|
raise ValueError("No results returned from parser")
|
||||||
|
|
||||||
|
result = results[0] # parse_image returns a list with a single result
|
||||||
|
|
||||||
|
# Read the result files
|
||||||
|
layout_image = None
|
||||||
|
cells_data = None
|
||||||
|
md_content = None
|
||||||
|
filtered = False
|
||||||
|
|
||||||
|
# Read the layout image
|
||||||
|
if 'layout_image_path' in result and os.path.exists(result['layout_image_path']):
|
||||||
|
layout_image = Image.open(result['layout_image_path'])
|
||||||
|
|
||||||
|
# Read the JSON data
|
||||||
|
if 'layout_info_path' in result and os.path.exists(result['layout_info_path']):
|
||||||
|
with open(result['layout_info_path'], 'r', encoding='utf-8') as f:
|
||||||
|
cells_data = json.load(f)
|
||||||
|
|
||||||
|
# Read the Markdown content
|
||||||
|
if 'md_content_path' in result and os.path.exists(result['md_content_path']):
|
||||||
|
with open(result['md_content_path'], 'r', encoding='utf-8') as f:
|
||||||
|
md_content = f.read()
|
||||||
|
|
||||||
|
# Check for the original response file (if JSON parsing fails)
|
||||||
|
if 'filtered' in result:
|
||||||
|
filtered = result['filtered']
|
||||||
|
|
||||||
|
return {
|
||||||
|
'layout_image': layout_image,
|
||||||
|
'cells_data': cells_data,
|
||||||
|
'md_content': md_content,
|
||||||
|
'filtered': filtered,
|
||||||
|
'temp_dir': temp_dir,
|
||||||
|
'session_id': session_id,
|
||||||
|
'result_paths': result
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# Clean up the temporary directory on error
|
||||||
|
import shutil
|
||||||
|
if os.path.exists(temp_dir):
|
||||||
|
shutil.rmtree(temp_dir, ignore_errors=True)
|
||||||
|
raise e
|
||||||
|
|
||||||
|
def process_annotation_data(annotation_data):
|
||||||
|
"""Processes annotation data, converting it to the format required by the model."""
|
||||||
|
if not annotation_data or not annotation_data.get('boxes'):
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
# Get image and box data
|
||||||
|
image = annotation_data.get('image')
|
||||||
|
boxes = annotation_data.get('boxes', [])
|
||||||
|
|
||||||
|
if not boxes:
|
||||||
|
return image, None
|
||||||
|
|
||||||
|
# Ensure the image is in PIL Image format
|
||||||
|
if image is not None:
|
||||||
|
import numpy as np
|
||||||
|
if isinstance(image, np.ndarray):
|
||||||
|
image = Image.fromarray(image)
|
||||||
|
elif not isinstance(image, Image.Image):
|
||||||
|
# If it's another format, try to convert it
|
||||||
|
try:
|
||||||
|
image = Image.open(image) if isinstance(image, str) else Image.fromarray(image)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Image format conversion failed: {e}")
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
# Get the coordinate information of the box (only one box)
|
||||||
|
box = boxes[0]
|
||||||
|
bbox = [box['xmin'], box['ymin'], box['xmax'], box['ymax']]
|
||||||
|
|
||||||
|
return image, bbox
|
||||||
|
|
||||||
|
# ==================== Core Processing Function ====================
|
||||||
|
def process_image_inference_with_annotation(annotation_data, test_image_input,
|
||||||
|
prompt_mode, server_ip, server_port, min_pixels, max_pixels,
|
||||||
|
fitz_preprocess=False
|
||||||
|
):
|
||||||
|
"""Core function for image inference, supporting annotation data."""
|
||||||
|
global current_config, processing_results, dots_parser
|
||||||
|
|
||||||
|
# First, clean up previous processing results
|
||||||
|
if processing_results.get('temp_dir') and os.path.exists(processing_results['temp_dir']):
|
||||||
|
import shutil
|
||||||
|
try:
|
||||||
|
shutil.rmtree(processing_results['temp_dir'], ignore_errors=True)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Failed to clean up previous temporary directory: {e}")
|
||||||
|
|
||||||
|
# Reset processing results
|
||||||
|
processing_results = {
|
||||||
|
'original_image': None,
|
||||||
|
'processed_image': None,
|
||||||
|
'layout_result': None,
|
||||||
|
'markdown_content': None,
|
||||||
|
'cells_data': None,
|
||||||
|
'temp_dir': None,
|
||||||
|
'session_id': None,
|
||||||
|
'result_paths': None,
|
||||||
|
'annotation_data': annotation_data
|
||||||
|
}
|
||||||
|
|
||||||
|
# Update configuration
|
||||||
|
current_config.update({
|
||||||
|
'ip': server_ip,
|
||||||
|
'port_vllm': server_port,
|
||||||
|
'min_pixels': min_pixels,
|
||||||
|
'max_pixels': max_pixels
|
||||||
|
})
|
||||||
|
|
||||||
|
# Update parser configuration
|
||||||
|
dots_parser.ip = server_ip
|
||||||
|
dots_parser.port = server_port
|
||||||
|
dots_parser.min_pixels = min_pixels
|
||||||
|
dots_parser.max_pixels = max_pixels
|
||||||
|
|
||||||
|
# Determine the input source and process annotation data
|
||||||
|
image = None
|
||||||
|
bbox = None
|
||||||
|
|
||||||
|
# Prioritize processing annotation data
|
||||||
|
if annotation_data and annotation_data.get('image') is not None:
|
||||||
|
image, bbox = process_annotation_data(annotation_data)
|
||||||
|
if image is not None:
|
||||||
|
# If there's a bbox, force the use of 'prompt_grounding_ocr' mode
|
||||||
|
assert bbox is not None
|
||||||
|
prompt_mode = "prompt_grounding_ocr"
|
||||||
|
|
||||||
|
# If there's no annotation data, check the test image input
|
||||||
|
if image is None and test_image_input and test_image_input != "":
|
||||||
|
try:
|
||||||
|
image = read_image_v2(test_image_input)
|
||||||
|
except Exception as e:
|
||||||
|
return None, f"Failed to read test image: {e}", "", "", gr.update(value=None), ""
|
||||||
|
|
||||||
|
if image is None:
|
||||||
|
return None, "Please select a test image or add an image in the annotation component", "", "", gr.update(value=None), ""
|
||||||
|
if bbox is None:
|
||||||
|
return "Please select a bounding box by mouse", "Please select a bounding box by mouse", "", "", gr.update(value=None)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Process using DotsOCRParser, passing the bbox parameter
|
||||||
|
original_image = image
|
||||||
|
parse_result = parse_image_with_bbox(dots_parser, image, prompt_mode, bbox, fitz_preprocess)
|
||||||
|
|
||||||
|
# Extract parsing results
|
||||||
|
layout_image = parse_result['layout_image']
|
||||||
|
cells_data = parse_result['cells_data']
|
||||||
|
md_content = parse_result['md_content']
|
||||||
|
filtered = parse_result['filtered']
|
||||||
|
|
||||||
|
# Store the results
|
||||||
|
processing_results.update({
|
||||||
|
'original_image': original_image,
|
||||||
|
'processed_image': None,
|
||||||
|
'layout_result': layout_image,
|
||||||
|
'markdown_content': md_content,
|
||||||
|
'cells_data': cells_data,
|
||||||
|
'temp_dir': parse_result['temp_dir'],
|
||||||
|
'session_id': parse_result['session_id'],
|
||||||
|
'result_paths': parse_result['result_paths'],
|
||||||
|
'annotation_data': annotation_data
|
||||||
|
})
|
||||||
|
|
||||||
|
# Handle the case where parsing fails
|
||||||
|
if filtered:
|
||||||
|
info_text = f"""
|
||||||
|
**Image Information:**
|
||||||
|
- Original Dimensions: {original_image.width} x {original_image.height}
|
||||||
|
- Processing Mode: {'Region OCR' if bbox else 'Full Image OCR'}
|
||||||
|
- Processing Status: JSON parsing failed, using cleaned text output
|
||||||
|
- Server: {current_config['ip']}:{current_config['port_vllm']}
|
||||||
|
- Session ID: {parse_result['session_id']}
|
||||||
|
- Box Coordinates: {bbox if bbox else 'None'}
|
||||||
|
"""
|
||||||
|
|
||||||
|
return (
|
||||||
|
md_content or "No markdown content generated",
|
||||||
|
info_text,
|
||||||
|
md_content or "No markdown content generated",
|
||||||
|
md_content or "No markdown content generated",
|
||||||
|
gr.update(visible=False),
|
||||||
|
""
|
||||||
|
)
|
||||||
|
|
||||||
|
# Handle the case where JSON parsing succeeds
|
||||||
|
num_elements = len(cells_data) if cells_data else 0
|
||||||
|
info_text = f"""
|
||||||
|
**Image Information:**
|
||||||
|
- Original Dimensions: {original_image.width} x {original_image.height}
|
||||||
|
- Processing Mode: {'Region OCR' if bbox else 'Full Image OCR'}
|
||||||
|
- Server: {current_config['ip']}:{current_config['port_vllm']}
|
||||||
|
- Detected {num_elements} layout elements
|
||||||
|
- Session ID: {parse_result['session_id']}
|
||||||
|
- Box Coordinates: {bbox if bbox else 'None'}
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Current page JSON output
|
||||||
|
current_json = ""
|
||||||
|
if cells_data:
|
||||||
|
try:
|
||||||
|
current_json = json.dumps(cells_data, ensure_ascii=False, indent=2)
|
||||||
|
except:
|
||||||
|
current_json = str(cells_data)
|
||||||
|
|
||||||
|
# Create a downloadable ZIP file
|
||||||
|
download_zip_path = None
|
||||||
|
if parse_result['temp_dir']:
|
||||||
|
download_zip_path = os.path.join(parse_result['temp_dir'], f"layout_results_{parse_result['session_id']}.zip")
|
||||||
|
try:
|
||||||
|
with zipfile.ZipFile(download_zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
|
||||||
|
for root, dirs, files in os.walk(parse_result['temp_dir']):
|
||||||
|
for file in files:
|
||||||
|
if file.endswith('.zip'):
|
||||||
|
continue
|
||||||
|
file_path = os.path.join(root, file)
|
||||||
|
arcname = os.path.relpath(file_path, parse_result['temp_dir'])
|
||||||
|
zipf.write(file_path, arcname)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Failed to create download ZIP: {e}")
|
||||||
|
download_zip_path = None
|
||||||
|
|
||||||
|
return (
|
||||||
|
md_content or "No markdown content generated",
|
||||||
|
info_text,
|
||||||
|
md_content or "No markdown content generated",
|
||||||
|
md_content or "No markdown content generated",
|
||||||
|
gr.update(value=download_zip_path, visible=True) if download_zip_path else gr.update(visible=False),
|
||||||
|
current_json
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return f"An error occurred during processing: {e}", f"An error occurred during processing: {e}", "", "", gr.update(value=None), ""
|
||||||
|
|
||||||
|
def load_image_to_annotator(test_image_input):
|
||||||
|
"""Loads an image into the annotation component."""
|
||||||
|
image = None
|
||||||
|
|
||||||
|
# Check the test image input
|
||||||
|
if test_image_input and test_image_input != "":
|
||||||
|
try:
|
||||||
|
image = read_image_v2(test_image_input)
|
||||||
|
except Exception as e:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if image is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Return the format required by the annotation component
|
||||||
|
return {
|
||||||
|
"image": image,
|
||||||
|
"boxes": []
|
||||||
|
}
|
||||||
|
|
||||||
|
def clear_all_data():
|
||||||
|
"""Clears all data."""
|
||||||
|
global processing_results
|
||||||
|
|
||||||
|
# Clean up the temporary directory
|
||||||
|
if processing_results.get('temp_dir') and os.path.exists(processing_results['temp_dir']):
|
||||||
|
import shutil
|
||||||
|
try:
|
||||||
|
shutil.rmtree(processing_results['temp_dir'], ignore_errors=True)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Failed to clean up temporary directory: {e}")
|
||||||
|
|
||||||
|
# Reset processing results
|
||||||
|
processing_results = {
|
||||||
|
'original_image': None,
|
||||||
|
'processed_image': None,
|
||||||
|
'layout_result': None,
|
||||||
|
'markdown_content': None,
|
||||||
|
'cells_data': None,
|
||||||
|
'temp_dir': None,
|
||||||
|
'session_id': None,
|
||||||
|
'result_paths': None,
|
||||||
|
'annotation_data': None
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
"", # Clear test image selection
|
||||||
|
None, # Clear annotation component
|
||||||
|
"Waiting for processing results...", # Reset info display
|
||||||
|
"## Waiting for processing results...", # Reset Markdown display
|
||||||
|
"🕐 Waiting for parsing results...", # Clear raw Markdown text
|
||||||
|
gr.update(visible=False), # Hide download button
|
||||||
|
"🕐 Waiting for parsing results..." # Clear JSON
|
||||||
|
)
|
||||||
|
|
||||||
|
def update_prompt_display(prompt_mode):
|
||||||
|
"""Updates the displayed prompt content."""
|
||||||
|
return dict_promptmode_to_prompt[prompt_mode]
|
||||||
|
|
||||||
|
# ==================== Gradio Interface ====================
|
||||||
|
def create_gradio_interface():
|
||||||
|
"""Creates the Gradio interface."""
|
||||||
|
|
||||||
|
# CSS styling to match the reference style
|
||||||
|
css = """
|
||||||
|
footer {
|
||||||
|
visibility: hidden;
|
||||||
|
}
|
||||||
|
|
||||||
|
#info_box {
|
||||||
|
padding: 10px;
|
||||||
|
background-color: #f8f9fa;
|
||||||
|
border-radius: 8px;
|
||||||
|
border: 1px solid #dee2e6;
|
||||||
|
margin: 10px 0;
|
||||||
|
font-size: 14px;
|
||||||
|
}
|
||||||
|
|
||||||
|
#markdown_tabs {
|
||||||
|
height: 100%;
|
||||||
|
}
|
||||||
|
|
||||||
|
#annotation_component {
|
||||||
|
border-radius: 8px;
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
with gr.Blocks(theme="ocean", css=css, title='dots.ocr - Annotation') as demo:
|
||||||
|
|
||||||
|
# Title
|
||||||
|
gr.HTML("""
|
||||||
|
<div style="display: flex; align-items: center; justify-content: center; margin-bottom: 20px;">
|
||||||
|
<h1 style="margin: 0; font-size: 2em;">🔍 dots.ocr - Annotation Version</h1>
|
||||||
|
</div>
|
||||||
|
<div style="text-align: center; margin-bottom: 10px;">
|
||||||
|
<em>Supports image annotation, drawing boxes, and sending box information to the model for OCR.</em>
|
||||||
|
</div>
|
||||||
|
""")
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
# Left side: Input and Configuration
|
||||||
|
with gr.Column(scale=1, variant="compact"):
|
||||||
|
gr.Markdown("### 📁 Select Example")
|
||||||
|
test_images = get_test_images()
|
||||||
|
test_image_input = gr.Dropdown(
|
||||||
|
label="Select Example",
|
||||||
|
choices=[""] + test_images,
|
||||||
|
value="",
|
||||||
|
show_label=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Button to load image into the annotation component
|
||||||
|
load_btn = gr.Button("📷 Load Image to Annotation Area", variant="secondary")
|
||||||
|
|
||||||
|
prompt_mode = gr.Dropdown(
|
||||||
|
label="Select Prompt",
|
||||||
|
# choices=["prompt_layout_all_en", "prompt_layout_only_en", "prompt_ocr", "prompt_grounding_ocr"],
|
||||||
|
choices=["prompt_grounding_ocr"],
|
||||||
|
value="prompt_grounding_ocr",
|
||||||
|
show_label=True,
|
||||||
|
info="If a box is drawn, 'prompt_grounding_ocr' mode will be used automatically."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Display the current prompt content
|
||||||
|
prompt_display = gr.Textbox(
|
||||||
|
label="Current Prompt Content",
|
||||||
|
# value=dict_promptmode_to_prompt[list(dict_promptmode_to_prompt.keys())[0]],
|
||||||
|
value=dict_promptmode_to_prompt["prompt_grounding_ocr"],
|
||||||
|
lines=4,
|
||||||
|
max_lines=8,
|
||||||
|
interactive=False,
|
||||||
|
show_copy_button=True
|
||||||
|
)
|
||||||
|
|
||||||
|
gr.Markdown("### ⚙️ Actions")
|
||||||
|
process_btn = gr.Button("🔍 Parse", variant="primary")
|
||||||
|
clear_btn = gr.Button("🗑️ Clear", variant="secondary")
|
||||||
|
|
||||||
|
gr.Markdown("### 🛠️ Configuration")
|
||||||
|
|
||||||
|
fitz_preprocess = gr.Checkbox(
|
||||||
|
label="Enable fitz_preprocess",
|
||||||
|
value=False,
|
||||||
|
info="Performs fitz preprocessing on the image input, converting the image to a PDF and then to a 200dpi image."
|
||||||
|
)
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
server_ip = gr.Textbox(
|
||||||
|
label="Server IP",
|
||||||
|
value=DEFAULT_CONFIG['ip']
|
||||||
|
)
|
||||||
|
server_port = gr.Number(
|
||||||
|
label="Port",
|
||||||
|
value=DEFAULT_CONFIG['port_vllm'],
|
||||||
|
precision=0
|
||||||
|
)
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
min_pixels = gr.Number(
|
||||||
|
label="Min Pixels",
|
||||||
|
value=DEFAULT_CONFIG['min_pixels'],
|
||||||
|
precision=0
|
||||||
|
)
|
||||||
|
max_pixels = gr.Number(
|
||||||
|
label="Max Pixels",
|
||||||
|
value=DEFAULT_CONFIG['max_pixels'],
|
||||||
|
precision=0
|
||||||
|
)
|
||||||
|
|
||||||
|
# Right side: Result Display
|
||||||
|
with gr.Column(scale=6, variant="compact"):
|
||||||
|
with gr.Row():
|
||||||
|
# Image Annotation Area
|
||||||
|
with gr.Column(scale=3):
|
||||||
|
gr.Markdown("### 🎯 Image Annotation Area")
|
||||||
|
gr.Markdown("""
|
||||||
|
**Instructions:**
|
||||||
|
- Method 1: Select an example image on the left and click "Load Image to Annotation Area".
|
||||||
|
- Method 2: Upload an image directly in the annotation area below (drag and drop or click to upload).
|
||||||
|
- Use the mouse to draw a box on the image to select the region for recognition.
|
||||||
|
- Only one box can be drawn. To draw a new one, please delete the old one first.
|
||||||
|
- **Hotkey: Press the Delete key to remove the selected box.**
|
||||||
|
- After drawing a box, clicking Parse will automatically use the Region OCR mode.
|
||||||
|
""")
|
||||||
|
|
||||||
|
annotator = image_annotator(
|
||||||
|
value=None,
|
||||||
|
label="Image Annotation",
|
||||||
|
height=600,
|
||||||
|
show_label=False,
|
||||||
|
elem_id="annotation_component",
|
||||||
|
single_box=True, # Only allow one box; a new box will replace the old one
|
||||||
|
box_min_size=10,
|
||||||
|
interactive=True,
|
||||||
|
disable_edit_boxes=True, # Disable the edit dialog
|
||||||
|
label_list=["OCR Region"], # Set the default label
|
||||||
|
label_colors=[(255, 0, 0)], # Set color to red
|
||||||
|
use_default_label=True, # Use the default label
|
||||||
|
image_type="pil" # Ensure it returns a PIL Image format
|
||||||
|
)
|
||||||
|
|
||||||
|
# Information Display
|
||||||
|
info_display = gr.Markdown(
|
||||||
|
"Waiting for processing results...",
|
||||||
|
elem_id="info_box"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Result Display Area
|
||||||
|
with gr.Column(scale=3):
|
||||||
|
gr.Markdown("### ✅ Results")
|
||||||
|
|
||||||
|
with gr.Tabs(elem_id="markdown_tabs"):
|
||||||
|
with gr.TabItem("Markdown Rendered View"):
|
||||||
|
md_output = gr.Markdown(
|
||||||
|
"## Please upload an image and click the Parse button for recognition...",
|
||||||
|
label="Markdown Preview",
|
||||||
|
max_height=1000,
|
||||||
|
latex_delimiters=[
|
||||||
|
{"left": "$$", "right": "$$", "display": True},
|
||||||
|
{"left": "$", "right": "$", "display": False},
|
||||||
|
],
|
||||||
|
show_copy_button=False,
|
||||||
|
elem_id="markdown_output"
|
||||||
|
)
|
||||||
|
|
||||||
|
with gr.TabItem("Markdown Raw Text"):
|
||||||
|
md_raw_output = gr.Textbox(
|
||||||
|
value="🕐 Waiting for parsing results...",
|
||||||
|
label="Markdown Raw Text",
|
||||||
|
max_lines=100,
|
||||||
|
lines=38,
|
||||||
|
show_copy_button=True,
|
||||||
|
elem_id="markdown_output",
|
||||||
|
show_label=False
|
||||||
|
)
|
||||||
|
|
||||||
|
with gr.TabItem("JSON Result"):
|
||||||
|
json_output = gr.Textbox(
|
||||||
|
value="🕐 Waiting for parsing results...",
|
||||||
|
label="JSON Result",
|
||||||
|
max_lines=100,
|
||||||
|
lines=38,
|
||||||
|
show_copy_button=True,
|
||||||
|
elem_id="markdown_output",
|
||||||
|
show_label=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# Download Button
|
||||||
|
with gr.Row():
|
||||||
|
download_btn = gr.DownloadButton(
|
||||||
|
"⬇️ Download Results",
|
||||||
|
visible=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# Event Binding
|
||||||
|
|
||||||
|
# When the prompt mode changes, update the displayed content
|
||||||
|
prompt_mode.change(
|
||||||
|
fn=update_prompt_display,
|
||||||
|
inputs=prompt_mode,
|
||||||
|
outputs=prompt_display,
|
||||||
|
show_progress=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load image into the annotation component
|
||||||
|
load_btn.click(
|
||||||
|
fn=load_image_to_annotator,
|
||||||
|
inputs=[test_image_input],
|
||||||
|
outputs=annotator,
|
||||||
|
show_progress=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# Process Inference
|
||||||
|
process_btn.click(
|
||||||
|
fn=process_image_inference_with_annotation,
|
||||||
|
inputs=[
|
||||||
|
annotator, test_image_input,
|
||||||
|
prompt_mode, server_ip, server_port, min_pixels, max_pixels,
|
||||||
|
fitz_preprocess
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
md_output, info_display, md_raw_output, md_raw_output,
|
||||||
|
download_btn, json_output
|
||||||
|
],
|
||||||
|
show_progress=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Clear Data
|
||||||
|
clear_btn.click(
|
||||||
|
fn=clear_all_data,
|
||||||
|
outputs=[
|
||||||
|
test_image_input, annotator,
|
||||||
|
info_display, md_output, md_raw_output,
|
||||||
|
download_btn, json_output
|
||||||
|
],
|
||||||
|
show_progress=False
|
||||||
|
)
|
||||||
|
|
||||||
|
return demo
|
||||||
|
|
||||||
|
# ==================== Main Program ====================
|
||||||
|
if __name__ == "__main__":
|
||||||
|
demo = create_gradio_interface()
|
||||||
|
demo.queue().launch(
|
||||||
|
server_name="0.0.0.0",
|
||||||
|
server_port=7861, # Use a different port to avoid conflicts
|
||||||
|
debug=True
|
||||||
|
)
|
||||||
@@ -0,0 +1,71 @@
|
|||||||
|
import os
|
||||||
|
if "LOCAL_RANK" not in os.environ:
|
||||||
|
os.environ["LOCAL_RANK"] = "0"
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from transformers import AutoModelForCausalLM, AutoProcessor, AutoTokenizer
|
||||||
|
from qwen_vl_utils import process_vision_info
|
||||||
|
from dots_ocr.utils import dict_promptmode_to_prompt
|
||||||
|
|
||||||
|
def inference(image_path, prompt, model, processor):
|
||||||
|
# image_path = "demo/demo_image1.jpg"
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "image",
|
||||||
|
"image": image_path
|
||||||
|
},
|
||||||
|
{"type": "text", "text": prompt}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# Preparation for inference
|
||||||
|
text = processor.apply_chat_template(
|
||||||
|
messages,
|
||||||
|
tokenize=False,
|
||||||
|
add_generation_prompt=True
|
||||||
|
)
|
||||||
|
image_inputs, video_inputs = process_vision_info(messages)
|
||||||
|
inputs = processor(
|
||||||
|
text=[text],
|
||||||
|
images=image_inputs,
|
||||||
|
videos=video_inputs,
|
||||||
|
padding=True,
|
||||||
|
return_tensors="pt",
|
||||||
|
)
|
||||||
|
|
||||||
|
inputs = inputs.to("cuda")
|
||||||
|
|
||||||
|
# Inference: Generation of the output
|
||||||
|
generated_ids = model.generate(**inputs, max_new_tokens=24000)
|
||||||
|
generated_ids_trimmed = [
|
||||||
|
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
||||||
|
]
|
||||||
|
output_text = processor.batch_decode(
|
||||||
|
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
||||||
|
)
|
||||||
|
print(output_text)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# We recommend enabling flash_attention_2 for better acceleration and memory saving, especially in multi-image and video scenarios.
|
||||||
|
model_path = "./weights/DotsOCR"
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
model_path,
|
||||||
|
attn_implementation="flash_attention_2",
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device_map="auto",
|
||||||
|
trust_remote_code=True
|
||||||
|
)
|
||||||
|
processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
|
||||||
|
|
||||||
|
image_path = "demo/demo_image1.jpg"
|
||||||
|
for prompt_mode, prompt in dict_promptmode_to_prompt.items():
|
||||||
|
print(f"prompt: {prompt}")
|
||||||
|
inference(image_path, prompt, model, processor)
|
||||||
|
|
||||||
|
After Width: | Height: | Size: 755 KiB |
@@ -0,0 +1,222 @@
|
|||||||
|
"""
|
||||||
|
Layout Inference Web Application
|
||||||
|
|
||||||
|
A Streamlit-based layout inference tool that supports image uploads and multiple backend inference engines.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import streamlit as st
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import io
|
||||||
|
import tempfile
|
||||||
|
from PIL import Image
|
||||||
|
import requests
|
||||||
|
|
||||||
|
# Local utility imports
|
||||||
|
|
||||||
|
# from utils import infer
|
||||||
|
|
||||||
|
from dots_ocr.utils import dict_promptmode_to_prompt
|
||||||
|
from dots_ocr.utils.format_transformer import layoutjson2md
|
||||||
|
from dots_ocr.utils.layout_utils import draw_layout_on_image, post_process_cells
|
||||||
|
from dots_ocr.utils.image_utils import get_input_dimensions, get_image_by_fitz_doc
|
||||||
|
from dots_ocr.model.inference import inference_with_vllm
|
||||||
|
from dots_ocr.utils.consts import MIN_PIXELS, MAX_PIXELS
|
||||||
|
|
||||||
|
import os
|
||||||
|
from PIL import Image
|
||||||
|
from dots_ocr.utils.demo_utils.display import read_image
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== Configuration ====================
|
||||||
|
DEFAULT_CONFIG = {
|
||||||
|
'ip': "127.0.0.1",
|
||||||
|
'port_vllm': 8000,
|
||||||
|
'min_pixels': MIN_PIXELS,
|
||||||
|
'max_pixels': MAX_PIXELS,
|
||||||
|
'test_images_dir': "./assets/showcase_origin",
|
||||||
|
}
|
||||||
|
|
||||||
|
# ==================== Utility Functions ====================
|
||||||
|
|
||||||
|
|
||||||
|
@st.cache_resource
|
||||||
|
def read_image_v2(img: str):
|
||||||
|
if img.startswith(("http://", "https://")):
|
||||||
|
with requests.get(img, stream=True) as response:
|
||||||
|
response.raise_for_status()
|
||||||
|
img = Image.open(io.BytesIO(response.content))
|
||||||
|
|
||||||
|
if isinstance(img, str):
|
||||||
|
# img = transform_image_path(img)
|
||||||
|
img, _, _ = read_image(img, use_native=True)
|
||||||
|
elif isinstance(img, Image.Image):
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid image type: {type(img)}")
|
||||||
|
return img
|
||||||
|
|
||||||
|
|
||||||
|
# ==================== UI Components ====================
|
||||||
|
def create_config_sidebar():
|
||||||
|
"""Create configuration sidebar"""
|
||||||
|
st.sidebar.header("Configuration Parameters")
|
||||||
|
|
||||||
|
config = {}
|
||||||
|
config['prompt_key'] = st.sidebar.selectbox("Prompt Mode", list(dict_promptmode_to_prompt.keys()))
|
||||||
|
config['ip'] = st.sidebar.text_input("Server IP", DEFAULT_CONFIG['ip'])
|
||||||
|
config['port'] = st.sidebar.number_input("Port", min_value=1000, max_value=9999, value=DEFAULT_CONFIG['port_vllm'])
|
||||||
|
# config['eos_word'] = st.sidebar.text_input("EOS Word", DEFAULT_CONFIG['eos_word'])
|
||||||
|
|
||||||
|
# Image configuration
|
||||||
|
st.sidebar.subheader("Image Configuration")
|
||||||
|
config['min_pixels'] = st.sidebar.number_input("Min Pixels", value=DEFAULT_CONFIG['min_pixels'])
|
||||||
|
config['max_pixels'] = st.sidebar.number_input("Max Pixels", value=DEFAULT_CONFIG['max_pixels'])
|
||||||
|
|
||||||
|
return config
|
||||||
|
|
||||||
|
def get_image_input():
|
||||||
|
"""Get image input"""
|
||||||
|
st.markdown("#### Image Input")
|
||||||
|
|
||||||
|
input_mode = st.pills(label="Select input method", options=["Upload Image", "Enter Image URL/Path", "Select Test Image"], key="input_mode", label_visibility="collapsed")
|
||||||
|
|
||||||
|
if input_mode == "Upload Image":
|
||||||
|
# File uploader
|
||||||
|
uploaded_file = st.file_uploader("Upload Image", type=["png", "jpg", "jpeg"])
|
||||||
|
if uploaded_file is not None:
|
||||||
|
with tempfile.NamedTemporaryFile(delete=False, suffix='.png') as tmp_file:
|
||||||
|
tmp_file.write(uploaded_file.getvalue())
|
||||||
|
return tmp_file.name
|
||||||
|
elif input_mode == 'Enter Image URL/Path':
|
||||||
|
# URL input
|
||||||
|
img_url_input = st.text_input("Enter Image URL/Path")
|
||||||
|
return img_url_input
|
||||||
|
|
||||||
|
elif input_mode == 'Select Test Image':
|
||||||
|
# Test image selection
|
||||||
|
test_images = []
|
||||||
|
test_dir = DEFAULT_CONFIG['test_images_dir']
|
||||||
|
if os.path.exists(test_dir):
|
||||||
|
test_images = [os.path.join(test_dir, name) for name in os.listdir(test_dir)]
|
||||||
|
img_url_test = st.selectbox("Select Test Image", [""] + test_images)
|
||||||
|
return img_url_test
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid input mode: {input_mode}")
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def process_and_display_results(output: str, image: Image.Image, config: dict):
|
||||||
|
"""Process and display inference results"""
|
||||||
|
prompt, response = output['prompt'], output['response']
|
||||||
|
|
||||||
|
try:
|
||||||
|
col1, col2 = st.columns(2)
|
||||||
|
# st.markdown('---')
|
||||||
|
cells = json.loads(response)
|
||||||
|
# image = Image.open(img_url)
|
||||||
|
|
||||||
|
# Post-processing
|
||||||
|
cells = post_process_cells(
|
||||||
|
image, cells,
|
||||||
|
image.width, image.height,
|
||||||
|
min_pixels=config['min_pixels'],
|
||||||
|
max_pixels=config['max_pixels']
|
||||||
|
)
|
||||||
|
|
||||||
|
# Calculate input dimensions
|
||||||
|
input_width, input_height = get_input_dimensions(
|
||||||
|
image,
|
||||||
|
min_pixels=config['min_pixels'],
|
||||||
|
max_pixels=config['max_pixels']
|
||||||
|
)
|
||||||
|
st.markdown('---')
|
||||||
|
st.write(f'Input Dimensions: {input_width} x {input_height}')
|
||||||
|
# st.write(f'Prompt: {prompt}')
|
||||||
|
# st.markdown(f'模型原始输出: <span style="color:blue">{result}</span>', unsafe_allow_html=True)
|
||||||
|
# st.write('模型原始输出:')
|
||||||
|
# st.write(response)
|
||||||
|
# st.write('后处理结果:', str(cells))
|
||||||
|
st.text_area('Original Model Output', response, height=200)
|
||||||
|
st.text_area('Post-processed Result', str(cells), height=200)
|
||||||
|
# 显示结果
|
||||||
|
# st.title("Layout推理结果")
|
||||||
|
|
||||||
|
with col1:
|
||||||
|
# st.markdown("##### 可视化结果")
|
||||||
|
new_image = draw_layout_on_image(
|
||||||
|
image, cells,
|
||||||
|
resized_height=None, resized_width=None,
|
||||||
|
# text_key='text',
|
||||||
|
fill_bbox=True, draw_bbox=True
|
||||||
|
)
|
||||||
|
st.markdown('##### Visualization Result')
|
||||||
|
st.image(new_image, width=new_image.width)
|
||||||
|
# st.write(f"尺寸: {new_image.width} x {new_image.height}")
|
||||||
|
|
||||||
|
with col2:
|
||||||
|
# st.markdown("##### Markdown格式")
|
||||||
|
md_code = layoutjson2md(image, cells, text_key='text')
|
||||||
|
# md_code = fix_streamlit_formula(md_code)
|
||||||
|
st.markdown('##### Markdown Format')
|
||||||
|
st.markdown(md_code, unsafe_allow_html=True)
|
||||||
|
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
st.error("Model output is not a valid JSON format")
|
||||||
|
except Exception as e:
|
||||||
|
st.error(f"Error processing results: {e}")
|
||||||
|
|
||||||
|
# ==================== Main Application ====================
|
||||||
|
def main():
|
||||||
|
"""Main application function"""
|
||||||
|
st.set_page_config(page_title="Layout Inference Tool", layout="wide")
|
||||||
|
st.title("🔍 Layout Inference Tool")
|
||||||
|
|
||||||
|
# Configuration
|
||||||
|
config = create_config_sidebar()
|
||||||
|
prompt = dict_promptmode_to_prompt[config['prompt_key']]
|
||||||
|
st.sidebar.info(f"Current Prompt: {prompt}")
|
||||||
|
|
||||||
|
# Image input
|
||||||
|
img_url = get_image_input()
|
||||||
|
start_button = st.button('🚀 Start Inference', type="primary")
|
||||||
|
|
||||||
|
if img_url is not None and img_url.strip() != "":
|
||||||
|
try:
|
||||||
|
# processed_image = read_image_v2(img_url)
|
||||||
|
origin_image = read_image_v2(img_url)
|
||||||
|
st.write(f"Original Dimensions: {origin_image.width} x {origin_image.height}")
|
||||||
|
# processed_image = get_image_by_fitz_doc(origin_image, target_dpi=200)
|
||||||
|
processed_image = origin_image
|
||||||
|
except Exception as e:
|
||||||
|
st.error(f"Failed to read image: {e}")
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
st.info("Please enter an image URL/path or upload an image")
|
||||||
|
return
|
||||||
|
|
||||||
|
output = None
|
||||||
|
# Inference button
|
||||||
|
if start_button:
|
||||||
|
with st.spinner(f"Inferring... Server: {config['ip']}:{config['port']}"):
|
||||||
|
|
||||||
|
response = inference_with_vllm(
|
||||||
|
processed_image, prompt, config['ip'], config['port'],
|
||||||
|
# config['min_pixels'], config['max_pixels']
|
||||||
|
)
|
||||||
|
output = {
|
||||||
|
'prompt': prompt,
|
||||||
|
'response': response,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
st.image(processed_image, width=500)
|
||||||
|
|
||||||
|
# Process results
|
||||||
|
if output:
|
||||||
|
process_and_display_results(output, processed_image, config)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -0,0 +1,41 @@
|
|||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
|
||||||
|
from openai import OpenAI
|
||||||
|
from transformers.utils.versions import require_version
|
||||||
|
from PIL import Image
|
||||||
|
import io
|
||||||
|
import base64
|
||||||
|
from dots_ocr.utils import dict_promptmode_to_prompt
|
||||||
|
from dots_ocr.model.inference import inference_with_vllm
|
||||||
|
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--ip", type=str, default="localhost")
|
||||||
|
parser.add_argument("--port", type=str, default="8000")
|
||||||
|
parser.add_argument("--model_name", type=str, default="model")
|
||||||
|
parser.add_argument("--prompt_mode", type=str, default="prompt_layout_all_en")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
require_version("openai>=1.5.0", "To fix: pip install openai>=1.5.0")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
addr = f"http://{args.ip}:{args.port}/v1"
|
||||||
|
image_path = "demo/demo_image1.jpg"
|
||||||
|
prompt = dict_promptmode_to_prompt[args.prompt_mode]
|
||||||
|
image = Image.open(image_path)
|
||||||
|
response = inference_with_vllm(
|
||||||
|
image,
|
||||||
|
prompt,
|
||||||
|
ip="localhost",
|
||||||
|
port=8000,
|
||||||
|
temperature=0.1,
|
||||||
|
top_p=0.9,
|
||||||
|
)
|
||||||
|
print(f"response: {response}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -0,0 +1,17 @@
|
|||||||
|
# download model to /path/to/model
|
||||||
|
if [ -z "$NODOWNLOAD" ]; then
|
||||||
|
python3 tools/download_model.py
|
||||||
|
fi
|
||||||
|
|
||||||
|
# register model to vllm
|
||||||
|
hf_model_path=./weights/DotsOCR # Path to your downloaded model weights
|
||||||
|
export PYTHONPATH=$(dirname "$hf_model_path"):$PYTHONPATH
|
||||||
|
sed -i '/^from vllm\.entrypoints\.cli\.main import main$/a\
|
||||||
|
from DotsOCR import modeling_dots_ocr_vllm' `which vllm`
|
||||||
|
|
||||||
|
# launch vllm server
|
||||||
|
model_name=model
|
||||||
|
CUDA_VISIBLE_DEVICES=0 vllm serve ${hf_model_path} --tensor-parallel-size 1 --gpu-memory-utilization 0.95 --chat-template-content-format string --served-model-name ${model_name} --trust-remote-code
|
||||||
|
|
||||||
|
# # run python demo after launch vllm server
|
||||||
|
# python demo/demo_vllm.py
|
||||||
@@ -0,0 +1,4 @@
|
|||||||
|
from vllm/vllm-openai:v0.9.1
|
||||||
|
|
||||||
|
RUN pip3 install flash_attn==2.8.0.post2
|
||||||
|
RUN pip3 install transformers==4.51.3
|
||||||
@@ -0,0 +1 @@
|
|||||||
|
from .parser import DotsOCRParser
|
||||||
@@ -0,0 +1,50 @@
|
|||||||
|
import json
|
||||||
|
import io
|
||||||
|
import base64
|
||||||
|
import math
|
||||||
|
from PIL import Image
|
||||||
|
import requests
|
||||||
|
from dots_ocr.utils.image_utils import PILimage_to_base64
|
||||||
|
from openai import OpenAI
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
def inference_with_vllm(
|
||||||
|
image,
|
||||||
|
prompt,
|
||||||
|
ip="localhost",
|
||||||
|
port=8000,
|
||||||
|
temperature=0.1,
|
||||||
|
top_p=0.9,
|
||||||
|
max_completion_tokens=32768,
|
||||||
|
model_name='model',
|
||||||
|
):
|
||||||
|
|
||||||
|
addr = f"http://{ip}:{port}/v1"
|
||||||
|
client = OpenAI(api_key="{}".format(os.environ.get("API_KEY", "0")), base_url=addr)
|
||||||
|
messages = []
|
||||||
|
messages.append(
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {"url": PILimage_to_base64(image)},
|
||||||
|
},
|
||||||
|
{"type": "text", "text": f"<|img|><|imgpad|><|endofimg|>{prompt}"} # if no "<|img|><|imgpad|><|endofimg|>" here,vllm v1 will add "\n" here
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
response = client.chat.completions.create(
|
||||||
|
messages=messages,
|
||||||
|
model=model_name,
|
||||||
|
max_completion_tokens=max_completion_tokens,
|
||||||
|
temperature=temperature,
|
||||||
|
top_p=top_p)
|
||||||
|
response = response.choices[0].message.content
|
||||||
|
return response
|
||||||
|
except requests.exceptions.RequestException as e:
|
||||||
|
print(f"request error: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
@@ -0,0 +1,349 @@
|
|||||||
|
import os
|
||||||
|
import json
|
||||||
|
from tqdm import tqdm
|
||||||
|
from multiprocessing.pool import ThreadPool, Pool
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
|
||||||
|
from dots_ocr.model.inference import inference_with_vllm
|
||||||
|
from dots_ocr.utils.consts import image_extensions, MIN_PIXELS, MAX_PIXELS
|
||||||
|
from dots_ocr.utils.image_utils import get_image_by_fitz_doc, fetch_image, smart_resize
|
||||||
|
from dots_ocr.utils.doc_utils import fitz_doc_to_image, load_images_from_pdf
|
||||||
|
from dots_ocr.utils.prompts import dict_promptmode_to_prompt
|
||||||
|
from dots_ocr.utils.layout_utils import post_process_output, draw_layout_on_image, pre_process_bboxes
|
||||||
|
from dots_ocr.utils.format_transformer import layoutjson2md
|
||||||
|
|
||||||
|
|
||||||
|
class DotsOCRParser:
|
||||||
|
"""
|
||||||
|
parse image or pdf file
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
ip='localhost',
|
||||||
|
port=8000,
|
||||||
|
model_name='model',
|
||||||
|
temperature=0.1,
|
||||||
|
top_p=1.0,
|
||||||
|
max_completion_tokens=16384,
|
||||||
|
num_thread=64,
|
||||||
|
dpi = 200,
|
||||||
|
output_dir="./output",
|
||||||
|
min_pixels=None,
|
||||||
|
max_pixels=None,
|
||||||
|
):
|
||||||
|
self.dpi = dpi
|
||||||
|
|
||||||
|
# default args for vllm server
|
||||||
|
self.ip = ip
|
||||||
|
self.port = port
|
||||||
|
self.model_name = model_name
|
||||||
|
# default args for inference
|
||||||
|
self.temperature = temperature
|
||||||
|
self.top_p = top_p
|
||||||
|
self.max_completion_tokens = max_completion_tokens
|
||||||
|
self.num_thread = num_thread
|
||||||
|
self.output_dir = output_dir
|
||||||
|
self.min_pixels = min_pixels
|
||||||
|
self.max_pixels = max_pixels
|
||||||
|
assert self.min_pixels is None or self.min_pixels >= MIN_PIXELS
|
||||||
|
assert self.max_pixels is None or self.max_pixels <= MAX_PIXELS
|
||||||
|
|
||||||
|
|
||||||
|
def _inference_with_vllm(self, image, prompt):
|
||||||
|
response = inference_with_vllm(
|
||||||
|
image,
|
||||||
|
prompt,
|
||||||
|
model_name=self.model_name,
|
||||||
|
ip=self.ip,
|
||||||
|
port=self.port,
|
||||||
|
temperature=self.temperature,
|
||||||
|
top_p=self.top_p,
|
||||||
|
max_completion_tokens=self.max_completion_tokens,
|
||||||
|
)
|
||||||
|
return response
|
||||||
|
|
||||||
|
def get_prompt(self, prompt_mode, bbox=None, origin_image=None, image=None, min_pixels=None, max_pixels=None):
|
||||||
|
prompt = dict_promptmode_to_prompt[prompt_mode]
|
||||||
|
if prompt_mode == 'prompt_grounding_ocr':
|
||||||
|
assert bbox is not None
|
||||||
|
bboxes = [bbox]
|
||||||
|
bbox = pre_process_bboxes(origin_image, bboxes, input_width=image.width, input_height=image.height, min_pixels=min_pixels, max_pixels=max_pixels)[0]
|
||||||
|
prompt = prompt + str(bbox)
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
# def post_process_results(self, response, prompt_mode, save_dir, save_name, origin_image, image, min_pixels, max_pixels)
|
||||||
|
def _parse_single_image(
|
||||||
|
self,
|
||||||
|
origin_image,
|
||||||
|
prompt_mode,
|
||||||
|
save_dir,
|
||||||
|
save_name,
|
||||||
|
source="image",
|
||||||
|
page_idx=0,
|
||||||
|
bbox=None,
|
||||||
|
fitz_preprocess=False,
|
||||||
|
):
|
||||||
|
min_pixels, max_pixels = self.min_pixels, self.max_pixels
|
||||||
|
if prompt_mode == "prompt_grounding_ocr":
|
||||||
|
min_pixels = min_pixels or MIN_PIXELS # preprocess image to the final input
|
||||||
|
max_pixels = max_pixels or MAX_PIXELS
|
||||||
|
if min_pixels is not None: assert min_pixels >= MIN_PIXELS, f"min_pixels should >= {MIN_PIXELS}"
|
||||||
|
if max_pixels is not None: assert max_pixels <= MAX_PIXELS, f"max_pixels should <+ {MAX_PIXELS}"
|
||||||
|
|
||||||
|
if source == 'image' and fitz_preprocess:
|
||||||
|
image = get_image_by_fitz_doc(origin_image, target_dpi=self.dpi)
|
||||||
|
image = fetch_image(image, min_pixels=min_pixels, max_pixels=max_pixels)
|
||||||
|
else:
|
||||||
|
image = fetch_image(origin_image, min_pixels=min_pixels, max_pixels=max_pixels)
|
||||||
|
input_height, input_width = smart_resize(image.height, image.width)
|
||||||
|
prompt = self.get_prompt(prompt_mode, bbox, origin_image, image, min_pixels=min_pixels, max_pixels=max_pixels)
|
||||||
|
response = self._inference_with_vllm(image, prompt)
|
||||||
|
result = {'page_no': page_idx,
|
||||||
|
"input_height": input_height,
|
||||||
|
"input_width": input_width
|
||||||
|
}
|
||||||
|
if source == 'pdf':
|
||||||
|
save_name = f"{save_name}_page_{page_idx}"
|
||||||
|
if prompt_mode in ['prompt_layout_all_en', 'prompt_layout_only_en', 'prompt_grounding_ocr']:
|
||||||
|
cells, filtered = post_process_output(
|
||||||
|
response,
|
||||||
|
prompt_mode,
|
||||||
|
origin_image,
|
||||||
|
image,
|
||||||
|
min_pixels=min_pixels,
|
||||||
|
max_pixels=max_pixels,
|
||||||
|
)
|
||||||
|
if filtered and prompt_mode != 'prompt_layout_only_en': # model output json failed, use filtered process
|
||||||
|
json_file_path = os.path.join(save_dir, f"{save_name}.json")
|
||||||
|
with open(json_file_path, 'w') as w:
|
||||||
|
json.dump(response, w, ensure_ascii=False)
|
||||||
|
|
||||||
|
image_layout_path = os.path.join(save_dir, f"{save_name}.jpg")
|
||||||
|
origin_image.save(image_layout_path)
|
||||||
|
result.update({
|
||||||
|
'layout_info_path': json_file_path,
|
||||||
|
'layout_image_path': image_layout_path,
|
||||||
|
})
|
||||||
|
|
||||||
|
md_file_path = os.path.join(save_dir, f"{save_name}.md")
|
||||||
|
with open(md_file_path, "w", encoding="utf-8") as md_file:
|
||||||
|
md_file.write(cells)
|
||||||
|
result.update({
|
||||||
|
'md_content_path': md_file_path
|
||||||
|
})
|
||||||
|
result.update({
|
||||||
|
'filtered': True
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
image_with_layout = draw_layout_on_image(origin_image, cells)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error drawing layout on image: {e}")
|
||||||
|
image_with_layout = origin_image
|
||||||
|
|
||||||
|
json_file_path = os.path.join(save_dir, f"{save_name}.json")
|
||||||
|
with open(json_file_path, 'w') as w:
|
||||||
|
json.dump(cells, w, ensure_ascii=False)
|
||||||
|
|
||||||
|
image_layout_path = os.path.join(save_dir, f"{save_name}.jpg")
|
||||||
|
image_with_layout.save(image_layout_path)
|
||||||
|
result.update({
|
||||||
|
'layout_info_path': json_file_path,
|
||||||
|
'layout_image_path': image_layout_path,
|
||||||
|
})
|
||||||
|
if prompt_mode != "prompt_layout_only_en": # no text md when detection only
|
||||||
|
md_content = layoutjson2md(origin_image, cells, text_key='text')
|
||||||
|
md_content_no_hf = layoutjson2md(origin_image, cells, text_key='text', no_page_hf=True) # used for clean output or metric of omnidocbench、olmbench
|
||||||
|
md_file_path = os.path.join(save_dir, f"{save_name}.md")
|
||||||
|
with open(md_file_path, "w", encoding="utf-8") as md_file:
|
||||||
|
md_file.write(md_content)
|
||||||
|
md_nohf_file_path = os.path.join(save_dir, f"{save_name}_nohf.md")
|
||||||
|
with open(md_nohf_file_path, "w", encoding="utf-8") as md_file:
|
||||||
|
md_file.write(md_content_no_hf)
|
||||||
|
result.update({
|
||||||
|
'md_content_path': md_file_path,
|
||||||
|
'md_content_nohf_path': md_nohf_file_path,
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
image_layout_path = os.path.join(save_dir, f"{save_name}.jpg")
|
||||||
|
origin_image.save(image_layout_path)
|
||||||
|
result.update({
|
||||||
|
'layout_image_path': image_layout_path,
|
||||||
|
})
|
||||||
|
|
||||||
|
md_content = response
|
||||||
|
md_file_path = os.path.join(save_dir, f"{save_name}.md")
|
||||||
|
with open(md_file_path, "w", encoding="utf-8") as md_file:
|
||||||
|
md_file.write(md_content)
|
||||||
|
result.update({
|
||||||
|
'md_content_path': md_file_path,
|
||||||
|
})
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def parse_image(self, input_path, filename, prompt_mode, save_dir, bbox=None, fitz_preprocess=False):
|
||||||
|
origin_image = fetch_image(input_path)
|
||||||
|
result = self._parse_single_image(origin_image, prompt_mode, save_dir, filename, source="image", bbox=bbox, fitz_preprocess=fitz_preprocess)
|
||||||
|
result['file_path'] = input_path
|
||||||
|
return [result]
|
||||||
|
|
||||||
|
def parse_pdf(self, input_path, filename, prompt_mode, save_dir):
|
||||||
|
print(f"loading pdf: {input_path}")
|
||||||
|
images_origin = load_images_from_pdf(input_path)
|
||||||
|
total_pages = len(images_origin)
|
||||||
|
tasks = [
|
||||||
|
{
|
||||||
|
"origin_image": image,
|
||||||
|
"prompt_mode": prompt_mode,
|
||||||
|
"save_dir": save_dir,
|
||||||
|
"save_name": filename,
|
||||||
|
"source":"pdf",
|
||||||
|
"page_idx": i,
|
||||||
|
} for i, image in enumerate(images_origin)
|
||||||
|
]
|
||||||
|
|
||||||
|
def _execute_task(task_args):
|
||||||
|
return self._parse_single_image(**task_args)
|
||||||
|
|
||||||
|
num_thread = min(total_pages, self.num_thread)
|
||||||
|
print(f"Parsing PDF with {total_pages} pages using {num_thread} threads...")
|
||||||
|
|
||||||
|
results = []
|
||||||
|
with ThreadPool(num_thread) as pool:
|
||||||
|
with tqdm(total=total_pages, desc="Processing PDF pages") as pbar:
|
||||||
|
for result in pool.imap_unordered(_execute_task, tasks):
|
||||||
|
results.append(result)
|
||||||
|
pbar.update(1)
|
||||||
|
|
||||||
|
results.sort(key=lambda x: x["page_no"])
|
||||||
|
for i in range(len(results)):
|
||||||
|
results[i]['file_path'] = input_path
|
||||||
|
return results
|
||||||
|
|
||||||
|
def parse_file(self,
|
||||||
|
input_path,
|
||||||
|
output_dir="",
|
||||||
|
prompt_mode="prompt_layout_all_en",
|
||||||
|
bbox=None,
|
||||||
|
fitz_preprocess=False
|
||||||
|
):
|
||||||
|
output_dir = output_dir or self.output_dir
|
||||||
|
output_dir = os.path.abspath(output_dir)
|
||||||
|
filename, file_ext = os.path.splitext(os.path.basename(input_path))
|
||||||
|
save_dir = os.path.join(output_dir, filename)
|
||||||
|
os.makedirs(save_dir, exist_ok=True)
|
||||||
|
|
||||||
|
if file_ext == '.pdf':
|
||||||
|
results = self.parse_pdf(input_path, filename, prompt_mode, save_dir)
|
||||||
|
elif file_ext in image_extensions:
|
||||||
|
results = self.parse_image(input_path, filename, prompt_mode, save_dir, bbox=bbox, fitz_preprocess=fitz_preprocess)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"file extension {file_ext} not supported, supported extensions are {image_extensions} and pdf")
|
||||||
|
|
||||||
|
print(f"Parsing finished, results saving to {save_dir}")
|
||||||
|
with open(os.path.join(output_dir, os.path.basename(filename)+'.jsonl'), 'w') as w:
|
||||||
|
for result in results:
|
||||||
|
w.write(json.dumps(result, ensure_ascii=False) + '\n')
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
prompts = list(dict_promptmode_to_prompt.keys())
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="dots.ocr Multilingual Document Layout Parser",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"input_path", type=str,
|
||||||
|
help="Input PDF/image file path"
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--output", type=str, default="./output",
|
||||||
|
help="Output directory (default: ./output)"
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--prompt", choices=prompts, type=str, default="prompt_layout_all_en",
|
||||||
|
help="prompt to query the model, different prompts for different tasks"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--bbox',
|
||||||
|
type=int,
|
||||||
|
nargs=4,
|
||||||
|
metavar=('x1', 'y1', 'x2', 'y2'),
|
||||||
|
help='should give this argument if you want to prompt_grounding_ocr'
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--ip", type=str, default="localhost",
|
||||||
|
help=""
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--port", type=int, default=8000,
|
||||||
|
help=""
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--model_name", type=str, default="model",
|
||||||
|
help=""
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--temperature", type=float, default=0.1,
|
||||||
|
help=""
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--top_p", type=float, default=1.0,
|
||||||
|
help=""
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--dpi", type=int, default=200,
|
||||||
|
help=""
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max_completion_tokens", type=int, default=16384,
|
||||||
|
help=""
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--num_thread", type=int, default=16,
|
||||||
|
help=""
|
||||||
|
)
|
||||||
|
# parser.add_argument(
|
||||||
|
# "--fitz_preprocess", type=bool, default=False,
|
||||||
|
# help="False will use tikz dpi upsample pipeline, good for images which has been render with low dpi, but maybe result in higher computational costs"
|
||||||
|
# )
|
||||||
|
parser.add_argument(
|
||||||
|
"--min_pixels", type=int, default=None,
|
||||||
|
help=""
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max_pixels", type=int, default=None,
|
||||||
|
help=""
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
dots_ocr_parser = DotsOCRParser(
|
||||||
|
ip=args.ip,
|
||||||
|
port=args.port,
|
||||||
|
model_name=args.model_name,
|
||||||
|
temperature=args.temperature,
|
||||||
|
top_p=args.top_p,
|
||||||
|
max_completion_tokens=args.max_completion_tokens,
|
||||||
|
num_thread=args.num_thread,
|
||||||
|
dpi=args.dpi,
|
||||||
|
output_dir=args.output,
|
||||||
|
min_pixels=args.min_pixels,
|
||||||
|
max_pixels=args.max_pixels,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = dots_ocr_parser.parse_file(
|
||||||
|
args.input_path,
|
||||||
|
prompt_mode=args.prompt,
|
||||||
|
bbox=args.bbox,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -0,0 +1 @@
|
|||||||
|
from .prompts import dict_promptmode_to_prompt
|
||||||
@@ -0,0 +1,5 @@
|
|||||||
|
MIN_PIXELS=3136
|
||||||
|
MAX_PIXELS=11289600
|
||||||
|
IMAGE_FACTOR=28
|
||||||
|
|
||||||
|
image_extensions = {'.jpg', '.jpeg', '.png'}
|
||||||
@@ -0,0 +1,61 @@
|
|||||||
|
import os
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
|
def is_valid_image_path(image_path):
|
||||||
|
"""
|
||||||
|
Checks if the image path is valid.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_path: The path to the image.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if the path is valid, False otherwise.
|
||||||
|
"""
|
||||||
|
if not os.path.exists(image_path):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check if the file extension is one of the common image formats.
|
||||||
|
image_extensions = ['.jpg', '.jpeg', '.png', '.gif', '.bmp']
|
||||||
|
_, extension = os.path.splitext(image_path)
|
||||||
|
if extension.lower() in image_extensions:
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def read_image(image_path, use_native=False):
|
||||||
|
"""
|
||||||
|
Reads an image and resizes it while maintaining aspect ratio.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_path: The path to the image.
|
||||||
|
use_native: If True, the max dimension of the original image is used as the max size.
|
||||||
|
If False, max size is set to 1024.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: (resized_image, original_width, original_height)
|
||||||
|
"""
|
||||||
|
# Create a default 512x512 blue image as a fallback.
|
||||||
|
image = Image.new('RGB', (512, 512), color=(0, 0, 255))
|
||||||
|
|
||||||
|
if is_valid_image_path(image_path):
|
||||||
|
image = Image.open(image_path)
|
||||||
|
else:
|
||||||
|
raise FileNotFoundError(f"{image_path}: Image path does not exist")
|
||||||
|
|
||||||
|
w, h = image.size
|
||||||
|
if use_native:
|
||||||
|
max_size = max(w, h)
|
||||||
|
else:
|
||||||
|
max_size = 1024
|
||||||
|
|
||||||
|
if w > h:
|
||||||
|
new_w = max_size
|
||||||
|
new_h = int(h * max_size / w)
|
||||||
|
else:
|
||||||
|
new_h = max_size
|
||||||
|
new_w = int(w * max_size / h)
|
||||||
|
|
||||||
|
image = image.resize((new_w, new_h))
|
||||||
|
return image, w, h
|
||||||
@@ -0,0 +1,60 @@
|
|||||||
|
import fitz
|
||||||
|
import numpy as np
|
||||||
|
import enum
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
|
class SupportedPdfParseMethod(enum.Enum):
|
||||||
|
OCR = 'ocr'
|
||||||
|
TXT = 'txt'
|
||||||
|
|
||||||
|
|
||||||
|
class PageInfo(BaseModel):
|
||||||
|
"""The width and height of page
|
||||||
|
"""
|
||||||
|
w: float = Field(description='the width of page')
|
||||||
|
h: float = Field(description='the height of page')
|
||||||
|
|
||||||
|
|
||||||
|
def fitz_doc_to_image(doc, target_dpi=200, origin_dpi=None) -> dict:
|
||||||
|
"""Convert fitz.Document to image, Then convert the image to numpy array.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
doc (_type_): pymudoc page
|
||||||
|
dpi (int, optional): reset the dpi of dpi. Defaults to 200.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: {'img': numpy array, 'width': width, 'height': height }
|
||||||
|
"""
|
||||||
|
from PIL import Image
|
||||||
|
mat = fitz.Matrix(target_dpi / 72, target_dpi / 72)
|
||||||
|
pm = doc.get_pixmap(matrix=mat, alpha=False)
|
||||||
|
|
||||||
|
if pm.width > 4500 or pm.height > 4500:
|
||||||
|
mat = fitz.Matrix(72 / 72, 72 / 72) # use fitz default dpi
|
||||||
|
pm = doc.get_pixmap(matrix=mat, alpha=False)
|
||||||
|
|
||||||
|
image = Image.frombytes('RGB', (pm.width, pm.height), pm.samples)
|
||||||
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
def load_images_from_pdf(pdf_file, dpi=200, start_page_id=0, end_page_id=None) -> list:
|
||||||
|
images = []
|
||||||
|
with fitz.open(pdf_file) as doc:
|
||||||
|
pdf_page_num = doc.page_count
|
||||||
|
end_page_id = (
|
||||||
|
end_page_id
|
||||||
|
if end_page_id is not None and end_page_id >= 0
|
||||||
|
else pdf_page_num - 1
|
||||||
|
)
|
||||||
|
if end_page_id > pdf_page_num - 1:
|
||||||
|
print('end_page_id is out of range, use images length')
|
||||||
|
end_page_id = pdf_page_num - 1
|
||||||
|
|
||||||
|
for index in range(0, doc.page_count):
|
||||||
|
if start_page_id <= index <= end_page_id:
|
||||||
|
page = doc[index]
|
||||||
|
img = fitz_doc_to_image(page, target_dpi=dpi)
|
||||||
|
images.append(img)
|
||||||
|
return images
|
||||||
@@ -0,0 +1,205 @@
|
|||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import json
|
||||||
|
import re
|
||||||
|
|
||||||
|
from PIL import Image
|
||||||
|
from dots_ocr.utils.image_utils import PILimage_to_base64
|
||||||
|
|
||||||
|
|
||||||
|
def has_latex_markdown(text: str) -> bool:
|
||||||
|
"""
|
||||||
|
Checks if a string contains LaTeX markdown patterns.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text (str): The string to check.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if LaTeX markdown is found, otherwise False.
|
||||||
|
"""
|
||||||
|
if not isinstance(text, str):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Define regular expression patterns for LaTeX markdown
|
||||||
|
latex_patterns = [
|
||||||
|
r'\$\$.*?\$\$', # Block-level math formula $$...$$
|
||||||
|
r'\$[^$\n]+?\$', # Inline math formula $...$
|
||||||
|
r'\\begin\{.*?\}.*?\\end\{.*?\}', # LaTeX environment \begin{...}...\end{...}
|
||||||
|
r'\\[a-zA-Z]+\{.*?\}', # LaTeX command \command{...}
|
||||||
|
r'\\[a-zA-Z]+', # Simple LaTeX command \command
|
||||||
|
r'\\\[.*?\\\]', # Display math formula \[...\]
|
||||||
|
r'\\\(.*?\\\)', # Inline math formula \(...\)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Check if any of the patterns match
|
||||||
|
for pattern in latex_patterns:
|
||||||
|
if re.search(pattern, text, re.DOTALL):
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def clean_latex_preamble(latex_text: str) -> str:
|
||||||
|
"""
|
||||||
|
Removes LaTeX preamble commands like document class and package imports.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
latex_text (str): The original LaTeX text.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The cleaned LaTeX text without preamble commands.
|
||||||
|
"""
|
||||||
|
# Define patterns to be removed
|
||||||
|
patterns = [
|
||||||
|
r'\\documentclass\{[^}]+\}', # \documentclass{...}
|
||||||
|
r'\\usepackage\{[^}]+\}', # \usepackage{...}
|
||||||
|
r'\\usepackage\[[^\]]*\]\{[^}]+\}', # \usepackage[options]{...}
|
||||||
|
r'\\begin\{document\}', # \begin{document}
|
||||||
|
r'\\end\{document\}', # \end{document}
|
||||||
|
]
|
||||||
|
|
||||||
|
# Apply each pattern to clean the text
|
||||||
|
cleaned_text = latex_text
|
||||||
|
for pattern in patterns:
|
||||||
|
cleaned_text = re.sub(pattern, '', cleaned_text, flags=re.IGNORECASE)
|
||||||
|
|
||||||
|
return cleaned_text
|
||||||
|
|
||||||
|
|
||||||
|
def get_formula_in_markdown(text: str) -> str:
|
||||||
|
"""
|
||||||
|
Formats a string containing a formula into a standard Markdown block.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text (str): The input string, potentially containing a formula.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The formatted string, ready for Markdown rendering.
|
||||||
|
"""
|
||||||
|
# Remove leading/trailing whitespace
|
||||||
|
text = text.strip()
|
||||||
|
|
||||||
|
# Check if it's already enclosed in $$
|
||||||
|
if text.startswith('$$') and text.endswith('$$'):
|
||||||
|
text_new = text[2:-2].strip()
|
||||||
|
if not '$' in text_new:
|
||||||
|
return f"$$\n{text_new}\n$$"
|
||||||
|
else:
|
||||||
|
return text
|
||||||
|
|
||||||
|
# Handle \[...\] format, convert to $$...$$
|
||||||
|
if text.startswith('\\[') and text.endswith('\\]'):
|
||||||
|
inner_content = text[2:-2].strip()
|
||||||
|
return f"$$\n{inner_content}\n$$"
|
||||||
|
|
||||||
|
# Check if it's enclosed in \[ \]
|
||||||
|
if len(re.findall(r'.*\\\[.*\\\].*', text)) > 0:
|
||||||
|
return text
|
||||||
|
|
||||||
|
# Handle inline formulas ($...$)
|
||||||
|
pattern = r'\$([^$]+)\$'
|
||||||
|
matches = re.findall(pattern, text)
|
||||||
|
if len(matches) > 0:
|
||||||
|
# It's an inline formula, return it as is
|
||||||
|
return text
|
||||||
|
|
||||||
|
# If no LaTeX markdown syntax is present, return directly
|
||||||
|
if not has_latex_markdown(text):
|
||||||
|
return text
|
||||||
|
|
||||||
|
# Handle unnecessary LaTeX formatting like \usepackage
|
||||||
|
if 'usepackage' in text:
|
||||||
|
text = clean_latex_preamble(text)
|
||||||
|
|
||||||
|
if text[0] == '`' and text[-1] == '`':
|
||||||
|
text = text[1:-1]
|
||||||
|
|
||||||
|
# Enclose the final text in a $$ block with newlines
|
||||||
|
text = f"$$\n{text}\n$$"
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
def clean_text(text: str) -> str:
|
||||||
|
"""
|
||||||
|
Cleans text by removing extra whitespace.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: The original text.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The cleaned text.
|
||||||
|
"""
|
||||||
|
if not text:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
# Remove leading and trailing whitespace
|
||||||
|
text = text.strip()
|
||||||
|
|
||||||
|
# Replace multiple consecutive whitespace characters with a single space
|
||||||
|
text = re.sub(r'\s+', ' ', text)
|
||||||
|
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
def layoutjson2md(image: Image.Image, cells: list, text_key: str = 'text', no_page_hf: bool = False) -> str:
|
||||||
|
"""
|
||||||
|
Converts a layout JSON format to Markdown.
|
||||||
|
|
||||||
|
In the layout JSON, formulas are LaTeX, tables are HTML, and text is Markdown.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image: A PIL Image object.
|
||||||
|
cells: A list of dictionaries, each representing a layout cell.
|
||||||
|
text_key: The key for the text field in the cell dictionary.
|
||||||
|
no_page_header_footer: If True, skips page headers and footers.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The text in Markdown format.
|
||||||
|
"""
|
||||||
|
text_items = []
|
||||||
|
|
||||||
|
for i, cell in enumerate(cells):
|
||||||
|
x1, y1, x2, y2 = [int(coord) for coord in cell['bbox']]
|
||||||
|
text = cell.get(text_key, "")
|
||||||
|
|
||||||
|
if no_page_hf and cell['category'] in ['Page-header', 'Page-footer']:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if cell['category'] == 'Picture':
|
||||||
|
image_crop = image.crop((x1, y1, x2, y2))
|
||||||
|
image_base64 = PILimage_to_base64(image_crop)
|
||||||
|
text_items.append(f"")
|
||||||
|
elif cell['category'] == 'Formula':
|
||||||
|
text_items.append(get_formula_in_markdown(text))
|
||||||
|
else:
|
||||||
|
text = clean_text(text)
|
||||||
|
text_items.append(f"{text}")
|
||||||
|
|
||||||
|
markdown_text = '\n\n'.join(text_items)
|
||||||
|
return markdown_text
|
||||||
|
|
||||||
|
|
||||||
|
def fix_streamlit_formulas(md: str) -> str:
|
||||||
|
"""
|
||||||
|
Fixes the format of formulas in Markdown to ensure they display correctly in Streamlit.
|
||||||
|
It adds a newline after the opening $$ and before the closing $$ if they don't already exist.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
md_text (str): The Markdown text to fix.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The fixed Markdown text.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# This inner function will be used by re.sub to perform the replacement
|
||||||
|
def replace_formula(match):
|
||||||
|
content = match.group(1)
|
||||||
|
# If the content already has surrounding newlines, don't add more.
|
||||||
|
if content.startswith('\n'):
|
||||||
|
content = content[1:]
|
||||||
|
if content.endswith('\n'):
|
||||||
|
content = content[:-1]
|
||||||
|
return f'$$\n{content}\n$$'
|
||||||
|
|
||||||
|
# Use regex to find all $$....$$ patterns and replace them using the helper function.
|
||||||
|
return re.sub(r'\$\$(.*?)\$\$', replace_formula, md, flags=re.DOTALL)
|
||||||
@@ -0,0 +1,196 @@
|
|||||||
|
import math
|
||||||
|
import base64
|
||||||
|
from PIL import Image
|
||||||
|
from typing import Tuple
|
||||||
|
import os
|
||||||
|
from dots_ocr.utils.consts import IMAGE_FACTOR, MIN_PIXELS, MAX_PIXELS
|
||||||
|
from dots_ocr.utils.doc_utils import fitz_doc_to_image
|
||||||
|
from io import BytesIO
|
||||||
|
import fitz
|
||||||
|
import requests
|
||||||
|
import copy
|
||||||
|
|
||||||
|
|
||||||
|
def round_by_factor(number: int, factor: int) -> int:
|
||||||
|
"""Returns the closest integer to 'number' that is divisible by 'factor'."""
|
||||||
|
return round(number / factor) * factor
|
||||||
|
|
||||||
|
|
||||||
|
def ceil_by_factor(number: int, factor: int) -> int:
|
||||||
|
"""Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
|
||||||
|
return math.ceil(number / factor) * factor
|
||||||
|
|
||||||
|
|
||||||
|
def floor_by_factor(number: int, factor: int) -> int:
|
||||||
|
"""Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
|
||||||
|
return math.floor(number / factor) * factor
|
||||||
|
|
||||||
|
|
||||||
|
def smart_resize(
|
||||||
|
height: int,
|
||||||
|
width: int,
|
||||||
|
factor: int = 28,
|
||||||
|
min_pixels: int = 3136,
|
||||||
|
max_pixels: int = 11289600,
|
||||||
|
):
|
||||||
|
"""Rescales the image so that the following conditions are met:
|
||||||
|
|
||||||
|
1. Both dimensions (height and width) are divisible by 'factor'.
|
||||||
|
|
||||||
|
2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
|
||||||
|
|
||||||
|
3. The aspect ratio of the image is maintained as closely as possible.
|
||||||
|
|
||||||
|
"""
|
||||||
|
if max(height, width) / min(height, width) > 200:
|
||||||
|
raise ValueError(
|
||||||
|
f"absolute aspect ratio must be smaller than 200, got {max(height, width) / min(height, width)}"
|
||||||
|
)
|
||||||
|
h_bar = max(factor, round_by_factor(height, factor))
|
||||||
|
w_bar = max(factor, round_by_factor(width, factor))
|
||||||
|
if h_bar * w_bar > max_pixels:
|
||||||
|
beta = math.sqrt((height * width) / max_pixels)
|
||||||
|
h_bar = max(factor, floor_by_factor(height / beta, factor))
|
||||||
|
w_bar = max(factor, floor_by_factor(width / beta, factor))
|
||||||
|
elif h_bar * w_bar < min_pixels:
|
||||||
|
beta = math.sqrt(min_pixels / (height * width))
|
||||||
|
h_bar = ceil_by_factor(height * beta, factor)
|
||||||
|
w_bar = ceil_by_factor(width * beta, factor)
|
||||||
|
if h_bar * w_bar > max_pixels: # max_pixels first to control the token length
|
||||||
|
beta = math.sqrt((h_bar * w_bar) / max_pixels)
|
||||||
|
h_bar = max(factor, floor_by_factor(h_bar / beta, factor))
|
||||||
|
w_bar = max(factor, floor_by_factor(w_bar / beta, factor))
|
||||||
|
return h_bar, w_bar
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def PILimage_to_base64(image, format='PNG'):
|
||||||
|
buffered = BytesIO()
|
||||||
|
image.save(buffered, format=format)
|
||||||
|
base64_str = base64.b64encode(buffered.getvalue()).decode('utf-8')
|
||||||
|
return f"data:image;base64,{base64_str}"
|
||||||
|
|
||||||
|
|
||||||
|
def to_rgb(pil_image: Image.Image) -> Image.Image:
|
||||||
|
if pil_image.mode == 'RGBA':
|
||||||
|
white_background = Image.new("RGB", pil_image.size, (255, 255, 255))
|
||||||
|
white_background.paste(pil_image, mask=pil_image.split()[3]) # Use alpha channel as mask
|
||||||
|
return white_background
|
||||||
|
else:
|
||||||
|
return pil_image.convert("RGB")
|
||||||
|
|
||||||
|
|
||||||
|
# copy from https://github.com/QwenLM/Qwen2.5-VL/blob/main/qwen-vl-utils/src/qwen_vl_utils/vision_process.py
|
||||||
|
def fetch_image(
|
||||||
|
image,
|
||||||
|
min_pixels=None,
|
||||||
|
max_pixels=None,
|
||||||
|
resized_height=None,
|
||||||
|
resized_width=None,
|
||||||
|
) -> Image.Image:
|
||||||
|
assert image is not None, f"image not found, maybe input format error: {image}"
|
||||||
|
image_obj = None
|
||||||
|
if isinstance(image, Image.Image):
|
||||||
|
image_obj = image
|
||||||
|
elif image.startswith("http://") or image.startswith("https://"):
|
||||||
|
# fix memory leak issue while using BytesIO
|
||||||
|
with requests.get(image, stream=True) as response:
|
||||||
|
response.raise_for_status()
|
||||||
|
with BytesIO(response.content) as bio:
|
||||||
|
image_obj = copy.deepcopy(Image.open(bio))
|
||||||
|
elif image.startswith("file://"):
|
||||||
|
image_obj = Image.open(image[7:])
|
||||||
|
elif image.startswith("data:image"):
|
||||||
|
if "base64," in image:
|
||||||
|
_, base64_data = image.split("base64,", 1)
|
||||||
|
data = base64.b64decode(base64_data)
|
||||||
|
# fix memory leak issue while using BytesIO
|
||||||
|
with BytesIO(data) as bio:
|
||||||
|
image_obj = copy.deepcopy(Image.open(bio))
|
||||||
|
else:
|
||||||
|
image_obj = Image.open(image)
|
||||||
|
if image_obj is None:
|
||||||
|
raise ValueError(f"Unrecognized image input, support local path, http url, base64 and PIL.Image, got {image}")
|
||||||
|
image = to_rgb(image_obj)
|
||||||
|
## resize
|
||||||
|
if resized_height and resized_width:
|
||||||
|
resized_height, resized_width = smart_resize(
|
||||||
|
resized_height,
|
||||||
|
resized_width,
|
||||||
|
factor=IMAGE_FACTOR,
|
||||||
|
)
|
||||||
|
assert resized_height>0 and resized_width>0, f"resized_height: {resized_height}, resized_width: {resized_width}, min_pixels: {min_pixels}, max_pixels:{max_pixels}, width: {width}, height:{height}, "
|
||||||
|
image = image.resize((resized_width, resized_height))
|
||||||
|
elif min_pixels or max_pixels:
|
||||||
|
width, height = image.size
|
||||||
|
if not min_pixels:
|
||||||
|
min_pixels = MIN_PIXELS
|
||||||
|
if not max_pixels:
|
||||||
|
max_pixels = MAX_PIXELS
|
||||||
|
resized_height, resized_width = smart_resize(
|
||||||
|
height,
|
||||||
|
width,
|
||||||
|
factor=IMAGE_FACTOR,
|
||||||
|
min_pixels=min_pixels,
|
||||||
|
max_pixels=max_pixels,
|
||||||
|
)
|
||||||
|
assert resized_height>0 and resized_width>0, f"resized_height: {resized_height}, resized_width: {resized_width}, min_pixels: {min_pixels}, max_pixels:{max_pixels}, width: {width}, height:{height}, "
|
||||||
|
image = image.resize((resized_width, resized_height))
|
||||||
|
|
||||||
|
return image
|
||||||
|
|
||||||
|
def get_input_dimensions(
|
||||||
|
image: Image.Image,
|
||||||
|
min_pixels: int,
|
||||||
|
max_pixels: int,
|
||||||
|
factor: int = 28
|
||||||
|
) -> Tuple[int, int]:
|
||||||
|
"""
|
||||||
|
Gets the resized dimensions of the input image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image: The original image.
|
||||||
|
min_pixels: The minimum number of pixels.
|
||||||
|
max_pixels: The maximum number of pixels.
|
||||||
|
factor: The resizing factor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The resized (width, height).
|
||||||
|
"""
|
||||||
|
input_height, input_width = smart_resize(
|
||||||
|
image.height,
|
||||||
|
image.width,
|
||||||
|
factor=factor,
|
||||||
|
min_pixels=min_pixels,
|
||||||
|
max_pixels=max_pixels
|
||||||
|
)
|
||||||
|
return input_width, input_height
|
||||||
|
|
||||||
|
|
||||||
|
def get_image_by_fitz_doc(image, target_dpi=200):
|
||||||
|
# get image through fitz, to get target dpi image, mainly for higher image
|
||||||
|
if not isinstance(image, Image.Image):
|
||||||
|
assert isinstance(image, str)
|
||||||
|
_, file_ext = os.path.splitext(image)
|
||||||
|
assert file_ext in {'.jpg', '.jpeg', '.png'}
|
||||||
|
|
||||||
|
if image.startswith("http://") or image.startswith("https://"):
|
||||||
|
with requests.get(image, stream=True) as response:
|
||||||
|
response.raise_for_status()
|
||||||
|
data_bytes = response.content
|
||||||
|
else:
|
||||||
|
with open(image, 'rb') as f:
|
||||||
|
data_bytes = f.read()
|
||||||
|
|
||||||
|
image = Image.open(BytesIO(data_bytes))
|
||||||
|
else:
|
||||||
|
data_bytes = BytesIO()
|
||||||
|
image.save(data_bytes, format='PNG')
|
||||||
|
|
||||||
|
origin_dpi = image.info.get('dpi', None)
|
||||||
|
pdf_bytes = fitz.open(stream=data_bytes).convert_to_pdf()
|
||||||
|
doc = fitz.open('pdf', pdf_bytes)
|
||||||
|
page = doc[0]
|
||||||
|
image_fitz = fitz_doc_to_image(page, target_dpi=target_dpi, origin_dpi=origin_dpi)
|
||||||
|
|
||||||
|
return image_fitz
|
||||||
@@ -0,0 +1,228 @@
|
|||||||
|
from PIL import Image
|
||||||
|
from typing import Dict, List
|
||||||
|
|
||||||
|
import fitz
|
||||||
|
from io import BytesIO
|
||||||
|
import json
|
||||||
|
|
||||||
|
from dots_ocr.utils.image_utils import smart_resize
|
||||||
|
from dots_ocr.utils.consts import MIN_PIXELS, MAX_PIXELS
|
||||||
|
from dots_ocr.utils.output_cleaner import OutputCleaner
|
||||||
|
|
||||||
|
|
||||||
|
# Define a color map (using RGBA format)
|
||||||
|
dict_layout_type_to_color = {
|
||||||
|
"Text": (0, 128, 0, 256), # Green, translucent
|
||||||
|
"Picture": (255, 0, 255, 256), # Magenta, translucent
|
||||||
|
"Caption": (255, 165, 0, 256), # Orange, translucent
|
||||||
|
"Section-header": (0, 255, 255, 256), # Cyan, translucent
|
||||||
|
"Footnote": (0, 128, 0, 256), # Green, translucent
|
||||||
|
"Formula": (128, 128, 128, 256), # Gray, translucent
|
||||||
|
"Table": (255, 192, 203, 256), # Pink, translucent
|
||||||
|
"Title": (255, 0, 0, 256), # Red, translucent
|
||||||
|
"List-item": (0, 0, 255, 256), # Blue, translucent
|
||||||
|
"Page-header": (0, 128, 0, 256), # Green, translucent
|
||||||
|
"Page-footer": (128, 0, 128, 256), # Purple, translucent
|
||||||
|
"Other": (165, 42, 42, 256), # Brown, translucent
|
||||||
|
"Unknown": (0, 0, 0, 0),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def draw_layout_on_image(image, cells, resized_height=None, resized_width=None, fill_bbox=True, draw_bbox=True):
|
||||||
|
"""
|
||||||
|
Draw transparent boxes on an image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image: The source PIL Image.
|
||||||
|
cells: A list of cells containing bounding box information.
|
||||||
|
resized_height: The resized height.
|
||||||
|
resized_width: The resized width.
|
||||||
|
fill_bbox: Whether to fill the bounding box.
|
||||||
|
draw_bbox: Whether to draw the bounding box.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
PIL.Image: The image with drawings.
|
||||||
|
"""
|
||||||
|
# origin_image = Image.open(image_path)
|
||||||
|
original_width, original_height = image.size
|
||||||
|
|
||||||
|
# Create a new PDF document
|
||||||
|
doc = fitz.open()
|
||||||
|
|
||||||
|
# Get image information
|
||||||
|
img_bytes = BytesIO()
|
||||||
|
image.save(img_bytes, format='PNG')
|
||||||
|
# pix = fitz.Pixmap(image_path)
|
||||||
|
pix = fitz.Pixmap(img_bytes)
|
||||||
|
|
||||||
|
# Create a page
|
||||||
|
page = doc.new_page(width=pix.width, height=pix.height)
|
||||||
|
page.insert_image(
|
||||||
|
fitz.Rect(0, 0, pix.width, pix.height),
|
||||||
|
# filename=image_path
|
||||||
|
pixmap=pix
|
||||||
|
)
|
||||||
|
|
||||||
|
for i, cell in enumerate(cells):
|
||||||
|
bbox = cell['bbox']
|
||||||
|
layout_type = cell['category']
|
||||||
|
order = i
|
||||||
|
|
||||||
|
top_left = (bbox[0], bbox[1])
|
||||||
|
down_right = (bbox[2], bbox[3])
|
||||||
|
if resized_height and resized_width:
|
||||||
|
scale_x = resized_width / original_width
|
||||||
|
scale_y = resized_height / original_height
|
||||||
|
top_left = (int(bbox[0] / scale_x), int(bbox[1] / scale_y))
|
||||||
|
down_right = (int(bbox[2] / scale_x), int(bbox[3] / scale_y))
|
||||||
|
|
||||||
|
color = dict_layout_type_to_color.get(layout_type, (0, 128, 0, 256))
|
||||||
|
color = [col/255 for col in color[:3]]
|
||||||
|
|
||||||
|
x0, y0, x1, y1 = top_left[0], top_left[1], down_right[0], down_right[1]
|
||||||
|
rect_coords = fitz.Rect(x0, y0, x1, y1)
|
||||||
|
if draw_bbox:
|
||||||
|
if fill_bbox:
|
||||||
|
page.draw_rect(
|
||||||
|
rect_coords,
|
||||||
|
color=None,
|
||||||
|
fill=color,
|
||||||
|
fill_opacity=0.3,
|
||||||
|
width=0.5,
|
||||||
|
overlay=True,
|
||||||
|
) # Draw the rectangle
|
||||||
|
else:
|
||||||
|
page.draw_rect(
|
||||||
|
rect_coords,
|
||||||
|
color=color,
|
||||||
|
fill=None,
|
||||||
|
fill_opacity=1,
|
||||||
|
width=0.5,
|
||||||
|
overlay=True,
|
||||||
|
) # Draw the rectangle
|
||||||
|
order_cate = f"{order}_{layout_type}"
|
||||||
|
page.insert_text(
|
||||||
|
(x1, y0 + 20), order_cate, fontsize=20, color=color
|
||||||
|
) # Insert the index in the top left corner of the rectangle
|
||||||
|
|
||||||
|
# Convert to a Pixmap (maintaining original dimensions)
|
||||||
|
mat = fitz.Matrix(1.0, 1.0)
|
||||||
|
pix = page.get_pixmap(matrix=mat)
|
||||||
|
|
||||||
|
return Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
|
||||||
|
|
||||||
|
|
||||||
|
def pre_process_bboxes(
|
||||||
|
origin_image,
|
||||||
|
bboxes,
|
||||||
|
input_width,
|
||||||
|
input_height,
|
||||||
|
factor: int = 28,
|
||||||
|
min_pixels: int = 3136,
|
||||||
|
max_pixels: int = 11289600
|
||||||
|
):
|
||||||
|
assert isinstance(bboxes, list) and len(bboxes) > 0 and isinstance(bboxes[0], list)
|
||||||
|
min_pixels = min_pixels or MIN_PIXELS
|
||||||
|
max_pixels = max_pixels or MAX_PIXELS
|
||||||
|
original_width, original_height = origin_image.size
|
||||||
|
|
||||||
|
input_height, input_width = smart_resize(input_height, input_width, min_pixels=min_pixels, max_pixels=max_pixels)
|
||||||
|
|
||||||
|
scale_x = original_width / input_width
|
||||||
|
scale_y = original_height / input_height
|
||||||
|
|
||||||
|
bboxes_out = []
|
||||||
|
for bbox in bboxes:
|
||||||
|
bbox_resized = [
|
||||||
|
int(float(bbox[0]) / scale_x),
|
||||||
|
int(float(bbox[1]) / scale_y),
|
||||||
|
int(float(bbox[2]) / scale_x),
|
||||||
|
int(float(bbox[3]) / scale_y)
|
||||||
|
]
|
||||||
|
bboxes_out.append(bbox_resized)
|
||||||
|
|
||||||
|
return bboxes_out
|
||||||
|
|
||||||
|
def post_process_cells(
|
||||||
|
origin_image: Image.Image,
|
||||||
|
cells: List[Dict],
|
||||||
|
input_width, # server input width, also has smart_resize in server
|
||||||
|
input_height,
|
||||||
|
factor: int = 28,
|
||||||
|
min_pixels: int = 3136,
|
||||||
|
max_pixels: int = 11289600
|
||||||
|
) -> List[Dict]:
|
||||||
|
"""
|
||||||
|
Post-processes cell bounding boxes, converting coordinates from the resized dimensions back to the original dimensions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
origin_image: The original PIL Image.
|
||||||
|
cells: A list of cells containing bounding box information.
|
||||||
|
input_width: The width of the input image sent to the server.
|
||||||
|
input_height: The height of the input image sent to the server.
|
||||||
|
factor: Resizing factor.
|
||||||
|
min_pixels: Minimum number of pixels.
|
||||||
|
max_pixels: Maximum number of pixels.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of post-processed cells.
|
||||||
|
"""
|
||||||
|
assert isinstance(cells, list) and len(cells) > 0 and isinstance(cells[0], dict)
|
||||||
|
min_pixels = min_pixels or MIN_PIXELS
|
||||||
|
max_pixels = max_pixels or MAX_PIXELS
|
||||||
|
original_width, original_height = origin_image.size
|
||||||
|
|
||||||
|
input_height, input_width = smart_resize(input_height, input_width, min_pixels=min_pixels, max_pixels=max_pixels)
|
||||||
|
|
||||||
|
scale_x = input_width / original_width
|
||||||
|
scale_y = input_height / original_height
|
||||||
|
|
||||||
|
cells_out = []
|
||||||
|
for cell in cells:
|
||||||
|
bbox = cell['bbox']
|
||||||
|
bbox_resized = [
|
||||||
|
int(float(bbox[0]) / scale_x),
|
||||||
|
int(float(bbox[1]) / scale_y),
|
||||||
|
int(float(bbox[2]) / scale_x),
|
||||||
|
int(float(bbox[3]) / scale_y)
|
||||||
|
]
|
||||||
|
cell_copy = cell.copy()
|
||||||
|
cell_copy['bbox'] = bbox_resized
|
||||||
|
cells_out.append(cell_copy)
|
||||||
|
|
||||||
|
return cells_out
|
||||||
|
|
||||||
|
def is_legal_bbox(cells):
|
||||||
|
for cell in cells:
|
||||||
|
bbox = cell['bbox']
|
||||||
|
if bbox[2] <= bbox[0] or bbox[3] <= bbox[1]:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def post_process_output(response, prompt_mode, origin_image, input_image, min_pixels=None, max_pixels=None):
|
||||||
|
if prompt_mode in ["prompt_ocr", "prompt_table_html", "prompt_table_latex", "prompt_formula_latex"]:
|
||||||
|
return response
|
||||||
|
|
||||||
|
json_load_failed = False
|
||||||
|
cells = response
|
||||||
|
try:
|
||||||
|
cells = json.loads(cells)
|
||||||
|
cells = post_process_cells(
|
||||||
|
origin_image,
|
||||||
|
cells,
|
||||||
|
input_image.width,
|
||||||
|
input_image.height,
|
||||||
|
min_pixels=min_pixels,
|
||||||
|
max_pixels=max_pixels
|
||||||
|
)
|
||||||
|
return cells, False
|
||||||
|
except Exception as e:
|
||||||
|
print(f"cells post process error: {e}, when using {prompt_mode}")
|
||||||
|
json_load_failed = True
|
||||||
|
|
||||||
|
if json_load_failed:
|
||||||
|
cleaner = OutputCleaner()
|
||||||
|
response_clean = cleaner.clean_model_output(cells)
|
||||||
|
if isinstance(response_clean, list):
|
||||||
|
response_clean = "\n\n".join([cell['text'] for cell in response_clean if 'text' in cell])
|
||||||
|
return response_clean, True
|
||||||
@@ -0,0 +1,623 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Data Cleaning Script - Cleans all data using a simplified regex method and saves the results
|
||||||
|
|
||||||
|
Features:
|
||||||
|
1. Cleans all cases using a simplified regex method.
|
||||||
|
2. Saves the cleaned data for each case.
|
||||||
|
3. Ensures the relative order of dicts remains unchanged.
|
||||||
|
4. Generates a before-and-after cleaning report.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import re
|
||||||
|
import os
|
||||||
|
from typing import Dict, List, Tuple, Optional, Any
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from collections import Counter
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CleanedData:
|
||||||
|
"""Data structure for cleaned data"""
|
||||||
|
case_id: int
|
||||||
|
original_type: str # 'list' or 'str'
|
||||||
|
original_length: int
|
||||||
|
cleaned_data: List[Dict]
|
||||||
|
cleaning_operations: Dict[str, Any] # Records the cleaning operations performed
|
||||||
|
success: bool
|
||||||
|
|
||||||
|
|
||||||
|
class OutputCleaner:
|
||||||
|
"""Data Cleaner - Based on a simplified regex method"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
# Simplified regular expression patterns
|
||||||
|
self.dict_pattern = re.compile(r'\{[^{}]*?"bbox"\s*:\s*\[[^\]]*?\][^{}]*?\}', re.DOTALL)
|
||||||
|
self.bbox_pattern = re.compile(r'"bbox"\s*:\s*\[([^\]]+)\]')
|
||||||
|
self.missing_delimiter_pattern = re.compile(r'\}\s*\{(?!")')
|
||||||
|
|
||||||
|
self.cleaned_results: List[CleanedData] = []
|
||||||
|
|
||||||
|
def clean_list_data(self, data: List[Dict], case_id: int) -> CleanedData:
|
||||||
|
"""Cleans list-type data"""
|
||||||
|
|
||||||
|
print(f"🔧 Cleaning List data - Case {case_id}")
|
||||||
|
print(f" Original items: {len(data)}")
|
||||||
|
|
||||||
|
cleaned_data = []
|
||||||
|
operations = {
|
||||||
|
'type': 'list',
|
||||||
|
'bbox_fixes': 0,
|
||||||
|
'removed_items': 0,
|
||||||
|
'original_count': len(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, item in enumerate(data):
|
||||||
|
if not isinstance(item, dict):
|
||||||
|
operations['removed_items'] += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Check the bbox field
|
||||||
|
if 'bbox' in item:
|
||||||
|
bbox = item['bbox']
|
||||||
|
|
||||||
|
# Check bbox length - core logic
|
||||||
|
if isinstance(bbox, list) and len(bbox) == 3:
|
||||||
|
print(f" ⚠️ Item {i}: bbox has only 3 coordinates. Removing bbox, keeping category and text.")
|
||||||
|
# Keep only category and text, ensuring order is preserved
|
||||||
|
new_item = {}
|
||||||
|
if 'category' in item:
|
||||||
|
new_item['category'] = item['category']
|
||||||
|
if 'text' in item:
|
||||||
|
new_item['text'] = item['text']
|
||||||
|
if new_item: # Add only if there is valid content
|
||||||
|
cleaned_data.append(new_item)
|
||||||
|
operations['bbox_fixes'] += 1
|
||||||
|
else:
|
||||||
|
operations['removed_items'] += 1
|
||||||
|
continue
|
||||||
|
elif isinstance(bbox, list) and len(bbox) == 4:
|
||||||
|
# bbox is normal, add directly, preserving original order
|
||||||
|
cleaned_data.append(item.copy())
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
print(f" ❌ Item {i}: Abnormal bbox format, skipping.")
|
||||||
|
operations['removed_items'] += 1
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
# No bbox field, keep if category exists
|
||||||
|
if 'category' in item:
|
||||||
|
cleaned_data.append(item.copy())
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
operations['removed_items'] += 1
|
||||||
|
|
||||||
|
operations['final_count'] = len(cleaned_data)
|
||||||
|
print(f" ✅ Cleaning complete: {len(cleaned_data)} items, {operations['bbox_fixes']} bbox fixes, {operations['removed_items']} items removed")
|
||||||
|
|
||||||
|
return CleanedData(
|
||||||
|
case_id=case_id,
|
||||||
|
original_type='list',
|
||||||
|
original_length=len(data),
|
||||||
|
cleaned_data=cleaned_data,
|
||||||
|
cleaning_operations=operations,
|
||||||
|
success=True
|
||||||
|
)
|
||||||
|
|
||||||
|
def clean_string_data(self, data_str: str, case_id: int) -> CleanedData:
|
||||||
|
"""Cleans string-type data"""
|
||||||
|
|
||||||
|
print(f"🔧 Cleaning String data - Case {case_id}")
|
||||||
|
print(f" Original length: {len(data_str):,}")
|
||||||
|
|
||||||
|
operations = {
|
||||||
|
'type': 'str',
|
||||||
|
'original_length': len(data_str),
|
||||||
|
'delimiter_fixes': 0,
|
||||||
|
'tail_truncated': False,
|
||||||
|
'truncated_length': 0,
|
||||||
|
'duplicate_dicts_removed': 0,
|
||||||
|
'final_objects': 0
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Step 1: Detect and fix missing delimiters
|
||||||
|
data_str, delimiter_fixes = self._fix_missing_delimiters(data_str)
|
||||||
|
operations['delimiter_fixes'] = delimiter_fixes
|
||||||
|
|
||||||
|
# Step 2: Truncate the last incomplete element
|
||||||
|
data_str, tail_truncated = self._truncate_last_incomplete_element(data_str)
|
||||||
|
operations['tail_truncated'] = tail_truncated
|
||||||
|
operations['truncated_length'] = len(data_str)
|
||||||
|
|
||||||
|
# Step 3: Remove duplicate complete dict objects, preserving order
|
||||||
|
data_str, duplicate_removes = self._remove_duplicate_complete_dicts_preserve_order(data_str)
|
||||||
|
operations['duplicate_dicts_removed'] = duplicate_removes
|
||||||
|
|
||||||
|
# Step 4: Ensure correct JSON format
|
||||||
|
data_str = self._ensure_json_format(data_str)
|
||||||
|
|
||||||
|
# Step 5: Try to parse the final result
|
||||||
|
final_data = self._parse_final_json(data_str)
|
||||||
|
|
||||||
|
if final_data is not None:
|
||||||
|
operations['final_objects'] = len(final_data)
|
||||||
|
print(f" ✅ Cleaning complete: {len(final_data)} objects")
|
||||||
|
|
||||||
|
return CleanedData(
|
||||||
|
case_id=case_id,
|
||||||
|
original_type='str',
|
||||||
|
original_length=operations['original_length'],
|
||||||
|
cleaned_data=final_data,
|
||||||
|
cleaning_operations=operations,
|
||||||
|
success=True
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise Exception("Could not parse the cleaned data")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f" ❌ Cleaning failed: {e}")
|
||||||
|
return CleanedData(
|
||||||
|
case_id=case_id,
|
||||||
|
original_type='str',
|
||||||
|
original_length=operations['original_length'],
|
||||||
|
cleaned_data=[],
|
||||||
|
cleaning_operations=operations,
|
||||||
|
success=False
|
||||||
|
)
|
||||||
|
|
||||||
|
def _fix_missing_delimiters(self, text: str) -> Tuple[str, int]:
|
||||||
|
"""Fixes missing delimiters"""
|
||||||
|
|
||||||
|
fixes = 0
|
||||||
|
|
||||||
|
def replace_delimiter(match):
|
||||||
|
nonlocal fixes
|
||||||
|
fixes += 1
|
||||||
|
return '},{'
|
||||||
|
|
||||||
|
text = self.missing_delimiter_pattern.sub(replace_delimiter, text)
|
||||||
|
|
||||||
|
if fixes > 0:
|
||||||
|
print(f" ✅ Fixed {fixes} missing delimiters")
|
||||||
|
|
||||||
|
return text, fixes
|
||||||
|
|
||||||
|
def _truncate_last_incomplete_element(self, text: str) -> Tuple[str, bool]:
|
||||||
|
"""Truncates the last incomplete element"""
|
||||||
|
|
||||||
|
# For very long text (>50k) or text not ending with ']', directly truncate the last '{"bbox":'
|
||||||
|
needs_truncation = (
|
||||||
|
len(text) > 50000 or
|
||||||
|
not text.strip().endswith(']')
|
||||||
|
)
|
||||||
|
|
||||||
|
if needs_truncation:
|
||||||
|
# Check how many dict objects there are
|
||||||
|
bbox_count = text.count('{"bbox":')
|
||||||
|
|
||||||
|
# If there is only one dict object, do not truncate to avoid deleting the only object
|
||||||
|
if bbox_count <= 1:
|
||||||
|
print(f" ⚠️ Only {bbox_count} dict objects found, skipping truncation to avoid deleting all content")
|
||||||
|
return text, False
|
||||||
|
|
||||||
|
# Find the position of the last '{"bbox":'
|
||||||
|
last_bbox_pos = text.rfind('{"bbox":')
|
||||||
|
|
||||||
|
if last_bbox_pos > 0:
|
||||||
|
# Truncate before this position
|
||||||
|
truncated_text = text[:last_bbox_pos].rstrip()
|
||||||
|
|
||||||
|
# Remove trailing comma
|
||||||
|
if truncated_text.endswith(','):
|
||||||
|
truncated_text = truncated_text[:-1]
|
||||||
|
|
||||||
|
print(f" ✂️ Truncated the last incomplete element, length reduced from {len(text):,} to {len(truncated_text):,}")
|
||||||
|
return truncated_text, True
|
||||||
|
|
||||||
|
return text, False
|
||||||
|
|
||||||
|
def _remove_duplicate_complete_dicts_preserve_order(self, text: str) -> Tuple[str, int]:
|
||||||
|
"""Removes duplicate complete dict objects, preserving original order"""
|
||||||
|
|
||||||
|
# Extract all dict objects, preserving order
|
||||||
|
dict_matches = list(self.dict_pattern.finditer(text))
|
||||||
|
|
||||||
|
if not dict_matches:
|
||||||
|
return text, 0
|
||||||
|
|
||||||
|
print(f" 📊 Found {len(dict_matches)} dict objects")
|
||||||
|
|
||||||
|
# Deduplication while preserving order: only keep the first occurrence of a dict
|
||||||
|
unique_dicts = []
|
||||||
|
seen_dict_strings = set()
|
||||||
|
total_duplicates = 0
|
||||||
|
|
||||||
|
for match in dict_matches:
|
||||||
|
dict_str = match.group()
|
||||||
|
|
||||||
|
if dict_str not in seen_dict_strings:
|
||||||
|
unique_dicts.append(dict_str)
|
||||||
|
seen_dict_strings.add(dict_str)
|
||||||
|
else:
|
||||||
|
total_duplicates += 1
|
||||||
|
|
||||||
|
if total_duplicates > 0:
|
||||||
|
# Reconstruct the JSON array, preserving the original order
|
||||||
|
new_text = '[' + ', '.join(unique_dicts) + ']'
|
||||||
|
print(f" ✅ Removed {total_duplicates} duplicate dicts, keeping {len(unique_dicts)} unique dicts (order preserved)")
|
||||||
|
return new_text, total_duplicates
|
||||||
|
else:
|
||||||
|
print(f" ✅ No duplicate dict objects found")
|
||||||
|
return text, 0
|
||||||
|
|
||||||
|
def _ensure_json_format(self, text: str) -> str:
|
||||||
|
"""Ensures correct JSON format"""
|
||||||
|
|
||||||
|
text = text.strip()
|
||||||
|
|
||||||
|
if not text.startswith('['):
|
||||||
|
text = '[' + text
|
||||||
|
|
||||||
|
if not text.endswith(']'):
|
||||||
|
# Remove trailing comma
|
||||||
|
text = text.rstrip(',').rstrip()
|
||||||
|
text += ']'
|
||||||
|
|
||||||
|
return text
|
||||||
|
|
||||||
|
def _parse_final_json(self, text: str) -> Optional[List[Dict]]:
|
||||||
|
"""Tries to parse the final JSON"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
data = json.loads(text)
|
||||||
|
if isinstance(data, list):
|
||||||
|
return data
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
print(f" ❌ JSON parsing failed: {e}")
|
||||||
|
|
||||||
|
# fallback1: Extract valid dict objects
|
||||||
|
valid_dicts = []
|
||||||
|
|
||||||
|
for match in self.dict_pattern.finditer(text):
|
||||||
|
dict_str = match.group()
|
||||||
|
try:
|
||||||
|
dict_obj = json.loads(dict_str)
|
||||||
|
valid_dicts.append(dict_obj)
|
||||||
|
except:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if valid_dicts:
|
||||||
|
print(f" ✅ Extracted {len(valid_dicts)} valid dicts")
|
||||||
|
return valid_dicts
|
||||||
|
|
||||||
|
# fallback2: Special handling for a single incomplete dict
|
||||||
|
return self._handle_single_incomplete_dict(text)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _handle_single_incomplete_dict(self, text: str) -> Optional[List[Dict]]:
|
||||||
|
"""Handles the special case of a single incomplete dict"""
|
||||||
|
|
||||||
|
# Check if it's a single incomplete dict case
|
||||||
|
if not text.strip().startswith('[{"bbox":'):
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Try to extract bbox coordinates
|
||||||
|
bbox_match = re.search(r'"bbox"\s*:\s*\[([^\]]+)\]', text)
|
||||||
|
if not bbox_match:
|
||||||
|
return None
|
||||||
|
|
||||||
|
bbox_str = bbox_match.group(1)
|
||||||
|
bbox_coords = [int(x.strip()) for x in bbox_str.split(',')]
|
||||||
|
|
||||||
|
if len(bbox_coords) != 4:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Try to extract category
|
||||||
|
category_match = re.search(r'"category"\s*:\s*"([^"]+)"', text)
|
||||||
|
category = category_match.group(1) if category_match else "Text"
|
||||||
|
|
||||||
|
# Try to extract the beginning of the text (first 10000 characters)
|
||||||
|
text_match = re.search(r'"text"\s*:\s*"([^"]{0,10000})', text)
|
||||||
|
if text_match:
|
||||||
|
text_content = text_match.group(1)
|
||||||
|
else:
|
||||||
|
text_content = ""
|
||||||
|
|
||||||
|
# Construct the fixed dict
|
||||||
|
fixed_dict = {
|
||||||
|
"bbox": bbox_coords,
|
||||||
|
"category": category
|
||||||
|
}
|
||||||
|
|
||||||
|
if text_content:
|
||||||
|
fixed_dict["text"] = text_content
|
||||||
|
|
||||||
|
print(f" 🔧 Special fix: single incomplete dict → {fixed_dict}")
|
||||||
|
return [fixed_dict]
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f" ❌ Special fix failed: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def remove_duplicate_category_text_pairs_and_bbox(self, data_list: List[dict], case_id: int) -> List[dict]:
|
||||||
|
"""Removes duplicate category-text pairs and duplicate bboxes"""
|
||||||
|
|
||||||
|
if not data_list or len(data_list) <= 1:
|
||||||
|
print(f" 📊 Data length {len(data_list)} <= 1, skipping deduplication check")
|
||||||
|
return data_list
|
||||||
|
|
||||||
|
print(f" 📊 Original data length: {len(data_list)}")
|
||||||
|
|
||||||
|
# 1. Count occurrences and positions of each category-text pair
|
||||||
|
category_text_pairs = {}
|
||||||
|
for i, item in enumerate(data_list):
|
||||||
|
if isinstance(item, dict) and 'category' in item and 'text' in item:
|
||||||
|
pair_key = (item.get('category', ''), item.get('text', ''))
|
||||||
|
if pair_key not in category_text_pairs:
|
||||||
|
category_text_pairs[pair_key] = []
|
||||||
|
category_text_pairs[pair_key].append(i)
|
||||||
|
|
||||||
|
# 2. Count occurrences and positions of each bbox
|
||||||
|
bbox_pairs = {}
|
||||||
|
for i, item in enumerate(data_list):
|
||||||
|
if isinstance(item, dict) and 'bbox' in item:
|
||||||
|
bbox = item.get('bbox')
|
||||||
|
if isinstance(bbox, list) and len(bbox) > 0:
|
||||||
|
bbox_key = tuple(bbox) # Convert to tuple to use as a dictionary key
|
||||||
|
if bbox_key not in bbox_pairs:
|
||||||
|
bbox_pairs[bbox_key] = []
|
||||||
|
bbox_pairs[bbox_key].append(i)
|
||||||
|
|
||||||
|
# 3. Identify items to be removed
|
||||||
|
duplicates_to_remove = set()
|
||||||
|
|
||||||
|
# 3a. Process category-text pairs that appear 5 or more times
|
||||||
|
for pair_key, positions in category_text_pairs.items():
|
||||||
|
if len(positions) >= 5:
|
||||||
|
category, text = pair_key
|
||||||
|
# Keep the first occurrence, remove subsequent duplicates
|
||||||
|
positions_to_remove = positions[1:]
|
||||||
|
duplicates_to_remove.update(positions_to_remove)
|
||||||
|
|
||||||
|
print(f" 🔍 Found duplicate category-text pair: category='{category}', first 50 chars of text='{text[:50]}...'")
|
||||||
|
print(f" Count: {len(positions)}, removing at positions: {positions_to_remove}")
|
||||||
|
|
||||||
|
# 3b. Process bboxes that appear 2 or more times
|
||||||
|
for bbox_key, positions in bbox_pairs.items():
|
||||||
|
if len(positions) >= 2:
|
||||||
|
# Keep the first occurrence, remove subsequent duplicates
|
||||||
|
positions_to_remove = positions[1:]
|
||||||
|
duplicates_to_remove.update(positions_to_remove)
|
||||||
|
|
||||||
|
print(f" 🔍 Found duplicate bbox: {list(bbox_key)}")
|
||||||
|
print(f" Count: {len(positions)}, removing at positions: {positions_to_remove}")
|
||||||
|
|
||||||
|
if not duplicates_to_remove:
|
||||||
|
print(f" ✅ No category-text pairs or bboxes found exceeding the duplication threshold")
|
||||||
|
return data_list
|
||||||
|
|
||||||
|
# 4. Remove duplicate items from the original data (preserving order)
|
||||||
|
cleaned_data = []
|
||||||
|
removed_count = 0
|
||||||
|
for i, item in enumerate(data_list):
|
||||||
|
if i not in duplicates_to_remove:
|
||||||
|
cleaned_data.append(item)
|
||||||
|
else:
|
||||||
|
removed_count += 1
|
||||||
|
|
||||||
|
print(f" ✅ Deduplication complete: Removed {removed_count} duplicate items")
|
||||||
|
print(f" 📊 Cleaned data length: {len(cleaned_data)}")
|
||||||
|
|
||||||
|
return cleaned_data
|
||||||
|
|
||||||
|
def clean_model_output(self, model_output: str):
|
||||||
|
try:
|
||||||
|
# Select cleaning method based on data type
|
||||||
|
if isinstance(model_output, list):
|
||||||
|
result = self.clean_list_data(model_output, case_id=0)
|
||||||
|
else:
|
||||||
|
result = self.clean_string_data(str(model_output), case_id=0)
|
||||||
|
|
||||||
|
# Add deduplication step: remove duplicate category-text pairs and bboxes
|
||||||
|
if result and hasattr(result, 'success') and result.success and result.cleaned_data:
|
||||||
|
original_data = result.cleaned_data
|
||||||
|
deduplicated_data = self.remove_duplicate_category_text_pairs_and_bbox(original_data, case_id=0)
|
||||||
|
# Update the cleaned_data in the CleanedData object
|
||||||
|
result.cleaned_data = deduplicated_data
|
||||||
|
return result.cleaned_data
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ Case cleaning failed: {e}")
|
||||||
|
return model_output
|
||||||
|
|
||||||
|
def clean_all_data(self, jsonl_path: str) -> List[CleanedData]:
|
||||||
|
"""Cleans all data from a JSONL file"""
|
||||||
|
|
||||||
|
print(f"🚀 Starting to clean JSONL file: {jsonl_path}")
|
||||||
|
|
||||||
|
with open(jsonl_path, 'r', encoding='utf-8') as f:
|
||||||
|
lines = f.readlines()
|
||||||
|
|
||||||
|
datas = []
|
||||||
|
for i, line in enumerate(lines):
|
||||||
|
if line.strip():
|
||||||
|
try:
|
||||||
|
data = json.loads(line)
|
||||||
|
predict_field = data.get('predict')
|
||||||
|
case_id = i + 1
|
||||||
|
|
||||||
|
print(f"\n{'='*50}")
|
||||||
|
print(f"🎯 Cleaning Case {case_id}")
|
||||||
|
print(f"{'='*50}")
|
||||||
|
|
||||||
|
# Select cleaning method based on data type
|
||||||
|
if isinstance(predict_field, list):
|
||||||
|
print("📊 Data type: List")
|
||||||
|
result = self.clean_list_data(predict_field, case_id)
|
||||||
|
else:
|
||||||
|
print("📊 Data type: String")
|
||||||
|
result = self.clean_string_data(str(predict_field), case_id)
|
||||||
|
|
||||||
|
# Add deduplication step: remove duplicate category-text pairs and bboxes
|
||||||
|
if result and hasattr(result, 'success') and result.success and result.cleaned_data:
|
||||||
|
print("🔄 Checking for and removing duplicate category-text pairs and bboxes...")
|
||||||
|
original_data = result.cleaned_data
|
||||||
|
deduplicated_data = self.remove_duplicate_category_text_pairs_and_bbox(original_data, case_id)
|
||||||
|
# Update the cleaned_data in the CleanedData object
|
||||||
|
result.cleaned_data = deduplicated_data
|
||||||
|
data['predict_resized'] = result.cleaned_data
|
||||||
|
|
||||||
|
datas.append(data)
|
||||||
|
self.cleaned_results.append(result)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ Case {i+1} cleaning failed: {e}")
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
save_path = jsonl_path.replace('.jsonl', '_filtered.jsonl')
|
||||||
|
with open(save_path, 'w') as w:
|
||||||
|
for data in datas:
|
||||||
|
w.write(json.dumps(data, ensure_ascii=False) + '\n')
|
||||||
|
print(f"✅ Saved cleaned data to: {save_path}")
|
||||||
|
|
||||||
|
return self.cleaned_results
|
||||||
|
|
||||||
|
def save_cleaned_data(self, output_dir: str):
|
||||||
|
"""Saves the cleaned data"""
|
||||||
|
|
||||||
|
print(f"\n💾 Saving cleaned data to: {output_dir}")
|
||||||
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# 1. Save cleaned data for each case
|
||||||
|
for result in self.cleaned_results:
|
||||||
|
case_filename = f"cleaned_case_{result.case_id:02d}.json"
|
||||||
|
case_filepath = os.path.join(output_dir, case_filename)
|
||||||
|
|
||||||
|
# Save the cleaned data
|
||||||
|
with open(case_filepath, 'w', encoding='utf-8') as f:
|
||||||
|
json.dump(result.cleaned_data, f, ensure_ascii=False, indent=2)
|
||||||
|
|
||||||
|
print(f" ✅ Case {result.case_id}: {len(result.cleaned_data)} objects → {case_filename}")
|
||||||
|
|
||||||
|
# 2. Save all cleaned data to a single file
|
||||||
|
all_cleaned_data = []
|
||||||
|
for result in self.cleaned_results:
|
||||||
|
all_cleaned_data.append({
|
||||||
|
'case_id': result.case_id,
|
||||||
|
'original_type': result.original_type,
|
||||||
|
'original_length': result.original_length,
|
||||||
|
'cleaned_objects_count': len(result.cleaned_data),
|
||||||
|
'success': result.success,
|
||||||
|
'cleaning_operations': result.cleaning_operations,
|
||||||
|
'cleaned_data': result.cleaned_data
|
||||||
|
})
|
||||||
|
|
||||||
|
all_data_filepath = os.path.join(output_dir, "all_cleaned_data.json")
|
||||||
|
with open(all_data_filepath, 'w', encoding='utf-8') as f:
|
||||||
|
json.dump(all_cleaned_data, f, ensure_ascii=False, indent=2)
|
||||||
|
|
||||||
|
print(f" 📁 All data: {len(all_cleaned_data)} cases → all_cleaned_data.json")
|
||||||
|
|
||||||
|
# 3. Generate a cleaning report
|
||||||
|
self._generate_cleaning_report(output_dir)
|
||||||
|
|
||||||
|
def _generate_cleaning_report(self, output_dir: str):
|
||||||
|
"""Generates a cleaning report"""
|
||||||
|
|
||||||
|
report = []
|
||||||
|
report.append("📊 Data Cleaning Report")
|
||||||
|
report.append("=" * 60)
|
||||||
|
import datetime
|
||||||
|
report.append(f"Processing Time: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
||||||
|
report.append("")
|
||||||
|
|
||||||
|
# Overall statistics
|
||||||
|
total_cases = len(self.cleaned_results)
|
||||||
|
successful_cases = sum(1 for r in self.cleaned_results if r.success)
|
||||||
|
total_objects = sum(len(r.cleaned_data) for r in self.cleaned_results)
|
||||||
|
|
||||||
|
report.append("📈 Overall Statistics:")
|
||||||
|
report.append(f" Total Cases: {total_cases}")
|
||||||
|
report.append(f" Successfully Cleaned: {successful_cases}")
|
||||||
|
report.append(f" Success Rate: {successful_cases/total_cases*100:.1f}%")
|
||||||
|
report.append(f" Total Recovered Objects: {total_objects}")
|
||||||
|
report.append("")
|
||||||
|
|
||||||
|
# Detailed statistics
|
||||||
|
list_results = [r for r in self.cleaned_results if r.original_type == 'list']
|
||||||
|
str_results = [r for r in self.cleaned_results if r.original_type == 'str']
|
||||||
|
|
||||||
|
if list_results:
|
||||||
|
report.append("📋 List Type Cleaning Statistics:")
|
||||||
|
for r in list_results:
|
||||||
|
ops = r.cleaning_operations
|
||||||
|
report.append(f" Case {r.case_id}: {ops['original_count']} → {ops['final_count']} objects")
|
||||||
|
if ops['bbox_fixes'] > 0:
|
||||||
|
report.append(f" - bbox fixes: {ops['bbox_fixes']}")
|
||||||
|
if ops['removed_items'] > 0:
|
||||||
|
report.append(f" - invalid items removed: {ops['removed_items']}")
|
||||||
|
report.append("")
|
||||||
|
|
||||||
|
if str_results:
|
||||||
|
report.append("📝 String Type Cleaning Statistics:")
|
||||||
|
for r in str_results:
|
||||||
|
ops = r.cleaning_operations
|
||||||
|
status = "✅" if r.success else "❌"
|
||||||
|
report.append(f" Case {r.case_id} {status}: {ops['original_length']:,} chars → {ops['final_objects']} objects")
|
||||||
|
details = []
|
||||||
|
if ops['delimiter_fixes'] > 0:
|
||||||
|
details.append(f"Delimiter fixes: {ops['delimiter_fixes']}")
|
||||||
|
if ops['tail_truncated']:
|
||||||
|
reduction = ops['original_length'] - ops['truncated_length']
|
||||||
|
details.append(f"Tail truncation: -{reduction:,} chars")
|
||||||
|
if ops['duplicate_dicts_removed'] > 0:
|
||||||
|
details.append(f"Duplicates removed: {ops['duplicate_dicts_removed']}")
|
||||||
|
if details:
|
||||||
|
report.append(f" - {', '.join(details)}")
|
||||||
|
report.append("")
|
||||||
|
|
||||||
|
# Note on data order
|
||||||
|
report.append("🔄 Data Order Guarantee:")
|
||||||
|
report.append(" ✅ The relative order of all dict objects is preserved during cleaning.")
|
||||||
|
report.append(" ✅ When deduplicating, the first occurrence of a dict is kept, and subsequent duplicates are removed.")
|
||||||
|
report.append(" ✅ The order of items in List-type data is fully preserved.")
|
||||||
|
|
||||||
|
# Save the report
|
||||||
|
report_filepath = os.path.join(output_dir, "cleaning_report.txt")
|
||||||
|
with open(report_filepath, 'w', encoding='utf-8') as f:
|
||||||
|
f.write('\n'.join(report))
|
||||||
|
|
||||||
|
print(f" 📋 Cleaning report: cleaning_report.txt")
|
||||||
|
|
||||||
|
# Also print to console
|
||||||
|
print(f"\n{chr(10).join(report)}")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Main function"""
|
||||||
|
|
||||||
|
# Create a data cleaner instance
|
||||||
|
cleaner = OutputCleaner()
|
||||||
|
|
||||||
|
# Input file
|
||||||
|
jsonl_path = "output_with_failcase.jsonl"
|
||||||
|
|
||||||
|
# Output directory
|
||||||
|
output_dir = "output_with_failcase_cleaned"
|
||||||
|
|
||||||
|
# Clean all data
|
||||||
|
results = cleaner.clean_all_data(jsonl_path)
|
||||||
|
|
||||||
|
# Save the cleaned data
|
||||||
|
cleaner.save_cleaned_data(output_dir)
|
||||||
|
|
||||||
|
print(f"\n🎉 Data cleaning complete!")
|
||||||
|
print(f"📁 Cleaned data saved in: {output_dir}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -0,0 +1,34 @@
|
|||||||
|
dict_promptmode_to_prompt = {
|
||||||
|
# prompt_layout_all_en: parse all layout info in json format.
|
||||||
|
"prompt_layout_all_en": """Please output the layout information from the PDF image, including each layout element's bbox, its category, and the corresponding text content within the bbox.
|
||||||
|
|
||||||
|
1. Bbox format: [x1, y1, x2, y2]
|
||||||
|
|
||||||
|
2. Layout Categories: The possible categories are ['Caption', 'Footnote', 'Formula', 'List-item', 'Page-footer', 'Page-header', 'Picture', 'Section-header', 'Table', 'Text', 'Title'].
|
||||||
|
|
||||||
|
3. Text Extraction & Formatting Rules:
|
||||||
|
- Picture: For the 'Picture' category, the text field should be omitted.
|
||||||
|
- Formula: Format its text as LaTeX.
|
||||||
|
- Table: Format its text as HTML.
|
||||||
|
- All Others (Text, Title, etc.): Format their text as Markdown.
|
||||||
|
|
||||||
|
4. Constraints:
|
||||||
|
- The output text must be the original text from the image, with no translation.
|
||||||
|
- All layout elements must be sorted according to human reading order.
|
||||||
|
|
||||||
|
5. Final Output: The entire output must be a single JSON object.
|
||||||
|
""",
|
||||||
|
|
||||||
|
# prompt_layout_only_en: layout detection
|
||||||
|
"prompt_layout_only_en": """Please output the layout information from this PDF image, including each layout's bbox and its category. The bbox should be in the format [x1, y1, x2, y2]. The layout categories for the PDF document include ['Caption', 'Footnote', 'Formula', 'List-item', 'Page-footer', 'Page-header', 'Picture', 'Section-header', 'Table', 'Text', 'Title']. Do not output the corresponding text. The layout result should be in JSON format.""",
|
||||||
|
|
||||||
|
# prompt_layout_only_en: parse ocr text except the Page-header and Page-footer
|
||||||
|
"prompt_ocr": """Extract the text content from this image.""",
|
||||||
|
|
||||||
|
# prompt_grounding_ocr: extract text content in the given bounding box
|
||||||
|
"prompt_grounding_ocr": """Extract text from the given bounding box on the image (format: [x1, y1, x2, y2]).\nBounding Box:\n""",
|
||||||
|
|
||||||
|
# "prompt_table_html": """Convert the table in this image to HTML.""",
|
||||||
|
# "prompt_table_latex": """Convert the table in this image to LaTeX.""",
|
||||||
|
# "prompt_formula_latex": """Convert the formula in this image to LaTeX.""",
|
||||||
|
}
|
||||||
@@ -0,0 +1,11 @@
|
|||||||
|
# 生产环境依赖
|
||||||
|
# streamlit
|
||||||
|
gradio
|
||||||
|
gradio_image_annotation
|
||||||
|
PyMuPDF
|
||||||
|
openai
|
||||||
|
qwen_vl_utils
|
||||||
|
transformers==4.51.3
|
||||||
|
huggingface_hub
|
||||||
|
flash-attn==2.8.0.post2
|
||||||
|
accelerate
|
||||||
@@ -0,0 +1,17 @@
|
|||||||
|
from setuptools import setup, find_packages
|
||||||
|
|
||||||
|
# 从requirements.txt文件读取依赖
|
||||||
|
def parse_requirements(filename):
|
||||||
|
with open(filename, 'r', encoding='utf-8') as f:
|
||||||
|
return f.read().splitlines()
|
||||||
|
|
||||||
|
setup(
|
||||||
|
name='dots_ocr',
|
||||||
|
version='1.0',
|
||||||
|
packages=find_packages(),
|
||||||
|
include_package_data=True,
|
||||||
|
install_requires=parse_requirements('requirements.txt'),
|
||||||
|
description='dots.ocr: Multilingual Document Layout Parsing in one Vision-Language Model',
|
||||||
|
url="https://github.com/rednote-hilab/dots.ocr",
|
||||||
|
python_requires=">=3.10",
|
||||||
|
)
|
||||||
@@ -0,0 +1,19 @@
|
|||||||
|
from argparse import ArgumentParser
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = ArgumentParser()
|
||||||
|
parser.add_argument('--type', '-t', type=str, default="huggingface")
|
||||||
|
parser.add_argument('--name', '-n', type=str, default="rednote-hilab/dots.ocr")
|
||||||
|
args = parser.parse_args()
|
||||||
|
script_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||||
|
print(f"Attention: The model save dir dots.ocr should be replace by a name without `.` like DotsOCR, util we merge our code to transformers.")
|
||||||
|
model_dir = os.path.join(script_dir, "weights/DotsOCR")
|
||||||
|
if not os.path.exists(model_dir):
|
||||||
|
os.makedirs(model_dir)
|
||||||
|
if args.type == "huggingface":
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
snapshot_download(repo_id=args.name, local_dir=model_dir, local_dir_use_symlinks=False, resume_download=True)
|
||||||
|
|
||||||
|
print(f"model downloaded to {model_dir}")
|
||||||