diff --git a/utils/donut_model_pl_stream.py b/utils/donut_model_pl_stream.py index 79fd91d..44c1d00 100644 --- a/utils/donut_model_pl_stream.py +++ b/utils/donut_model_pl_stream.py @@ -29,7 +29,7 @@ class DonutModelPLModuleStream(pl.LightningModule): pixel_values = batch['pixel_values'] labels = batch['labels'] - answers = batch['target_sequence'] + answers = batch['target_sequence'][0] batch_size = pixel_values.shape[0] # we feed the prompt to the model decoder_input_ids = torch.full((batch_size, 1), self.model.config.decoder_start_token_id, device=self.device)