diff --git a/train_stream.py b/train_stream.py index bb339de..eee6e62 100644 --- a/train_stream.py +++ b/train_stream.py @@ -171,8 +171,8 @@ def main(config, hug_token): dataset = load_dataset(config.dataset_path, streaming=True) val_dataset = dataset.pop('validation') train_dataset = interleave_datasets(list(dataset.values())) - # train_length = sum(split.num_examples for split in dataset[list(dataset.keys())[0]].info.splits.values() if split.name != 'validation') - # val_length = list(val_dataset.info.splits.values())[-1].num_examples + train_length = sum(split.num_examples for split in dataset[list(dataset.keys())[0]].info.splits.values() if split.name != 'validation') + val_length = list(val_dataset.info.splits.values())[-1].num_examples train_dataset = train_dataset.map(proces_train, remove_columns = ['image', 'ground_truth']) @@ -181,8 +181,8 @@ def main(config, hug_token): # train_dataset = train_dataset.with_format('torch') # val_dataset = val_dataset.with_format('torch') - train_dataset = IterableWrapper(train_dataset) - val_dataset = IterableWrapper(val_dataset) + train_dataset = TestIterator(train_dataset, total_len=train_length) + val_dataset = TestIterator(val_dataset, total_len=val_length) model.config.pad_token_id = processor.tokenizer.pad_token_id model.config.decoder_start_token_id = processor.tokenizer.convert_tokens_to_ids([''])[0]