working streaming training
This commit is contained in:
parent
d9cbe1f9a8
commit
34316df6f5
@ -29,7 +29,7 @@ class DonutModelPLModuleStream(pl.LightningModule):
|
||||
|
||||
pixel_values = batch['pixel_values']
|
||||
labels = batch['labels']
|
||||
answers = batch['target_sequence']
|
||||
answers = batch['target_sequence'][0]
|
||||
batch_size = pixel_values.shape[0]
|
||||
# we feed the prompt to the model
|
||||
decoder_input_ids = torch.full((batch_size, 1), self.model.config.decoder_start_token_id, device=self.device)
|
||||
|
Loading…
Reference in New Issue
Block a user