229 lines
7.6 KiB
Python
Executable File
229 lines
7.6 KiB
Python
Executable File
from PIL import Image
|
|
from typing import Dict, List
|
|
|
|
import fitz
|
|
from io import BytesIO
|
|
import json
|
|
|
|
from dots_ocr.utils.image_utils import smart_resize
|
|
from dots_ocr.utils.consts import MIN_PIXELS, MAX_PIXELS
|
|
from dots_ocr.utils.output_cleaner import OutputCleaner
|
|
|
|
|
|
# Define a color map (using RGBA format)
|
|
dict_layout_type_to_color = {
|
|
"Text": (0, 128, 0, 256), # Green, translucent
|
|
"Picture": (255, 0, 255, 256), # Magenta, translucent
|
|
"Caption": (255, 165, 0, 256), # Orange, translucent
|
|
"Section-header": (0, 255, 255, 256), # Cyan, translucent
|
|
"Footnote": (0, 128, 0, 256), # Green, translucent
|
|
"Formula": (128, 128, 128, 256), # Gray, translucent
|
|
"Table": (255, 192, 203, 256), # Pink, translucent
|
|
"Title": (255, 0, 0, 256), # Red, translucent
|
|
"List-item": (0, 0, 255, 256), # Blue, translucent
|
|
"Page-header": (0, 128, 0, 256), # Green, translucent
|
|
"Page-footer": (128, 0, 128, 256), # Purple, translucent
|
|
"Other": (165, 42, 42, 256), # Brown, translucent
|
|
"Unknown": (0, 0, 0, 0),
|
|
}
|
|
|
|
|
|
def draw_layout_on_image(image, cells, resized_height=None, resized_width=None, fill_bbox=True, draw_bbox=True):
|
|
"""
|
|
Draw transparent boxes on an image.
|
|
|
|
Args:
|
|
image: The source PIL Image.
|
|
cells: A list of cells containing bounding box information.
|
|
resized_height: The resized height.
|
|
resized_width: The resized width.
|
|
fill_bbox: Whether to fill the bounding box.
|
|
draw_bbox: Whether to draw the bounding box.
|
|
|
|
Returns:
|
|
PIL.Image: The image with drawings.
|
|
"""
|
|
# origin_image = Image.open(image_path)
|
|
original_width, original_height = image.size
|
|
|
|
# Create a new PDF document
|
|
doc = fitz.open()
|
|
|
|
# Get image information
|
|
img_bytes = BytesIO()
|
|
image.save(img_bytes, format='PNG')
|
|
# pix = fitz.Pixmap(image_path)
|
|
pix = fitz.Pixmap(img_bytes)
|
|
|
|
# Create a page
|
|
page = doc.new_page(width=pix.width, height=pix.height)
|
|
page.insert_image(
|
|
fitz.Rect(0, 0, pix.width, pix.height),
|
|
# filename=image_path
|
|
pixmap=pix
|
|
)
|
|
|
|
for i, cell in enumerate(cells):
|
|
bbox = cell['bbox']
|
|
layout_type = cell['category']
|
|
order = i
|
|
|
|
top_left = (bbox[0], bbox[1])
|
|
down_right = (bbox[2], bbox[3])
|
|
if resized_height and resized_width:
|
|
scale_x = resized_width / original_width
|
|
scale_y = resized_height / original_height
|
|
top_left = (int(bbox[0] / scale_x), int(bbox[1] / scale_y))
|
|
down_right = (int(bbox[2] / scale_x), int(bbox[3] / scale_y))
|
|
|
|
color = dict_layout_type_to_color.get(layout_type, (0, 128, 0, 256))
|
|
color = [col/255 for col in color[:3]]
|
|
|
|
x0, y0, x1, y1 = top_left[0], top_left[1], down_right[0], down_right[1]
|
|
rect_coords = fitz.Rect(x0, y0, x1, y1)
|
|
if draw_bbox:
|
|
if fill_bbox:
|
|
page.draw_rect(
|
|
rect_coords,
|
|
color=None,
|
|
fill=color,
|
|
fill_opacity=0.3,
|
|
width=0.5,
|
|
overlay=True,
|
|
) # Draw the rectangle
|
|
else:
|
|
page.draw_rect(
|
|
rect_coords,
|
|
color=color,
|
|
fill=None,
|
|
fill_opacity=1,
|
|
width=0.5,
|
|
overlay=True,
|
|
) # Draw the rectangle
|
|
order_cate = f"{order}_{layout_type}"
|
|
page.insert_text(
|
|
(x1, y0 + 20), order_cate, fontsize=20, color=color
|
|
) # Insert the index in the top left corner of the rectangle
|
|
|
|
# Convert to a Pixmap (maintaining original dimensions)
|
|
mat = fitz.Matrix(1.0, 1.0)
|
|
pix = page.get_pixmap(matrix=mat)
|
|
|
|
return Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
|
|
|
|
|
|
def pre_process_bboxes(
|
|
origin_image,
|
|
bboxes,
|
|
input_width,
|
|
input_height,
|
|
factor: int = 28,
|
|
min_pixels: int = 3136,
|
|
max_pixels: int = 11289600
|
|
):
|
|
assert isinstance(bboxes, list) and len(bboxes) > 0 and isinstance(bboxes[0], list)
|
|
min_pixels = min_pixels or MIN_PIXELS
|
|
max_pixels = max_pixels or MAX_PIXELS
|
|
original_width, original_height = origin_image.size
|
|
|
|
input_height, input_width = smart_resize(input_height, input_width, min_pixels=min_pixels, max_pixels=max_pixels)
|
|
|
|
scale_x = original_width / input_width
|
|
scale_y = original_height / input_height
|
|
|
|
bboxes_out = []
|
|
for bbox in bboxes:
|
|
bbox_resized = [
|
|
int(float(bbox[0]) / scale_x),
|
|
int(float(bbox[1]) / scale_y),
|
|
int(float(bbox[2]) / scale_x),
|
|
int(float(bbox[3]) / scale_y)
|
|
]
|
|
bboxes_out.append(bbox_resized)
|
|
|
|
return bboxes_out
|
|
|
|
def post_process_cells(
|
|
origin_image: Image.Image,
|
|
cells: List[Dict],
|
|
input_width, # server input width, also has smart_resize in server
|
|
input_height,
|
|
factor: int = 28,
|
|
min_pixels: int = 3136,
|
|
max_pixels: int = 11289600
|
|
) -> List[Dict]:
|
|
"""
|
|
Post-processes cell bounding boxes, converting coordinates from the resized dimensions back to the original dimensions.
|
|
|
|
Args:
|
|
origin_image: The original PIL Image.
|
|
cells: A list of cells containing bounding box information.
|
|
input_width: The width of the input image sent to the server.
|
|
input_height: The height of the input image sent to the server.
|
|
factor: Resizing factor.
|
|
min_pixels: Minimum number of pixels.
|
|
max_pixels: Maximum number of pixels.
|
|
|
|
Returns:
|
|
A list of post-processed cells.
|
|
"""
|
|
assert isinstance(cells, list) and len(cells) > 0 and isinstance(cells[0], dict)
|
|
min_pixels = min_pixels or MIN_PIXELS
|
|
max_pixels = max_pixels or MAX_PIXELS
|
|
original_width, original_height = origin_image.size
|
|
|
|
input_height, input_width = smart_resize(input_height, input_width, min_pixels=min_pixels, max_pixels=max_pixels)
|
|
|
|
scale_x = input_width / original_width
|
|
scale_y = input_height / original_height
|
|
|
|
cells_out = []
|
|
for cell in cells:
|
|
bbox = cell['bbox']
|
|
bbox_resized = [
|
|
int(float(bbox[0]) / scale_x),
|
|
int(float(bbox[1]) / scale_y),
|
|
int(float(bbox[2]) / scale_x),
|
|
int(float(bbox[3]) / scale_y)
|
|
]
|
|
cell_copy = cell.copy()
|
|
cell_copy['bbox'] = bbox_resized
|
|
cells_out.append(cell_copy)
|
|
|
|
return cells_out
|
|
|
|
def is_legal_bbox(cells):
|
|
for cell in cells:
|
|
bbox = cell['bbox']
|
|
if bbox[2] <= bbox[0] or bbox[3] <= bbox[1]:
|
|
return False
|
|
return True
|
|
|
|
def post_process_output(response, prompt_mode, origin_image, input_image, min_pixels=None, max_pixels=None):
|
|
if prompt_mode in ["prompt_ocr", "prompt_table_html", "prompt_table_latex", "prompt_formula_latex"]:
|
|
return response
|
|
|
|
json_load_failed = False
|
|
cells = response
|
|
try:
|
|
cells = json.loads(cells)
|
|
cells = post_process_cells(
|
|
origin_image,
|
|
cells,
|
|
input_image.width,
|
|
input_image.height,
|
|
min_pixels=min_pixels,
|
|
max_pixels=max_pixels
|
|
)
|
|
return cells, False
|
|
except Exception as e:
|
|
print(f"cells post process error: {e}, when using {prompt_mode}")
|
|
json_load_failed = True
|
|
|
|
if json_load_failed:
|
|
cleaner = OutputCleaner()
|
|
response_clean = cleaner.clean_model_output(cells)
|
|
if isinstance(response_clean, list):
|
|
response_clean = "\n\n".join([cell['text'] for cell in response_clean if 'text' in cell])
|
|
return response_clean, True
|