from transformers import DonutProcessor, VisionEncoderDecoderModel, VisionEncoderDecoderConfig import re import torch from PIL import Image import time from fastapi import FastAPI, UploadFile, File import io import os from sys import platform image_size = [1920, 2560] print("Set up config") config_vision = VisionEncoderDecoderConfig.from_pretrained("Zombely/plwiki-proto-fine-tuned-v3.2") config_vision.encoder.image_size = image_size # (height, width) config_vision.decoder.max_length = 768 processor = DonutProcessor.from_pretrained("Zombely/plwiki-proto-fine-tuned-v3.2") model = VisionEncoderDecoderModel.from_pretrained("Zombely/plwiki-proto-fine-tuned-v3.2", config=config_vision) processor.image_processor.size = image_size[::-1] # should be (width, height) processor.image_processor.do_align_long_axis = False # dataset = load_dataset(config.validation_dataset_path, split=config.validation_dataset_split) device = "cuda" if torch.cuda.is_available() else "cpu" model.eval() model.to(device) print("Print ipconfig") if platform.linux: os.system("ip r") else: os.system("ipconfig") print("Starting server") app = FastAPI() @app.get("/test") async def test(): return {"message": "Test"} @app.post("/process") async def process_image(file: UploadFile= File(...)): request_object_content = await file.read() input_image = Image.open(io.BytesIO(request_object_content)) # prepare encoder inputs pixel_values = processor(input_image.convert("RGB"), return_tensors="pt").pixel_values pixel_values = pixel_values.to(device) # prepare decoder inputs task_prompt = "" decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids decoder_input_ids = decoder_input_ids.to(device) print("Start processing") # autoregressively generate sequence start_time = time.time() outputs = model.generate( pixel_values, decoder_input_ids=decoder_input_ids, max_length=model.decoder.config.max_position_embeddings, early_stopping=True, pad_token_id=processor.tokenizer.pad_token_id, eos_token_id=processor.tokenizer.eos_token_id, use_cache=True, num_beams=1, bad_words_ids=[[processor.tokenizer.unk_token_id]], return_dict_in_generate=True, ) processing_time = (time.time() - start_time) # turn into JSON seq = processor.batch_decode(outputs.sequences)[0] seq = seq.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "") seq = re.sub(r"<.*?>", "", seq, count=1).strip() # remove first task start token seq = processor.token2json(seq) return {"data": seq['text_sequence'], "processing_time": f"{processing_time} seconds"}