Custom wrapper
This commit is contained in:
parent
30369a7885
commit
ef0faa1f8c
@ -19,7 +19,7 @@ from torchdata.datapipes.iter import IterableWrapper
|
|||||||
import json
|
import json
|
||||||
|
|
||||||
|
|
||||||
class TestIterator(IterableWrapper):
|
class CustomWrapperIterator(IterableWrapper):
|
||||||
def __init__(self, iterable, deepcopy=True, total_len=None):
|
def __init__(self, iterable, deepcopy=True, total_len=None):
|
||||||
super().__init__(iterable, deepcopy)
|
super().__init__(iterable, deepcopy)
|
||||||
self.total_len = total_len
|
self.total_len = total_len
|
||||||
@ -181,8 +181,8 @@ def main(config, hug_token):
|
|||||||
# train_dataset = train_dataset.with_format('torch')
|
# train_dataset = train_dataset.with_format('torch')
|
||||||
# val_dataset = val_dataset.with_format('torch')
|
# val_dataset = val_dataset.with_format('torch')
|
||||||
|
|
||||||
train_dataset = TestIterator(train_dataset, total_len=train_length)
|
train_dataset = CustomWrapperIterator(train_dataset, total_len=train_length)
|
||||||
val_dataset = TestIterator(val_dataset, total_len=val_length)
|
val_dataset = CustomWrapperIterator(val_dataset, total_len=val_length)
|
||||||
|
|
||||||
model.config.pad_token_id = processor.tokenizer.pad_token_id
|
model.config.pad_token_id = processor.tokenizer.pad_token_id
|
||||||
model.config.decoder_start_token_id = processor.tokenizer.convert_tokens_to_ids(['<s_cord-v2>'])[0]
|
model.config.decoder_start_token_id = processor.tokenizer.convert_tokens_to_ids(['<s_cord-v2>'])[0]
|
||||||
|
Loading…
Reference in New Issue
Block a user