working streaming training
This commit is contained in:
parent
d9cbe1f9a8
commit
34316df6f5
@ -29,7 +29,7 @@ class DonutModelPLModuleStream(pl.LightningModule):
|
|||||||
|
|
||||||
pixel_values = batch['pixel_values']
|
pixel_values = batch['pixel_values']
|
||||||
labels = batch['labels']
|
labels = batch['labels']
|
||||||
answers = batch['target_sequence']
|
answers = batch['target_sequence'][0]
|
||||||
batch_size = pixel_values.shape[0]
|
batch_size = pixel_values.shape[0]
|
||||||
# we feed the prompt to the model
|
# we feed the prompt to the model
|
||||||
decoder_input_ids = torch.full((batch_size, 1), self.model.config.decoder_start_token_id, device=self.device)
|
decoder_input_ids = torch.full((batch_size, 1), self.model.config.decoder_start_token_id, device=self.device)
|
||||||
|
Loading…
Reference in New Issue
Block a user