dots.ocr release
This commit is contained in:
Executable
+228
@@ -0,0 +1,228 @@
|
||||
from PIL import Image
|
||||
from typing import Dict, List
|
||||
|
||||
import fitz
|
||||
from io import BytesIO
|
||||
import json
|
||||
|
||||
from dots_ocr.utils.image_utils import smart_resize
|
||||
from dots_ocr.utils.consts import MIN_PIXELS, MAX_PIXELS
|
||||
from dots_ocr.utils.output_cleaner import OutputCleaner
|
||||
|
||||
|
||||
# Define a color map (using RGBA format)
|
||||
dict_layout_type_to_color = {
|
||||
"Text": (0, 128, 0, 256), # Green, translucent
|
||||
"Picture": (255, 0, 255, 256), # Magenta, translucent
|
||||
"Caption": (255, 165, 0, 256), # Orange, translucent
|
||||
"Section-header": (0, 255, 255, 256), # Cyan, translucent
|
||||
"Footnote": (0, 128, 0, 256), # Green, translucent
|
||||
"Formula": (128, 128, 128, 256), # Gray, translucent
|
||||
"Table": (255, 192, 203, 256), # Pink, translucent
|
||||
"Title": (255, 0, 0, 256), # Red, translucent
|
||||
"List-item": (0, 0, 255, 256), # Blue, translucent
|
||||
"Page-header": (0, 128, 0, 256), # Green, translucent
|
||||
"Page-footer": (128, 0, 128, 256), # Purple, translucent
|
||||
"Other": (165, 42, 42, 256), # Brown, translucent
|
||||
"Unknown": (0, 0, 0, 0),
|
||||
}
|
||||
|
||||
|
||||
def draw_layout_on_image(image, cells, resized_height=None, resized_width=None, fill_bbox=True, draw_bbox=True):
|
||||
"""
|
||||
Draw transparent boxes on an image.
|
||||
|
||||
Args:
|
||||
image: The source PIL Image.
|
||||
cells: A list of cells containing bounding box information.
|
||||
resized_height: The resized height.
|
||||
resized_width: The resized width.
|
||||
fill_bbox: Whether to fill the bounding box.
|
||||
draw_bbox: Whether to draw the bounding box.
|
||||
|
||||
Returns:
|
||||
PIL.Image: The image with drawings.
|
||||
"""
|
||||
# origin_image = Image.open(image_path)
|
||||
original_width, original_height = image.size
|
||||
|
||||
# Create a new PDF document
|
||||
doc = fitz.open()
|
||||
|
||||
# Get image information
|
||||
img_bytes = BytesIO()
|
||||
image.save(img_bytes, format='PNG')
|
||||
# pix = fitz.Pixmap(image_path)
|
||||
pix = fitz.Pixmap(img_bytes)
|
||||
|
||||
# Create a page
|
||||
page = doc.new_page(width=pix.width, height=pix.height)
|
||||
page.insert_image(
|
||||
fitz.Rect(0, 0, pix.width, pix.height),
|
||||
# filename=image_path
|
||||
pixmap=pix
|
||||
)
|
||||
|
||||
for i, cell in enumerate(cells):
|
||||
bbox = cell['bbox']
|
||||
layout_type = cell['category']
|
||||
order = i
|
||||
|
||||
top_left = (bbox[0], bbox[1])
|
||||
down_right = (bbox[2], bbox[3])
|
||||
if resized_height and resized_width:
|
||||
scale_x = resized_width / original_width
|
||||
scale_y = resized_height / original_height
|
||||
top_left = (int(bbox[0] / scale_x), int(bbox[1] / scale_y))
|
||||
down_right = (int(bbox[2] / scale_x), int(bbox[3] / scale_y))
|
||||
|
||||
color = dict_layout_type_to_color.get(layout_type, (0, 128, 0, 256))
|
||||
color = [col/255 for col in color[:3]]
|
||||
|
||||
x0, y0, x1, y1 = top_left[0], top_left[1], down_right[0], down_right[1]
|
||||
rect_coords = fitz.Rect(x0, y0, x1, y1)
|
||||
if draw_bbox:
|
||||
if fill_bbox:
|
||||
page.draw_rect(
|
||||
rect_coords,
|
||||
color=None,
|
||||
fill=color,
|
||||
fill_opacity=0.3,
|
||||
width=0.5,
|
||||
overlay=True,
|
||||
) # Draw the rectangle
|
||||
else:
|
||||
page.draw_rect(
|
||||
rect_coords,
|
||||
color=color,
|
||||
fill=None,
|
||||
fill_opacity=1,
|
||||
width=0.5,
|
||||
overlay=True,
|
||||
) # Draw the rectangle
|
||||
order_cate = f"{order}_{layout_type}"
|
||||
page.insert_text(
|
||||
(x1, y0 + 20), order_cate, fontsize=20, color=color
|
||||
) # Insert the index in the top left corner of the rectangle
|
||||
|
||||
# Convert to a Pixmap (maintaining original dimensions)
|
||||
mat = fitz.Matrix(1.0, 1.0)
|
||||
pix = page.get_pixmap(matrix=mat)
|
||||
|
||||
return Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
|
||||
|
||||
|
||||
def pre_process_bboxes(
|
||||
origin_image,
|
||||
bboxes,
|
||||
input_width,
|
||||
input_height,
|
||||
factor: int = 28,
|
||||
min_pixels: int = 3136,
|
||||
max_pixels: int = 11289600
|
||||
):
|
||||
assert isinstance(bboxes, list) and len(bboxes) > 0 and isinstance(bboxes[0], list)
|
||||
min_pixels = min_pixels or MIN_PIXELS
|
||||
max_pixels = max_pixels or MAX_PIXELS
|
||||
original_width, original_height = origin_image.size
|
||||
|
||||
input_height, input_width = smart_resize(input_height, input_width, min_pixels=min_pixels, max_pixels=max_pixels)
|
||||
|
||||
scale_x = original_width / input_width
|
||||
scale_y = original_height / input_height
|
||||
|
||||
bboxes_out = []
|
||||
for bbox in bboxes:
|
||||
bbox_resized = [
|
||||
int(float(bbox[0]) / scale_x),
|
||||
int(float(bbox[1]) / scale_y),
|
||||
int(float(bbox[2]) / scale_x),
|
||||
int(float(bbox[3]) / scale_y)
|
||||
]
|
||||
bboxes_out.append(bbox_resized)
|
||||
|
||||
return bboxes_out
|
||||
|
||||
def post_process_cells(
|
||||
origin_image: Image.Image,
|
||||
cells: List[Dict],
|
||||
input_width, # server input width, also has smart_resize in server
|
||||
input_height,
|
||||
factor: int = 28,
|
||||
min_pixels: int = 3136,
|
||||
max_pixels: int = 11289600
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
Post-processes cell bounding boxes, converting coordinates from the resized dimensions back to the original dimensions.
|
||||
|
||||
Args:
|
||||
origin_image: The original PIL Image.
|
||||
cells: A list of cells containing bounding box information.
|
||||
input_width: The width of the input image sent to the server.
|
||||
input_height: The height of the input image sent to the server.
|
||||
factor: Resizing factor.
|
||||
min_pixels: Minimum number of pixels.
|
||||
max_pixels: Maximum number of pixels.
|
||||
|
||||
Returns:
|
||||
A list of post-processed cells.
|
||||
"""
|
||||
assert isinstance(cells, list) and len(cells) > 0 and isinstance(cells[0], dict)
|
||||
min_pixels = min_pixels or MIN_PIXELS
|
||||
max_pixels = max_pixels or MAX_PIXELS
|
||||
original_width, original_height = origin_image.size
|
||||
|
||||
input_height, input_width = smart_resize(input_height, input_width, min_pixels=min_pixels, max_pixels=max_pixels)
|
||||
|
||||
scale_x = input_width / original_width
|
||||
scale_y = input_height / original_height
|
||||
|
||||
cells_out = []
|
||||
for cell in cells:
|
||||
bbox = cell['bbox']
|
||||
bbox_resized = [
|
||||
int(float(bbox[0]) / scale_x),
|
||||
int(float(bbox[1]) / scale_y),
|
||||
int(float(bbox[2]) / scale_x),
|
||||
int(float(bbox[3]) / scale_y)
|
||||
]
|
||||
cell_copy = cell.copy()
|
||||
cell_copy['bbox'] = bbox_resized
|
||||
cells_out.append(cell_copy)
|
||||
|
||||
return cells_out
|
||||
|
||||
def is_legal_bbox(cells):
|
||||
for cell in cells:
|
||||
bbox = cell['bbox']
|
||||
if bbox[2] <= bbox[0] or bbox[3] <= bbox[1]:
|
||||
return False
|
||||
return True
|
||||
|
||||
def post_process_output(response, prompt_mode, origin_image, input_image, min_pixels=None, max_pixels=None):
|
||||
if prompt_mode in ["prompt_ocr", "prompt_table_html", "prompt_table_latex", "prompt_formula_latex"]:
|
||||
return response
|
||||
|
||||
json_load_failed = False
|
||||
cells = response
|
||||
try:
|
||||
cells = json.loads(cells)
|
||||
cells = post_process_cells(
|
||||
origin_image,
|
||||
cells,
|
||||
input_image.width,
|
||||
input_image.height,
|
||||
min_pixels=min_pixels,
|
||||
max_pixels=max_pixels
|
||||
)
|
||||
return cells, False
|
||||
except Exception as e:
|
||||
print(f"cells post process error: {e}, when using {prompt_mode}")
|
||||
json_load_failed = True
|
||||
|
||||
if json_load_failed:
|
||||
cleaner = OutputCleaner()
|
||||
response_clean = cleaner.clean_model_output(cells)
|
||||
if isinstance(response_clean, list):
|
||||
response_clean = "\n\n".join([cell['text'] for cell in response_clean if 'text' in cell])
|
||||
return response_clean, True
|
||||
Reference in New Issue
Block a user