From ef0faa1f8c4ab49d364479c195488f32601b86fb Mon Sep 17 00:00:00 2001 From: zzombely Date: Tue, 14 Mar 2023 22:01:50 +0000 Subject: [PATCH] Custom wrapper --- train_stream.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/train_stream.py b/train_stream.py index eee6e62..86fb836 100644 --- a/train_stream.py +++ b/train_stream.py @@ -19,7 +19,7 @@ from torchdata.datapipes.iter import IterableWrapper import json -class TestIterator(IterableWrapper): +class CustomWrapperIterator(IterableWrapper): def __init__(self, iterable, deepcopy=True, total_len=None): super().__init__(iterable, deepcopy) self.total_len = total_len @@ -181,8 +181,8 @@ def main(config, hug_token): # train_dataset = train_dataset.with_format('torch') # val_dataset = val_dataset.with_format('torch') - train_dataset = TestIterator(train_dataset, total_len=train_length) - val_dataset = TestIterator(val_dataset, total_len=val_length) + train_dataset = CustomWrapperIterator(train_dataset, total_len=train_length) + val_dataset = CustomWrapperIterator(val_dataset, total_len=val_length) model.config.pad_token_id = processor.tokenizer.pad_token_id model.config.decoder_start_token_id = processor.tokenizer.convert_tokens_to_ids([''])[0]