fixing donut

This commit is contained in:
Michał Kozłowski 2023-01-24 18:13:25 +01:00
parent e867ce77dc
commit 3ad876ea69
4 changed files with 11 additions and 82 deletions

View File

@ -1,77 +0,0 @@
from transformers import DonutProcessor, VisionEncoderDecoderModel, VisionEncoderDecoderConfig
import re
import torch
from PIL import Image
import time
from fastapi import FastAPI, UploadFile, File
import io
import os
from sys import platform
image_size = [1920, 2560]
print("Set up config")
config_vision = VisionEncoderDecoderConfig.from_pretrained("Zombely/plwiki-proto-fine-tuned-v3.2")
config_vision.encoder.image_size = image_size # (height, width)
config_vision.decoder.max_length = 768
processor = DonutProcessor.from_pretrained("Zombely/plwiki-proto-fine-tuned-v3.2")
model = VisionEncoderDecoderModel.from_pretrained("Zombely/plwiki-proto-fine-tuned-v3.2", config=config_vision)
processor.image_processor.size = image_size[::-1] # should be (width, height)
processor.image_processor.do_align_long_axis = False
# dataset = load_dataset(config.validation_dataset_path, split=config.validation_dataset_split)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.eval()
model.to(device)
print("Print ipconfig")
if platform == "linux":
os.system("ip r")
else:
os.system("ipconfig")
print("Starting server")
app = FastAPI()
@app.get("/test")
async def test():
return {"message": "Test"}
@app.post("/process")
async def process_image(file: UploadFile= File(...)):
request_object_content = await file.read()
input_image = Image.open(io.BytesIO(request_object_content))
# prepare encoder inputs
pixel_values = processor(input_image.convert("RGB"), return_tensors="pt").pixel_values
pixel_values = pixel_values.to(device)
# prepare decoder inputs
task_prompt = "<s_cord-v2>"
decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids
decoder_input_ids = decoder_input_ids.to(device)
print("Start processing")
# autoregressively generate sequence
start_time = time.time()
outputs = model.generate(
pixel_values,
decoder_input_ids=decoder_input_ids,
max_length=model.decoder.config.max_position_embeddings,
early_stopping=True,
pad_token_id=processor.tokenizer.pad_token_id,
eos_token_id=processor.tokenizer.eos_token_id,
use_cache=True,
num_beams=1,
bad_words_ids=[[processor.tokenizer.unk_token_id]],
return_dict_in_generate=True,
)
processing_time = (time.time() - start_time)
# turn into JSON
seq = processor.batch_decode(outputs.sequences)[0]
seq = seq.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
seq = re.sub(r"<.*?>", "", seq, count=1).strip() # remove first task start token
seq = processor.token2json(seq)
return {"data": seq['text_sequence'], "processing_time": f"{processing_time} seconds"}

Binary file not shown.

Before

Width:  |  Height:  |  Size: 355 KiB

View File

@ -13,6 +13,8 @@ from utils.donut_dataset import DonutDataset
from utils.donut_model_pl import DonutModelPLModule
from utils.callbacks import PushToHubCallback
import warnings
from datasets import load_dataset
@ -32,8 +34,12 @@ def main(config, hug_token):
added_tokens = []
dataset = load_dataset(config.dataset_path, split='train', streaming=True)
train_dataset = dataset.skip(100)
validation_dataset = dataset.take(100)
train_dataset = DonutDataset(
config.dataset_path,
train_dataset,
processor=processor,
model=model,
max_length=config.max_length,
@ -45,7 +51,7 @@ def main(config, hug_token):
)
val_dataset = DonutDataset(
config.dataset_path,
validation_dataset,
processor=processor,
model=model,
max_length=config.max_length,

View File

@ -24,7 +24,7 @@ class DonutDataset(Dataset):
def __init__(
self,
dataset_name_or_path: str,
dataset: Dataset,
max_length: int,
processor: DonutProcessor,
model: VisionEncoderDecoderModel,
@ -47,8 +47,8 @@ class DonutDataset(Dataset):
self.sort_json_key = sort_json_key
self.added_tokens = added_tokens
self.dataset = load_dataset(dataset_name_or_path, split=self.split)
self.dataset_length = len(self.dataset)
self.dataset = dataset
self.dataset_length = len(list(self.dataset))
self.gt_token_sequences = []
for sample in self.dataset: