dots.ocr release
This commit is contained in:
Executable
+948
@@ -0,0 +1,948 @@
|
||||
"""
|
||||
Layout Inference Web Application with Gradio
|
||||
|
||||
A Gradio-based layout inference tool that supports image uploads and multiple backend inference engines.
|
||||
It adopts a reference-style interface design while preserving the original inference logic.
|
||||
"""
|
||||
|
||||
import gradio as gr
|
||||
import json
|
||||
import os
|
||||
import io
|
||||
import tempfile
|
||||
import base64
|
||||
import zipfile
|
||||
import uuid
|
||||
import re
|
||||
from pathlib import Path
|
||||
from PIL import Image
|
||||
import requests
|
||||
|
||||
# Local tool imports
|
||||
from dots_ocr.utils import dict_promptmode_to_prompt
|
||||
from dots_ocr.utils.consts import MIN_PIXELS, MAX_PIXELS
|
||||
from dots_ocr.utils.demo_utils.display import read_image
|
||||
from dots_ocr.utils.doc_utils import load_images_from_pdf
|
||||
|
||||
# Add DotsOCRParser import
|
||||
from dots_ocr.parser import DotsOCRParser
|
||||
|
||||
|
||||
# ==================== Configuration ====================
|
||||
DEFAULT_CONFIG = {
|
||||
'ip': "127.0.0.1",
|
||||
'port_vllm': 8000,
|
||||
'min_pixels': MIN_PIXELS,
|
||||
'max_pixels': MAX_PIXELS,
|
||||
'test_images_dir': "./assets/showcase_origin",
|
||||
}
|
||||
|
||||
# ==================== Global Variables ====================
|
||||
# Store current configuration
|
||||
current_config = DEFAULT_CONFIG.copy()
|
||||
|
||||
# Create DotsOCRParser instance
|
||||
dots_parser = DotsOCRParser(
|
||||
ip=DEFAULT_CONFIG['ip'],
|
||||
port=DEFAULT_CONFIG['port_vllm'],
|
||||
dpi=200,
|
||||
min_pixels=DEFAULT_CONFIG['min_pixels'],
|
||||
max_pixels=DEFAULT_CONFIG['max_pixels']
|
||||
)
|
||||
|
||||
# Store processing results
|
||||
processing_results = {
|
||||
'original_image': None,
|
||||
'processed_image': None,
|
||||
'layout_result': None,
|
||||
'markdown_content': None,
|
||||
'cells_data': None,
|
||||
'temp_dir': None,
|
||||
'session_id': None,
|
||||
'result_paths': None,
|
||||
'pdf_results': None # Store multi-page PDF results
|
||||
}
|
||||
|
||||
# PDF caching mechanism
|
||||
pdf_cache = {
|
||||
"images": [],
|
||||
"current_page": 0,
|
||||
"total_pages": 0,
|
||||
"file_type": None, # 'image' or 'pdf'
|
||||
"is_parsed": False, # Whether it has been parsed
|
||||
"results": [] # Store parsing results for each page
|
||||
}
|
||||
|
||||
def read_image_v2(img):
|
||||
"""Reads an image, supports URLs and local paths"""
|
||||
if isinstance(img, str) and img.startswith(("http://", "https://")):
|
||||
with requests.get(img, stream=True) as response:
|
||||
response.raise_for_status()
|
||||
img = Image.open(io.BytesIO(response.content))
|
||||
elif isinstance(img, str):
|
||||
img, _, _ = read_image(img, use_native=True)
|
||||
elif isinstance(img, Image.Image):
|
||||
pass
|
||||
else:
|
||||
raise ValueError(f"Invalid image type: {type(img)}")
|
||||
return img
|
||||
|
||||
def load_file_for_preview(file_path):
|
||||
"""Loads a file for preview, supports PDF and image files"""
|
||||
global pdf_cache
|
||||
|
||||
if not file_path or not os.path.exists(file_path):
|
||||
return None, "<div id='page_info_box'>0 / 0</div>"
|
||||
|
||||
file_ext = os.path.splitext(file_path)[1].lower()
|
||||
|
||||
if file_ext == '.pdf':
|
||||
try:
|
||||
# Read PDF and convert to images (one image per page)
|
||||
pages = load_images_from_pdf(file_path)
|
||||
pdf_cache["file_type"] = "pdf"
|
||||
except Exception as e:
|
||||
return None, f"<div id='page_info_box'>PDF loading failed: {str(e)}</div>"
|
||||
elif file_ext in ['.jpg', '.jpeg', '.png']:
|
||||
# For image files, read directly as a single-page image
|
||||
try:
|
||||
image = Image.open(file_path)
|
||||
pages = [image]
|
||||
pdf_cache["file_type"] = "image"
|
||||
except Exception as e:
|
||||
return None, f"<div id='page_info_box'>Image loading failed: {str(e)}</div>"
|
||||
else:
|
||||
return None, "<div id='page_info_box'>Unsupported file format</div>"
|
||||
|
||||
pdf_cache["images"] = pages
|
||||
pdf_cache["current_page"] = 0
|
||||
pdf_cache["total_pages"] = len(pages)
|
||||
pdf_cache["is_parsed"] = False
|
||||
pdf_cache["results"] = []
|
||||
|
||||
return pages[0], f"<div id='page_info_box'>1 / {len(pages)}</div>"
|
||||
|
||||
def turn_page(direction):
|
||||
"""Page turning function"""
|
||||
global pdf_cache
|
||||
|
||||
if not pdf_cache["images"]:
|
||||
return None, "<div id='page_info_box'>0 / 0</div>", "", ""
|
||||
|
||||
if direction == "prev":
|
||||
pdf_cache["current_page"] = max(0, pdf_cache["current_page"] - 1)
|
||||
elif direction == "next":
|
||||
pdf_cache["current_page"] = min(pdf_cache["total_pages"] - 1, pdf_cache["current_page"] + 1)
|
||||
|
||||
index = pdf_cache["current_page"]
|
||||
current_image = pdf_cache["images"][index] # Use the original image by default
|
||||
page_info = f"<div id='page_info_box'>{index + 1} / {pdf_cache['total_pages']}</div>"
|
||||
|
||||
# If parsed, display the results for the current page
|
||||
current_md = ""
|
||||
current_md_raw = ""
|
||||
current_json = ""
|
||||
if pdf_cache["is_parsed"] and index < len(pdf_cache["results"]):
|
||||
result = pdf_cache["results"][index]
|
||||
if 'md_content' in result:
|
||||
# Get the raw markdown content
|
||||
current_md_raw = result['md_content']
|
||||
# Process the content after LaTeX rendering
|
||||
current_md = result['md_content'] if result['md_content'] else ""
|
||||
if 'cells_data' in result:
|
||||
try:
|
||||
current_json = json.dumps(result['cells_data'], ensure_ascii=False, indent=2)
|
||||
except:
|
||||
current_json = str(result.get('cells_data', ''))
|
||||
# Use the image with layout boxes (if available)
|
||||
if 'layout_image' in result and result['layout_image']:
|
||||
current_image = result['layout_image']
|
||||
|
||||
return current_image, page_info, current_json
|
||||
|
||||
def get_test_images():
|
||||
"""Gets the list of test images"""
|
||||
test_images = []
|
||||
test_dir = current_config['test_images_dir']
|
||||
if os.path.exists(test_dir):
|
||||
test_images = [os.path.join(test_dir, name) for name in os.listdir(test_dir)
|
||||
if name.lower().endswith(('.png', '.jpg', '.jpeg', '.pdf'))]
|
||||
return test_images
|
||||
|
||||
def convert_image_to_base64(image):
|
||||
"""Converts a PIL image to base64 encoding"""
|
||||
buffered = io.BytesIO()
|
||||
image.save(buffered, format="PNG")
|
||||
img_str = base64.b64encode(buffered.getvalue()).decode()
|
||||
return f"data:image/png;base64,{img_str}"
|
||||
|
||||
def create_temp_session_dir():
|
||||
"""Creates a unique temporary directory for each processing request"""
|
||||
session_id = uuid.uuid4().hex[:8]
|
||||
temp_dir = os.path.join(tempfile.gettempdir(), f"dots_ocr_demo_{session_id}")
|
||||
os.makedirs(temp_dir, exist_ok=True)
|
||||
return temp_dir, session_id
|
||||
|
||||
def parse_image_with_high_level_api(parser, image, prompt_mode, fitz_preprocess=False):
|
||||
"""
|
||||
Processes using the high-level API parse_image from DotsOCRParser
|
||||
"""
|
||||
# Create a temporary session directory
|
||||
temp_dir, session_id = create_temp_session_dir()
|
||||
|
||||
try:
|
||||
# Save the PIL Image as a temporary file
|
||||
temp_image_path = os.path.join(temp_dir, f"input_{session_id}.png")
|
||||
image.save(temp_image_path, "PNG")
|
||||
|
||||
# Use the high-level API parse_image
|
||||
filename = f"demo_{session_id}"
|
||||
results = parser.parse_image(
|
||||
# input_path=temp_image_path,
|
||||
input_path=image,
|
||||
filename=filename,
|
||||
prompt_mode=prompt_mode,
|
||||
save_dir=temp_dir,
|
||||
fitz_preprocess=fitz_preprocess
|
||||
)
|
||||
|
||||
# Parse the results
|
||||
if not results:
|
||||
raise ValueError("No results returned from parser")
|
||||
|
||||
result = results[0] # parse_image returns a list with a single result
|
||||
|
||||
# Read the result files
|
||||
layout_image = None
|
||||
cells_data = None
|
||||
md_content = None
|
||||
raw_response = None
|
||||
filtered = False
|
||||
|
||||
# Read the layout image
|
||||
if 'layout_image_path' in result and os.path.exists(result['layout_image_path']):
|
||||
layout_image = Image.open(result['layout_image_path'])
|
||||
|
||||
# Read the JSON data
|
||||
if 'layout_info_path' in result and os.path.exists(result['layout_info_path']):
|
||||
with open(result['layout_info_path'], 'r', encoding='utf-8') as f:
|
||||
cells_data = json.load(f)
|
||||
|
||||
# Read the Markdown content
|
||||
if 'md_content_path' in result and os.path.exists(result['md_content_path']):
|
||||
with open(result['md_content_path'], 'r', encoding='utf-8') as f:
|
||||
md_content = f.read()
|
||||
|
||||
# Check for the raw response file (when JSON parsing fails)
|
||||
if 'filtered' in result:
|
||||
filtered = result['filtered']
|
||||
|
||||
return {
|
||||
'layout_image': layout_image,
|
||||
'cells_data': cells_data,
|
||||
'md_content': md_content,
|
||||
'filtered': filtered,
|
||||
'temp_dir': temp_dir,
|
||||
'session_id': session_id,
|
||||
'result_paths': result,
|
||||
'input_width': result['input_width'],
|
||||
'input_height': result['input_height'],
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
# Clean up the temporary directory on error
|
||||
import shutil
|
||||
if os.path.exists(temp_dir):
|
||||
shutil.rmtree(temp_dir, ignore_errors=True)
|
||||
raise e
|
||||
|
||||
def parse_pdf_with_high_level_api(parser, pdf_path, prompt_mode):
|
||||
"""
|
||||
Processes using the high-level API parse_pdf from DotsOCRParser
|
||||
"""
|
||||
# Create a temporary session directory
|
||||
temp_dir, session_id = create_temp_session_dir()
|
||||
|
||||
try:
|
||||
# Use the high-level API parse_pdf
|
||||
filename = f"demo_{session_id}"
|
||||
results = parser.parse_pdf(
|
||||
input_path=pdf_path,
|
||||
filename=filename,
|
||||
prompt_mode=prompt_mode,
|
||||
save_dir=temp_dir
|
||||
)
|
||||
|
||||
# Parse the results
|
||||
if not results:
|
||||
raise ValueError("No results returned from parser")
|
||||
|
||||
# Handle multi-page results
|
||||
parsed_results = []
|
||||
all_md_content = []
|
||||
all_cells_data = []
|
||||
|
||||
for i, result in enumerate(results):
|
||||
page_result = {
|
||||
'page_no': result.get('page_no', i),
|
||||
'layout_image': None,
|
||||
'cells_data': None,
|
||||
'md_content': None,
|
||||
'filtered': False
|
||||
}
|
||||
|
||||
# Read the layout image
|
||||
if 'layout_image_path' in result and os.path.exists(result['layout_image_path']):
|
||||
page_result['layout_image'] = Image.open(result['layout_image_path'])
|
||||
|
||||
# Read the JSON data
|
||||
if 'layout_info_path' in result and os.path.exists(result['layout_info_path']):
|
||||
with open(result['layout_info_path'], 'r', encoding='utf-8') as f:
|
||||
page_result['cells_data'] = json.load(f)
|
||||
all_cells_data.extend(page_result['cells_data'])
|
||||
|
||||
# Read the Markdown content
|
||||
if 'md_content_path' in result and os.path.exists(result['md_content_path']):
|
||||
with open(result['md_content_path'], 'r', encoding='utf-8') as f:
|
||||
page_content = f.read()
|
||||
page_result['md_content'] = page_content
|
||||
all_md_content.append(page_content)
|
||||
|
||||
# Check for the raw response file (when JSON parsing fails)
|
||||
page_result['filtered'] = False
|
||||
if 'filtered' in page_result:
|
||||
page_result['filtered'] = page_result['filtered']
|
||||
|
||||
parsed_results.append(page_result)
|
||||
|
||||
# Merge the content of all pages
|
||||
combined_md = "\n\n---\n\n".join(all_md_content) if all_md_content else ""
|
||||
|
||||
return {
|
||||
'parsed_results': parsed_results,
|
||||
'combined_md_content': combined_md,
|
||||
'combined_cells_data': all_cells_data,
|
||||
'temp_dir': temp_dir,
|
||||
'session_id': session_id,
|
||||
'total_pages': len(results)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
# Clean up the temporary directory on error
|
||||
import shutil
|
||||
if os.path.exists(temp_dir):
|
||||
shutil.rmtree(temp_dir, ignore_errors=True)
|
||||
raise e
|
||||
|
||||
# ==================== Core Processing Function ====================
|
||||
def process_image_inference(test_image_input, file_input,
|
||||
prompt_mode, server_ip, server_port, min_pixels, max_pixels,
|
||||
fitz_preprocess=False
|
||||
):
|
||||
"""Core function to handle image/PDF inference"""
|
||||
global current_config, processing_results, dots_parser, pdf_cache
|
||||
|
||||
# First, clean up previous processing results to avoid confusion with the download button
|
||||
if processing_results.get('temp_dir') and os.path.exists(processing_results['temp_dir']):
|
||||
import shutil
|
||||
try:
|
||||
shutil.rmtree(processing_results['temp_dir'], ignore_errors=True)
|
||||
except Exception as e:
|
||||
print(f"Failed to clean up previous temporary directory: {e}")
|
||||
|
||||
# Reset processing results
|
||||
processing_results = {
|
||||
'original_image': None,
|
||||
'processed_image': None,
|
||||
'layout_result': None,
|
||||
'markdown_content': None,
|
||||
'cells_data': None,
|
||||
'temp_dir': None,
|
||||
'session_id': None,
|
||||
'result_paths': None,
|
||||
'pdf_results': None
|
||||
}
|
||||
|
||||
# Update configuration
|
||||
current_config.update({
|
||||
'ip': server_ip,
|
||||
'port_vllm': server_port,
|
||||
'min_pixels': min_pixels,
|
||||
'max_pixels': max_pixels
|
||||
})
|
||||
|
||||
# Update parser configuration
|
||||
dots_parser.ip = server_ip
|
||||
dots_parser.port = server_port
|
||||
dots_parser.min_pixels = min_pixels
|
||||
dots_parser.max_pixels = max_pixels
|
||||
|
||||
# Determine the input source
|
||||
input_file_path = None
|
||||
image = None
|
||||
|
||||
# Prioritize file input (supports PDF)
|
||||
if file_input is not None:
|
||||
input_file_path = file_input
|
||||
file_ext = os.path.splitext(input_file_path)[1].lower()
|
||||
|
||||
if file_ext == '.pdf':
|
||||
# PDF file processing
|
||||
try:
|
||||
return process_pdf_file(input_file_path, prompt_mode)
|
||||
except Exception as e:
|
||||
return None, f"PDF processing failed: {e}", "", "", gr.update(value=None), None, ""
|
||||
elif file_ext in ['.jpg', '.jpeg', '.png']:
|
||||
# Image file processing
|
||||
try:
|
||||
image = Image.open(input_file_path)
|
||||
except Exception as e:
|
||||
return None, f"Failed to read image file: {e}", "", "", gr.update(value=None), None, ""
|
||||
|
||||
# If no file input, check the test image input
|
||||
if image is None:
|
||||
if test_image_input and test_image_input != "":
|
||||
file_ext = os.path.splitext(test_image_input)[1].lower()
|
||||
if file_ext == '.pdf':
|
||||
return process_pdf_file(test_image_input, prompt_mode)
|
||||
else:
|
||||
try:
|
||||
image = read_image_v2(test_image_input)
|
||||
except Exception as e:
|
||||
return None, f"Failed to read test image: {e}", "", "", gr.update(value=None), gr.update(value=None), None, ""
|
||||
|
||||
if image is None:
|
||||
return None, "Please upload image/PDF file or select test image", "", "", gr.update(value=None), None, ""
|
||||
|
||||
try:
|
||||
# Clear PDF cache (for image processing)
|
||||
pdf_cache["images"] = []
|
||||
pdf_cache["current_page"] = 0
|
||||
pdf_cache["total_pages"] = 0
|
||||
pdf_cache["is_parsed"] = False
|
||||
pdf_cache["results"] = []
|
||||
|
||||
# Process using the high-level API of DotsOCRParser
|
||||
original_image = image
|
||||
parse_result = parse_image_with_high_level_api(dots_parser, image, prompt_mode, fitz_preprocess)
|
||||
|
||||
# Extract parsing results
|
||||
layout_image = parse_result['layout_image']
|
||||
cells_data = parse_result['cells_data']
|
||||
md_content = parse_result['md_content']
|
||||
filtered = parse_result['filtered']
|
||||
|
||||
# Handle parsing failure case
|
||||
if filtered:
|
||||
# JSON parsing failed, only text content is available
|
||||
info_text = f"""
|
||||
**Image Information:**
|
||||
- Original Size: {original_image.width} x {original_image.height}
|
||||
- Processing: JSON parsing failed, using cleaned text output
|
||||
- Server: {current_config['ip']}:{current_config['port_vllm']}
|
||||
- Session ID: {parse_result['session_id']}
|
||||
"""
|
||||
|
||||
# Store results
|
||||
processing_results.update({
|
||||
'original_image': original_image,
|
||||
'processed_image': None,
|
||||
'layout_result': None,
|
||||
'markdown_content': md_content,
|
||||
'cells_data': None,
|
||||
'temp_dir': parse_result['temp_dir'],
|
||||
'session_id': parse_result['session_id'],
|
||||
'result_paths': parse_result['result_paths']
|
||||
})
|
||||
|
||||
return (
|
||||
original_image, # No layout image
|
||||
info_text,
|
||||
md_content,
|
||||
md_content, # Display raw markdown text
|
||||
gr.update(visible=False), # Hide download button
|
||||
None, # Page info
|
||||
"" # Current page JSON output
|
||||
)
|
||||
|
||||
# JSON parsing successful case
|
||||
# Save the raw markdown content (before LaTeX processing)
|
||||
md_content_raw = md_content or "No markdown content generated"
|
||||
|
||||
# Store results
|
||||
processing_results.update({
|
||||
'original_image': original_image,
|
||||
'processed_image': None, # High-level API does not return processed_image
|
||||
'layout_result': layout_image,
|
||||
'markdown_content': md_content,
|
||||
'cells_data': cells_data,
|
||||
'temp_dir': parse_result['temp_dir'],
|
||||
'session_id': parse_result['session_id'],
|
||||
'result_paths': parse_result['result_paths']
|
||||
})
|
||||
|
||||
# Prepare display information
|
||||
num_elements = len(cells_data) if cells_data else 0
|
||||
info_text = f"""
|
||||
**Image Information:**
|
||||
- Original Size: {original_image.width} x {original_image.height}
|
||||
- Model Input Size: {parse_result['input_width']} x {parse_result['input_height']}
|
||||
- Server: {current_config['ip']}:{current_config['port_vllm']}
|
||||
- Detected {num_elements} layout elements
|
||||
- Session ID: {parse_result['session_id']}
|
||||
"""
|
||||
|
||||
# Current page JSON output
|
||||
current_json = ""
|
||||
if cells_data:
|
||||
try:
|
||||
current_json = json.dumps(cells_data, ensure_ascii=False, indent=2)
|
||||
except:
|
||||
current_json = str(cells_data)
|
||||
|
||||
# Create the download ZIP file
|
||||
download_zip_path = None
|
||||
if parse_result['temp_dir']:
|
||||
download_zip_path = os.path.join(parse_result['temp_dir'], f"layout_results_{parse_result['session_id']}.zip")
|
||||
try:
|
||||
with zipfile.ZipFile(download_zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
|
||||
for root, dirs, files in os.walk(parse_result['temp_dir']):
|
||||
for file in files:
|
||||
if file.endswith('.zip'):
|
||||
continue
|
||||
file_path = os.path.join(root, file)
|
||||
arcname = os.path.relpath(file_path, parse_result['temp_dir'])
|
||||
zipf.write(file_path, arcname)
|
||||
except Exception as e:
|
||||
print(f"Failed to create download ZIP: {e}")
|
||||
download_zip_path = None
|
||||
|
||||
return (
|
||||
layout_image,
|
||||
info_text,
|
||||
md_content or "No markdown content generated",
|
||||
md_content_raw, # Raw markdown text
|
||||
gr.update(value=download_zip_path, visible=True) if download_zip_path else gr.update(visible=False), # Set the download file
|
||||
None, # Page info (not displayed for image processing)
|
||||
current_json # Current page JSON
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
return None, f"Error during processing: {e}", "", "", gr.update(value=None), None, ""
|
||||
|
||||
def process_pdf_file(pdf_path, prompt_mode):
|
||||
"""Dedicated function for processing PDF files"""
|
||||
global pdf_cache, processing_results, dots_parser
|
||||
|
||||
try:
|
||||
# First, load the PDF for preview
|
||||
preview_image, page_info = load_file_for_preview(pdf_path)
|
||||
|
||||
# Parse the PDF using DotsOCRParser
|
||||
pdf_result = parse_pdf_with_high_level_api(dots_parser, pdf_path, prompt_mode)
|
||||
|
||||
# Update the PDF cache
|
||||
pdf_cache["is_parsed"] = True
|
||||
pdf_cache["results"] = pdf_result['parsed_results']
|
||||
|
||||
# Handle LaTeX table rendering
|
||||
combined_md = pdf_result['combined_md_content']
|
||||
combined_md_raw = combined_md or "No markdown content generated" # Save the raw content
|
||||
|
||||
# Store results
|
||||
processing_results.update({
|
||||
'original_image': None,
|
||||
'processed_image': None,
|
||||
'layout_result': None,
|
||||
'markdown_content': combined_md,
|
||||
'cells_data': pdf_result['combined_cells_data'],
|
||||
'temp_dir': pdf_result['temp_dir'],
|
||||
'session_id': pdf_result['session_id'],
|
||||
'result_paths': None,
|
||||
'pdf_results': pdf_result['parsed_results']
|
||||
})
|
||||
|
||||
# Prepare display information
|
||||
total_elements = len(pdf_result['combined_cells_data'])
|
||||
info_text = f"""
|
||||
**PDF Information:**
|
||||
- Total Pages: {pdf_result['total_pages']}
|
||||
- Server: {current_config['ip']}:{current_config['port_vllm']}
|
||||
- Total Detected Elements: {total_elements}
|
||||
- Session ID: {pdf_result['session_id']}
|
||||
"""
|
||||
|
||||
# Content of the current page (first page)
|
||||
current_page_md = ""
|
||||
current_page_md_raw = ""
|
||||
current_page_json = ""
|
||||
current_page_layout_image = preview_image # Use the original preview image by default
|
||||
|
||||
if pdf_cache["results"] and len(pdf_cache["results"]) > 0:
|
||||
current_result = pdf_cache["results"][0]
|
||||
if current_result['md_content']:
|
||||
# Raw markdown content
|
||||
current_page_md_raw = current_result['md_content']
|
||||
# Process the content after LaTeX rendering
|
||||
|
||||
current_page_md = current_result['md_content']
|
||||
if current_result['cells_data']:
|
||||
try:
|
||||
current_page_json = json.dumps(current_result['cells_data'], ensure_ascii=False, indent=2)
|
||||
except:
|
||||
current_page_json = str(current_result['cells_data'])
|
||||
# Use the image with layout boxes (if available)
|
||||
if 'layout_image' in current_result and current_result['layout_image']:
|
||||
current_page_layout_image = current_result['layout_image']
|
||||
|
||||
# Create the download ZIP file
|
||||
download_zip_path = None
|
||||
if pdf_result['temp_dir']:
|
||||
download_zip_path = os.path.join(pdf_result['temp_dir'], f"layout_results_{pdf_result['session_id']}.zip")
|
||||
try:
|
||||
with zipfile.ZipFile(download_zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
|
||||
for root, dirs, files in os.walk(pdf_result['temp_dir']):
|
||||
for file in files:
|
||||
if file.endswith('.zip'):
|
||||
continue
|
||||
file_path = os.path.join(root, file)
|
||||
arcname = os.path.relpath(file_path, pdf_result['temp_dir'])
|
||||
zipf.write(file_path, arcname)
|
||||
except Exception as e:
|
||||
print(f"Failed to create download ZIP: {e}")
|
||||
download_zip_path = None
|
||||
|
||||
return (
|
||||
current_page_layout_image, # Use the image with layout boxes
|
||||
info_text,
|
||||
combined_md or "No markdown content generated", # Display the markdown for the entire PDF
|
||||
combined_md_raw or "No markdown content generated", # Display the raw markdown for the entire PDF
|
||||
gr.update(value=download_zip_path, visible=True) if download_zip_path else gr.update(visible=False), # Set the download file
|
||||
page_info,
|
||||
current_page_json
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
# Reset the PDF cache
|
||||
pdf_cache["images"] = []
|
||||
pdf_cache["current_page"] = 0
|
||||
pdf_cache["total_pages"] = 0
|
||||
pdf_cache["is_parsed"] = False
|
||||
pdf_cache["results"] = []
|
||||
raise e
|
||||
|
||||
def clear_all_data():
|
||||
"""Clears all data"""
|
||||
global processing_results, pdf_cache
|
||||
|
||||
# Clean up the temporary directory
|
||||
if processing_results.get('temp_dir') and os.path.exists(processing_results['temp_dir']):
|
||||
import shutil
|
||||
try:
|
||||
shutil.rmtree(processing_results['temp_dir'], ignore_errors=True)
|
||||
except Exception as e:
|
||||
print(f"Failed to clean up temporary directory: {e}")
|
||||
|
||||
# Reset processing results
|
||||
processing_results = {
|
||||
'original_image': None,
|
||||
'processed_image': None,
|
||||
'layout_result': None,
|
||||
'markdown_content': None,
|
||||
'cells_data': None,
|
||||
'temp_dir': None,
|
||||
'session_id': None,
|
||||
'result_paths': None,
|
||||
'pdf_results': None
|
||||
}
|
||||
|
||||
# Reset the PDF cache
|
||||
pdf_cache = {
|
||||
"images": [],
|
||||
"current_page": 0,
|
||||
"total_pages": 0,
|
||||
"file_type": None,
|
||||
"is_parsed": False,
|
||||
"results": []
|
||||
}
|
||||
|
||||
return (
|
||||
None, # Clear file input
|
||||
"", # Clear test image selection
|
||||
None, # Clear result image
|
||||
"Waiting for processing results...", # Reset info display
|
||||
"## Waiting for processing results...", # Reset Markdown display
|
||||
"🕐 Waiting for parsing result...", # Clear raw Markdown text
|
||||
gr.update(visible=False), # Hide download button
|
||||
"<div id='page_info_box'>0 / 0</div>", # Reset page info
|
||||
"🕐 Waiting for parsing result..." # Clear current page JSON
|
||||
)
|
||||
|
||||
def update_prompt_display(prompt_mode):
|
||||
"""Updates the prompt display content"""
|
||||
return dict_promptmode_to_prompt[prompt_mode]
|
||||
|
||||
# ==================== Gradio Interface ====================
|
||||
def create_gradio_interface():
|
||||
"""Creates the Gradio interface"""
|
||||
|
||||
# CSS styles, matching the reference style
|
||||
css = """
|
||||
|
||||
#parse_button {
|
||||
background: #FF576D !important; /* !important 确保覆盖主题默认样式 */
|
||||
border-color: #FF576D !important;
|
||||
}
|
||||
/* 鼠标悬停时的颜色 */
|
||||
#parse_button:hover {
|
||||
background: #F72C49 !important;
|
||||
border-color: #F72C49 !important;
|
||||
}
|
||||
|
||||
#page_info_html {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
height: 100%;
|
||||
margin: 0 12px;
|
||||
}
|
||||
|
||||
#page_info_box {
|
||||
padding: 8px 20px;
|
||||
font-size: 16px;
|
||||
border: 1px solid #bbb;
|
||||
border-radius: 8px;
|
||||
background-color: #f8f8f8;
|
||||
text-align: center;
|
||||
min-width: 80px;
|
||||
box-shadow: 0 1px 3px rgba(0,0,0,0.1);
|
||||
}
|
||||
|
||||
#markdown_output {
|
||||
min-height: 800px;
|
||||
overflow: auto;
|
||||
}
|
||||
|
||||
footer {
|
||||
visibility: hidden;
|
||||
}
|
||||
|
||||
#info_box {
|
||||
padding: 10px;
|
||||
background-color: #f8f9fa;
|
||||
border-radius: 8px;
|
||||
border: 1px solid #dee2e6;
|
||||
margin: 10px 0;
|
||||
font-size: 14px;
|
||||
}
|
||||
|
||||
#result_image {
|
||||
border-radius: 8px;
|
||||
}
|
||||
|
||||
#markdown_tabs {
|
||||
height: 100%;
|
||||
}
|
||||
"""
|
||||
|
||||
with gr.Blocks(theme="ocean", css=css, title='dots.ocr') as demo:
|
||||
|
||||
# Title
|
||||
gr.HTML("""
|
||||
<div style="display: flex; align-items: center; justify-content: center; margin-bottom: 20px;">
|
||||
<h1 style="margin: 0; font-size: 2em;">🔍 dots.ocr</h1>
|
||||
</div>
|
||||
<div style="text-align: center; margin-bottom: 10px;">
|
||||
<em>Supports image/PDF layout analysis and structured output</em>
|
||||
</div>
|
||||
""")
|
||||
|
||||
with gr.Row():
|
||||
# Left side: Input and Configuration
|
||||
with gr.Column(scale=1, elem_id="left-panel"):
|
||||
gr.Markdown("### 📥 Upload & Select")
|
||||
file_input = gr.File(
|
||||
label="Upload PDF/Image",
|
||||
type="filepath",
|
||||
file_types=[".pdf", ".jpg", ".jpeg", ".png"],
|
||||
)
|
||||
|
||||
test_images = get_test_images()
|
||||
test_image_input = gr.Dropdown(
|
||||
label="Or Select an Example",
|
||||
choices=[""] + test_images,
|
||||
value="",
|
||||
)
|
||||
|
||||
gr.Markdown("### ⚙️ Prompt & Actions")
|
||||
prompt_mode = gr.Dropdown(
|
||||
label="Select Prompt",
|
||||
choices=["prompt_layout_all_en", "prompt_layout_only_en", "prompt_ocr"],
|
||||
value="prompt_layout_all_en",
|
||||
show_label=True
|
||||
)
|
||||
|
||||
# Display current prompt content
|
||||
prompt_display = gr.Textbox(
|
||||
label="Current Prompt Content",
|
||||
value=dict_promptmode_to_prompt[list(dict_promptmode_to_prompt.keys())[0]],
|
||||
lines=4,
|
||||
max_lines=8,
|
||||
interactive=False,
|
||||
show_copy_button=True
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
process_btn = gr.Button("🔍 Parse", variant="primary", scale=2, elem_id="parse_button")
|
||||
clear_btn = gr.Button("🗑️ Clear", variant="secondary", scale=1)
|
||||
|
||||
with gr.Accordion("🛠️ Advanced Configuration", open=False):
|
||||
fitz_preprocess = gr.Checkbox(
|
||||
label="Enable fitz_preprocess for images",
|
||||
value=True,
|
||||
info="Processes image via a PDF-like pipeline (image->pdf->200dpi image). Recommended if your image DPI is low."
|
||||
)
|
||||
with gr.Row():
|
||||
server_ip = gr.Textbox(label="Server IP", value=DEFAULT_CONFIG['ip'])
|
||||
server_port = gr.Number(label="Port", value=DEFAULT_CONFIG['port_vllm'], precision=0)
|
||||
with gr.Row():
|
||||
min_pixels = gr.Number(label="Min Pixels", value=DEFAULT_CONFIG['min_pixels'], precision=0)
|
||||
max_pixels = gr.Number(label="Max Pixels", value=DEFAULT_CONFIG['max_pixels'], precision=0)
|
||||
# Right side: Result Display
|
||||
with gr.Column(scale=6, variant="compact"):
|
||||
with gr.Row():
|
||||
# Result Image
|
||||
with gr.Column(scale=3):
|
||||
gr.Markdown("### 👁️ File Preview")
|
||||
result_image = gr.Image(
|
||||
label="Layout Preview",
|
||||
visible=True,
|
||||
height=800,
|
||||
show_label=False
|
||||
)
|
||||
|
||||
# Page navigation (shown during PDF preview)
|
||||
with gr.Row():
|
||||
prev_btn = gr.Button("⬅ Previous", size="sm")
|
||||
page_info = gr.HTML(
|
||||
value="<div id='page_info_box'>0 / 0</div>",
|
||||
elem_id="page_info_html"
|
||||
)
|
||||
next_btn = gr.Button("Next ➡", size="sm")
|
||||
|
||||
# Info Display
|
||||
info_display = gr.Markdown(
|
||||
"Waiting for processing results...",
|
||||
elem_id="info_box"
|
||||
)
|
||||
|
||||
# Markdown Result
|
||||
with gr.Column(scale=3):
|
||||
gr.Markdown("### ✔️ Result Display")
|
||||
|
||||
with gr.Tabs(elem_id="markdown_tabs"):
|
||||
with gr.TabItem("Markdown Render Preview"):
|
||||
md_output = gr.Markdown(
|
||||
"## Please click the parse button to parse or select for single-task recognition...",
|
||||
label="Markdown Preview",
|
||||
max_height=600,
|
||||
latex_delimiters=[
|
||||
{"left": "$$", "right": "$$", "display": True},
|
||||
{"left": "$", "right": "$", "display": False},
|
||||
],
|
||||
show_copy_button=False,
|
||||
elem_id="markdown_output"
|
||||
)
|
||||
|
||||
with gr.TabItem("Markdown Raw Text"):
|
||||
md_raw_output = gr.Textbox(
|
||||
value="🕐 Waiting for parsing result...",
|
||||
label="Markdown Raw Text",
|
||||
max_lines=100,
|
||||
lines=38,
|
||||
show_copy_button=True,
|
||||
elem_id="markdown_output",
|
||||
show_label=False
|
||||
)
|
||||
|
||||
with gr.TabItem("Current Page JSON"):
|
||||
current_page_json = gr.Textbox(
|
||||
value="🕐 Waiting for parsing result...",
|
||||
label="Current Page JSON",
|
||||
max_lines=100,
|
||||
lines=38,
|
||||
show_copy_button=True,
|
||||
elem_id="markdown_output",
|
||||
show_label=False
|
||||
)
|
||||
|
||||
# Download Button
|
||||
with gr.Row():
|
||||
download_btn = gr.DownloadButton(
|
||||
"⬇️ Download Results",
|
||||
visible=False
|
||||
)
|
||||
|
||||
# When the prompt mode changes, update the display content
|
||||
prompt_mode.change(
|
||||
fn=update_prompt_display,
|
||||
inputs=prompt_mode,
|
||||
outputs=prompt_display,
|
||||
show_progress=False
|
||||
)
|
||||
|
||||
# Show preview on file upload
|
||||
file_input.upload(
|
||||
fn=load_file_for_preview,
|
||||
inputs=file_input,
|
||||
outputs=[result_image, page_info],
|
||||
show_progress=False
|
||||
)
|
||||
|
||||
# Page navigation
|
||||
prev_btn.click(
|
||||
fn=lambda: turn_page("prev"),
|
||||
outputs=[result_image, page_info, current_page_json],
|
||||
show_progress=False
|
||||
)
|
||||
|
||||
next_btn.click(
|
||||
fn=lambda: turn_page("next"),
|
||||
outputs=[result_image, page_info, current_page_json],
|
||||
show_progress=False
|
||||
)
|
||||
|
||||
process_btn.click(
|
||||
fn=process_image_inference,
|
||||
inputs=[
|
||||
test_image_input, file_input,
|
||||
prompt_mode, server_ip, server_port, min_pixels, max_pixels,
|
||||
fitz_preprocess
|
||||
],
|
||||
outputs=[
|
||||
result_image, info_display, md_output, md_raw_output,
|
||||
download_btn, page_info, current_page_json
|
||||
],
|
||||
show_progress=True
|
||||
)
|
||||
|
||||
clear_btn.click(
|
||||
fn=clear_all_data,
|
||||
outputs=[
|
||||
file_input, test_image_input,
|
||||
result_image, info_display, md_output, md_raw_output,
|
||||
download_btn, page_info, current_page_json
|
||||
],
|
||||
show_progress=False
|
||||
)
|
||||
|
||||
return demo
|
||||
|
||||
# ==================== Main Program ====================
|
||||
if __name__ == "__main__":
|
||||
demo = create_gradio_interface()
|
||||
demo.queue().launch(
|
||||
server_name="0.0.0.0",
|
||||
server_port=7860,
|
||||
debug=True
|
||||
)
|
||||
Executable
+666
@@ -0,0 +1,666 @@
|
||||
"""
|
||||
Layout Inference Web Application with Gradio - Annotation Version
|
||||
|
||||
A Gradio-based layout inference tool that supports image uploads and multiple backend inference engines.
|
||||
This version adds an image annotation feature, allowing users to draw bounding boxes on an image and send both the image and the boxes to the model.
|
||||
"""
|
||||
|
||||
import gradio as gr
|
||||
import json
|
||||
import os
|
||||
import io
|
||||
import tempfile
|
||||
import base64
|
||||
import zipfile
|
||||
import uuid
|
||||
import re
|
||||
from pathlib import Path
|
||||
from PIL import Image
|
||||
import requests
|
||||
from gradio_image_annotation import image_annotator
|
||||
|
||||
# Local utility imports
|
||||
from dots_ocr.utils import dict_promptmode_to_prompt
|
||||
from dots_ocr.utils.consts import MIN_PIXELS, MAX_PIXELS
|
||||
from dots_ocr.utils.demo_utils.display import read_image
|
||||
from dots_ocr.utils.doc_utils import load_images_from_pdf
|
||||
|
||||
# Add DotsOCRParser import
|
||||
from dots_ocr.parser import DotsOCRParser
|
||||
|
||||
# ==================== Configuration ====================
|
||||
DEFAULT_CONFIG = {
|
||||
'ip': "127.0.0.1",
|
||||
'port_vllm': 8000,
|
||||
'min_pixels': MIN_PIXELS,
|
||||
'max_pixels': MAX_PIXELS,
|
||||
'test_images_dir': "./assets/showcase_origin",
|
||||
}
|
||||
|
||||
# ==================== Global Variables ====================
|
||||
# Store the current configuration
|
||||
current_config = DEFAULT_CONFIG.copy()
|
||||
|
||||
# Create a DotsOCRParser instance
|
||||
dots_parser = DotsOCRParser(
|
||||
ip=DEFAULT_CONFIG['ip'],
|
||||
port=DEFAULT_CONFIG['port_vllm'],
|
||||
dpi=200,
|
||||
min_pixels=DEFAULT_CONFIG['min_pixels'],
|
||||
max_pixels=DEFAULT_CONFIG['max_pixels']
|
||||
)
|
||||
|
||||
# Store processing results
|
||||
processing_results = {
|
||||
'original_image': None,
|
||||
'processed_image': None,
|
||||
'layout_result': None,
|
||||
'markdown_content': None,
|
||||
'cells_data': None,
|
||||
'temp_dir': None,
|
||||
'session_id': None,
|
||||
'result_paths': None,
|
||||
'annotation_data': None # Store annotation data
|
||||
}
|
||||
|
||||
# ==================== Utility Functions ====================
|
||||
def read_image_v2(img):
|
||||
"""Reads an image, supporting URLs and local paths."""
|
||||
if isinstance(img, str) and img.startswith(("http://", "https://")):
|
||||
with requests.get(img, stream=True) as response:
|
||||
response.raise_for_status()
|
||||
img = Image.open(io.BytesIO(response.content))
|
||||
elif isinstance(img, str):
|
||||
img, _, _ = read_image(img, use_native=True)
|
||||
elif isinstance(img, Image.Image):
|
||||
pass
|
||||
else:
|
||||
raise ValueError(f"Invalid image type: {type(img)}")
|
||||
return img
|
||||
|
||||
def get_test_images():
|
||||
"""Gets the list of test images."""
|
||||
test_images = []
|
||||
test_dir = current_config['test_images_dir']
|
||||
if os.path.exists(test_dir):
|
||||
test_images = [os.path.join(test_dir, name) for name in os.listdir(test_dir)
|
||||
if name.lower().endswith(('.png', '.jpg', '.jpeg'))]
|
||||
return test_images
|
||||
|
||||
def create_temp_session_dir():
|
||||
"""Creates a unique temporary directory for each processing request."""
|
||||
session_id = uuid.uuid4().hex[:8]
|
||||
temp_dir = os.path.join(tempfile.gettempdir(), f"dots_ocr_demo_{session_id}")
|
||||
os.makedirs(temp_dir, exist_ok=True)
|
||||
return temp_dir, session_id
|
||||
|
||||
def parse_image_with_bbox(parser, image, prompt_mode, bbox=None, fitz_preprocess=False):
|
||||
"""
|
||||
Processes an image using DotsOCRParser, with support for the bbox parameter.
|
||||
"""
|
||||
# Create a temporary session directory
|
||||
temp_dir, session_id = create_temp_session_dir()
|
||||
|
||||
try:
|
||||
# Save the PIL Image to a temporary file
|
||||
temp_image_path = os.path.join(temp_dir, f"input_{session_id}.png")
|
||||
image.save(temp_image_path, "PNG")
|
||||
|
||||
# Use the high-level parse_image interface, passing the bbox parameter
|
||||
filename = f"demo_{session_id}"
|
||||
results = parser.parse_image(
|
||||
input_path=temp_image_path,
|
||||
filename=filename,
|
||||
prompt_mode=prompt_mode,
|
||||
save_dir=temp_dir,
|
||||
bbox=bbox,
|
||||
fitz_preprocess=fitz_preprocess
|
||||
)
|
||||
|
||||
# Parse the results
|
||||
if not results:
|
||||
raise ValueError("No results returned from parser")
|
||||
|
||||
result = results[0] # parse_image returns a list with a single result
|
||||
|
||||
# Read the result files
|
||||
layout_image = None
|
||||
cells_data = None
|
||||
md_content = None
|
||||
filtered = False
|
||||
|
||||
# Read the layout image
|
||||
if 'layout_image_path' in result and os.path.exists(result['layout_image_path']):
|
||||
layout_image = Image.open(result['layout_image_path'])
|
||||
|
||||
# Read the JSON data
|
||||
if 'layout_info_path' in result and os.path.exists(result['layout_info_path']):
|
||||
with open(result['layout_info_path'], 'r', encoding='utf-8') as f:
|
||||
cells_data = json.load(f)
|
||||
|
||||
# Read the Markdown content
|
||||
if 'md_content_path' in result and os.path.exists(result['md_content_path']):
|
||||
with open(result['md_content_path'], 'r', encoding='utf-8') as f:
|
||||
md_content = f.read()
|
||||
|
||||
# Check for the original response file (if JSON parsing fails)
|
||||
if 'filtered' in result:
|
||||
filtered = result['filtered']
|
||||
|
||||
return {
|
||||
'layout_image': layout_image,
|
||||
'cells_data': cells_data,
|
||||
'md_content': md_content,
|
||||
'filtered': filtered,
|
||||
'temp_dir': temp_dir,
|
||||
'session_id': session_id,
|
||||
'result_paths': result
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
# Clean up the temporary directory on error
|
||||
import shutil
|
||||
if os.path.exists(temp_dir):
|
||||
shutil.rmtree(temp_dir, ignore_errors=True)
|
||||
raise e
|
||||
|
||||
def process_annotation_data(annotation_data):
|
||||
"""Processes annotation data, converting it to the format required by the model."""
|
||||
if not annotation_data or not annotation_data.get('boxes'):
|
||||
return None, None
|
||||
|
||||
# Get image and box data
|
||||
image = annotation_data.get('image')
|
||||
boxes = annotation_data.get('boxes', [])
|
||||
|
||||
if not boxes:
|
||||
return image, None
|
||||
|
||||
# Ensure the image is in PIL Image format
|
||||
if image is not None:
|
||||
import numpy as np
|
||||
if isinstance(image, np.ndarray):
|
||||
image = Image.fromarray(image)
|
||||
elif not isinstance(image, Image.Image):
|
||||
# If it's another format, try to convert it
|
||||
try:
|
||||
image = Image.open(image) if isinstance(image, str) else Image.fromarray(image)
|
||||
except Exception as e:
|
||||
print(f"Image format conversion failed: {e}")
|
||||
return None, None
|
||||
|
||||
# Get the coordinate information of the box (only one box)
|
||||
box = boxes[0]
|
||||
bbox = [box['xmin'], box['ymin'], box['xmax'], box['ymax']]
|
||||
|
||||
return image, bbox
|
||||
|
||||
# ==================== Core Processing Function ====================
|
||||
def process_image_inference_with_annotation(annotation_data, test_image_input,
|
||||
prompt_mode, server_ip, server_port, min_pixels, max_pixels,
|
||||
fitz_preprocess=False
|
||||
):
|
||||
"""Core function for image inference, supporting annotation data."""
|
||||
global current_config, processing_results, dots_parser
|
||||
|
||||
# First, clean up previous processing results
|
||||
if processing_results.get('temp_dir') and os.path.exists(processing_results['temp_dir']):
|
||||
import shutil
|
||||
try:
|
||||
shutil.rmtree(processing_results['temp_dir'], ignore_errors=True)
|
||||
except Exception as e:
|
||||
print(f"Failed to clean up previous temporary directory: {e}")
|
||||
|
||||
# Reset processing results
|
||||
processing_results = {
|
||||
'original_image': None,
|
||||
'processed_image': None,
|
||||
'layout_result': None,
|
||||
'markdown_content': None,
|
||||
'cells_data': None,
|
||||
'temp_dir': None,
|
||||
'session_id': None,
|
||||
'result_paths': None,
|
||||
'annotation_data': annotation_data
|
||||
}
|
||||
|
||||
# Update configuration
|
||||
current_config.update({
|
||||
'ip': server_ip,
|
||||
'port_vllm': server_port,
|
||||
'min_pixels': min_pixels,
|
||||
'max_pixels': max_pixels
|
||||
})
|
||||
|
||||
# Update parser configuration
|
||||
dots_parser.ip = server_ip
|
||||
dots_parser.port = server_port
|
||||
dots_parser.min_pixels = min_pixels
|
||||
dots_parser.max_pixels = max_pixels
|
||||
|
||||
# Determine the input source and process annotation data
|
||||
image = None
|
||||
bbox = None
|
||||
|
||||
# Prioritize processing annotation data
|
||||
if annotation_data and annotation_data.get('image') is not None:
|
||||
image, bbox = process_annotation_data(annotation_data)
|
||||
if image is not None:
|
||||
# If there's a bbox, force the use of 'prompt_grounding_ocr' mode
|
||||
assert bbox is not None
|
||||
prompt_mode = "prompt_grounding_ocr"
|
||||
|
||||
# If there's no annotation data, check the test image input
|
||||
if image is None and test_image_input and test_image_input != "":
|
||||
try:
|
||||
image = read_image_v2(test_image_input)
|
||||
except Exception as e:
|
||||
return None, f"Failed to read test image: {e}", "", "", gr.update(value=None), ""
|
||||
|
||||
if image is None:
|
||||
return None, "Please select a test image or add an image in the annotation component", "", "", gr.update(value=None), ""
|
||||
if bbox is None:
|
||||
return "Please select a bounding box by mouse", "Please select a bounding box by mouse", "", "", gr.update(value=None)
|
||||
|
||||
try:
|
||||
# Process using DotsOCRParser, passing the bbox parameter
|
||||
original_image = image
|
||||
parse_result = parse_image_with_bbox(dots_parser, image, prompt_mode, bbox, fitz_preprocess)
|
||||
|
||||
# Extract parsing results
|
||||
layout_image = parse_result['layout_image']
|
||||
cells_data = parse_result['cells_data']
|
||||
md_content = parse_result['md_content']
|
||||
filtered = parse_result['filtered']
|
||||
|
||||
# Store the results
|
||||
processing_results.update({
|
||||
'original_image': original_image,
|
||||
'processed_image': None,
|
||||
'layout_result': layout_image,
|
||||
'markdown_content': md_content,
|
||||
'cells_data': cells_data,
|
||||
'temp_dir': parse_result['temp_dir'],
|
||||
'session_id': parse_result['session_id'],
|
||||
'result_paths': parse_result['result_paths'],
|
||||
'annotation_data': annotation_data
|
||||
})
|
||||
|
||||
# Handle the case where parsing fails
|
||||
if filtered:
|
||||
info_text = f"""
|
||||
**Image Information:**
|
||||
- Original Dimensions: {original_image.width} x {original_image.height}
|
||||
- Processing Mode: {'Region OCR' if bbox else 'Full Image OCR'}
|
||||
- Processing Status: JSON parsing failed, using cleaned text output
|
||||
- Server: {current_config['ip']}:{current_config['port_vllm']}
|
||||
- Session ID: {parse_result['session_id']}
|
||||
- Box Coordinates: {bbox if bbox else 'None'}
|
||||
"""
|
||||
|
||||
return (
|
||||
md_content or "No markdown content generated",
|
||||
info_text,
|
||||
md_content or "No markdown content generated",
|
||||
md_content or "No markdown content generated",
|
||||
gr.update(visible=False),
|
||||
""
|
||||
)
|
||||
|
||||
# Handle the case where JSON parsing succeeds
|
||||
num_elements = len(cells_data) if cells_data else 0
|
||||
info_text = f"""
|
||||
**Image Information:**
|
||||
- Original Dimensions: {original_image.width} x {original_image.height}
|
||||
- Processing Mode: {'Region OCR' if bbox else 'Full Image OCR'}
|
||||
- Server: {current_config['ip']}:{current_config['port_vllm']}
|
||||
- Detected {num_elements} layout elements
|
||||
- Session ID: {parse_result['session_id']}
|
||||
- Box Coordinates: {bbox if bbox else 'None'}
|
||||
"""
|
||||
|
||||
# Current page JSON output
|
||||
current_json = ""
|
||||
if cells_data:
|
||||
try:
|
||||
current_json = json.dumps(cells_data, ensure_ascii=False, indent=2)
|
||||
except:
|
||||
current_json = str(cells_data)
|
||||
|
||||
# Create a downloadable ZIP file
|
||||
download_zip_path = None
|
||||
if parse_result['temp_dir']:
|
||||
download_zip_path = os.path.join(parse_result['temp_dir'], f"layout_results_{parse_result['session_id']}.zip")
|
||||
try:
|
||||
with zipfile.ZipFile(download_zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
|
||||
for root, dirs, files in os.walk(parse_result['temp_dir']):
|
||||
for file in files:
|
||||
if file.endswith('.zip'):
|
||||
continue
|
||||
file_path = os.path.join(root, file)
|
||||
arcname = os.path.relpath(file_path, parse_result['temp_dir'])
|
||||
zipf.write(file_path, arcname)
|
||||
except Exception as e:
|
||||
print(f"Failed to create download ZIP: {e}")
|
||||
download_zip_path = None
|
||||
|
||||
return (
|
||||
md_content or "No markdown content generated",
|
||||
info_text,
|
||||
md_content or "No markdown content generated",
|
||||
md_content or "No markdown content generated",
|
||||
gr.update(value=download_zip_path, visible=True) if download_zip_path else gr.update(visible=False),
|
||||
current_json
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
return f"An error occurred during processing: {e}", f"An error occurred during processing: {e}", "", "", gr.update(value=None), ""
|
||||
|
||||
def load_image_to_annotator(test_image_input):
|
||||
"""Loads an image into the annotation component."""
|
||||
image = None
|
||||
|
||||
# Check the test image input
|
||||
if test_image_input and test_image_input != "":
|
||||
try:
|
||||
image = read_image_v2(test_image_input)
|
||||
except Exception as e:
|
||||
return None
|
||||
|
||||
if image is None:
|
||||
return None
|
||||
|
||||
# Return the format required by the annotation component
|
||||
return {
|
||||
"image": image,
|
||||
"boxes": []
|
||||
}
|
||||
|
||||
def clear_all_data():
|
||||
"""Clears all data."""
|
||||
global processing_results
|
||||
|
||||
# Clean up the temporary directory
|
||||
if processing_results.get('temp_dir') and os.path.exists(processing_results['temp_dir']):
|
||||
import shutil
|
||||
try:
|
||||
shutil.rmtree(processing_results['temp_dir'], ignore_errors=True)
|
||||
except Exception as e:
|
||||
print(f"Failed to clean up temporary directory: {e}")
|
||||
|
||||
# Reset processing results
|
||||
processing_results = {
|
||||
'original_image': None,
|
||||
'processed_image': None,
|
||||
'layout_result': None,
|
||||
'markdown_content': None,
|
||||
'cells_data': None,
|
||||
'temp_dir': None,
|
||||
'session_id': None,
|
||||
'result_paths': None,
|
||||
'annotation_data': None
|
||||
}
|
||||
|
||||
return (
|
||||
"", # Clear test image selection
|
||||
None, # Clear annotation component
|
||||
"Waiting for processing results...", # Reset info display
|
||||
"## Waiting for processing results...", # Reset Markdown display
|
||||
"🕐 Waiting for parsing results...", # Clear raw Markdown text
|
||||
gr.update(visible=False), # Hide download button
|
||||
"🕐 Waiting for parsing results..." # Clear JSON
|
||||
)
|
||||
|
||||
def update_prompt_display(prompt_mode):
|
||||
"""Updates the displayed prompt content."""
|
||||
return dict_promptmode_to_prompt[prompt_mode]
|
||||
|
||||
# ==================== Gradio Interface ====================
|
||||
def create_gradio_interface():
|
||||
"""Creates the Gradio interface."""
|
||||
|
||||
# CSS styling to match the reference style
|
||||
css = """
|
||||
footer {
|
||||
visibility: hidden;
|
||||
}
|
||||
|
||||
#info_box {
|
||||
padding: 10px;
|
||||
background-color: #f8f9fa;
|
||||
border-radius: 8px;
|
||||
border: 1px solid #dee2e6;
|
||||
margin: 10px 0;
|
||||
font-size: 14px;
|
||||
}
|
||||
|
||||
#markdown_tabs {
|
||||
height: 100%;
|
||||
}
|
||||
|
||||
#annotation_component {
|
||||
border-radius: 8px;
|
||||
}
|
||||
"""
|
||||
|
||||
with gr.Blocks(theme="ocean", css=css, title='dots.ocr - Annotation') as demo:
|
||||
|
||||
# Title
|
||||
gr.HTML("""
|
||||
<div style="display: flex; align-items: center; justify-content: center; margin-bottom: 20px;">
|
||||
<h1 style="margin: 0; font-size: 2em;">🔍 dots.ocr - Annotation Version</h1>
|
||||
</div>
|
||||
<div style="text-align: center; margin-bottom: 10px;">
|
||||
<em>Supports image annotation, drawing boxes, and sending box information to the model for OCR.</em>
|
||||
</div>
|
||||
""")
|
||||
|
||||
with gr.Row():
|
||||
# Left side: Input and Configuration
|
||||
with gr.Column(scale=1, variant="compact"):
|
||||
gr.Markdown("### 📁 Select Example")
|
||||
test_images = get_test_images()
|
||||
test_image_input = gr.Dropdown(
|
||||
label="Select Example",
|
||||
choices=[""] + test_images,
|
||||
value="",
|
||||
show_label=True
|
||||
)
|
||||
|
||||
# Button to load image into the annotation component
|
||||
load_btn = gr.Button("📷 Load Image to Annotation Area", variant="secondary")
|
||||
|
||||
prompt_mode = gr.Dropdown(
|
||||
label="Select Prompt",
|
||||
# choices=["prompt_layout_all_en", "prompt_layout_only_en", "prompt_ocr", "prompt_grounding_ocr"],
|
||||
choices=["prompt_grounding_ocr"],
|
||||
value="prompt_grounding_ocr",
|
||||
show_label=True,
|
||||
info="If a box is drawn, 'prompt_grounding_ocr' mode will be used automatically."
|
||||
)
|
||||
|
||||
# Display the current prompt content
|
||||
prompt_display = gr.Textbox(
|
||||
label="Current Prompt Content",
|
||||
# value=dict_promptmode_to_prompt[list(dict_promptmode_to_prompt.keys())[0]],
|
||||
value=dict_promptmode_to_prompt["prompt_grounding_ocr"],
|
||||
lines=4,
|
||||
max_lines=8,
|
||||
interactive=False,
|
||||
show_copy_button=True
|
||||
)
|
||||
|
||||
gr.Markdown("### ⚙️ Actions")
|
||||
process_btn = gr.Button("🔍 Parse", variant="primary")
|
||||
clear_btn = gr.Button("🗑️ Clear", variant="secondary")
|
||||
|
||||
gr.Markdown("### 🛠️ Configuration")
|
||||
|
||||
fitz_preprocess = gr.Checkbox(
|
||||
label="Enable fitz_preprocess",
|
||||
value=False,
|
||||
info="Performs fitz preprocessing on the image input, converting the image to a PDF and then to a 200dpi image."
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
server_ip = gr.Textbox(
|
||||
label="Server IP",
|
||||
value=DEFAULT_CONFIG['ip']
|
||||
)
|
||||
server_port = gr.Number(
|
||||
label="Port",
|
||||
value=DEFAULT_CONFIG['port_vllm'],
|
||||
precision=0
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
min_pixels = gr.Number(
|
||||
label="Min Pixels",
|
||||
value=DEFAULT_CONFIG['min_pixels'],
|
||||
precision=0
|
||||
)
|
||||
max_pixels = gr.Number(
|
||||
label="Max Pixels",
|
||||
value=DEFAULT_CONFIG['max_pixels'],
|
||||
precision=0
|
||||
)
|
||||
|
||||
# Right side: Result Display
|
||||
with gr.Column(scale=6, variant="compact"):
|
||||
with gr.Row():
|
||||
# Image Annotation Area
|
||||
with gr.Column(scale=3):
|
||||
gr.Markdown("### 🎯 Image Annotation Area")
|
||||
gr.Markdown("""
|
||||
**Instructions:**
|
||||
- Method 1: Select an example image on the left and click "Load Image to Annotation Area".
|
||||
- Method 2: Upload an image directly in the annotation area below (drag and drop or click to upload).
|
||||
- Use the mouse to draw a box on the image to select the region for recognition.
|
||||
- Only one box can be drawn. To draw a new one, please delete the old one first.
|
||||
- **Hotkey: Press the Delete key to remove the selected box.**
|
||||
- After drawing a box, clicking Parse will automatically use the Region OCR mode.
|
||||
""")
|
||||
|
||||
annotator = image_annotator(
|
||||
value=None,
|
||||
label="Image Annotation",
|
||||
height=600,
|
||||
show_label=False,
|
||||
elem_id="annotation_component",
|
||||
single_box=True, # Only allow one box; a new box will replace the old one
|
||||
box_min_size=10,
|
||||
interactive=True,
|
||||
disable_edit_boxes=True, # Disable the edit dialog
|
||||
label_list=["OCR Region"], # Set the default label
|
||||
label_colors=[(255, 0, 0)], # Set color to red
|
||||
use_default_label=True, # Use the default label
|
||||
image_type="pil" # Ensure it returns a PIL Image format
|
||||
)
|
||||
|
||||
# Information Display
|
||||
info_display = gr.Markdown(
|
||||
"Waiting for processing results...",
|
||||
elem_id="info_box"
|
||||
)
|
||||
|
||||
# Result Display Area
|
||||
with gr.Column(scale=3):
|
||||
gr.Markdown("### ✅ Results")
|
||||
|
||||
with gr.Tabs(elem_id="markdown_tabs"):
|
||||
with gr.TabItem("Markdown Rendered View"):
|
||||
md_output = gr.Markdown(
|
||||
"## Please upload an image and click the Parse button for recognition...",
|
||||
label="Markdown Preview",
|
||||
max_height=1000,
|
||||
latex_delimiters=[
|
||||
{"left": "$$", "right": "$$", "display": True},
|
||||
{"left": "$", "right": "$", "display": False},
|
||||
],
|
||||
show_copy_button=False,
|
||||
elem_id="markdown_output"
|
||||
)
|
||||
|
||||
with gr.TabItem("Markdown Raw Text"):
|
||||
md_raw_output = gr.Textbox(
|
||||
value="🕐 Waiting for parsing results...",
|
||||
label="Markdown Raw Text",
|
||||
max_lines=100,
|
||||
lines=38,
|
||||
show_copy_button=True,
|
||||
elem_id="markdown_output",
|
||||
show_label=False
|
||||
)
|
||||
|
||||
with gr.TabItem("JSON Result"):
|
||||
json_output = gr.Textbox(
|
||||
value="🕐 Waiting for parsing results...",
|
||||
label="JSON Result",
|
||||
max_lines=100,
|
||||
lines=38,
|
||||
show_copy_button=True,
|
||||
elem_id="markdown_output",
|
||||
show_label=False
|
||||
)
|
||||
|
||||
# Download Button
|
||||
with gr.Row():
|
||||
download_btn = gr.DownloadButton(
|
||||
"⬇️ Download Results",
|
||||
visible=False
|
||||
)
|
||||
|
||||
# Event Binding
|
||||
|
||||
# When the prompt mode changes, update the displayed content
|
||||
prompt_mode.change(
|
||||
fn=update_prompt_display,
|
||||
inputs=prompt_mode,
|
||||
outputs=prompt_display,
|
||||
show_progress=False
|
||||
)
|
||||
|
||||
# Load image into the annotation component
|
||||
load_btn.click(
|
||||
fn=load_image_to_annotator,
|
||||
inputs=[test_image_input],
|
||||
outputs=annotator,
|
||||
show_progress=False
|
||||
)
|
||||
|
||||
# Process Inference
|
||||
process_btn.click(
|
||||
fn=process_image_inference_with_annotation,
|
||||
inputs=[
|
||||
annotator, test_image_input,
|
||||
prompt_mode, server_ip, server_port, min_pixels, max_pixels,
|
||||
fitz_preprocess
|
||||
],
|
||||
outputs=[
|
||||
md_output, info_display, md_raw_output, md_raw_output,
|
||||
download_btn, json_output
|
||||
],
|
||||
show_progress=True
|
||||
)
|
||||
|
||||
# Clear Data
|
||||
clear_btn.click(
|
||||
fn=clear_all_data,
|
||||
outputs=[
|
||||
test_image_input, annotator,
|
||||
info_display, md_output, md_raw_output,
|
||||
download_btn, json_output
|
||||
],
|
||||
show_progress=False
|
||||
)
|
||||
|
||||
return demo
|
||||
|
||||
# ==================== Main Program ====================
|
||||
if __name__ == "__main__":
|
||||
demo = create_gradio_interface()
|
||||
demo.queue().launch(
|
||||
server_name="0.0.0.0",
|
||||
server_port=7861, # Use a different port to avoid conflicts
|
||||
debug=True
|
||||
)
|
||||
Executable
+71
@@ -0,0 +1,71 @@
|
||||
import os
|
||||
if "LOCAL_RANK" not in os.environ:
|
||||
os.environ["LOCAL_RANK"] = "0"
|
||||
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoProcessor, AutoTokenizer
|
||||
from qwen_vl_utils import process_vision_info
|
||||
from dots_ocr.utils import dict_promptmode_to_prompt
|
||||
|
||||
def inference(image_path, prompt, model, processor):
|
||||
# image_path = "demo/demo_image1.jpg"
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image",
|
||||
"image": image_path
|
||||
},
|
||||
{"type": "text", "text": prompt}
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
# Preparation for inference
|
||||
text = processor.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True
|
||||
)
|
||||
image_inputs, video_inputs = process_vision_info(messages)
|
||||
inputs = processor(
|
||||
text=[text],
|
||||
images=image_inputs,
|
||||
videos=video_inputs,
|
||||
padding=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
inputs = inputs.to("cuda")
|
||||
|
||||
# Inference: Generation of the output
|
||||
generated_ids = model.generate(**inputs, max_new_tokens=24000)
|
||||
generated_ids_trimmed = [
|
||||
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
||||
]
|
||||
output_text = processor.batch_decode(
|
||||
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
||||
)
|
||||
print(output_text)
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# We recommend enabling flash_attention_2 for better acceleration and memory saving, especially in multi-image and video scenarios.
|
||||
model_path = "./weights/DotsOCR"
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_path,
|
||||
attn_implementation="flash_attention_2",
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="auto",
|
||||
trust_remote_code=True
|
||||
)
|
||||
processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
|
||||
|
||||
image_path = "demo/demo_image1.jpg"
|
||||
for prompt_mode, prompt in dict_promptmode_to_prompt.items():
|
||||
print(f"prompt: {prompt}")
|
||||
inference(image_path, prompt, model, processor)
|
||||
|
||||
Executable
BIN
Binary file not shown.
|
After Width: | Height: | Size: 755 KiB |
Executable
BIN
Binary file not shown.
Executable
+222
@@ -0,0 +1,222 @@
|
||||
"""
|
||||
Layout Inference Web Application
|
||||
|
||||
A Streamlit-based layout inference tool that supports image uploads and multiple backend inference engines.
|
||||
"""
|
||||
|
||||
import streamlit as st
|
||||
import json
|
||||
import os
|
||||
import io
|
||||
import tempfile
|
||||
from PIL import Image
|
||||
import requests
|
||||
|
||||
# Local utility imports
|
||||
|
||||
# from utils import infer
|
||||
|
||||
from dots_ocr.utils import dict_promptmode_to_prompt
|
||||
from dots_ocr.utils.format_transformer import layoutjson2md
|
||||
from dots_ocr.utils.layout_utils import draw_layout_on_image, post_process_cells
|
||||
from dots_ocr.utils.image_utils import get_input_dimensions, get_image_by_fitz_doc
|
||||
from dots_ocr.model.inference import inference_with_vllm
|
||||
from dots_ocr.utils.consts import MIN_PIXELS, MAX_PIXELS
|
||||
|
||||
import os
|
||||
from PIL import Image
|
||||
from dots_ocr.utils.demo_utils.display import read_image
|
||||
|
||||
|
||||
|
||||
# ==================== Configuration ====================
|
||||
DEFAULT_CONFIG = {
|
||||
'ip': "127.0.0.1",
|
||||
'port_vllm': 8000,
|
||||
'min_pixels': MIN_PIXELS,
|
||||
'max_pixels': MAX_PIXELS,
|
||||
'test_images_dir': "./assets/showcase_origin",
|
||||
}
|
||||
|
||||
# ==================== Utility Functions ====================
|
||||
|
||||
|
||||
@st.cache_resource
|
||||
def read_image_v2(img: str):
|
||||
if img.startswith(("http://", "https://")):
|
||||
with requests.get(img, stream=True) as response:
|
||||
response.raise_for_status()
|
||||
img = Image.open(io.BytesIO(response.content))
|
||||
|
||||
if isinstance(img, str):
|
||||
# img = transform_image_path(img)
|
||||
img, _, _ = read_image(img, use_native=True)
|
||||
elif isinstance(img, Image.Image):
|
||||
pass
|
||||
else:
|
||||
raise ValueError(f"Invalid image type: {type(img)}")
|
||||
return img
|
||||
|
||||
|
||||
# ==================== UI Components ====================
|
||||
def create_config_sidebar():
|
||||
"""Create configuration sidebar"""
|
||||
st.sidebar.header("Configuration Parameters")
|
||||
|
||||
config = {}
|
||||
config['prompt_key'] = st.sidebar.selectbox("Prompt Mode", list(dict_promptmode_to_prompt.keys()))
|
||||
config['ip'] = st.sidebar.text_input("Server IP", DEFAULT_CONFIG['ip'])
|
||||
config['port'] = st.sidebar.number_input("Port", min_value=1000, max_value=9999, value=DEFAULT_CONFIG['port_vllm'])
|
||||
# config['eos_word'] = st.sidebar.text_input("EOS Word", DEFAULT_CONFIG['eos_word'])
|
||||
|
||||
# Image configuration
|
||||
st.sidebar.subheader("Image Configuration")
|
||||
config['min_pixels'] = st.sidebar.number_input("Min Pixels", value=DEFAULT_CONFIG['min_pixels'])
|
||||
config['max_pixels'] = st.sidebar.number_input("Max Pixels", value=DEFAULT_CONFIG['max_pixels'])
|
||||
|
||||
return config
|
||||
|
||||
def get_image_input():
|
||||
"""Get image input"""
|
||||
st.markdown("#### Image Input")
|
||||
|
||||
input_mode = st.pills(label="Select input method", options=["Upload Image", "Enter Image URL/Path", "Select Test Image"], key="input_mode", label_visibility="collapsed")
|
||||
|
||||
if input_mode == "Upload Image":
|
||||
# File uploader
|
||||
uploaded_file = st.file_uploader("Upload Image", type=["png", "jpg", "jpeg"])
|
||||
if uploaded_file is not None:
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix='.png') as tmp_file:
|
||||
tmp_file.write(uploaded_file.getvalue())
|
||||
return tmp_file.name
|
||||
elif input_mode == 'Enter Image URL/Path':
|
||||
# URL input
|
||||
img_url_input = st.text_input("Enter Image URL/Path")
|
||||
return img_url_input
|
||||
|
||||
elif input_mode == 'Select Test Image':
|
||||
# Test image selection
|
||||
test_images = []
|
||||
test_dir = DEFAULT_CONFIG['test_images_dir']
|
||||
if os.path.exists(test_dir):
|
||||
test_images = [os.path.join(test_dir, name) for name in os.listdir(test_dir)]
|
||||
img_url_test = st.selectbox("Select Test Image", [""] + test_images)
|
||||
return img_url_test
|
||||
else:
|
||||
raise ValueError(f"Invalid input mode: {input_mode}")
|
||||
|
||||
return None
|
||||
|
||||
|
||||
|
||||
def process_and_display_results(output: str, image: Image.Image, config: dict):
|
||||
"""Process and display inference results"""
|
||||
prompt, response = output['prompt'], output['response']
|
||||
|
||||
try:
|
||||
col1, col2 = st.columns(2)
|
||||
# st.markdown('---')
|
||||
cells = json.loads(response)
|
||||
# image = Image.open(img_url)
|
||||
|
||||
# Post-processing
|
||||
cells = post_process_cells(
|
||||
image, cells,
|
||||
image.width, image.height,
|
||||
min_pixels=config['min_pixels'],
|
||||
max_pixels=config['max_pixels']
|
||||
)
|
||||
|
||||
# Calculate input dimensions
|
||||
input_width, input_height = get_input_dimensions(
|
||||
image,
|
||||
min_pixels=config['min_pixels'],
|
||||
max_pixels=config['max_pixels']
|
||||
)
|
||||
st.markdown('---')
|
||||
st.write(f'Input Dimensions: {input_width} x {input_height}')
|
||||
# st.write(f'Prompt: {prompt}')
|
||||
# st.markdown(f'模型原始输出: <span style="color:blue">{result}</span>', unsafe_allow_html=True)
|
||||
# st.write('模型原始输出:')
|
||||
# st.write(response)
|
||||
# st.write('后处理结果:', str(cells))
|
||||
st.text_area('Original Model Output', response, height=200)
|
||||
st.text_area('Post-processed Result', str(cells), height=200)
|
||||
# 显示结果
|
||||
# st.title("Layout推理结果")
|
||||
|
||||
with col1:
|
||||
# st.markdown("##### 可视化结果")
|
||||
new_image = draw_layout_on_image(
|
||||
image, cells,
|
||||
resized_height=None, resized_width=None,
|
||||
# text_key='text',
|
||||
fill_bbox=True, draw_bbox=True
|
||||
)
|
||||
st.markdown('##### Visualization Result')
|
||||
st.image(new_image, width=new_image.width)
|
||||
# st.write(f"尺寸: {new_image.width} x {new_image.height}")
|
||||
|
||||
with col2:
|
||||
# st.markdown("##### Markdown格式")
|
||||
md_code = layoutjson2md(image, cells, text_key='text')
|
||||
# md_code = fix_streamlit_formula(md_code)
|
||||
st.markdown('##### Markdown Format')
|
||||
st.markdown(md_code, unsafe_allow_html=True)
|
||||
|
||||
except json.JSONDecodeError:
|
||||
st.error("Model output is not a valid JSON format")
|
||||
except Exception as e:
|
||||
st.error(f"Error processing results: {e}")
|
||||
|
||||
# ==================== Main Application ====================
|
||||
def main():
|
||||
"""Main application function"""
|
||||
st.set_page_config(page_title="Layout Inference Tool", layout="wide")
|
||||
st.title("🔍 Layout Inference Tool")
|
||||
|
||||
# Configuration
|
||||
config = create_config_sidebar()
|
||||
prompt = dict_promptmode_to_prompt[config['prompt_key']]
|
||||
st.sidebar.info(f"Current Prompt: {prompt}")
|
||||
|
||||
# Image input
|
||||
img_url = get_image_input()
|
||||
start_button = st.button('🚀 Start Inference', type="primary")
|
||||
|
||||
if img_url is not None and img_url.strip() != "":
|
||||
try:
|
||||
# processed_image = read_image_v2(img_url)
|
||||
origin_image = read_image_v2(img_url)
|
||||
st.write(f"Original Dimensions: {origin_image.width} x {origin_image.height}")
|
||||
# processed_image = get_image_by_fitz_doc(origin_image, target_dpi=200)
|
||||
processed_image = origin_image
|
||||
except Exception as e:
|
||||
st.error(f"Failed to read image: {e}")
|
||||
return
|
||||
else:
|
||||
st.info("Please enter an image URL/path or upload an image")
|
||||
return
|
||||
|
||||
output = None
|
||||
# Inference button
|
||||
if start_button:
|
||||
with st.spinner(f"Inferring... Server: {config['ip']}:{config['port']}"):
|
||||
|
||||
response = inference_with_vllm(
|
||||
processed_image, prompt, config['ip'], config['port'],
|
||||
# config['min_pixels'], config['max_pixels']
|
||||
)
|
||||
output = {
|
||||
'prompt': prompt,
|
||||
'response': response,
|
||||
}
|
||||
else:
|
||||
st.image(processed_image, width=500)
|
||||
|
||||
# Process results
|
||||
if output:
|
||||
process_and_display_results(output, processed_image, config)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Executable
+41
@@ -0,0 +1,41 @@
|
||||
import argparse
|
||||
import os
|
||||
|
||||
from openai import OpenAI
|
||||
from transformers.utils.versions import require_version
|
||||
from PIL import Image
|
||||
import io
|
||||
import base64
|
||||
from dots_ocr.utils import dict_promptmode_to_prompt
|
||||
from dots_ocr.model.inference import inference_with_vllm
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--ip", type=str, default="localhost")
|
||||
parser.add_argument("--port", type=str, default="8000")
|
||||
parser.add_argument("--model_name", type=str, default="model")
|
||||
parser.add_argument("--prompt_mode", type=str, default="prompt_layout_all_en")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
require_version("openai>=1.5.0", "To fix: pip install openai>=1.5.0")
|
||||
|
||||
|
||||
def main():
|
||||
addr = f"http://{args.ip}:{args.port}/v1"
|
||||
image_path = "demo/demo_image1.jpg"
|
||||
prompt = dict_promptmode_to_prompt[args.prompt_mode]
|
||||
image = Image.open(image_path)
|
||||
response = inference_with_vllm(
|
||||
image,
|
||||
prompt,
|
||||
ip="localhost",
|
||||
port=8000,
|
||||
temperature=0.1,
|
||||
top_p=0.9,
|
||||
)
|
||||
print(f"response: {response}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Executable
+17
@@ -0,0 +1,17 @@
|
||||
# download model to /path/to/model
|
||||
if [ -z "$NODOWNLOAD" ]; then
|
||||
python3 tools/download_model.py
|
||||
fi
|
||||
|
||||
# register model to vllm
|
||||
hf_model_path=./weights/DotsOCR # Path to your downloaded model weights
|
||||
export PYTHONPATH=$(dirname "$hf_model_path"):$PYTHONPATH
|
||||
sed -i '/^from vllm\.entrypoints\.cli\.main import main$/a\
|
||||
from DotsOCR import modeling_dots_ocr_vllm' `which vllm`
|
||||
|
||||
# launch vllm server
|
||||
model_name=model
|
||||
CUDA_VISIBLE_DEVICES=0 vllm serve ${hf_model_path} --tensor-parallel-size 1 --gpu-memory-utilization 0.95 --chat-template-content-format string --served-model-name ${model_name} --trust-remote-code
|
||||
|
||||
# # run python demo after launch vllm server
|
||||
# python demo/demo_vllm.py
|
||||
Reference in New Issue
Block a user