updata inference demo
This commit is contained in:
Binary file not shown.
|
After Width: | Height: | Size: 42 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 145 KiB |
+9
-1
@@ -11,7 +11,15 @@ parser = argparse.ArgumentParser()
|
|||||||
parser.add_argument("--ip", type=str, default="localhost")
|
parser.add_argument("--ip", type=str, default="localhost")
|
||||||
parser.add_argument("--port", type=str, default="8000")
|
parser.add_argument("--port", type=str, default="8000")
|
||||||
parser.add_argument("--model_name", type=str, default="rednote-hilab/dots.ocr")
|
parser.add_argument("--model_name", type=str, default="rednote-hilab/dots.ocr")
|
||||||
parser.add_argument("--prompt_mode", type=str, default="prompt_layout_all_en")
|
parser.add_argument("--prompt_mode", type=str, default="prompt_layout_all_en",help=(
|
||||||
|
"Choose a task prompt: "
|
||||||
|
"prompt_layout_all_en=full document layout+OCR to JSON/MD; "
|
||||||
|
"prompt_layout_only_en=layout detection only; "
|
||||||
|
"prompt_grounding_ocr=OCR within a given bbox; "
|
||||||
|
"prompt_web_parsing=parse webpage screenshot layout into JSON; "
|
||||||
|
"prompt_scene_spotting=detect+recognize scene text (OCR boxes+texts); "
|
||||||
|
"prompt_image_to_svg=generate SVG code to reconstruct the image.")
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
|||||||
Executable
+40
@@ -0,0 +1,40 @@
|
|||||||
|
import argparse
|
||||||
|
|
||||||
|
from openai import OpenAI
|
||||||
|
from transformers.utils.versions import require_version
|
||||||
|
from PIL import Image
|
||||||
|
from dots_ocr.utils import dict_promptmode_to_prompt
|
||||||
|
from dots_ocr.model.inference import inference_with_vllm
|
||||||
|
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--ip", type=str, default="localhost")
|
||||||
|
parser.add_argument("--port", type=str, default="8000")
|
||||||
|
parser.add_argument("--model_name", type=str, default="rednote-hilab/dots.ocr")
|
||||||
|
parser.add_argument("--custom_prompt", type=str, default="Please describe the content of this image.")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
require_version("openai>=1.5.0", "To fix: pip install openai>=1.5.0")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
addr = f"http://{args.ip}:{args.port}/v1"
|
||||||
|
image_path = "demo/demo_image3.jpg"
|
||||||
|
prompt = args.custom_prompt
|
||||||
|
image = Image.open(image_path)
|
||||||
|
response = inference_with_vllm(
|
||||||
|
image,
|
||||||
|
prompt,
|
||||||
|
ip=args.ip,
|
||||||
|
port=args.port,
|
||||||
|
temperature=0.1,
|
||||||
|
top_p=0.9,
|
||||||
|
model_name=args.model_name,
|
||||||
|
system_prompt="You are a helpful assistant.", #general tasks need system_prompt
|
||||||
|
)
|
||||||
|
print(f"response: {response}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
Executable
+43
@@ -0,0 +1,43 @@
|
|||||||
|
import argparse
|
||||||
|
|
||||||
|
from openai import OpenAI
|
||||||
|
from transformers.utils.versions import require_version
|
||||||
|
from PIL import Image
|
||||||
|
from dots_ocr.utils import dict_promptmode_to_prompt
|
||||||
|
from dots_ocr.model.inference import inference_with_vllm
|
||||||
|
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--ip", type=str, default="localhost")
|
||||||
|
parser.add_argument("--port", type=str, default="8000")
|
||||||
|
parser.add_argument("--model_name", type=str, default="rednote-hilab/dots.ocr")
|
||||||
|
parser.add_argument("--prompt_mode", type=str, default="prompt_image_to_svg")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
require_version("openai>=1.5.0", "To fix: pip install openai>=1.5.0")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
addr = f"http://{args.ip}:{args.port}/v1"
|
||||||
|
image_path = "demo/demo_image2.png"
|
||||||
|
image = Image.open(image_path)
|
||||||
|
prompt = dict_promptmode_to_prompt[args.prompt_mode]
|
||||||
|
|
||||||
|
#prompt = Please generate the SVG code based on the image.viewBox="0 0 {img_width} {img_height}"
|
||||||
|
prompt = prompt.replace("{width}", str(image.width)).replace("{height}", str(image.height))
|
||||||
|
|
||||||
|
response = inference_with_vllm(
|
||||||
|
image,
|
||||||
|
prompt,
|
||||||
|
ip=args.ip,
|
||||||
|
port=args.port,
|
||||||
|
temperature=0.9, # SVG: low temperature often causes repetitive/looping output
|
||||||
|
top_p=1.0,
|
||||||
|
model_name=args.model_name,
|
||||||
|
)
|
||||||
|
print(f"response: {response}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -14,11 +14,14 @@ def inference_with_vllm(
|
|||||||
top_p=0.9,
|
top_p=0.9,
|
||||||
max_completion_tokens=32768,
|
max_completion_tokens=32768,
|
||||||
model_name='rednote-hilab/dots.ocr',
|
model_name='rednote-hilab/dots.ocr',
|
||||||
|
system_prompt=None,
|
||||||
):
|
):
|
||||||
|
|
||||||
addr = f"{protocol}://{ip}:{port}/v1"
|
addr = f"{protocol}://{ip}:{port}/v1"
|
||||||
client = OpenAI(api_key="{}".format(os.environ.get("API_KEY", "0")), base_url=addr)
|
client = OpenAI(api_key="{}".format(os.environ.get("API_KEY", "0")), base_url=addr)
|
||||||
messages = []
|
messages = []
|
||||||
|
if system_prompt:
|
||||||
|
messages.append({"role": "system", "content": system_prompt})
|
||||||
messages.append(
|
messages.append(
|
||||||
{
|
{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
|
|||||||
@@ -22,12 +22,24 @@ dict_promptmode_to_prompt = {
|
|||||||
# prompt_layout_only_en: layout detection
|
# prompt_layout_only_en: layout detection
|
||||||
"prompt_layout_only_en": """Please output the layout information from this PDF image, including each layout's bbox and its category. The bbox should be in the format [x1, y1, x2, y2]. The layout categories for the PDF document include ['Caption', 'Footnote', 'Formula', 'List-item', 'Page-footer', 'Page-header', 'Picture', 'Section-header', 'Table', 'Text', 'Title']. Do not output the corresponding text. The layout result should be in JSON format.""",
|
"prompt_layout_only_en": """Please output the layout information from this PDF image, including each layout's bbox and its category. The bbox should be in the format [x1, y1, x2, y2]. The layout categories for the PDF document include ['Caption', 'Footnote', 'Formula', 'List-item', 'Page-footer', 'Page-header', 'Picture', 'Section-header', 'Table', 'Text', 'Title']. Do not output the corresponding text. The layout result should be in JSON format.""",
|
||||||
|
|
||||||
# prompt_layout_only_en: parse ocr text except the Page-header and Page-footer
|
# prompt_ocr: parse ocr text except the Page-header and Page-footer
|
||||||
"prompt_ocr": """Extract the text content from this image.""",
|
"prompt_ocr": """Extract the text content from this image.""",
|
||||||
|
|
||||||
# prompt_grounding_ocr: extract text content in the given bounding box
|
# prompt_grounding_ocr: extract text content in the given bounding box
|
||||||
"prompt_grounding_ocr": """Extract text from the given bounding box on the image (format: [x1, y1, x2, y2]).\nBounding Box:\n""",
|
"prompt_grounding_ocr": """Extract text from the given bounding box on the image (format: [x1, y1, x2, y2]).\nBounding Box:\n""",
|
||||||
|
|
||||||
|
# prompt_web_parsing: parse all webpage layout info in json format.
|
||||||
|
"prompt_web_parsing": """Parsing the layout info of this webpage image with format json:\n""",
|
||||||
|
|
||||||
|
# prompt_scene_spotting: scene spotting
|
||||||
|
"prompt_scene_spotting": """Detect and recognize the text in the image.""",
|
||||||
|
|
||||||
|
# prompt_img2svg: generate the SVG code of the image
|
||||||
|
"prompt_image_to_svg": """Please generate the SVG code based on the image.viewBox="0 0 {width} {height}\"""",
|
||||||
|
|
||||||
|
# prompt_free_qa: general prompt
|
||||||
|
"prompt_general": """ """,
|
||||||
|
|
||||||
# "prompt_table_html": """Convert the table in this image to HTML.""",
|
# "prompt_table_html": """Convert the table in this image to HTML.""",
|
||||||
# "prompt_table_latex": """Convert the table in this image to LaTeX.""",
|
# "prompt_table_latex": """Convert the table in this image to LaTeX.""",
|
||||||
# "prompt_formula_latex": """Convert the formula in this image to LaTeX.""",
|
# "prompt_formula_latex": """Convert the formula in this image to LaTeX.""",
|
||||||
|
|||||||
Reference in New Issue
Block a user