diff --git a/train_stream.py b/train_stream.py index eee6e62..86fb836 100644 --- a/train_stream.py +++ b/train_stream.py @@ -19,7 +19,7 @@ from torchdata.datapipes.iter import IterableWrapper import json -class TestIterator(IterableWrapper): +class CustomWrapperIterator(IterableWrapper): def __init__(self, iterable, deepcopy=True, total_len=None): super().__init__(iterable, deepcopy) self.total_len = total_len @@ -181,8 +181,8 @@ def main(config, hug_token): # train_dataset = train_dataset.with_format('torch') # val_dataset = val_dataset.with_format('torch') - train_dataset = TestIterator(train_dataset, total_len=train_length) - val_dataset = TestIterator(val_dataset, total_len=val_length) + train_dataset = CustomWrapperIterator(train_dataset, total_len=train_length) + val_dataset = CustomWrapperIterator(val_dataset, total_len=val_length) model.config.pad_token_id = processor.tokenizer.pad_token_id model.config.decoder_start_token_id = processor.tokenizer.convert_tokens_to_ids([''])[0]