diff --git a/README.md b/README.md index 80a11c2..e8bdcc3 100755 --- a/README.md +++ b/README.md @@ -560,18 +560,18 @@ CUDA_VISIBLE_DEVICES=0 vllm serve rednote-hilab/dots.ocr-1.5 --tensor-parallel-s CUDA_VISIBLE_DEVICES=0 vllm serve rednote-hilab/dots.ocr-1.5-svg --tensor-parallel-size 1 --gpu-memory-utilization 0.9 --chat-template-content-format string --served-model-name model --trust-remote-code # vLLM API Demo -# See dots_ocr/model/inference.py for details on parameter and prompt settings +# See dots_ocr/model/inference.py and dots_ocr/utils/prompts.py for details on parameter and prompt settings # that help achieve the best output quality. ## document parsing -python3 ./demo/demo_vllm.py --prompt_mode prompt_layout_all_en +python3 ./demo/demo_vllm.py --prompt_mode prompt_layout_all_en ## web parsing - +python3 ./demo/demo_vllm.py --prompt_mode prompt_web_parsing --image_path ./assets/showcase_dots_ocr_1_5/origin/webpage_1.png ## scene spoting - +python3 ./demo/demo_vllm.py --prompt_mode prompt_scene_spotting --image_path ./assets/showcase_dots_ocr_1_5/origin/scene_1.jpg ## image parsing with svg code - +python3 ./demo/demo_vllm_svg.py --prompt_mode prompt_image_to_svg ## general qa - +python3 ./demo/demo_vllm_general.py ``` ### Hugginface inference diff --git a/assets/showcase_dots_ocr_1_5/origin/scene_1.jpg b/assets/showcase_dots_ocr_1_5/origin/scene_1.jpg new file mode 100644 index 0000000..0c3cc4d Binary files /dev/null and b/assets/showcase_dots_ocr_1_5/origin/scene_1.jpg differ diff --git a/assets/showcase_dots_ocr_1_5/origin/scene_2.jpg b/assets/showcase_dots_ocr_1_5/origin/scene_2.jpg new file mode 100644 index 0000000..91521a2 Binary files /dev/null and b/assets/showcase_dots_ocr_1_5/origin/scene_2.jpg differ diff --git a/assets/showcase_dots_ocr_1_5/origin/svg_1.png b/assets/showcase_dots_ocr_1_5/origin/svg_1.png new file mode 100644 index 0000000..1957326 Binary files /dev/null and b/assets/showcase_dots_ocr_1_5/origin/svg_1.png differ diff --git a/assets/showcase_dots_ocr_1_5/origin/svg_2.png b/assets/showcase_dots_ocr_1_5/origin/svg_2.png new file mode 100644 index 0000000..557fddd Binary files /dev/null and b/assets/showcase_dots_ocr_1_5/origin/svg_2.png differ diff --git a/assets/showcase_dots_ocr_1_5/origin/svg_4.png b/assets/showcase_dots_ocr_1_5/origin/svg_4.png new file mode 100644 index 0000000..27656bf Binary files /dev/null and b/assets/showcase_dots_ocr_1_5/origin/svg_4.png differ diff --git a/assets/showcase_dots_ocr_1_5/origin/svg_5.png b/assets/showcase_dots_ocr_1_5/origin/svg_5.png new file mode 100644 index 0000000..7ffccdf Binary files /dev/null and b/assets/showcase_dots_ocr_1_5/origin/svg_5.png differ diff --git a/assets/showcase_dots_ocr_1_5/origin/svg_6.png b/assets/showcase_dots_ocr_1_5/origin/svg_6.png new file mode 100644 index 0000000..4ad8aa5 Binary files /dev/null and b/assets/showcase_dots_ocr_1_5/origin/svg_6.png differ diff --git a/assets/showcase_dots_ocr_1_5/origin/webpage_1.png b/assets/showcase_dots_ocr_1_5/origin/webpage_1.png new file mode 100755 index 0000000..1f68324 Binary files /dev/null and b/assets/showcase_dots_ocr_1_5/origin/webpage_1.png differ diff --git a/assets/showcase_dots_ocr_1_5/origin/webpage_2.jpg b/assets/showcase_dots_ocr_1_5/origin/webpage_2.jpg new file mode 100755 index 0000000..0c37c81 Binary files /dev/null and b/assets/showcase_dots_ocr_1_5/origin/webpage_2.jpg differ diff --git a/demo/demo_image2.png b/demo/demo_image2.png new file mode 100644 index 0000000..557fddd Binary files /dev/null and b/demo/demo_image2.png differ diff --git a/demo/demo_image3.jpg b/demo/demo_image3.jpg new file mode 100644 index 0000000..0650edd Binary files /dev/null and b/demo/demo_image3.jpg differ diff --git a/demo/demo_vllm.py b/demo/demo_vllm.py index 166c521..3b6f4f8 100755 --- a/demo/demo_vllm.py +++ b/demo/demo_vllm.py @@ -10,8 +10,17 @@ from dots_ocr.model.inference import inference_with_vllm parser = argparse.ArgumentParser() parser.add_argument("--ip", type=str, default="localhost") parser.add_argument("--port", type=str, default="8000") -parser.add_argument("--model_name", type=str, default="rednote-hilab/dots.ocr") -parser.add_argument("--prompt_mode", type=str, default="prompt_layout_all_en") +parser.add_argument("--model_name", type=str, default="rednote-hilab/dots.ocr-1.5") +parser.add_argument("--image_path", type=str, default="demo/demo_image1.jpg") +parser.add_argument("--prompt_mode", type=str, default="prompt_layout_all_en",help=( + "Choose a task prompt: " + "prompt_layout_all_en=full document layout+OCR to JSON/MD; " + "prompt_layout_only_en=layout detection only; " + "prompt_grounding_ocr=OCR within a given bbox; " + "prompt_web_parsing=parse webpage screenshot layout into JSON; " + "prompt_scene_spotting=detect+recognize scene text (OCR boxes+texts); " + "prompt_image_to_svg=generate SVG code to reconstruct the image.") +) args = parser.parse_args() @@ -20,7 +29,7 @@ require_version("openai>=1.5.0", "To fix: pip install openai>=1.5.0") def main(): addr = f"http://{args.ip}:{args.port}/v1" - image_path = "demo/demo_image1.jpg" + image_path = args.image_path prompt = dict_promptmode_to_prompt[args.prompt_mode] image = Image.open(image_path) response = inference_with_vllm( diff --git a/demo/demo_vllm_general.py b/demo/demo_vllm_general.py new file mode 100755 index 0000000..d230c57 --- /dev/null +++ b/demo/demo_vllm_general.py @@ -0,0 +1,40 @@ +import argparse + +from openai import OpenAI +from transformers.utils.versions import require_version +from PIL import Image +from dots_ocr.utils import dict_promptmode_to_prompt +from dots_ocr.model.inference import inference_with_vllm + + +parser = argparse.ArgumentParser() +parser.add_argument("--ip", type=str, default="localhost") +parser.add_argument("--port", type=str, default="8000") +parser.add_argument("--model_name", type=str, default="rednote-hilab/dots.ocr-1.5") +parser.add_argument("--custom_prompt", type=str, default="Please describe the content of this image.") + +args = parser.parse_args() + +require_version("openai>=1.5.0", "To fix: pip install openai>=1.5.0") + + +def main(): + addr = f"http://{args.ip}:{args.port}/v1" + image_path = "demo/demo_image3.jpg" + prompt = args.custom_prompt + image = Image.open(image_path) + response = inference_with_vllm( + image, + prompt, + ip=args.ip, + port=args.port, + temperature=0.1, + top_p=0.9, + model_name=args.model_name, + system_prompt="You are a helpful assistant.", #general tasks need system_prompt + ) + print(f"response: {response}") + + +if __name__ == "__main__": + main() diff --git a/demo/demo_vllm_svg.py b/demo/demo_vllm_svg.py new file mode 100755 index 0000000..e7719b6 --- /dev/null +++ b/demo/demo_vllm_svg.py @@ -0,0 +1,43 @@ +import argparse + +from openai import OpenAI +from transformers.utils.versions import require_version +from PIL import Image +from dots_ocr.utils import dict_promptmode_to_prompt +from dots_ocr.model.inference import inference_with_vllm + + +parser = argparse.ArgumentParser() +parser.add_argument("--ip", type=str, default="localhost") +parser.add_argument("--port", type=str, default="8000") +parser.add_argument("--model_name", type=str, default="rednote-hilab/dots.ocr-1.5-svg") +parser.add_argument("--prompt_mode", type=str, default="prompt_image_to_svg") + +args = parser.parse_args() + +require_version("openai>=1.5.0", "To fix: pip install openai>=1.5.0") + + +def main(): + addr = f"http://{args.ip}:{args.port}/v1" + image_path = "demo/demo_image2.png" + image = Image.open(image_path) + prompt = dict_promptmode_to_prompt[args.prompt_mode] + + #prompt = Please generate the SVG code based on the image.viewBox="0 0 {img_width} {img_height}" + prompt = prompt.replace("{width}", str(image.width)).replace("{height}", str(image.height)) + + response = inference_with_vllm( + image, + prompt, + ip=args.ip, + port=args.port, + temperature=0.9, # SVG: low temperature often causes repetitive/looping output + top_p=1.0, + model_name=args.model_name, + ) + print(f"response: {response}") + + +if __name__ == "__main__": + main() diff --git a/dots_ocr/model/inference.py b/dots_ocr/model/inference.py index eb85007..34e4f16 100755 --- a/dots_ocr/model/inference.py +++ b/dots_ocr/model/inference.py @@ -14,11 +14,14 @@ def inference_with_vllm( top_p=0.9, max_completion_tokens=32768, model_name='rednote-hilab/dots.ocr', + system_prompt=None, ): addr = f"{protocol}://{ip}:{port}/v1" client = OpenAI(api_key="{}".format(os.environ.get("API_KEY", "0")), base_url=addr) messages = [] + if system_prompt: + messages.append({"role": "system", "content": system_prompt}) messages.append( { "role": "user", diff --git a/dots_ocr/utils/prompts.py b/dots_ocr/utils/prompts.py index 87714c3..b210347 100755 --- a/dots_ocr/utils/prompts.py +++ b/dots_ocr/utils/prompts.py @@ -22,12 +22,24 @@ dict_promptmode_to_prompt = { # prompt_layout_only_en: layout detection "prompt_layout_only_en": """Please output the layout information from this PDF image, including each layout's bbox and its category. The bbox should be in the format [x1, y1, x2, y2]. The layout categories for the PDF document include ['Caption', 'Footnote', 'Formula', 'List-item', 'Page-footer', 'Page-header', 'Picture', 'Section-header', 'Table', 'Text', 'Title']. Do not output the corresponding text. The layout result should be in JSON format.""", - # prompt_layout_only_en: parse ocr text except the Page-header and Page-footer + # prompt_ocr: parse ocr text except the Page-header and Page-footer "prompt_ocr": """Extract the text content from this image.""", # prompt_grounding_ocr: extract text content in the given bounding box "prompt_grounding_ocr": """Extract text from the given bounding box on the image (format: [x1, y1, x2, y2]).\nBounding Box:\n""", + # prompt_web_parsing: parse all webpage layout info in json format. + "prompt_web_parsing": """Parsing the layout info of this webpage image with format json:\n""", + + # prompt_scene_spotting: scene spotting + "prompt_scene_spotting": """Detect and recognize the text in the image.""", + + # prompt_img2svg: generate the SVG code of the image + "prompt_image_to_svg": """Please generate the SVG code based on the image.viewBox="0 0 {width} {height}\"""", + + # prompt_free_qa: general prompt + "prompt_general": """ """, + # "prompt_table_html": """Convert the table in this image to HTML.""", # "prompt_table_latex": """Convert the table in this image to LaTeX.""", # "prompt_formula_latex": """Convert the formula in this image to LaTeX.""", diff --git a/requirements.txt b/requirements.txt index 15852ca..67a798f 100755 --- a/requirements.txt +++ b/requirements.txt @@ -4,7 +4,7 @@ gradio_image_annotation PyMuPDF openai qwen_vl_utils -transformers==4.51.3 +transformers==4.56.1 huggingface_hub modelscope # flash-attn==2.8.0.post2 # to speed up inference need flash-attn