donut/single_eval.py
mkozlowskiAzimuthe 28104dd686 lower params
2023-01-22 23:46:09 +01:00

74 lines
2.7 KiB
Python

from transformers import DonutProcessor, VisionEncoderDecoderModel, VisionEncoderDecoderConfig
import re
import torch
from PIL import Image
import time
from fastapi import FastAPI, UploadFile, File
import io
import os
print("Set up config")
image_size = [768, 1280]
config_vision = VisionEncoderDecoderConfig.from_pretrained("Zombely/plwiki-proto-fine-tuned-v3.2")
config_vision.encoder.image_size = image_size # (height, width)
config_vision.decoder.max_length = 768
processor = DonutProcessor.from_pretrained("Zombely/plwiki-proto-fine-tuned-v3.2")
model = VisionEncoderDecoderModel.from_pretrained("Zombely/plwiki-proto-fine-tuned-v3.2", config=config_vision)
processor.image_processor.size = image_size[::-1] # should be (width, height)
processor.image_processor.do_align_long_axis = False
# dataset = load_dataset(config.validation_dataset_path, split=config.validation_dataset_split)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.eval()
model.to(device)
print("Print ipconfig")
os.system("ipconfig")
print("Starting server")
app = FastAPI()
@app.get("/test")
async def test():
return {"message": "Test"}
@app.post("/process")
async def process_image(file: UploadFile= File(...)):
request_object_content = await file.read()
input_image = Image.open(io.BytesIO(request_object_content))
# prepare encoder inputs
pixel_values = processor(input_image.convert("RGB"), return_tensors="pt").pixel_values
pixel_values = pixel_values.to(device)
# prepare decoder inputs
task_prompt = "<s_cord-v2>"
decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids
decoder_input_ids = decoder_input_ids.to(device)
print("Start processing")
# autoregressively generate sequence
start_time = time.time()
outputs = model.generate(
pixel_values,
decoder_input_ids=decoder_input_ids,
max_length=model.decoder.config.max_position_embeddings,
early_stopping=True,
pad_token_id=processor.tokenizer.pad_token_id,
eos_token_id=processor.tokenizer.eos_token_id,
use_cache=True,
num_beams=1,
bad_words_ids=[[processor.tokenizer.unk_token_id]],
return_dict_in_generate=True,
)
processing_time = (time.time() - start_time)
# turn into JSON
seq = processor.batch_decode(outputs.sequences)[0]
seq = seq.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
seq = re.sub(r"<.*?>", "", seq, count=1).strip() # remove first task start token
seq = processor.token2json(seq)
return {"data": seq['text_sequence'], "processing_time": f"{processing_time} seconds"}