Merge pull request #33 from yjmm10/master
support infer with transformer
This commit is contained in:
@@ -1151,6 +1151,9 @@ python3 dots_ocr/parser.py demo/demo_image1.jpg --prompt prompt_ocr
|
|||||||
python3 dots_ocr/parser.py demo/demo_image1.jpg --prompt prompt_grounding_ocr --bbox 163 241 1536 705
|
python3 dots_ocr/parser.py demo/demo_image1.jpg --prompt prompt_grounding_ocr --bbox 163 241 1536 705
|
||||||
|
|
||||||
```
|
```
|
||||||
|
**Based on Transformers**, you can parse an image or a pdf file using the same commands above, just add `--use_hf true`.
|
||||||
|
|
||||||
|
> Notice: transformers is slower than vllm, if you want to use demo/* with transformers,just add `use_hf=True` in `DotsOCRParser(..,use_hf=True)`
|
||||||
|
|
||||||
<details>
|
<details>
|
||||||
<summary><b>Output Results</b></summary>
|
<summary><b>Output Results</b></summary>
|
||||||
|
|||||||
+77
-2
@@ -31,6 +31,7 @@ class DotsOCRParser:
|
|||||||
output_dir="./output",
|
output_dir="./output",
|
||||||
min_pixels=None,
|
min_pixels=None,
|
||||||
max_pixels=None,
|
max_pixels=None,
|
||||||
|
use_hf=False,
|
||||||
):
|
):
|
||||||
self.dpi = dpi
|
self.dpi = dpi
|
||||||
|
|
||||||
@@ -46,9 +47,72 @@ class DotsOCRParser:
|
|||||||
self.output_dir = output_dir
|
self.output_dir = output_dir
|
||||||
self.min_pixels = min_pixels
|
self.min_pixels = min_pixels
|
||||||
self.max_pixels = max_pixels
|
self.max_pixels = max_pixels
|
||||||
|
|
||||||
|
self.use_hf = use_hf
|
||||||
|
if self.use_hf:
|
||||||
|
self._load_hf_model()
|
||||||
|
print(f"use hf model, num_thread will be set to 1")
|
||||||
|
else:
|
||||||
|
print(f"use vllm model, num_thread will be set to {self.num_thread}")
|
||||||
assert self.min_pixels is None or self.min_pixels >= MIN_PIXELS
|
assert self.min_pixels is None or self.min_pixels >= MIN_PIXELS
|
||||||
assert self.max_pixels is None or self.max_pixels <= MAX_PIXELS
|
assert self.max_pixels is None or self.max_pixels <= MAX_PIXELS
|
||||||
|
|
||||||
|
def _load_hf_model(self):
|
||||||
|
import torch
|
||||||
|
from transformers import AutoModelForCausalLM, AutoProcessor, AutoTokenizer
|
||||||
|
from qwen_vl_utils import process_vision_info
|
||||||
|
|
||||||
|
model_path = "./weights/DotsOCR"
|
||||||
|
self.model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
model_path,
|
||||||
|
attn_implementation="flash_attention_2",
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device_map="auto",
|
||||||
|
trust_remote_code=True
|
||||||
|
)
|
||||||
|
self.processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True,use_fast=True)
|
||||||
|
self.process_vision_info = process_vision_info
|
||||||
|
|
||||||
|
def _inference_with_hf(self, image, prompt):
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "image",
|
||||||
|
"image": image
|
||||||
|
},
|
||||||
|
{"type": "text", "text": prompt}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
# Preparation for inference
|
||||||
|
text = self.processor.apply_chat_template(
|
||||||
|
messages,
|
||||||
|
tokenize=False,
|
||||||
|
add_generation_prompt=True
|
||||||
|
)
|
||||||
|
image_inputs, video_inputs = self.process_vision_info(messages)
|
||||||
|
inputs = self.processor(
|
||||||
|
text=[text],
|
||||||
|
images=image_inputs,
|
||||||
|
videos=video_inputs,
|
||||||
|
padding=True,
|
||||||
|
return_tensors="pt",
|
||||||
|
)
|
||||||
|
|
||||||
|
inputs = inputs.to("cuda")
|
||||||
|
|
||||||
|
# Inference: Generation of the output
|
||||||
|
generated_ids = self.model.generate(**inputs, max_new_tokens=24000)
|
||||||
|
generated_ids_trimmed = [
|
||||||
|
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
||||||
|
]
|
||||||
|
response = self.processor.batch_decode(
|
||||||
|
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
||||||
|
)[0]
|
||||||
|
return response
|
||||||
|
|
||||||
def _inference_with_vllm(self, image, prompt):
|
def _inference_with_vllm(self, image, prompt):
|
||||||
response = inference_with_vllm(
|
response = inference_with_vllm(
|
||||||
@@ -98,7 +162,10 @@ class DotsOCRParser:
|
|||||||
image = fetch_image(origin_image, min_pixels=min_pixels, max_pixels=max_pixels)
|
image = fetch_image(origin_image, min_pixels=min_pixels, max_pixels=max_pixels)
|
||||||
input_height, input_width = smart_resize(image.height, image.width)
|
input_height, input_width = smart_resize(image.height, image.width)
|
||||||
prompt = self.get_prompt(prompt_mode, bbox, origin_image, image, min_pixels=min_pixels, max_pixels=max_pixels)
|
prompt = self.get_prompt(prompt_mode, bbox, origin_image, image, min_pixels=min_pixels, max_pixels=max_pixels)
|
||||||
response = self._inference_with_vllm(image, prompt)
|
if self.use_hf:
|
||||||
|
response = self._inference_with_hf(image, prompt)
|
||||||
|
else:
|
||||||
|
response = self._inference_with_vllm(image, prompt)
|
||||||
result = {'page_no': page_idx,
|
result = {'page_no': page_idx,
|
||||||
"input_height": input_height,
|
"input_height": input_height,
|
||||||
"input_width": input_width
|
"input_width": input_width
|
||||||
@@ -206,7 +273,10 @@ class DotsOCRParser:
|
|||||||
def _execute_task(task_args):
|
def _execute_task(task_args):
|
||||||
return self._parse_single_image(**task_args)
|
return self._parse_single_image(**task_args)
|
||||||
|
|
||||||
num_thread = min(total_pages, self.num_thread)
|
if self.use_hf:
|
||||||
|
num_thread = 1
|
||||||
|
else:
|
||||||
|
num_thread = min(total_pages, self.num_thread)
|
||||||
print(f"Parsing PDF with {total_pages} pages using {num_thread} threads...")
|
print(f"Parsing PDF with {total_pages} pages using {num_thread} threads...")
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
@@ -321,6 +391,10 @@ def main():
|
|||||||
"--max_pixels", type=int, default=None,
|
"--max_pixels", type=int, default=None,
|
||||||
help=""
|
help=""
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--use_hf", type=bool, default=False,
|
||||||
|
help=""
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
dots_ocr_parser = DotsOCRParser(
|
dots_ocr_parser = DotsOCRParser(
|
||||||
@@ -335,6 +409,7 @@ def main():
|
|||||||
output_dir=args.output,
|
output_dir=args.output,
|
||||||
min_pixels=args.min_pixels,
|
min_pixels=args.min_pixels,
|
||||||
max_pixels=args.max_pixels,
|
max_pixels=args.max_pixels,
|
||||||
|
use_hf=args.use_hf,
|
||||||
)
|
)
|
||||||
|
|
||||||
result = dots_ocr_parser.parse_file(
|
result = dots_ocr_parser.parse_file(
|
||||||
|
|||||||
Reference in New Issue
Block a user