115 lines
3.5 KiB
Python
115 lines
3.5 KiB
Python
#!/usr/bin/env python3
|
|
|
|
import json
|
|
import logging
|
|
from pathlib import Path
|
|
from typing import List, Dict
|
|
|
|
from datasets import load_dataset
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
MAP_LABEL_TRANSLATION = {
|
|
0: 'world',
|
|
1: 'sport',
|
|
2: 'business',
|
|
3: 'scitech'
|
|
}
|
|
|
|
|
|
def save_as_translations(original_save_path: Path, data_to_save: List[Dict]) -> None:
|
|
file_name = 's2s-' + original_save_path.name
|
|
file_path = original_save_path.parent / file_name
|
|
|
|
print(f'Saving into: {file_path}')
|
|
with open(file_path, 'wt') as f_write:
|
|
for data_line in data_to_save:
|
|
label = data_line['label']
|
|
new_label = MAP_LABEL_TRANSLATION[label]
|
|
data_line['label'] = new_label
|
|
data_line_str = json.dumps(data_line)
|
|
f_write.write(f'{data_line_str}\n')
|
|
|
|
|
|
def main() -> None:
|
|
loaded_data = load_dataset('ag_news')
|
|
logger.info(f'Loaded dataset ag_news: {loaded_data}')
|
|
|
|
save_path = Path('data/')
|
|
save_train_path = save_path / 'train.json'
|
|
save_valid_path = save_path / 'valid.json'
|
|
save_test_path = save_path / 'test.json'
|
|
if not save_path.exists():
|
|
save_path.mkdir()
|
|
|
|
# Read train and validation data
|
|
data_train, data_valid, data_test = [], [], []
|
|
for source_data, dataset, max_size in [
|
|
(loaded_data['train'], data_train, None),
|
|
(loaded_data['test'], data_valid, None)
|
|
]:
|
|
for i, data in enumerate(source_data):
|
|
if max_size is not None and i >= max_size:
|
|
break
|
|
data_line = {
|
|
'label': int(data['label']),
|
|
'text': data['text'],
|
|
}
|
|
dataset.append(data_line)
|
|
logger.info(f'Train: {len(data_train):6d}')
|
|
|
|
# Split validation set into 2 classes for validation and test splitting
|
|
world, sport, business, scitech = [], [], [], []
|
|
|
|
for data in data_valid:
|
|
label = data['label']
|
|
if label == 0:
|
|
world.append(data)
|
|
elif label == 1:
|
|
sport.append(data)
|
|
elif label == 2:
|
|
business.append(data)
|
|
elif label == 3:
|
|
scitech.append(data)
|
|
|
|
logger.info(f'World: {len(world):6d}')
|
|
logger.info(f'Sport: {len(sport):6d}')
|
|
logger.info(f'Business: {len(business):6d}')
|
|
logger.info(f'Scitech: {len(scitech):6d}')
|
|
|
|
print(world)
|
|
print(f'World: {len(world)}')
|
|
print(f'Sport: {len(sport):6d}')
|
|
print(f'Business: {len(business):6d}')
|
|
print(f'Scitech: {len(scitech):6d}')
|
|
|
|
|
|
# Split 2 classes into validation and test
|
|
size_half_world = int(len(world) / 2)
|
|
size_half_sport = int(len(sport) / 2)
|
|
size_half_business = int(len(business) / 2)
|
|
size_half_scitech = int(len(scitech) / 2)
|
|
logger.info(f'Valid: {len(data_valid):6d}')
|
|
logger.info(f'Test : {len(data_test):6d}')
|
|
|
|
data_valid = world[:size_half_world] + sport[:size_half_sport] + business[:size_half_business] + scitech[:size_half_scitech]
|
|
data_test = world[size_half_world:] + sport[size_half_sport:] + business[size_half_business:] + scitech[size_half_scitech:]
|
|
|
|
# Save files
|
|
for file_path, data_to_save in [
|
|
(save_train_path, data_train),
|
|
(save_valid_path, data_valid),
|
|
(save_test_path, data_test)
|
|
]:
|
|
print(f'Saving into: {file_path}')
|
|
with open(file_path, 'wt') as f_write:
|
|
for data_line in data_to_save:
|
|
data_line_str = json.dumps(data_line)
|
|
f_write.write(f'{data_line_str}\n')
|
|
|
|
save_as_translations(file_path, data_to_save)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|