train_len
This commit is contained in:
parent
ccb1c9c2ed
commit
30369a7885
@ -171,8 +171,8 @@ def main(config, hug_token):
|
||||
dataset = load_dataset(config.dataset_path, streaming=True)
|
||||
val_dataset = dataset.pop('validation')
|
||||
train_dataset = interleave_datasets(list(dataset.values()))
|
||||
# train_length = sum(split.num_examples for split in dataset[list(dataset.keys())[0]].info.splits.values() if split.name != 'validation')
|
||||
# val_length = list(val_dataset.info.splits.values())[-1].num_examples
|
||||
train_length = sum(split.num_examples for split in dataset[list(dataset.keys())[0]].info.splits.values() if split.name != 'validation')
|
||||
val_length = list(val_dataset.info.splits.values())[-1].num_examples
|
||||
|
||||
|
||||
train_dataset = train_dataset.map(proces_train, remove_columns = ['image', 'ground_truth'])
|
||||
@ -181,8 +181,8 @@ def main(config, hug_token):
|
||||
# train_dataset = train_dataset.with_format('torch')
|
||||
# val_dataset = val_dataset.with_format('torch')
|
||||
|
||||
train_dataset = IterableWrapper(train_dataset)
|
||||
val_dataset = IterableWrapper(val_dataset)
|
||||
train_dataset = TestIterator(train_dataset, total_len=train_length)
|
||||
val_dataset = TestIterator(val_dataset, total_len=val_length)
|
||||
|
||||
model.config.pad_token_id = processor.tokenizer.pad_token_id
|
||||
model.config.decoder_start_token_id = processor.tokenizer.convert_tokens_to_ids(['<s_cord-v2>'])[0]
|
||||
|
Loading…
Reference in New Issue
Block a user