92 lines
4.0 KiB
Python
92 lines
4.0 KiB
Python
import os
|
|
import argparse
|
|
import pandas as pd
|
|
import requests
|
|
from PIL import Image
|
|
from tqdm import tqdm
|
|
import pickle
|
|
import time
|
|
from pprint import pprint
|
|
import json
|
|
from datasets import load_dataset
|
|
from huggingface_hub import login
|
|
import shutil
|
|
|
|
headers = {'User-Agent': 'ImageDownloadOcrBot/1.0 (no.rp.mk.info@gmail.com) requests/2.28.1'}
|
|
|
|
class WikiImage:
|
|
|
|
def __init__(self, input_file_path: str, dataset_name: str, output_folder: str = 'temp_images', split_number: int = 1):
|
|
self.input_file_path = input_file_path
|
|
self.split_number = split_number
|
|
self.max_dataset_len = 10000
|
|
self.output_folder = output_folder
|
|
self.dataset_name = dataset_name
|
|
print("Loading input file")
|
|
self.dataframe = pd.read_csv(self.input_file_path, sep='\t')[(self.split_number - 1) * self.max_dataset_len:]
|
|
if os.path.exists(self.output_folder):
|
|
print("Removing old dear")
|
|
if os.path.exists('/home/zombely/.cache/huggingface/datasets'):
|
|
shutil.rmtree('/home/zombely/.cache/huggingface/datasets')
|
|
shutil.rmtree(self.output_folder)
|
|
os.mkdir(self.output_folder)
|
|
self.pbar = tqdm(self.dataframe.iterrows(), total=len(self.dataframe), desc=f"Split: {self.split_number}")
|
|
|
|
login(os.environ.get("HUG_TOKEN"), True)
|
|
|
|
def image_save(self, row):
|
|
time.sleep(0.3)
|
|
image_request = requests.get(f"https:{row[1]['image_url']}", stream=True, headers=headers)
|
|
if image_request.status_code in [500, 404]:
|
|
print(f"Image {row[1]['title']} is not reacheable")
|
|
return
|
|
if image_request.status_code != 200:
|
|
time.sleep(80)
|
|
image_request = requests.get(f"https:{row[1]['image_url']}", stream=True, headers=headers)
|
|
assert image_request.status_code == 200, f"Response status is diffrent, status_code: {image_request.status_code}, full info: {image_request.__dict__}"
|
|
|
|
image = Image.open(image_request.raw)
|
|
if image.mode != "RGB":
|
|
image = image.convert("RGB")
|
|
title = row[1]['title'].replace("Strona:", "").replace("/", "-")
|
|
image.save(f"{self.output_folder}/{title}.png")
|
|
|
|
with open(f"{self.output_folder}/metadata.jsonl", mode='a', encoding='utf-8') as f:
|
|
# f.write(str({"file_name": f"{title}.png", "ground_truth": json.dumps({"gt_parse": {"text_sequance": row[1]['text'].replace('"', "'")}}, ensure_ascii=False)})+"\n")
|
|
json.dump({"file_name": f"{title}.png", "ground_truth": json.dumps({"gt_parse": {"text_sequance": row[1]['text'].replace('"', "'")}}, ensure_ascii=False)}, f, ensure_ascii=False)
|
|
f.write("\n")
|
|
|
|
def push_dataset(self, split_name: str):
|
|
print(f"Pushing split: {split_name}")
|
|
dataset = load_dataset(self.output_folder)
|
|
dataset[split_name] = dataset.pop('train')
|
|
dataset.push_to_hub(f'Zombely/{self.dataset_name}')
|
|
shutil.rmtree(self.output_folder)
|
|
shutil.rmtree('/home/zombely/.cache/huggingface/datasets')
|
|
os.mkdir(self.output_folder)
|
|
del dataset
|
|
print("Upload finished")
|
|
|
|
def crawl(self):
|
|
print("Start download")
|
|
for index, row in enumerate(self.pbar):
|
|
self.image_save(row)
|
|
if (index + 1) % self.max_dataset_len == 0:
|
|
self.push_dataset(f'train_{self.split_number}')
|
|
self.split_number += 1
|
|
self.pbar.set_description(f'Split: {self.split_number}')
|
|
|
|
self.push_dataset('validation')
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--input_file_path", type=str, required=True)
|
|
parser.add_argument("--dataset_name", type=str, required=True)
|
|
parser.add_argument("--output_folder", type=str, required=False, default='temp_images')
|
|
parser.add_argument("--split_number", type=int, required=False, default=1)
|
|
args, left_argv = parser.parse_known_args()
|
|
crawler = WikiImage(**vars(args))
|
|
crawler.crawl()
|
|
|