#!/usr/bin/env python3 import json import logging from pathlib import Path from typing import List, Dict from datasets import load_dataset logger = logging.getLogger(__name__) MAP_LABEL_TRANSLATION = { 0: 'sadness', 1: 'joy', 2: 'love', 3: 'anger', 4: 'fear', 5: 'surprise', } def save_as_translations(original_save_path: Path, data_to_save: List[Dict]) -> None: file_name = 's2s-' + original_save_path.name file_path = original_save_path.parent / file_name print(f'Saving into: {file_path}') with open(file_path, 'wt') as f_write: for data_line in data_to_save: label = data_line['label'] new_label = MAP_LABEL_TRANSLATION[label] data_line['label'] = new_label data_line_str = json.dumps(data_line) f_write.write(f'{data_line_str}\n') def main() -> None: loaded_data = load_dataset('emotion') logger.info(f'Loaded dataset emotion: {loaded_data}') save_path = Path('data') save_train_path = save_path / 'train.json' save_valid_path = save_path / 'valid.json' save_test_path = save_path / 'test.json' if not save_path.exists(): save_path.mkdir() data_train, data_valid, data_test = [], [], [] for source_data, dataset, max_size in [ (loaded_data['train'], data_train, None), (loaded_data['test'], data_test, None), (loaded_data['validation'], data_valid, None) ]: for i, data in enumerate(source_data): if max_size is not None and i >= max_size: break data_line = { 'label': int(data['label']), 'text': data['text'], } dataset.append(data_line) logger.info(f'Train: {len(data_train):6d}') logger.info(f'Test: {len(data_test):6d}') logger.info(f'Validation: {len(data_valid):6d}') for file_path, data_to_save in [ (save_train_path, data_train), (save_valid_path, data_valid), (save_test_path, data_test) ]: print(f'Saving into: {file_path}') with open(file_path, 'wt') as f_write: for data_line in data_to_save: data_line_str = json.dumps(data_line) f_write.write(f'{data_line_str}\n') save_as_translations(file_path, data_to_save) if __name__ == '__main__': main()