wikisource-crawler/image_class.py
2023-03-12 15:57:35 +00:00

92 lines
4.0 KiB
Python

import os
import argparse
import pandas as pd
import requests
from PIL import Image
from tqdm import tqdm
import pickle
import time
from pprint import pprint
import json
from datasets import load_dataset
from huggingface_hub import login
import shutil
headers = {'User-Agent': 'ImageDownloadOcrBot/1.0 (no.rp.mk.info@gmail.com) requests/2.28.1'}
class WikiImage:
def __init__(self, input_file_path: str, dataset_name: str, output_folder: str = 'temp_images', split_number: int = 1):
self.input_file_path = input_file_path
self.split_number = split_number
self.max_dataset_len = 10000
self.output_folder = output_folder
self.dataset_name = dataset_name
print("Loading input file")
self.dataframe = pd.read_csv(self.input_file_path, sep='\t')[(self.split_number - 1) * self.max_dataset_len:]
if os.path.exists(self.output_folder):
print("Removing old dear")
if os.path.exists('/home/zombely/.cache/huggingface/datasets'):
shutil.rmtree('/home/zombely/.cache/huggingface/datasets')
shutil.rmtree(self.output_folder)
os.mkdir(self.output_folder)
self.pbar = tqdm(self.dataframe.iterrows(), total=len(self.dataframe), desc=f"Split: {self.split_number}")
login(os.environ.get("HUG_TOKEN"), True)
def image_save(self, row):
time.sleep(0.3)
image_request = requests.get(f"https:{row[1]['image_url']}", stream=True, headers=headers)
if image_request.status_code in [500, 404]:
print(f"Image {row[1]['title']} is not reacheable")
return
if image_request.status_code != 200:
time.sleep(80)
image_request = requests.get(f"https:{row[1]['image_url']}", stream=True, headers=headers)
assert image_request.status_code == 200, f"Response status is diffrent, status_code: {image_request.status_code}, full info: {image_request.__dict__}"
image = Image.open(image_request.raw)
if image.mode != "RGB":
image = image.convert("RGB")
title = row[1]['title'].replace("Strona:", "").replace("/", "-")
image.save(f"{self.output_folder}/{title}.png")
with open(f"{self.output_folder}/metadata.jsonl", mode='a', encoding='utf-8') as f:
# f.write(str({"file_name": f"{title}.png", "ground_truth": json.dumps({"gt_parse": {"text_sequance": row[1]['text'].replace('"', "'")}}, ensure_ascii=False)})+"\n")
json.dump({"file_name": f"{title}.png", "ground_truth": json.dumps({"gt_parse": {"text_sequance": row[1]['text'].replace('"', "'")}}, ensure_ascii=False)}, f, ensure_ascii=False)
f.write("\n")
def push_dataset(self, split_name: str):
print(f"Pushing split: {split_name}")
dataset = load_dataset(self.output_folder)
dataset[split_name] = dataset.pop('train')
dataset.push_to_hub(f'Zombely/{self.dataset_name}')
shutil.rmtree(self.output_folder)
shutil.rmtree('/home/zombely/.cache/huggingface/datasets')
os.mkdir(self.output_folder)
del dataset
print("Upload finished")
def crawl(self):
print("Start download")
for index, row in enumerate(self.pbar):
self.image_save(row)
if (index + 1) % self.max_dataset_len == 0:
self.push_dataset(f'train_{self.split_number}')
self.split_number += 1
self.pbar.set_description(f'Split: {self.split_number}')
self.push_dataset('validation')
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--input_file_path", type=str, required=True)
parser.add_argument("--dataset_name", type=str, required=True)
parser.add_argument("--output_folder", type=str, required=False, default='temp_images')
parser.add_argument("--split_number", type=int, required=False, default=1)
args, left_argv = parser.parse_known_args()
crawler = WikiImage(**vars(args))
crawler.crawl()