This commit is contained in:
zhangwei13
2026-03-24 22:43:01 +08:00
parent d9ea2a4108
commit 36d7248878
59 changed files with 396 additions and 390 deletions
+268 -54
View File
@@ -35,21 +35,143 @@ DEFAULT_CONFIG = {
'port_vllm': 8000,
'min_pixels': MIN_PIXELS,
'max_pixels': MAX_PIXELS,
'test_images_dir': "./assets/showcase_origin",
'test_images_dir': "./assets/showcase/origin",
}
# ==================== Multi-Model Server Configuration ====================
MODEL_SERVERS = {
"dots.mocr": {
'ip': "127.0.0.1",
'port_vllm': 8000,
'description': "dots.mocr"
},
"dots.mocr-svg": {
'ip': "127.0.0.1",
'port_vllm': 8000, # 请根据实际情况修改端口
'description': "dots.mocr-svg"
},
}
#每个prompt的预处理写死
PROMPT_TO_FITZ_PREPROCESS = {
"prompt_layout_all_en": True, # 文档布局分析 - 启用预处理
"prompt_layout_only_en": True, # 仅布局检测 - 启用预处理
"prompt_ocr": True, # 仅文字识别 - 启用预处理
"prompt_web_parsing": False, # 网页解析 - 禁用预处理
"prompt_scene_spotting": False, # 场景检测 - 禁用预处理
"prompt_image_to_svg": False, # SVG 转换 - 禁用预处理
"prompt_general": False, # 自由问答 - 禁用预处理
}
#不同任务需要不同temperature
PROMPT_TO_TEMPERATURE = {
"prompt_layout_all_en": 0.1, # 文档布局分析 - 低温度,更确定性
"prompt_layout_only_en": 0.1, # 仅布局检测 - 低温度
"prompt_ocr": 0.1, # OCR 识别 - 低温度
"prompt_web_parsing": 0.1, # 网页解析 - 稍高一点
"prompt_scene_spotting": 0.1, # 场景检测 - 中等温度
"prompt_image_to_svg": 0.9, # SVG 转换 - 较低温度
"prompt_general": 0.1, # 自由问答 - 高温度,更有创造性
}
# 不同prompt_mode对应的模型
PROMPT_TO_MODEL = {
"prompt_image_to_svg": "dots.mocr-svg", # SVG任务使用SVG模型
}
# ==================== Demo Case Configuration ====================
# 根据文件名自动选择 prompt_mode 和预设的 custom_prompt
DEMO_CASE_CONFIG = {
# 格式: "文件名关键字": {"prompt_mode": "xxx", "custom_prompt": "xxx"}
# 布局分析类
"doc": {"prompt_mode": "prompt_layout_all_en"},
"formula": {"prompt_mode": "prompt_layout_all_en"},
"table": {"prompt_mode": "prompt_layout_all_en"},
# 仅布局检测
"detect": {"prompt_mode": "prompt_layout_only_en"},
# OCR 识别
"ocr": {"prompt_mode": "prompt_ocr"},
# 网页解析
"webpage": {"prompt_mode": "prompt_web_parsing"},
# 场景文字检测
"scene": {"prompt_mode": "prompt_scene_spotting"},
# SVG 转换
"svg": {"prompt_mode": "prompt_image_to_svg"},
# QA 任务(带预设 prompt
"general_qa": {
"prompt_mode": "prompt_general",
"custom_prompt": "Across panels 1-12 plotting against clean accuracy, which variable appears most positively correlated with clean accuracy?"
},
}
# 默认配置(找不到匹配时使用)
DEFAULT_DEMO_CONFIG = {"prompt_mode": "prompt_layout_all_en"}
def get_config_for_file(file_path):
"""
根据文件名自动匹配 prompt_mode 和 custom_prompt
支持部分匹配(文件名包含关键字即可)
"""
if not file_path:
return DEFAULT_DEMO_CONFIG.copy()
filename = os.path.basename(file_path).lower()
# 遍历配置字典,查找匹配的关键字
for keyword, config in DEMO_CASE_CONFIG.items():
if keyword.lower() in filename:
return config.copy()
# 没有匹配则返回默认配置
return DEFAULT_DEMO_CONFIG.copy()
# ==================== Global Variables ====================
# Store current configuration
current_config = DEFAULT_CONFIG.copy()
# Create DotsOCRParser instance
dots_parser = DotsOCRParser(
ip=DEFAULT_CONFIG['ip'],
port=DEFAULT_CONFIG['port_vllm'],
dpi=200,
min_pixels=DEFAULT_CONFIG['min_pixels'],
max_pixels=DEFAULT_CONFIG['max_pixels']
)
# Parser cache for multiple models
_parser_cache = {}
def get_parser(model_name: str, min_pixels: int = None, max_pixels: int = None) -> DotsMOCRParser:
"""
Get or create a parser instance for the specified model.
Uses cache to avoid recreating parsers for the same model.
"""
if model_name not in MODEL_SERVERS:
raise ValueError(f"Unknown model: {model_name}")
model_config = MODEL_SERVERS[model_name]
# Create cache key based on model and pixel settings
cache_key = model_name
# If parser exists in cache, update its settings and return
if cache_key in _parser_cache:
parser = _parser_cache[cache_key]
parser.min_pixels = min_pixels or DEFAULT_CONFIG['min_pixels']
parser.max_pixels = max_pixels or DEFAULT_CONFIG['max_pixels']
return parser
# Create new parser instance
parser = DotsMOCRParser(
ip=model_config['ip'],
port=model_config['port_vllm'],
dpi=200,
min_pixels=min_pixels or DEFAULT_CONFIG['min_pixels'],
max_pixels=max_pixels or DEFAULT_CONFIG['max_pixels']
)
_parser_cache[cache_key] = parser
return parser
def get_initial_session_state():
return {
@@ -71,7 +193,8 @@ def get_initial_session_state():
"file_type": None,
"is_parsed": False,
"results": []
}
},
'auto_custom_prompt': None,
}
def read_image_v2(img):
@@ -118,6 +241,46 @@ def load_file_for_preview(file_path, session_state):
return pages[0], f"<div id='page_info_box'>1 / {len(pages)}</div>", session_state
def on_test_image_select(file_path, session_state):
"""选择测试图片时的回调:加载预览 + 自动设置 prompt_mode + 自动切换模型"""
preview_image, page_info, session_state = load_file_for_preview(file_path, session_state)
if not file_path:
return (
preview_image,
page_info,
session_state,
gr.update(),
gr.update(),
gr.update()
)
auto_config = get_config_for_file(file_path)
prompt_mode_value = auto_config["prompt_mode"]
custom_prompt_value = auto_config.get("custom_prompt", "")
session_state['auto_custom_prompt'] = custom_prompt_value if custom_prompt_value else None
is_free_qa = prompt_mode_value == 'prompt_general'
if is_free_qa and custom_prompt_value:
prompt_text = custom_prompt_value
else:
prompt_text = update_prompt_display(prompt_mode_value)
# 根据prompt_mode自动选择模型
auto_model = PROMPT_TO_MODEL.get(prompt_mode_value, list(MODEL_SERVERS.keys())[0])
return (
preview_image,
page_info,
session_state,
gr.update(value=prompt_mode_value),
gr.update(value=prompt_text, interactive=is_free_qa),
gr.update(value=auto_model),
)
def turn_page(direction, session_state):
"""Page turning function"""
pdf_cache = session_state['pdf_cache']
@@ -152,20 +315,23 @@ def get_test_images():
test_images = []
test_dir = current_config['test_images_dir']
if os.path.exists(test_dir):
test_images = [os.path.join(test_dir, name) for name in os.listdir(test_dir)
if name.lower().endswith(('.png', '.jpg', '.jpeg', '.pdf'))]
test_images = sorted([
os.path.join(test_dir, name)
for name in os.listdir(test_dir)
if name.lower().endswith(('.png', '.jpg', '.jpeg', '.pdf'))
])
return test_images
def create_temp_session_dir():
"""Creates a unique temporary directory for each processing request"""
session_id = uuid.uuid4().hex[:8]
temp_dir = os.path.join(tempfile.gettempdir(), f"dots_ocr_demo_{session_id}")
temp_dir = os.path.join(tempfile.gettempdir(), f"dots_mocr_demo_{session_id}")
os.makedirs(temp_dir, exist_ok=True)
return temp_dir, session_id
def parse_image_with_high_level_api(parser, image, prompt_mode, fitz_preprocess=False):
def parse_image_with_high_level_api(parser, image, prompt_mode, fitz_preprocess=False, custom_prompt=None, temperature=None):
"""
Processes using the high-level API parse_image from DotsOCRParser
Processes using the high-level API parse_image from DotsMOCRParser
"""
# Create a temporary session directory
temp_dir, session_id = create_temp_session_dir()
@@ -182,7 +348,9 @@ def parse_image_with_high_level_api(parser, image, prompt_mode, fitz_preprocess=
filename=filename,
prompt_mode=prompt_mode,
save_dir=temp_dir,
fitz_preprocess=fitz_preprocess
fitz_preprocess=fitz_preprocess,
custom_prompt=custom_prompt,
temperature=temperature,
)
# Parse the results
@@ -223,7 +391,7 @@ def parse_image_with_high_level_api(parser, image, prompt_mode, fitz_preprocess=
def parse_pdf_with_high_level_api(parser, pdf_path, prompt_mode):
"""
Processes using the high-level API parse_pdf from DotsOCRParser
Processes using the high-level API parse_pdf from DotsMOCRParser
"""
# Create a temporary session directory
temp_dir, session_id = create_temp_session_dir()
@@ -292,8 +460,9 @@ def parse_pdf_with_high_level_api(parser, pdf_path, prompt_mode):
# ==================== Core Processing Function ====================
def process_image_inference(session_state, test_image_input, file_input,
prompt_mode, server_ip, server_port, min_pixels, max_pixels,
fitz_preprocess=False
prompt_mode, model_selector, # Changed: use model_selector instead of server_ip/port
min_pixels, max_pixels,
fitz_preprocess=False, custom_prompt=""
):
"""Core function to handle image/PDF inference"""
# Use session_state instead of global variables
@@ -310,18 +479,23 @@ def process_image_inference(session_state, test_image_input, file_input,
session_state['processing_results'] = get_initial_session_state()['processing_results']
processing_results = session_state['processing_results']
fitz_preprocess = PROMPT_TO_FITZ_PREPROCESS.get(prompt_mode, True)
temperature = PROMPT_TO_TEMPERATURE.get(prompt_mode, 0.1)
print(temperature)
# Get the selected model configuration
model_config = MODEL_SERVERS[model_selector]
current_config.update({
'ip': server_ip,
'port_vllm': server_port,
'ip': model_config['ip'],
'port_vllm': model_config['port_vllm'],
'min_pixels': min_pixels,
'max_pixels': max_pixels
})
# Update parser configuration
dots_parser.ip = server_ip
dots_parser.port = server_port
dots_parser.min_pixels = min_pixels
dots_parser.max_pixels = max_pixels
# Get parser for the selected model
try:
dots_parser = get_parser(model_selector, min_pixels, max_pixels)
except ValueError as e:
return None, f"Error: {str(e)}", "", "", gr.update(value=None), None, "", session_state
input_file_path = file_input if file_input else test_image_input
@@ -348,7 +522,7 @@ def process_image_inference(session_state, test_image_input, file_input,
})
total_elements = len(pdf_result['combined_cells_data'])
info_text = f"**PDF Information:**\n- Total Pages: {pdf_result['total_pages']}\n- Server: {current_config['ip']}:{current_config['port_vllm']}\n- Total Detected Elements: {total_elements}\n- Session ID: {pdf_result['session_id']}"
info_text = f"**PDF Information:**\n- Total Pages: {pdf_result['total_pages']}\n- Model: {model_selector}\n- Server: {model_config['ip']}:{model_config['port_vllm']}\n- Total Detected Elements: {total_elements}\n- Session ID: {pdf_result['session_id']}"
current_page_layout_image = preview_image
current_page_json = ""
@@ -381,10 +555,11 @@ def process_image_inference(session_state, test_image_input, file_input,
session_state['pdf_cache'] = get_initial_session_state()['pdf_cache']
original_image = image
parse_result = parse_image_with_high_level_api(dots_parser, image, prompt_mode, fitz_preprocess)
effective_custom_prompt = custom_prompt if prompt_mode == 'prompt_general' else None
parse_result = parse_image_with_high_level_api(dots_parser, image, prompt_mode, fitz_preprocess, effective_custom_prompt, temperature)
if parse_result['filtered']:
info_text = f"**Image Information:**\n- Original Size: {original_image.width} x {original_image.height}\n- Processing: JSON parsing failed, using cleaned text output\n- Server: {current_config['ip']}:{current_config['port_vllm']}\n- Session ID: {parse_result['session_id']}"
info_text = f"**Image Information:**\n- Original Size: {original_image.width} x {original_image.height}\n- Model: {model_selector}\n- Processing: JSON parsing failed, using cleaned text output\n- Server: {model_config['ip']}:{model_config['port_vllm']}\n- Session ID: {parse_result['session_id']}"
processing_results.update({
'original_image': original_image, 'markdown_content': parse_result['md_content'],
'temp_dir': parse_result['temp_dir'], 'session_id': parse_result['session_id'],
@@ -401,7 +576,7 @@ def process_image_inference(session_state, test_image_input, file_input,
})
num_elements = len(parse_result['cells_data']) if parse_result['cells_data'] else 0
info_text = f"**Image Information:**\n- Original Size: {original_image.width} x {original_image.height}\n- Model Input Size: {parse_result['input_width']} x {parse_result['input_height']}\n- Server: {current_config['ip']}:{current_config['port_vllm']}\n- Detected {num_elements} layout elements\n- Session ID: {parse_result['session_id']}"
info_text = f"**Image Information:**\n- Original Size: {original_image.width} x {original_image.height}\n- Model Input Size: {parse_result['input_width']} x {parse_result['input_height']}\n- Model: {model_selector}\n- Server: {model_config['ip']}:{model_config['port_vllm']}\n- Detected {num_elements} layout elements\n- Session ID: {parse_result['session_id']}"
current_json = json.dumps(parse_result['cells_data'], ensure_ascii=False, indent=2) if parse_result['cells_data'] else ""
@@ -452,6 +627,8 @@ def clear_all_data(session_state):
def update_prompt_display(prompt_mode):
"""Updates the prompt display content"""
if prompt_mode == 'prompt_general':
return "" # free_qa 模式下清空,让用户输入
return dict_promptmode_to_prompt[prompt_mode]
# ==================== Gradio Interface ====================
@@ -515,18 +692,22 @@ def create_gradio_interface():
#markdown_tabs {
height: 100%;
}
#model_selector_box {
margin-bottom: 8px;
}
"""
with gr.Blocks(theme="ocean", css=css, title='dots.ocr') as demo:
with gr.Blocks(theme="ocean", css=css, title='dots.mocr') as demo:
session_state = gr.State(value=get_initial_session_state())
# Title
gr.HTML("""
<div style="display: flex; align-items: center; justify-content: center; margin-bottom: 20px;">
<h1 style="margin: 0; font-size: 2em;">🔍 dots.ocr</h1>
<h1 style="margin: 0; font-size: 2em;">🔍 dots.mocr</h1>
</div>
<div style="text-align: center; margin-bottom: 10px;">
<em>Supports image/PDF layout analysis and structured output</em>
<em>Recognize Any Human Scripts and Symbols</em>
</div>
""")
@@ -540,6 +721,15 @@ def create_gradio_interface():
file_types=[".pdf", ".jpg", ".jpeg", ".png"],
)
# ============ NEW: Model Selector ============
model_selector = gr.Dropdown(
label="🤖 Select Model",
choices=list(MODEL_SERVERS.keys()),
value=list(MODEL_SERVERS.keys())[0],
elem_id="model_selector_box",
info="Switch between different model servers"
)
test_images = get_test_images()
test_image_input = gr.Dropdown(
label="Or Select an Example",
@@ -550,7 +740,15 @@ def create_gradio_interface():
gr.Markdown("### ⚙️ Prompt & Actions")
prompt_mode = gr.Dropdown(
label="Select Prompt",
choices=["prompt_layout_all_en", "prompt_layout_only_en", "prompt_ocr"],
choices=[
"prompt_layout_all_en",
"prompt_web_parsing",
"prompt_scene_spotting",
"prompt_image_to_svg",
"prompt_general",
"prompt_layout_only_en",
"prompt_ocr",
],
value="prompt_layout_all_en",
)
@@ -560,8 +758,7 @@ def create_gradio_interface():
value=dict_promptmode_to_prompt[list(dict_promptmode_to_prompt.keys())[0]],
lines=4,
max_lines=8,
interactive=False,
show_copy_button=True
interactive=False, # 默认不可编辑,free_qa 模式下改为可编辑
)
with gr.Row():
@@ -572,11 +769,9 @@ def create_gradio_interface():
fitz_preprocess = gr.Checkbox(
label="Enable fitz_preprocess for images",
value=True,
info="Processes image via a PDF-like pipeline (image->pdf->200dpi image). Recommended if your image DPI is low."
info="Processes image via a PDF-like pipeline (image->pdf->200dpi image). Recommended if your image DPI is low.",
visible=False, ###直接隐藏,调用模型前根据prompt mode 写死
)
with gr.Row():
server_ip = gr.Textbox(label="Server IP", value=DEFAULT_CONFIG['ip'])
server_port = gr.Number(label="Port", value=DEFAULT_CONFIG['port_vllm'], precision=0)
with gr.Row():
min_pixels = gr.Number(label="Min Pixels", value=DEFAULT_CONFIG['min_pixels'], precision=0)
max_pixels = gr.Number(label="Max Pixels", value=DEFAULT_CONFIG['max_pixels'], precision=0)
@@ -590,7 +785,7 @@ def create_gradio_interface():
label="Layout Preview",
visible=True,
height=800,
show_label=False
show_label=False,
)
# Page navigation (shown during PDF preview)
@@ -621,7 +816,6 @@ def create_gradio_interface():
{"left": "$$", "right": "$$", "display": True},
{"left": "$", "right": "$", "display": False}
],
show_copy_button=False,
elem_id="markdown_output"
)
@@ -631,7 +825,6 @@ def create_gradio_interface():
label="Markdown Raw Text",
max_lines=100,
lines=38,
show_copy_button=True,
elem_id="markdown_output",
show_label=False
)
@@ -642,7 +835,6 @@ def create_gradio_interface():
label="Current Page JSON",
max_lines=100,
lines=38,
show_copy_button=True,
elem_id="markdown_output",
show_label=False
)
@@ -653,12 +845,33 @@ def create_gradio_interface():
"⬇️ Download Results",
visible=False
)
# When the prompt mode changes, update the display content
def update_prompt_and_interactive(prompt_mode, session_state):
"""更新 prompt_display 并自动切换模型"""
is_free_qa = prompt_mode == 'prompt_general'
auto_custom_prompt = session_state.get('auto_custom_prompt')
if is_free_qa and auto_custom_prompt:
prompt_text = auto_custom_prompt
interactive = True
else:
prompt_text = update_prompt_display(prompt_mode)
interactive = is_free_qa
# 根据prompt_mode自动选择模型
auto_model = PROMPT_TO_MODEL.get(prompt_mode, list(MODEL_SERVERS.keys())[0])
return (
gr.update(value=prompt_text, interactive=interactive),
session_state,
gr.update(value=auto_model),
)
prompt_mode.change(
fn=update_prompt_display,
inputs=prompt_mode,
outputs=prompt_display,
fn=update_prompt_and_interactive,
inputs=[prompt_mode, session_state],
outputs=[prompt_display, session_state, model_selector],
)
# Show preview on file upload
@@ -671,10 +884,9 @@ def create_gradio_interface():
# Also handle test image selection
test_image_input.change(
# fn=lambda path, state: load_file_for_preview(path, state),
fn=load_file_for_preview,
fn=on_test_image_select,
inputs=[test_image_input, session_state],
outputs=[result_image, page_info, session_state]
outputs=[result_image, page_info, session_state, prompt_mode, prompt_display, model_selector],
)
prev_btn.click(
@@ -689,12 +901,14 @@ def create_gradio_interface():
outputs=[result_image, page_info, current_page_json, session_state]
)
# ============ MODIFIED: process_btn.click with model_selector ============
process_btn.click(
fn=process_image_inference,
inputs=[
session_state, test_image_input, file_input,
prompt_mode, server_ip, server_port, min_pixels, max_pixels,
fitz_preprocess
prompt_mode, model_selector, # Changed: model_selector instead of server_ip/port
min_pixels, max_pixels,
fitz_preprocess, prompt_display
],
outputs=[
result_image, info_display, md_output, md_raw_output,
+3 -2
View File
@@ -53,13 +53,14 @@ def inference(image_path, prompt, model, processor):
if __name__ == "__main__":
# We recommend enabling flash_attention_2 for better acceleration and memory saving, especially in multi-image and video scenarios.
model_path = "./weights/DotsOCR"
# We recommend enabling flash_attention_2 or flash_attention_3 for better acceleration and memory saving, especially in multi-image and video scenarios.
model_path = "./weights/DotsMOCR"
model = AutoModelForCausalLM.from_pretrained(
model_path,
attn_implementation="flash_attention_2",
torch_dtype=torch.bfloat16,
device_map="auto",
# device_map="cpu", # ve里默认使用flash-attn,无法直接运行
trust_remote_code=True
)
processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
+1 -1
View File
@@ -10,7 +10,7 @@ from dots_ocr.model.inference import inference_with_vllm
parser = argparse.ArgumentParser()
parser.add_argument("--ip", type=str, default="localhost")
parser.add_argument("--port", type=str, default="8000")
parser.add_argument("--model_name", type=str, default="rednote-hilab/dots.ocr-1.5")
parser.add_argument("--model_name", type=str, default="rednote-hilab/dots.mocr")
parser.add_argument("--image_path", type=str, default="demo/demo_image1.jpg")
parser.add_argument("--prompt_mode", type=str, default="prompt_layout_all_en",help=(
"Choose a task prompt: "
+1 -1
View File
@@ -10,7 +10,7 @@ from dots_ocr.model.inference import inference_with_vllm
parser = argparse.ArgumentParser()
parser.add_argument("--ip", type=str, default="localhost")
parser.add_argument("--port", type=str, default="8000")
parser.add_argument("--model_name", type=str, default="rednote-hilab/dots.ocr-1.5")
parser.add_argument("--model_name", type=str, default="rednote-hilab/dots.mocr")
parser.add_argument("--custom_prompt", type=str, default="Please describe the content of this image.")
args = parser.parse_args()
+1 -1
View File
@@ -10,7 +10,7 @@ from dots_ocr.model.inference import inference_with_vllm
parser = argparse.ArgumentParser()
parser.add_argument("--ip", type=str, default="localhost")
parser.add_argument("--port", type=str, default="8000")
parser.add_argument("--model_name", type=str, default="rednote-hilab/dots.ocr-1.5-svg")
parser.add_argument("--model_name", type=str, default="rednote-hilab/dots.mocr")
parser.add_argument("--prompt_mode", type=str, default="prompt_image_to_svg")
args = parser.parse_args()
+1 -17
View File
@@ -1,17 +1 @@
# download model to /path/to/model
if [ -z "$NODOWNLOAD" ]; then
python3 tools/download_model.py
fi
# register model to vllm
hf_model_path=./weights/DotsOCR # Path to your downloaded model weights
export PYTHONPATH=$(dirname "$hf_model_path"):$PYTHONPATH
sed -i '/^from vllm\.entrypoints\.cli\.main import main$/a\
from DotsOCR import modeling_dots_ocr_vllm' `which vllm`
# launch vllm server
model_name=model
CUDA_VISIBLE_DEVICES=0 vllm serve ${hf_model_path} --tensor-parallel-size 1 --gpu-memory-utilization 0.95 --chat-template-content-format string --served-model-name ${model_name} --trust-remote-code
# # run python demo after launch vllm server
# python demo/demo_vllm.py
CUDA_VISIBLE_DEVICES=0 nohup vllm serve dots.mocr --tensor-parallel-size 1 --gpu-memory-utilization 0.9 --chat-template-content-format string --served-model-name ${model_name} --trust-remote-code