support infer with transformers
add `--use_hf true` to use transformers,but the speed is very slower than vllm
This commit is contained in:
+77
-2
@@ -31,6 +31,7 @@ class DotsOCRParser:
|
||||
output_dir="./output",
|
||||
min_pixels=None,
|
||||
max_pixels=None,
|
||||
use_hf=False,
|
||||
):
|
||||
self.dpi = dpi
|
||||
|
||||
@@ -46,9 +47,72 @@ class DotsOCRParser:
|
||||
self.output_dir = output_dir
|
||||
self.min_pixels = min_pixels
|
||||
self.max_pixels = max_pixels
|
||||
|
||||
self.use_hf = use_hf
|
||||
if self.use_hf:
|
||||
self._load_hf_model()
|
||||
print(f"use hf model, num_thread will be set to 1")
|
||||
else:
|
||||
print(f"use vllm model, num_thread will be set to {self.num_thread}")
|
||||
assert self.min_pixels is None or self.min_pixels >= MIN_PIXELS
|
||||
assert self.max_pixels is None or self.max_pixels <= MAX_PIXELS
|
||||
|
||||
def _load_hf_model(self):
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoProcessor, AutoTokenizer
|
||||
from qwen_vl_utils import process_vision_info
|
||||
|
||||
model_path = "./weights/DotsOCR"
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
model_path,
|
||||
attn_implementation="flash_attention_2",
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="auto",
|
||||
trust_remote_code=True
|
||||
)
|
||||
self.processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True,use_fast=True)
|
||||
self.process_vision_info = process_vision_info
|
||||
|
||||
def _inference_with_hf(self, image, prompt):
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image",
|
||||
"image": image
|
||||
},
|
||||
{"type": "text", "text": prompt}
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
# Preparation for inference
|
||||
text = self.processor.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True
|
||||
)
|
||||
image_inputs, video_inputs = self.process_vision_info(messages)
|
||||
inputs = self.processor(
|
||||
text=[text],
|
||||
images=image_inputs,
|
||||
videos=video_inputs,
|
||||
padding=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
inputs = inputs.to("cuda")
|
||||
|
||||
# Inference: Generation of the output
|
||||
generated_ids = self.model.generate(**inputs, max_new_tokens=24000)
|
||||
generated_ids_trimmed = [
|
||||
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
||||
]
|
||||
response = self.processor.batch_decode(
|
||||
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
||||
)[0]
|
||||
return response
|
||||
|
||||
def _inference_with_vllm(self, image, prompt):
|
||||
response = inference_with_vllm(
|
||||
@@ -98,7 +162,10 @@ class DotsOCRParser:
|
||||
image = fetch_image(origin_image, min_pixels=min_pixels, max_pixels=max_pixels)
|
||||
input_height, input_width = smart_resize(image.height, image.width)
|
||||
prompt = self.get_prompt(prompt_mode, bbox, origin_image, image, min_pixels=min_pixels, max_pixels=max_pixels)
|
||||
response = self._inference_with_vllm(image, prompt)
|
||||
if self.use_hf:
|
||||
response = self._inference_with_hf(image, prompt)
|
||||
else:
|
||||
response = self._inference_with_vllm(image, prompt)
|
||||
result = {'page_no': page_idx,
|
||||
"input_height": input_height,
|
||||
"input_width": input_width
|
||||
@@ -206,7 +273,10 @@ class DotsOCRParser:
|
||||
def _execute_task(task_args):
|
||||
return self._parse_single_image(**task_args)
|
||||
|
||||
num_thread = min(total_pages, self.num_thread)
|
||||
if self.use_hf:
|
||||
num_thread = 1
|
||||
else:
|
||||
num_thread = min(total_pages, self.num_thread)
|
||||
print(f"Parsing PDF with {total_pages} pages using {num_thread} threads...")
|
||||
|
||||
results = []
|
||||
@@ -321,6 +391,10 @@ def main():
|
||||
"--max_pixels", type=int, default=None,
|
||||
help=""
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_hf", type=bool, default=False,
|
||||
help=""
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
dots_ocr_parser = DotsOCRParser(
|
||||
@@ -335,6 +409,7 @@ def main():
|
||||
output_dir=args.output,
|
||||
min_pixels=args.min_pixels,
|
||||
max_pixels=args.max_pixels,
|
||||
use_hf=args.use_hf,
|
||||
)
|
||||
|
||||
result = dots_ocr_parser.parse_file(
|
||||
|
||||
Reference in New Issue
Block a user