working streaming training

This commit is contained in:
michal.kozlowski 2023-03-17 15:18:40 +01:00
parent d9cbe1f9a8
commit 34316df6f5

View File

@ -29,7 +29,7 @@ class DonutModelPLModuleStream(pl.LightningModule):
pixel_values = batch['pixel_values'] pixel_values = batch['pixel_values']
labels = batch['labels'] labels = batch['labels']
answers = batch['target_sequence'] answers = batch['target_sequence'][0]
batch_size = pixel_values.shape[0] batch_size = pixel_values.shape[0]
# we feed the prompt to the model # we feed the prompt to the model
decoder_input_ids = torch.full((batch_size, 1), self.model.config.decoder_start_token_id, device=self.device) decoder_input_ids = torch.full((batch_size, 1), self.model.config.decoder_start_token_id, device=self.device)