UG-final/data_prep.py

80 lines
2.3 KiB
Python
Raw Normal View History

2023-02-14 23:44:20 +01:00
#!/usr/bin/env python3
import json
import logging
from pathlib import Path
from typing import List, Dict
from datasets import load_dataset
logger = logging.getLogger(__name__)
MAP_LABEL_TRANSLATION = {
0: 'sadness',
1: 'joy',
2: 'love',
3: 'anger',
4: 'fear',
5: 'surprise',
}
def save_as_translations(original_save_path: Path, data_to_save: List[Dict]) -> None:
file_name = 's2s-' + original_save_path.name
file_path = original_save_path.parent / file_name
print(f'Saving into: {file_path}')
with open(file_path, 'wt') as f_write:
for data_line in data_to_save:
label = data_line['label']
new_label = MAP_LABEL_TRANSLATION[label]
data_line['label'] = new_label
data_line_str = json.dumps(data_line)
f_write.write(f'{data_line_str}\n')
def main() -> None:
loaded_data = load_dataset('emotion')
logger.info(f'Loaded dataset emotion: {loaded_data}')
save_path = Path('data')
save_train_path = save_path / 'train.json'
save_valid_path = save_path / 'valid.json'
save_test_path = save_path / 'test.json'
if not save_path.exists():
save_path.mkdir()
data_train, data_valid, data_test = [], [], []
for source_data, dataset, max_size in [
(loaded_data['train'], data_train, None),
(loaded_data['test'], data_test, None),
(loaded_data['validation'], data_valid, None)
]:
for i, data in enumerate(source_data):
if max_size is not None and i >= max_size:
break
data_line = {
'label': int(data['label']),
'text': data['text'],
}
dataset.append(data_line)
logger.info(f'Train: {len(data_train):6d}')
logger.info(f'Test: {len(data_test):6d}')
logger.info(f'Validation: {len(data_valid):6d}')
for file_path, data_to_save in [
(save_train_path, data_train),
(save_valid_path, data_valid),
(save_test_path, data_test)
]:
print(f'Saving into: {file_path}')
with open(file_path, 'wt') as f_write:
for data_line in data_to_save:
data_line_str = json.dumps(data_line)
f_write.write(f'{data_line_str}\n')
save_as_translations(file_path, data_to_save)
if __name__ == '__main__':
main()