align
This commit is contained in:
+268
-54
@@ -35,21 +35,143 @@ DEFAULT_CONFIG = {
|
||||
'port_vllm': 8000,
|
||||
'min_pixels': MIN_PIXELS,
|
||||
'max_pixels': MAX_PIXELS,
|
||||
'test_images_dir': "./assets/showcase_origin",
|
||||
'test_images_dir': "./assets/showcase/origin",
|
||||
}
|
||||
|
||||
# ==================== Multi-Model Server Configuration ====================
|
||||
MODEL_SERVERS = {
|
||||
"dots.mocr": {
|
||||
'ip': "127.0.0.1",
|
||||
'port_vllm': 8000,
|
||||
'description': "dots.mocr"
|
||||
},
|
||||
"dots.mocr-svg": {
|
||||
'ip': "127.0.0.1",
|
||||
'port_vllm': 8000, # 请根据实际情况修改端口
|
||||
'description': "dots.mocr-svg"
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
|
||||
#每个prompt的预处理写死
|
||||
PROMPT_TO_FITZ_PREPROCESS = {
|
||||
"prompt_layout_all_en": True, # 文档布局分析 - 启用预处理
|
||||
"prompt_layout_only_en": True, # 仅布局检测 - 启用预处理
|
||||
"prompt_ocr": True, # 仅文字识别 - 启用预处理
|
||||
"prompt_web_parsing": False, # 网页解析 - 禁用预处理
|
||||
"prompt_scene_spotting": False, # 场景检测 - 禁用预处理
|
||||
"prompt_image_to_svg": False, # SVG 转换 - 禁用预处理
|
||||
"prompt_general": False, # 自由问答 - 禁用预处理
|
||||
}
|
||||
|
||||
#不同任务需要不同temperature
|
||||
PROMPT_TO_TEMPERATURE = {
|
||||
"prompt_layout_all_en": 0.1, # 文档布局分析 - 低温度,更确定性
|
||||
"prompt_layout_only_en": 0.1, # 仅布局检测 - 低温度
|
||||
"prompt_ocr": 0.1, # OCR 识别 - 低温度
|
||||
"prompt_web_parsing": 0.1, # 网页解析 - 稍高一点
|
||||
"prompt_scene_spotting": 0.1, # 场景检测 - 中等温度
|
||||
"prompt_image_to_svg": 0.9, # SVG 转换 - 较低温度
|
||||
"prompt_general": 0.1, # 自由问答 - 高温度,更有创造性
|
||||
}
|
||||
|
||||
# 不同prompt_mode对应的模型
|
||||
PROMPT_TO_MODEL = {
|
||||
"prompt_image_to_svg": "dots.mocr-svg", # SVG任务使用SVG模型
|
||||
}
|
||||
|
||||
# ==================== Demo Case Configuration ====================
|
||||
# 根据文件名自动选择 prompt_mode 和预设的 custom_prompt
|
||||
DEMO_CASE_CONFIG = {
|
||||
# 格式: "文件名关键字": {"prompt_mode": "xxx", "custom_prompt": "xxx"}
|
||||
|
||||
# 布局分析类
|
||||
"doc": {"prompt_mode": "prompt_layout_all_en"},
|
||||
"formula": {"prompt_mode": "prompt_layout_all_en"},
|
||||
"table": {"prompt_mode": "prompt_layout_all_en"},
|
||||
|
||||
# 仅布局检测
|
||||
"detect": {"prompt_mode": "prompt_layout_only_en"},
|
||||
# OCR 识别
|
||||
"ocr": {"prompt_mode": "prompt_ocr"},
|
||||
|
||||
# 网页解析
|
||||
"webpage": {"prompt_mode": "prompt_web_parsing"},
|
||||
|
||||
# 场景文字检测
|
||||
"scene": {"prompt_mode": "prompt_scene_spotting"},
|
||||
|
||||
# SVG 转换
|
||||
"svg": {"prompt_mode": "prompt_image_to_svg"},
|
||||
|
||||
# QA 任务(带预设 prompt)
|
||||
"general_qa": {
|
||||
"prompt_mode": "prompt_general",
|
||||
"custom_prompt": "Across panels 1-12 plotting against clean accuracy, which variable appears most positively correlated with clean accuracy?"
|
||||
},
|
||||
|
||||
|
||||
}
|
||||
|
||||
# 默认配置(找不到匹配时使用)
|
||||
DEFAULT_DEMO_CONFIG = {"prompt_mode": "prompt_layout_all_en"}
|
||||
|
||||
def get_config_for_file(file_path):
|
||||
"""
|
||||
根据文件名自动匹配 prompt_mode 和 custom_prompt
|
||||
支持部分匹配(文件名包含关键字即可)
|
||||
"""
|
||||
if not file_path:
|
||||
return DEFAULT_DEMO_CONFIG.copy()
|
||||
|
||||
filename = os.path.basename(file_path).lower()
|
||||
|
||||
# 遍历配置字典,查找匹配的关键字
|
||||
for keyword, config in DEMO_CASE_CONFIG.items():
|
||||
if keyword.lower() in filename:
|
||||
return config.copy()
|
||||
|
||||
# 没有匹配则返回默认配置
|
||||
return DEFAULT_DEMO_CONFIG.copy()
|
||||
|
||||
# ==================== Global Variables ====================
|
||||
# Store current configuration
|
||||
current_config = DEFAULT_CONFIG.copy()
|
||||
|
||||
# Create DotsOCRParser instance
|
||||
dots_parser = DotsOCRParser(
|
||||
ip=DEFAULT_CONFIG['ip'],
|
||||
port=DEFAULT_CONFIG['port_vllm'],
|
||||
dpi=200,
|
||||
min_pixels=DEFAULT_CONFIG['min_pixels'],
|
||||
max_pixels=DEFAULT_CONFIG['max_pixels']
|
||||
)
|
||||
# Parser cache for multiple models
|
||||
_parser_cache = {}
|
||||
|
||||
def get_parser(model_name: str, min_pixels: int = None, max_pixels: int = None) -> DotsMOCRParser:
|
||||
"""
|
||||
Get or create a parser instance for the specified model.
|
||||
Uses cache to avoid recreating parsers for the same model.
|
||||
"""
|
||||
if model_name not in MODEL_SERVERS:
|
||||
raise ValueError(f"Unknown model: {model_name}")
|
||||
|
||||
model_config = MODEL_SERVERS[model_name]
|
||||
|
||||
# Create cache key based on model and pixel settings
|
||||
cache_key = model_name
|
||||
|
||||
# If parser exists in cache, update its settings and return
|
||||
if cache_key in _parser_cache:
|
||||
parser = _parser_cache[cache_key]
|
||||
parser.min_pixels = min_pixels or DEFAULT_CONFIG['min_pixels']
|
||||
parser.max_pixels = max_pixels or DEFAULT_CONFIG['max_pixels']
|
||||
return parser
|
||||
|
||||
# Create new parser instance
|
||||
parser = DotsMOCRParser(
|
||||
ip=model_config['ip'],
|
||||
port=model_config['port_vllm'],
|
||||
dpi=200,
|
||||
min_pixels=min_pixels or DEFAULT_CONFIG['min_pixels'],
|
||||
max_pixels=max_pixels or DEFAULT_CONFIG['max_pixels']
|
||||
)
|
||||
_parser_cache[cache_key] = parser
|
||||
return parser
|
||||
|
||||
def get_initial_session_state():
|
||||
return {
|
||||
@@ -71,7 +193,8 @@ def get_initial_session_state():
|
||||
"file_type": None,
|
||||
"is_parsed": False,
|
||||
"results": []
|
||||
}
|
||||
},
|
||||
'auto_custom_prompt': None,
|
||||
}
|
||||
|
||||
def read_image_v2(img):
|
||||
@@ -118,6 +241,46 @@ def load_file_for_preview(file_path, session_state):
|
||||
|
||||
return pages[0], f"<div id='page_info_box'>1 / {len(pages)}</div>", session_state
|
||||
|
||||
def on_test_image_select(file_path, session_state):
|
||||
"""选择测试图片时的回调:加载预览 + 自动设置 prompt_mode + 自动切换模型"""
|
||||
preview_image, page_info, session_state = load_file_for_preview(file_path, session_state)
|
||||
|
||||
if not file_path:
|
||||
return (
|
||||
preview_image,
|
||||
page_info,
|
||||
session_state,
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
gr.update()
|
||||
)
|
||||
|
||||
auto_config = get_config_for_file(file_path)
|
||||
prompt_mode_value = auto_config["prompt_mode"]
|
||||
custom_prompt_value = auto_config.get("custom_prompt", "")
|
||||
|
||||
session_state['auto_custom_prompt'] = custom_prompt_value if custom_prompt_value else None
|
||||
|
||||
is_free_qa = prompt_mode_value == 'prompt_general'
|
||||
if is_free_qa and custom_prompt_value:
|
||||
prompt_text = custom_prompt_value
|
||||
else:
|
||||
prompt_text = update_prompt_display(prompt_mode_value)
|
||||
|
||||
# 根据prompt_mode自动选择模型
|
||||
auto_model = PROMPT_TO_MODEL.get(prompt_mode_value, list(MODEL_SERVERS.keys())[0])
|
||||
|
||||
return (
|
||||
preview_image,
|
||||
page_info,
|
||||
session_state,
|
||||
gr.update(value=prompt_mode_value),
|
||||
gr.update(value=prompt_text, interactive=is_free_qa),
|
||||
gr.update(value=auto_model),
|
||||
)
|
||||
|
||||
|
||||
|
||||
def turn_page(direction, session_state):
|
||||
"""Page turning function"""
|
||||
pdf_cache = session_state['pdf_cache']
|
||||
@@ -152,20 +315,23 @@ def get_test_images():
|
||||
test_images = []
|
||||
test_dir = current_config['test_images_dir']
|
||||
if os.path.exists(test_dir):
|
||||
test_images = [os.path.join(test_dir, name) for name in os.listdir(test_dir)
|
||||
if name.lower().endswith(('.png', '.jpg', '.jpeg', '.pdf'))]
|
||||
test_images = sorted([
|
||||
os.path.join(test_dir, name)
|
||||
for name in os.listdir(test_dir)
|
||||
if name.lower().endswith(('.png', '.jpg', '.jpeg', '.pdf'))
|
||||
])
|
||||
return test_images
|
||||
|
||||
def create_temp_session_dir():
|
||||
"""Creates a unique temporary directory for each processing request"""
|
||||
session_id = uuid.uuid4().hex[:8]
|
||||
temp_dir = os.path.join(tempfile.gettempdir(), f"dots_ocr_demo_{session_id}")
|
||||
temp_dir = os.path.join(tempfile.gettempdir(), f"dots_mocr_demo_{session_id}")
|
||||
os.makedirs(temp_dir, exist_ok=True)
|
||||
return temp_dir, session_id
|
||||
|
||||
def parse_image_with_high_level_api(parser, image, prompt_mode, fitz_preprocess=False):
|
||||
def parse_image_with_high_level_api(parser, image, prompt_mode, fitz_preprocess=False, custom_prompt=None, temperature=None):
|
||||
"""
|
||||
Processes using the high-level API parse_image from DotsOCRParser
|
||||
Processes using the high-level API parse_image from DotsMOCRParser
|
||||
"""
|
||||
# Create a temporary session directory
|
||||
temp_dir, session_id = create_temp_session_dir()
|
||||
@@ -182,7 +348,9 @@ def parse_image_with_high_level_api(parser, image, prompt_mode, fitz_preprocess=
|
||||
filename=filename,
|
||||
prompt_mode=prompt_mode,
|
||||
save_dir=temp_dir,
|
||||
fitz_preprocess=fitz_preprocess
|
||||
fitz_preprocess=fitz_preprocess,
|
||||
custom_prompt=custom_prompt,
|
||||
temperature=temperature,
|
||||
)
|
||||
|
||||
# Parse the results
|
||||
@@ -223,7 +391,7 @@ def parse_image_with_high_level_api(parser, image, prompt_mode, fitz_preprocess=
|
||||
|
||||
def parse_pdf_with_high_level_api(parser, pdf_path, prompt_mode):
|
||||
"""
|
||||
Processes using the high-level API parse_pdf from DotsOCRParser
|
||||
Processes using the high-level API parse_pdf from DotsMOCRParser
|
||||
"""
|
||||
# Create a temporary session directory
|
||||
temp_dir, session_id = create_temp_session_dir()
|
||||
@@ -292,8 +460,9 @@ def parse_pdf_with_high_level_api(parser, pdf_path, prompt_mode):
|
||||
|
||||
# ==================== Core Processing Function ====================
|
||||
def process_image_inference(session_state, test_image_input, file_input,
|
||||
prompt_mode, server_ip, server_port, min_pixels, max_pixels,
|
||||
fitz_preprocess=False
|
||||
prompt_mode, model_selector, # Changed: use model_selector instead of server_ip/port
|
||||
min_pixels, max_pixels,
|
||||
fitz_preprocess=False, custom_prompt=""
|
||||
):
|
||||
"""Core function to handle image/PDF inference"""
|
||||
# Use session_state instead of global variables
|
||||
@@ -310,18 +479,23 @@ def process_image_inference(session_state, test_image_input, file_input,
|
||||
session_state['processing_results'] = get_initial_session_state()['processing_results']
|
||||
processing_results = session_state['processing_results']
|
||||
|
||||
fitz_preprocess = PROMPT_TO_FITZ_PREPROCESS.get(prompt_mode, True)
|
||||
temperature = PROMPT_TO_TEMPERATURE.get(prompt_mode, 0.1)
|
||||
print(temperature)
|
||||
# Get the selected model configuration
|
||||
model_config = MODEL_SERVERS[model_selector]
|
||||
current_config.update({
|
||||
'ip': server_ip,
|
||||
'port_vllm': server_port,
|
||||
'ip': model_config['ip'],
|
||||
'port_vllm': model_config['port_vllm'],
|
||||
'min_pixels': min_pixels,
|
||||
'max_pixels': max_pixels
|
||||
})
|
||||
|
||||
# Update parser configuration
|
||||
dots_parser.ip = server_ip
|
||||
dots_parser.port = server_port
|
||||
dots_parser.min_pixels = min_pixels
|
||||
dots_parser.max_pixels = max_pixels
|
||||
# Get parser for the selected model
|
||||
try:
|
||||
dots_parser = get_parser(model_selector, min_pixels, max_pixels)
|
||||
except ValueError as e:
|
||||
return None, f"Error: {str(e)}", "", "", gr.update(value=None), None, "", session_state
|
||||
|
||||
input_file_path = file_input if file_input else test_image_input
|
||||
|
||||
@@ -348,7 +522,7 @@ def process_image_inference(session_state, test_image_input, file_input,
|
||||
})
|
||||
|
||||
total_elements = len(pdf_result['combined_cells_data'])
|
||||
info_text = f"**PDF Information:**\n- Total Pages: {pdf_result['total_pages']}\n- Server: {current_config['ip']}:{current_config['port_vllm']}\n- Total Detected Elements: {total_elements}\n- Session ID: {pdf_result['session_id']}"
|
||||
info_text = f"**PDF Information:**\n- Total Pages: {pdf_result['total_pages']}\n- Model: {model_selector}\n- Server: {model_config['ip']}:{model_config['port_vllm']}\n- Total Detected Elements: {total_elements}\n- Session ID: {pdf_result['session_id']}"
|
||||
|
||||
current_page_layout_image = preview_image
|
||||
current_page_json = ""
|
||||
@@ -381,10 +555,11 @@ def process_image_inference(session_state, test_image_input, file_input,
|
||||
session_state['pdf_cache'] = get_initial_session_state()['pdf_cache']
|
||||
|
||||
original_image = image
|
||||
parse_result = parse_image_with_high_level_api(dots_parser, image, prompt_mode, fitz_preprocess)
|
||||
effective_custom_prompt = custom_prompt if prompt_mode == 'prompt_general' else None
|
||||
parse_result = parse_image_with_high_level_api(dots_parser, image, prompt_mode, fitz_preprocess, effective_custom_prompt, temperature)
|
||||
|
||||
if parse_result['filtered']:
|
||||
info_text = f"**Image Information:**\n- Original Size: {original_image.width} x {original_image.height}\n- Processing: JSON parsing failed, using cleaned text output\n- Server: {current_config['ip']}:{current_config['port_vllm']}\n- Session ID: {parse_result['session_id']}"
|
||||
info_text = f"**Image Information:**\n- Original Size: {original_image.width} x {original_image.height}\n- Model: {model_selector}\n- Processing: JSON parsing failed, using cleaned text output\n- Server: {model_config['ip']}:{model_config['port_vllm']}\n- Session ID: {parse_result['session_id']}"
|
||||
processing_results.update({
|
||||
'original_image': original_image, 'markdown_content': parse_result['md_content'],
|
||||
'temp_dir': parse_result['temp_dir'], 'session_id': parse_result['session_id'],
|
||||
@@ -401,7 +576,7 @@ def process_image_inference(session_state, test_image_input, file_input,
|
||||
})
|
||||
|
||||
num_elements = len(parse_result['cells_data']) if parse_result['cells_data'] else 0
|
||||
info_text = f"**Image Information:**\n- Original Size: {original_image.width} x {original_image.height}\n- Model Input Size: {parse_result['input_width']} x {parse_result['input_height']}\n- Server: {current_config['ip']}:{current_config['port_vllm']}\n- Detected {num_elements} layout elements\n- Session ID: {parse_result['session_id']}"
|
||||
info_text = f"**Image Information:**\n- Original Size: {original_image.width} x {original_image.height}\n- Model Input Size: {parse_result['input_width']} x {parse_result['input_height']}\n- Model: {model_selector}\n- Server: {model_config['ip']}:{model_config['port_vllm']}\n- Detected {num_elements} layout elements\n- Session ID: {parse_result['session_id']}"
|
||||
|
||||
current_json = json.dumps(parse_result['cells_data'], ensure_ascii=False, indent=2) if parse_result['cells_data'] else ""
|
||||
|
||||
@@ -452,6 +627,8 @@ def clear_all_data(session_state):
|
||||
|
||||
def update_prompt_display(prompt_mode):
|
||||
"""Updates the prompt display content"""
|
||||
if prompt_mode == 'prompt_general':
|
||||
return "" # free_qa 模式下清空,让用户输入
|
||||
return dict_promptmode_to_prompt[prompt_mode]
|
||||
|
||||
# ==================== Gradio Interface ====================
|
||||
@@ -515,18 +692,22 @@ def create_gradio_interface():
|
||||
#markdown_tabs {
|
||||
height: 100%;
|
||||
}
|
||||
|
||||
#model_selector_box {
|
||||
margin-bottom: 8px;
|
||||
}
|
||||
"""
|
||||
|
||||
with gr.Blocks(theme="ocean", css=css, title='dots.ocr') as demo:
|
||||
with gr.Blocks(theme="ocean", css=css, title='dots.mocr') as demo:
|
||||
session_state = gr.State(value=get_initial_session_state())
|
||||
|
||||
# Title
|
||||
gr.HTML("""
|
||||
<div style="display: flex; align-items: center; justify-content: center; margin-bottom: 20px;">
|
||||
<h1 style="margin: 0; font-size: 2em;">🔍 dots.ocr</h1>
|
||||
<h1 style="margin: 0; font-size: 2em;">🔍 dots.mocr</h1>
|
||||
</div>
|
||||
<div style="text-align: center; margin-bottom: 10px;">
|
||||
<em>Supports image/PDF layout analysis and structured output</em>
|
||||
<em>Recognize Any Human Scripts and Symbols</em>
|
||||
</div>
|
||||
""")
|
||||
|
||||
@@ -540,6 +721,15 @@ def create_gradio_interface():
|
||||
file_types=[".pdf", ".jpg", ".jpeg", ".png"],
|
||||
)
|
||||
|
||||
# ============ NEW: Model Selector ============
|
||||
model_selector = gr.Dropdown(
|
||||
label="🤖 Select Model",
|
||||
choices=list(MODEL_SERVERS.keys()),
|
||||
value=list(MODEL_SERVERS.keys())[0],
|
||||
elem_id="model_selector_box",
|
||||
info="Switch between different model servers"
|
||||
)
|
||||
|
||||
test_images = get_test_images()
|
||||
test_image_input = gr.Dropdown(
|
||||
label="Or Select an Example",
|
||||
@@ -550,7 +740,15 @@ def create_gradio_interface():
|
||||
gr.Markdown("### ⚙️ Prompt & Actions")
|
||||
prompt_mode = gr.Dropdown(
|
||||
label="Select Prompt",
|
||||
choices=["prompt_layout_all_en", "prompt_layout_only_en", "prompt_ocr"],
|
||||
choices=[
|
||||
"prompt_layout_all_en",
|
||||
"prompt_web_parsing",
|
||||
"prompt_scene_spotting",
|
||||
"prompt_image_to_svg",
|
||||
"prompt_general",
|
||||
"prompt_layout_only_en",
|
||||
"prompt_ocr",
|
||||
],
|
||||
value="prompt_layout_all_en",
|
||||
)
|
||||
|
||||
@@ -560,8 +758,7 @@ def create_gradio_interface():
|
||||
value=dict_promptmode_to_prompt[list(dict_promptmode_to_prompt.keys())[0]],
|
||||
lines=4,
|
||||
max_lines=8,
|
||||
interactive=False,
|
||||
show_copy_button=True
|
||||
interactive=False, # 默认不可编辑,free_qa 模式下改为可编辑
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
@@ -572,11 +769,9 @@ def create_gradio_interface():
|
||||
fitz_preprocess = gr.Checkbox(
|
||||
label="Enable fitz_preprocess for images",
|
||||
value=True,
|
||||
info="Processes image via a PDF-like pipeline (image->pdf->200dpi image). Recommended if your image DPI is low."
|
||||
info="Processes image via a PDF-like pipeline (image->pdf->200dpi image). Recommended if your image DPI is low.",
|
||||
visible=False, ###直接隐藏,调用模型前根据prompt mode 写死
|
||||
)
|
||||
with gr.Row():
|
||||
server_ip = gr.Textbox(label="Server IP", value=DEFAULT_CONFIG['ip'])
|
||||
server_port = gr.Number(label="Port", value=DEFAULT_CONFIG['port_vllm'], precision=0)
|
||||
with gr.Row():
|
||||
min_pixels = gr.Number(label="Min Pixels", value=DEFAULT_CONFIG['min_pixels'], precision=0)
|
||||
max_pixels = gr.Number(label="Max Pixels", value=DEFAULT_CONFIG['max_pixels'], precision=0)
|
||||
@@ -590,7 +785,7 @@ def create_gradio_interface():
|
||||
label="Layout Preview",
|
||||
visible=True,
|
||||
height=800,
|
||||
show_label=False
|
||||
show_label=False,
|
||||
)
|
||||
|
||||
# Page navigation (shown during PDF preview)
|
||||
@@ -621,7 +816,6 @@ def create_gradio_interface():
|
||||
{"left": "$$", "right": "$$", "display": True},
|
||||
{"left": "$", "right": "$", "display": False}
|
||||
],
|
||||
show_copy_button=False,
|
||||
elem_id="markdown_output"
|
||||
)
|
||||
|
||||
@@ -631,7 +825,6 @@ def create_gradio_interface():
|
||||
label="Markdown Raw Text",
|
||||
max_lines=100,
|
||||
lines=38,
|
||||
show_copy_button=True,
|
||||
elem_id="markdown_output",
|
||||
show_label=False
|
||||
)
|
||||
@@ -642,7 +835,6 @@ def create_gradio_interface():
|
||||
label="Current Page JSON",
|
||||
max_lines=100,
|
||||
lines=38,
|
||||
show_copy_button=True,
|
||||
elem_id="markdown_output",
|
||||
show_label=False
|
||||
)
|
||||
@@ -653,12 +845,33 @@ def create_gradio_interface():
|
||||
"⬇️ Download Results",
|
||||
visible=False
|
||||
)
|
||||
|
||||
# When the prompt mode changes, update the display content
|
||||
|
||||
def update_prompt_and_interactive(prompt_mode, session_state):
|
||||
"""更新 prompt_display 并自动切换模型"""
|
||||
is_free_qa = prompt_mode == 'prompt_general'
|
||||
auto_custom_prompt = session_state.get('auto_custom_prompt')
|
||||
|
||||
if is_free_qa and auto_custom_prompt:
|
||||
prompt_text = auto_custom_prompt
|
||||
interactive = True
|
||||
else:
|
||||
prompt_text = update_prompt_display(prompt_mode)
|
||||
interactive = is_free_qa
|
||||
|
||||
# 根据prompt_mode自动选择模型
|
||||
auto_model = PROMPT_TO_MODEL.get(prompt_mode, list(MODEL_SERVERS.keys())[0])
|
||||
|
||||
return (
|
||||
gr.update(value=prompt_text, interactive=interactive),
|
||||
session_state,
|
||||
gr.update(value=auto_model),
|
||||
)
|
||||
|
||||
|
||||
prompt_mode.change(
|
||||
fn=update_prompt_display,
|
||||
inputs=prompt_mode,
|
||||
outputs=prompt_display,
|
||||
fn=update_prompt_and_interactive,
|
||||
inputs=[prompt_mode, session_state],
|
||||
outputs=[prompt_display, session_state, model_selector],
|
||||
)
|
||||
|
||||
# Show preview on file upload
|
||||
@@ -671,10 +884,9 @@ def create_gradio_interface():
|
||||
|
||||
# Also handle test image selection
|
||||
test_image_input.change(
|
||||
# fn=lambda path, state: load_file_for_preview(path, state),
|
||||
fn=load_file_for_preview,
|
||||
fn=on_test_image_select,
|
||||
inputs=[test_image_input, session_state],
|
||||
outputs=[result_image, page_info, session_state]
|
||||
outputs=[result_image, page_info, session_state, prompt_mode, prompt_display, model_selector],
|
||||
)
|
||||
|
||||
prev_btn.click(
|
||||
@@ -689,12 +901,14 @@ def create_gradio_interface():
|
||||
outputs=[result_image, page_info, current_page_json, session_state]
|
||||
)
|
||||
|
||||
# ============ MODIFIED: process_btn.click with model_selector ============
|
||||
process_btn.click(
|
||||
fn=process_image_inference,
|
||||
inputs=[
|
||||
session_state, test_image_input, file_input,
|
||||
prompt_mode, server_ip, server_port, min_pixels, max_pixels,
|
||||
fitz_preprocess
|
||||
prompt_mode, model_selector, # Changed: model_selector instead of server_ip/port
|
||||
min_pixels, max_pixels,
|
||||
fitz_preprocess, prompt_display
|
||||
],
|
||||
outputs=[
|
||||
result_image, info_display, md_output, md_raw_output,
|
||||
|
||||
+3
-2
@@ -53,13 +53,14 @@ def inference(image_path, prompt, model, processor):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# We recommend enabling flash_attention_2 for better acceleration and memory saving, especially in multi-image and video scenarios.
|
||||
model_path = "./weights/DotsOCR"
|
||||
# We recommend enabling flash_attention_2 or flash_attention_3 for better acceleration and memory saving, especially in multi-image and video scenarios.
|
||||
model_path = "./weights/DotsMOCR"
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_path,
|
||||
attn_implementation="flash_attention_2",
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="auto",
|
||||
# device_map="cpu", # ve里默认使用flash-attn,无法直接运行
|
||||
trust_remote_code=True
|
||||
)
|
||||
processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
|
||||
|
||||
+1
-1
@@ -10,7 +10,7 @@ from dots_ocr.model.inference import inference_with_vllm
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--ip", type=str, default="localhost")
|
||||
parser.add_argument("--port", type=str, default="8000")
|
||||
parser.add_argument("--model_name", type=str, default="rednote-hilab/dots.ocr-1.5")
|
||||
parser.add_argument("--model_name", type=str, default="rednote-hilab/dots.mocr")
|
||||
parser.add_argument("--image_path", type=str, default="demo/demo_image1.jpg")
|
||||
parser.add_argument("--prompt_mode", type=str, default="prompt_layout_all_en",help=(
|
||||
"Choose a task prompt: "
|
||||
|
||||
@@ -10,7 +10,7 @@ from dots_ocr.model.inference import inference_with_vllm
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--ip", type=str, default="localhost")
|
||||
parser.add_argument("--port", type=str, default="8000")
|
||||
parser.add_argument("--model_name", type=str, default="rednote-hilab/dots.ocr-1.5")
|
||||
parser.add_argument("--model_name", type=str, default="rednote-hilab/dots.mocr")
|
||||
parser.add_argument("--custom_prompt", type=str, default="Please describe the content of this image.")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -10,7 +10,7 @@ from dots_ocr.model.inference import inference_with_vllm
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--ip", type=str, default="localhost")
|
||||
parser.add_argument("--port", type=str, default="8000")
|
||||
parser.add_argument("--model_name", type=str, default="rednote-hilab/dots.ocr-1.5-svg")
|
||||
parser.add_argument("--model_name", type=str, default="rednote-hilab/dots.mocr")
|
||||
parser.add_argument("--prompt_mode", type=str, default="prompt_image_to_svg")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -1,17 +1 @@
|
||||
# download model to /path/to/model
|
||||
if [ -z "$NODOWNLOAD" ]; then
|
||||
python3 tools/download_model.py
|
||||
fi
|
||||
|
||||
# register model to vllm
|
||||
hf_model_path=./weights/DotsOCR # Path to your downloaded model weights
|
||||
export PYTHONPATH=$(dirname "$hf_model_path"):$PYTHONPATH
|
||||
sed -i '/^from vllm\.entrypoints\.cli\.main import main$/a\
|
||||
from DotsOCR import modeling_dots_ocr_vllm' `which vllm`
|
||||
|
||||
# launch vllm server
|
||||
model_name=model
|
||||
CUDA_VISIBLE_DEVICES=0 vllm serve ${hf_model_path} --tensor-parallel-size 1 --gpu-memory-utilization 0.95 --chat-template-content-format string --served-model-name ${model_name} --trust-remote-code
|
||||
|
||||
# # run python demo after launch vllm server
|
||||
# python demo/demo_vllm.py
|
||||
CUDA_VISIBLE_DEVICES=0 nohup vllm serve dots.mocr --tensor-parallel-size 1 --gpu-memory-utilization 0.9 --chat-template-content-format string --served-model-name ${model_name} --trust-remote-code
|
||||
Reference in New Issue
Block a user