This commit is contained in:
zhangwei13
2026-03-24 22:43:01 +08:00
parent d9ea2a4108
commit 36d7248878
59 changed files with 396 additions and 390 deletions
+3 -2
View File
@@ -53,13 +53,14 @@ def inference(image_path, prompt, model, processor):
if __name__ == "__main__":
# We recommend enabling flash_attention_2 for better acceleration and memory saving, especially in multi-image and video scenarios.
model_path = "./weights/DotsOCR"
# We recommend enabling flash_attention_2 or flash_attention_3 for better acceleration and memory saving, especially in multi-image and video scenarios.
model_path = "./weights/DotsMOCR"
model = AutoModelForCausalLM.from_pretrained(
model_path,
attn_implementation="flash_attention_2",
torch_dtype=torch.bfloat16,
device_map="auto",
# device_map="cpu", # ve里默认使用flash-attn,无法直接运行
trust_remote_code=True
)
processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)