restore donut
This commit is contained in:
parent
1c22eaabf9
commit
ab051e21b1
8
train.py
8
train.py
@ -34,12 +34,8 @@ def main(config, hug_token):
|
|||||||
|
|
||||||
added_tokens = []
|
added_tokens = []
|
||||||
|
|
||||||
dataset = load_dataset(config.dataset_path, split='train')
|
|
||||||
validation_dataset = dataset.take(100)
|
|
||||||
train_dataset = dataset.skip(10000)
|
|
||||||
|
|
||||||
train_dataset = DonutDataset(
|
train_dataset = DonutDataset(
|
||||||
train_dataset,
|
config.dataset_path,
|
||||||
processor=processor,
|
processor=processor,
|
||||||
model=model,
|
model=model,
|
||||||
max_length=config.max_length,
|
max_length=config.max_length,
|
||||||
@ -51,7 +47,7 @@ def main(config, hug_token):
|
|||||||
)
|
)
|
||||||
|
|
||||||
val_dataset = DonutDataset(
|
val_dataset = DonutDataset(
|
||||||
validation_dataset,
|
config.dataset_path,
|
||||||
processor=processor,
|
processor=processor,
|
||||||
model=model,
|
model=model,
|
||||||
max_length=config.max_length,
|
max_length=config.max_length,
|
||||||
|
@ -24,7 +24,7 @@ class DonutDataset(Dataset):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
dataset: Dataset,
|
dataset_name_or_path: str,
|
||||||
max_length: int,
|
max_length: int,
|
||||||
processor: DonutProcessor,
|
processor: DonutProcessor,
|
||||||
model: VisionEncoderDecoderModel,
|
model: VisionEncoderDecoderModel,
|
||||||
@ -47,7 +47,7 @@ class DonutDataset(Dataset):
|
|||||||
self.sort_json_key = sort_json_key
|
self.sort_json_key = sort_json_key
|
||||||
self.added_tokens = added_tokens
|
self.added_tokens = added_tokens
|
||||||
|
|
||||||
self.dataset = dataset
|
self.dataset = load_dataset(dataset_name_or_path, split=self.split)
|
||||||
self.dataset_length = len(self.dataset)
|
self.dataset_length = len(self.dataset)
|
||||||
|
|
||||||
self.gt_token_sequences = []
|
self.gt_token_sequences = []
|
||||||
|
Loading…
Reference in New Issue
Block a user