UGP/prepare_ag_news.py

105 lines
3.8 KiB
Python

#!/usr/bin/env python3
import json
import logging
from pathlib import Path
from datasets import load_dataset
LOGGER = logging.getLogger(__name__)
def save_limited_data(file_path: Path) -> None:
all_text = file_path.read_text().split("\n")
text = all_text[:2500] + all_text[-2500:]
save_path = file_path.parent / f"{file_path.stem}-5k.json"
save_path.write_text("\n".join(text))
LOGGER.info(f"Saved limited ({len(text)}) version in: {save_path}")
def save_as_translations(original_save_path: Path, data_to_save: list[dict], label_translation: dict) -> None:
file_name = "s2s-" + original_save_path.name
file_path = original_save_path.parent / file_name
LOGGER.info(f"Saving into: {file_path}")
with open(file_path, "wt") as f_write:
for data_line in data_to_save:
label = data_line["label"]
new_label = label_translation.get(label, str(label))
data_line["label"] = new_label
data_line_str = json.dumps(data_line)
f_write.write(f"{data_line_str}\n")
save_limited_data(file_path)
def main() -> None:
logging.basicConfig(level=logging.INFO)
loaded_data = load_dataset("ag_news")
LOGGER.info(f"Loaded dataset ag_news: {loaded_data}")
save_path = Path("data/")
save_train_path = save_path / "train.json"
save_valid_path = save_path / "valid.json"
save_test_path = save_path / "test.json"
if not save_path.exists():
save_path.mkdir()
# Read train and validation data
data_train, data_valid, data_test = [], [], []
for source_data, dataset in [
(loaded_data["train"], data_train),
(loaded_data["test"], data_valid),
]:
for i, data in enumerate(source_data):
data_line = {
"label": int(data["label"]),
"text": data["text"],
}
dataset.append(data_line)
LOGGER.info(f"Train: {len(data_train):6d}")
# Split validation set into 2 classes for validation and test splitting
data_class_1, data_class_2, data_class_3, data_class_4 = [], [], [], []
for data in data_valid:
label = data["label"]
if label == 0:
data_class_1.append(data)
elif label == 1:
data_class_2.append(data)
elif label == 2:
data_class_3.append(data)
elif label == 3:
data_class_4.append(data)
LOGGER.info(f"Label 1: {len(data_class_1):6d}")
LOGGER.info(f"Label 2: {len(data_class_2):6d}")
LOGGER.info(f"Label 3: {len(data_class_3):6d}")
LOGGER.info(f"Label 3: {len(data_class_4):6d}")
# Split 2 classes into validation and test
size_half_class_1 = int(len(data_class_1) / 2)
size_half_class_2 = int(len(data_class_2) / 2)
size_half_class_3 = int(len(data_class_3) / 2)
size_half_class_4 = int(len(data_class_4) / 2)
data_valid = data_class_1[:size_half_class_1] + data_class_2[:size_half_class_2] + data_class_3[:size_half_class_3] + data_class_4[:size_half_class_4]
data_test = data_class_1[size_half_class_1:] + data_class_2[size_half_class_2:] + data_class_3[size_half_class_3:] + data_class_4[size_half_class_4:]
LOGGER.info(f"Valid: {len(data_valid):6d}")
LOGGER.info(f"Test : {len(data_test):6d}")
# Save files
for file_path, data_to_save in [
(save_train_path, data_train),
(save_valid_path, data_valid),
(save_test_path, data_test),
]:
LOGGER.info(f"Saving into: {file_path}")
label_translation = {0: "negative", 1: "positive"}
with open(file_path, "wt") as f_write:
for data_line in data_to_save:
data_line_str = json.dumps(data_line)
f_write.write(f"{data_line_str}\n")
save_limited_data(file_path)
save_as_translations(file_path, data_to_save, label_translation)
if __name__ == "__main__":
main()