gitignore, reqs, train_stream test

This commit is contained in:
zzombely 2023-03-14 21:43:21 +00:00
parent 7fadabea91
commit ccb1c9c2ed
3 changed files with 20 additions and 7 deletions

1
.gitignore vendored
View File

@ -5,3 +5,4 @@ wandb
__pycache__/
checkpoint
.vscode
donut_env

View File

@ -14,7 +14,7 @@ beautifulsoup4==4.11.1
bleach==5.0.1
blend-modes==2.1.0
cachetools==5.2.0
certifi
certifi==2022.12.7
cffi==1.15.1
charset-normalizer==2.1.1
click==8.1.3
@ -26,7 +26,7 @@ decorator==5.1.1
defusedxml==0.7.1
dill==0.3.6
docker-pycreds==0.4.0
donut-python
donut-python==1.0.9
entrypoints==0.4
evaluate==0.3.0
fastapi==0.87.0
@ -109,8 +109,9 @@ parso==0.8.3
pathtools==0.1.2
pexpect==4.8.0
pickleshare==0.7.5
Pillow==9.3.0
Pillow==9.4.0
pkgutil_resolve_name==1.3.10
portalocker==2.7.0
prometheus-client==0.15.0
promise==2.3
prompt-toolkit==3.0.33
@ -175,9 +176,10 @@ tifffile==2021.11.2
timm==0.6.11
tinycss2==1.2.1
tokenizers==0.13.2
torch==1.13.0
torchmetrics==0.10.3
torchvision==0.14.0
torch==1.13.1
torchdata==0.5.1
torchmetrics==0.11.4
torchvision==0.14.1
tornado==6.2
tqdm==4.64.1
traitlets==5.5.0

View File

@ -19,6 +19,16 @@ from torchdata.datapipes.iter import IterableWrapper
import json
class TestIterator(IterableWrapper):
def __init__(self, iterable, deepcopy=True, total_len=None):
super().__init__(iterable, deepcopy)
self.total_len = total_len
def __len__(self):
if self.total_len:
return self.total_len
return super().__len__()
def main(config, hug_token):
config_vision = VisionEncoderDecoderConfig.from_pretrained(