diff --git a/README.md b/README.md index 3812c76..26e1c9e 100755 --- a/README.md +++ b/README.md @@ -1151,6 +1151,9 @@ python3 dots_ocr/parser.py demo/demo_image1.jpg --prompt prompt_ocr python3 dots_ocr/parser.py demo/demo_image1.jpg --prompt prompt_grounding_ocr --bbox 163 241 1536 705 ``` +**Based on Transformers**, you can parse an image or a pdf file using the same commands above, just add `--use_hf true`. + +> Notice: transformers is slower than vllm, if you want to use demo/* with transformers,just add `use_hf=True` in `DotsOCRParser(..,use_hf=True)`
Output Results diff --git a/dots_ocr/parser.py b/dots_ocr/parser.py index 8f78efb..8ce95e1 100755 --- a/dots_ocr/parser.py +++ b/dots_ocr/parser.py @@ -31,6 +31,7 @@ class DotsOCRParser: output_dir="./output", min_pixels=None, max_pixels=None, + use_hf=False, ): self.dpi = dpi @@ -46,9 +47,72 @@ class DotsOCRParser: self.output_dir = output_dir self.min_pixels = min_pixels self.max_pixels = max_pixels + + self.use_hf = use_hf + if self.use_hf: + self._load_hf_model() + print(f"use hf model, num_thread will be set to 1") + else: + print(f"use vllm model, num_thread will be set to {self.num_thread}") assert self.min_pixels is None or self.min_pixels >= MIN_PIXELS assert self.max_pixels is None or self.max_pixels <= MAX_PIXELS + def _load_hf_model(self): + import torch + from transformers import AutoModelForCausalLM, AutoProcessor, AutoTokenizer + from qwen_vl_utils import process_vision_info + + model_path = "./weights/DotsOCR" + self.model = AutoModelForCausalLM.from_pretrained( + model_path, + attn_implementation="flash_attention_2", + torch_dtype=torch.bfloat16, + device_map="auto", + trust_remote_code=True + ) + self.processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True,use_fast=True) + self.process_vision_info = process_vision_info + + def _inference_with_hf(self, image, prompt): + messages = [ + { + "role": "user", + "content": [ + { + "type": "image", + "image": image + }, + {"type": "text", "text": prompt} + ] + } + ] + + # Preparation for inference + text = self.processor.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True + ) + image_inputs, video_inputs = self.process_vision_info(messages) + inputs = self.processor( + text=[text], + images=image_inputs, + videos=video_inputs, + padding=True, + return_tensors="pt", + ) + + inputs = inputs.to("cuda") + + # Inference: Generation of the output + generated_ids = self.model.generate(**inputs, max_new_tokens=24000) + generated_ids_trimmed = [ + out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) + ] + response = self.processor.batch_decode( + generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False + )[0] + return response def _inference_with_vllm(self, image, prompt): response = inference_with_vllm( @@ -98,7 +162,10 @@ class DotsOCRParser: image = fetch_image(origin_image, min_pixels=min_pixels, max_pixels=max_pixels) input_height, input_width = smart_resize(image.height, image.width) prompt = self.get_prompt(prompt_mode, bbox, origin_image, image, min_pixels=min_pixels, max_pixels=max_pixels) - response = self._inference_with_vllm(image, prompt) + if self.use_hf: + response = self._inference_with_hf(image, prompt) + else: + response = self._inference_with_vllm(image, prompt) result = {'page_no': page_idx, "input_height": input_height, "input_width": input_width @@ -206,7 +273,10 @@ class DotsOCRParser: def _execute_task(task_args): return self._parse_single_image(**task_args) - num_thread = min(total_pages, self.num_thread) + if self.use_hf: + num_thread = 1 + else: + num_thread = min(total_pages, self.num_thread) print(f"Parsing PDF with {total_pages} pages using {num_thread} threads...") results = [] @@ -321,6 +391,10 @@ def main(): "--max_pixels", type=int, default=None, help="" ) + parser.add_argument( + "--use_hf", type=bool, default=False, + help="" + ) args = parser.parse_args() dots_ocr_parser = DotsOCRParser( @@ -335,6 +409,7 @@ def main(): output_dir=args.output, min_pixels=args.min_pixels, max_pixels=args.max_pixels, + use_hf=args.use_hf, ) result = dots_ocr_parser.parse_file(