From 2a7d5e1f5a6e51305e8c06d907c19974ce43f0fc Mon Sep 17 00:00:00 2001 From: liferecords Date: Tue, 5 Aug 2025 14:18:02 +0800 Subject: [PATCH] support infer with transformers MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit add `--use_hf true` to use transformers,but the speed is very slower than vllm --- dots_ocr/parser.py | 81 ++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 78 insertions(+), 3 deletions(-) diff --git a/dots_ocr/parser.py b/dots_ocr/parser.py index 1f1a1d8..f6648f1 100755 --- a/dots_ocr/parser.py +++ b/dots_ocr/parser.py @@ -31,6 +31,7 @@ class DotsOCRParser: output_dir="./output", min_pixels=None, max_pixels=None, + use_hf=False, ): self.dpi = dpi @@ -46,9 +47,72 @@ class DotsOCRParser: self.output_dir = output_dir self.min_pixels = min_pixels self.max_pixels = max_pixels + + self.use_hf = use_hf + if self.use_hf: + self._load_hf_model() + print(f"use hf model, num_thread will be set to 1") + else: + print(f"use vllm model, num_thread will be set to {self.num_thread}") assert self.min_pixels is None or self.min_pixels >= MIN_PIXELS assert self.max_pixels is None or self.max_pixels <= MAX_PIXELS + def _load_hf_model(self): + import torch + from transformers import AutoModelForCausalLM, AutoProcessor, AutoTokenizer + from qwen_vl_utils import process_vision_info + + model_path = "./weights/DotsOCR" + self.model = AutoModelForCausalLM.from_pretrained( + model_path, + attn_implementation="flash_attention_2", + torch_dtype=torch.bfloat16, + device_map="auto", + trust_remote_code=True + ) + self.processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True,use_fast=True) + self.process_vision_info = process_vision_info + + def _inference_with_hf(self, image, prompt): + messages = [ + { + "role": "user", + "content": [ + { + "type": "image", + "image": image + }, + {"type": "text", "text": prompt} + ] + } + ] + + # Preparation for inference + text = self.processor.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True + ) + image_inputs, video_inputs = self.process_vision_info(messages) + inputs = self.processor( + text=[text], + images=image_inputs, + videos=video_inputs, + padding=True, + return_tensors="pt", + ) + + inputs = inputs.to("cuda") + + # Inference: Generation of the output + generated_ids = self.model.generate(**inputs, max_new_tokens=24000) + generated_ids_trimmed = [ + out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) + ] + response = self.processor.batch_decode( + generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False + )[0] + return response def _inference_with_vllm(self, image, prompt): response = inference_with_vllm( @@ -98,7 +162,10 @@ class DotsOCRParser: image = fetch_image(origin_image, min_pixels=min_pixels, max_pixels=max_pixels) input_height, input_width = smart_resize(image.height, image.width) prompt = self.get_prompt(prompt_mode, bbox, origin_image, image, min_pixels=min_pixels, max_pixels=max_pixels) - response = self._inference_with_vllm(image, prompt) + if self.use_hf: + response = self._inference_with_hf(image, prompt) + else: + response = self._inference_with_vllm(image, prompt) result = {'page_no': page_idx, "input_height": input_height, "input_width": input_width @@ -206,7 +273,10 @@ class DotsOCRParser: def _execute_task(task_args): return self._parse_single_image(**task_args) - num_thread = min(total_pages, self.num_thread) + if self.use_hf: + num_thread = 1 + else: + num_thread = min(total_pages, self.num_thread) print(f"Parsing PDF with {total_pages} pages using {num_thread} threads...") results = [] @@ -321,6 +391,10 @@ def main(): "--max_pixels", type=int, default=None, help="" ) + parser.add_argument( + "--use_hf", type=bool, default=False, + help="" + ) args = parser.parse_args() dots_ocr_parser = DotsOCRParser( @@ -335,6 +409,7 @@ def main(): output_dir=args.output, min_pixels=args.min_pixels, max_pixels=args.max_pixels, + use_hf=args.use_hf, ) result = dots_ocr_parser.parse_file( @@ -346,4 +421,4 @@ def main(): if __name__ == "__main__": - main() \ No newline at end of file + main()