support infer with transformers

add `--use_hf true` to use transformers,but the speed is very slower than vllm
This commit is contained in:
liferecords
2025-08-05 14:18:02 +08:00
committed by GitHub
parent ae0ca3a8be
commit 2a7d5e1f5a
+78 -3
View File
@@ -31,6 +31,7 @@ class DotsOCRParser:
output_dir="./output",
min_pixels=None,
max_pixels=None,
use_hf=False,
):
self.dpi = dpi
@@ -46,9 +47,72 @@ class DotsOCRParser:
self.output_dir = output_dir
self.min_pixels = min_pixels
self.max_pixels = max_pixels
self.use_hf = use_hf
if self.use_hf:
self._load_hf_model()
print(f"use hf model, num_thread will be set to 1")
else:
print(f"use vllm model, num_thread will be set to {self.num_thread}")
assert self.min_pixels is None or self.min_pixels >= MIN_PIXELS
assert self.max_pixels is None or self.max_pixels <= MAX_PIXELS
def _load_hf_model(self):
import torch
from transformers import AutoModelForCausalLM, AutoProcessor, AutoTokenizer
from qwen_vl_utils import process_vision_info
model_path = "./weights/DotsOCR"
self.model = AutoModelForCausalLM.from_pretrained(
model_path,
attn_implementation="flash_attention_2",
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True
)
self.processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True,use_fast=True)
self.process_vision_info = process_vision_info
def _inference_with_hf(self, image, prompt):
messages = [
{
"role": "user",
"content": [
{
"type": "image",
"image": image
},
{"type": "text", "text": prompt}
]
}
]
# Preparation for inference
text = self.processor.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
image_inputs, video_inputs = self.process_vision_info(messages)
inputs = self.processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
)
inputs = inputs.to("cuda")
# Inference: Generation of the output
generated_ids = self.model.generate(**inputs, max_new_tokens=24000)
generated_ids_trimmed = [
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
response = self.processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)[0]
return response
def _inference_with_vllm(self, image, prompt):
response = inference_with_vllm(
@@ -98,7 +162,10 @@ class DotsOCRParser:
image = fetch_image(origin_image, min_pixels=min_pixels, max_pixels=max_pixels)
input_height, input_width = smart_resize(image.height, image.width)
prompt = self.get_prompt(prompt_mode, bbox, origin_image, image, min_pixels=min_pixels, max_pixels=max_pixels)
response = self._inference_with_vllm(image, prompt)
if self.use_hf:
response = self._inference_with_hf(image, prompt)
else:
response = self._inference_with_vllm(image, prompt)
result = {'page_no': page_idx,
"input_height": input_height,
"input_width": input_width
@@ -206,7 +273,10 @@ class DotsOCRParser:
def _execute_task(task_args):
return self._parse_single_image(**task_args)
num_thread = min(total_pages, self.num_thread)
if self.use_hf:
num_thread = 1
else:
num_thread = min(total_pages, self.num_thread)
print(f"Parsing PDF with {total_pages} pages using {num_thread} threads...")
results = []
@@ -321,6 +391,10 @@ def main():
"--max_pixels", type=int, default=None,
help=""
)
parser.add_argument(
"--use_hf", type=bool, default=False,
help=""
)
args = parser.parse_args()
dots_ocr_parser = DotsOCRParser(
@@ -335,6 +409,7 @@ def main():
output_dir=args.output,
min_pixels=args.min_pixels,
max_pixels=args.max_pixels,
use_hf=args.use_hf,
)
result = dots_ocr_parser.parse_file(
@@ -346,4 +421,4 @@ def main():
if __name__ == "__main__":
main()
main()