48 KiB
48 KiB
Setup
Requirements
!pip install torch
!pip install datasets
!pip install transformers
!pip install scikit-learn
!pip install evaluate
!pip install accelerate
!pip install sentencepiece
!pip install protobuf
!pip install sacrebleu
!pip install py7zr
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/ Requirement already satisfied: torch in /usr/local/lib/python3.8/dist-packages (1.13.1+cu116) Requirement already satisfied: typing-extensions in /usr/local/lib/python3.8/dist-packages (from torch) (4.4.0) Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/ Requirement already satisfied: datasets in /usr/local/lib/python3.8/dist-packages (2.9.0) Requirement already satisfied: fsspec[http]>=2021.11.1 in /usr/local/lib/python3.8/dist-packages (from datasets) (2023.1.0) Requirement already satisfied: tqdm>=4.62.1 in /usr/local/lib/python3.8/dist-packages (from datasets) (4.64.1) Requirement already satisfied: xxhash in /usr/local/lib/python3.8/dist-packages (from datasets) (3.2.0) Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.8/dist-packages (from datasets) (1.21.6) Requirement already satisfied: multiprocess in /usr/local/lib/python3.8/dist-packages (from datasets) (0.70.14) Requirement already satisfied: huggingface-hub<1.0.0,>=0.2.0 in /usr/local/lib/python3.8/dist-packages (from datasets) (0.12.0) Requirement already satisfied: pandas in /usr/local/lib/python3.8/dist-packages (from datasets) (1.3.5) Requirement already satisfied: packaging in /usr/local/lib/python3.8/dist-packages (from datasets) (23.0) Requirement already satisfied: responses<0.19 in /usr/local/lib/python3.8/dist-packages (from datasets) (0.18.0) Requirement already satisfied: requests>=2.19.0 in /usr/local/lib/python3.8/dist-packages (from datasets) (2.25.1) Requirement already satisfied: aiohttp in /usr/local/lib/python3.8/dist-packages (from datasets) (3.8.3) Requirement already satisfied: pyarrow>=6.0.0 in /usr/local/lib/python3.8/dist-packages (from datasets) (9.0.0) Requirement already satisfied: dill<0.3.7 in /usr/local/lib/python3.8/dist-packages (from datasets) (0.3.6) Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.8/dist-packages (from datasets) (6.0) Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.8/dist-packages (from aiohttp->datasets) (1.3.1) Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.8/dist-packages (from aiohttp->datasets) (1.8.2) Requirement already satisfied: charset-normalizer<3.0,>=2.0 in /usr/local/lib/python3.8/dist-packages (from aiohttp->datasets) (2.1.1) Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.8/dist-packages (from aiohttp->datasets) (6.0.4) Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /usr/local/lib/python3.8/dist-packages (from aiohttp->datasets) (4.0.2) Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.8/dist-packages (from aiohttp->datasets) (22.2.0) Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.8/dist-packages (from aiohttp->datasets) (1.3.3) Requirement already satisfied: filelock in /usr/local/lib/python3.8/dist-packages (from huggingface-hub<1.0.0,>=0.2.0->datasets) (3.9.0) Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.8/dist-packages (from huggingface-hub<1.0.0,>=0.2.0->datasets) (4.4.0) Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.8/dist-packages (from requests>=2.19.0->datasets) (2022.12.7) Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.8/dist-packages (from requests>=2.19.0->datasets) (1.26.14) Requirement already satisfied: chardet<5,>=3.0.2 in /usr/local/lib/python3.8/dist-packages (from requests>=2.19.0->datasets) (4.0.0) Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.8/dist-packages (from requests>=2.19.0->datasets) (2.10) Requirement already satisfied: python-dateutil>=2.7.3 in /usr/local/lib/python3.8/dist-packages (from pandas->datasets) (2.8.2) Requirement already satisfied: pytz>=2017.3 in /usr/local/lib/python3.8/dist-packages (from pandas->datasets) (2022.7.1) Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.8/dist-packages (from python-dateutil>=2.7.3->pandas->datasets) (1.15.0) Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/ Requirement already satisfied: transformers in /usr/local/lib/python3.8/dist-packages (4.26.1) Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.8/dist-packages (from transformers) (6.0) Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.8/dist-packages (from transformers) (1.21.6) Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.8/dist-packages (from transformers) (2022.6.2) Requirement already satisfied: tokenizers!=0.11.3,<0.14,>=0.11.1 in /usr/local/lib/python3.8/dist-packages (from transformers) (0.13.2) Requirement already satisfied: huggingface-hub<1.0,>=0.11.0 in /usr/local/lib/python3.8/dist-packages (from transformers) (0.12.0) Requirement already satisfied: filelock in /usr/local/lib/python3.8/dist-packages (from transformers) (3.9.0) Requirement already satisfied: requests in /usr/local/lib/python3.8/dist-packages (from transformers) (2.25.1) Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.8/dist-packages (from transformers) (23.0) Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.8/dist-packages (from transformers) (4.64.1) Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.8/dist-packages (from huggingface-hub<1.0,>=0.11.0->transformers) (4.4.0) Requirement already satisfied: chardet<5,>=3.0.2 in /usr/local/lib/python3.8/dist-packages (from requests->transformers) (4.0.0) Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.8/dist-packages (from requests->transformers) (2022.12.7) Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.8/dist-packages (from requests->transformers) (1.26.14) Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.8/dist-packages (from requests->transformers) (2.10) Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/ Requirement already satisfied: scikit-learn in /usr/local/lib/python3.8/dist-packages (1.0.2) Requirement already satisfied: numpy>=1.14.6 in /usr/local/lib/python3.8/dist-packages (from scikit-learn) (1.21.6) Requirement already satisfied: scipy>=1.1.0 in /usr/local/lib/python3.8/dist-packages (from scikit-learn) (1.7.3) Requirement already satisfied: joblib>=0.11 in /usr/local/lib/python3.8/dist-packages (from scikit-learn) (1.2.0) Requirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.8/dist-packages (from scikit-learn) (3.1.0) Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/ Requirement already satisfied: evaluate in /usr/local/lib/python3.8/dist-packages (0.4.0) Requirement already satisfied: xxhash in /usr/local/lib/python3.8/dist-packages (from evaluate) (3.2.0) Requirement already satisfied: datasets>=2.0.0 in /usr/local/lib/python3.8/dist-packages (from evaluate) (2.9.0) Requirement already satisfied: fsspec[http]>=2021.05.0 in /usr/local/lib/python3.8/dist-packages (from evaluate) (2023.1.0) Requirement already satisfied: packaging in /usr/local/lib/python3.8/dist-packages (from evaluate) (23.0) Requirement already satisfied: tqdm>=4.62.1 in /usr/local/lib/python3.8/dist-packages (from evaluate) (4.64.1) Requirement already satisfied: responses<0.19 in /usr/local/lib/python3.8/dist-packages (from evaluate) (0.18.0) Requirement already satisfied: requests>=2.19.0 in /usr/local/lib/python3.8/dist-packages (from evaluate) (2.25.1) Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.8/dist-packages (from evaluate) (1.21.6) Requirement already satisfied: pandas in /usr/local/lib/python3.8/dist-packages (from evaluate) (1.3.5) Requirement already satisfied: huggingface-hub>=0.7.0 in /usr/local/lib/python3.8/dist-packages (from evaluate) (0.12.0) Requirement already satisfied: multiprocess in /usr/local/lib/python3.8/dist-packages (from evaluate) (0.70.14) Requirement already satisfied: dill in /usr/local/lib/python3.8/dist-packages (from evaluate) (0.3.6) Requirement already satisfied: pyarrow>=6.0.0 in /usr/local/lib/python3.8/dist-packages (from datasets>=2.0.0->evaluate) (9.0.0) Requirement already satisfied: aiohttp in /usr/local/lib/python3.8/dist-packages (from datasets>=2.0.0->evaluate) (3.8.3) Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.8/dist-packages (from datasets>=2.0.0->evaluate) (6.0) Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.8/dist-packages (from huggingface-hub>=0.7.0->evaluate) (4.4.0) Requirement already satisfied: filelock in /usr/local/lib/python3.8/dist-packages (from huggingface-hub>=0.7.0->evaluate) (3.9.0) Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.8/dist-packages (from requests>=2.19.0->evaluate) (2022.12.7) Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.8/dist-packages (from requests>=2.19.0->evaluate) (2.10) Requirement already satisfied: chardet<5,>=3.0.2 in /usr/local/lib/python3.8/dist-packages (from requests>=2.19.0->evaluate) (4.0.0) Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.8/dist-packages (from requests>=2.19.0->evaluate) (1.26.14) Requirement already satisfied: pytz>=2017.3 in /usr/local/lib/python3.8/dist-packages (from pandas->evaluate) (2022.7.1) Requirement already satisfied: python-dateutil>=2.7.3 in /usr/local/lib/python3.8/dist-packages (from pandas->evaluate) (2.8.2) Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.8/dist-packages (from aiohttp->datasets>=2.0.0->evaluate) (22.2.0) Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.8/dist-packages (from aiohttp->datasets>=2.0.0->evaluate) (1.3.3) Requirement already satisfied: charset-normalizer<3.0,>=2.0 in /usr/local/lib/python3.8/dist-packages (from aiohttp->datasets>=2.0.0->evaluate) (2.1.1) Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /usr/local/lib/python3.8/dist-packages (from aiohttp->datasets>=2.0.0->evaluate) (4.0.2) Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.8/dist-packages (from aiohttp->datasets>=2.0.0->evaluate) (1.3.1) Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.8/dist-packages (from aiohttp->datasets>=2.0.0->evaluate) (1.8.2) Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.8/dist-packages (from aiohttp->datasets>=2.0.0->evaluate) (6.0.4) Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.8/dist-packages (from python-dateutil>=2.7.3->pandas->evaluate) (1.15.0) Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/ Requirement already satisfied: accelerate in /usr/local/lib/python3.8/dist-packages (0.16.0) Requirement already satisfied: psutil in /usr/local/lib/python3.8/dist-packages (from accelerate) (5.4.8) Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.8/dist-packages (from accelerate) (1.21.6) Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.8/dist-packages (from accelerate) (23.0) Requirement already satisfied: pyyaml in /usr/local/lib/python3.8/dist-packages (from accelerate) (6.0) Requirement already satisfied: torch>=1.4.0 in /usr/local/lib/python3.8/dist-packages (from accelerate) (1.13.1+cu116) Requirement already satisfied: typing-extensions in /usr/local/lib/python3.8/dist-packages (from torch>=1.4.0->accelerate) (4.4.0) Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/ Requirement already satisfied: sentencepiece in /usr/local/lib/python3.8/dist-packages (0.1.97) Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/ Requirement already satisfied: protobuf in /usr/local/lib/python3.8/dist-packages (3.19.6) Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/ Requirement already satisfied: sacrebleu in /usr/local/lib/python3.8/dist-packages (2.3.1) Requirement already satisfied: lxml in /usr/local/lib/python3.8/dist-packages (from sacrebleu) (4.9.2) Requirement already satisfied: regex in /usr/local/lib/python3.8/dist-packages (from sacrebleu) (2022.6.2) Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.8/dist-packages (from sacrebleu) (1.21.6) Requirement already satisfied: colorama in /usr/local/lib/python3.8/dist-packages (from sacrebleu) (0.4.6) Requirement already satisfied: tabulate>=0.8.9 in /usr/local/lib/python3.8/dist-packages (from sacrebleu) (0.8.10) Requirement already satisfied: portalocker in /usr/local/lib/python3.8/dist-packages (from sacrebleu) (2.7.0) Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/ Requirement already satisfied: py7zr in /usr/local/lib/python3.8/dist-packages (0.20.4) Requirement already satisfied: texttable in /usr/local/lib/python3.8/dist-packages (from py7zr) (1.6.7) Requirement already satisfied: multivolumefile>=0.2.3 in /usr/local/lib/python3.8/dist-packages (from py7zr) (0.2.3) Requirement already satisfied: brotli>=1.0.9 in /usr/local/lib/python3.8/dist-packages (from py7zr) (1.0.9) Requirement already satisfied: pybcj>=0.6.0 in /usr/local/lib/python3.8/dist-packages (from py7zr) (1.0.1) Requirement already satisfied: pyzstd>=0.14.4 in /usr/local/lib/python3.8/dist-packages (from py7zr) (0.15.3) Requirement already satisfied: pyppmd<1.1.0,>=0.18.1 in /usr/local/lib/python3.8/dist-packages (from py7zr) (1.0.0) Requirement already satisfied: psutil in /usr/local/lib/python3.8/dist-packages (from py7zr) (5.4.8) Requirement already satisfied: pycryptodomex>=3.6.6 in /usr/local/lib/python3.8/dist-packages (from py7zr) (3.17) Requirement already satisfied: inflate64>=0.3.1 in /usr/local/lib/python3.8/dist-packages (from py7zr) (0.3.1)
Imports
import os
import json
import torch
from google.colab import drive
from pathlib import Path
from typing import Dict, List
from datasets import load_dataset
from transformers import T5Tokenizer
Loading data
loaded_data = load_dataset('emotion')
!mkdir -v -p data
train_path = Path('data/train.json')
valid_path = Path('data/valid.json')
test_path = Path('data/test.json')
data_train, data_valid, data_test = [], [], []
WARNING:datasets.builder:No config specified, defaulting to: emotion/split WARNING:datasets.builder:Found cached dataset emotion (/root/.cache/huggingface/datasets/emotion/split/1.0.0/cca5efe2dfeb58c1d098e0f9eeb200e9927d889b5a03c67097275dfb5fe463bd)
0%| | 0/3 [00:00<?, ?it/s]
for source_data, dataset, max_size in [
(loaded_data['train'], data_train, None),
(loaded_data['validation'], data_valid, None),
(loaded_data['test'], data_test, None),
]:
for i, data in enumerate(source_data):
if max_size is not None and i >= max_size:
break
data_line = {
'label': int(data['label']),
'text': data['text'],
}
dataset.append(data_line)
print(f'Train: {len(data_train):6d}')
print(f'Valid: {len(data_valid):6d}')
print(f'Test: {len(data_test):6d}')
Train: 16000 Valid: 2000 Test: 2000
MAP_LABEL_TRANSLATION = {
0: 'sadness',
1: 'joy',
2: 'love',
3: 'anger',
4: 'fear',
5: 'surprise',
}
def save_as_translations(original_save_path: Path, data_to_save: List[Dict]) -> None:
file_name = 's2s-' + original_save_path.name
file_path = original_save_path.parent / file_name
print(f'Saving into: {file_path}')
with open(file_path, 'wt') as f_write:
for data_line in data_to_save:
label = data_line['label']
new_label = MAP_LABEL_TRANSLATION[label]
data_line['label'] = new_label
data_line_str = json.dumps(data_line)
f_write.write(f'{data_line_str}\n')
for file_path, data_to_save in [(train_path, data_train), (valid_path, data_valid), (test_path, data_test)]:
print(f'Saving into: {file_path}')
with open(file_path, 'wt') as f_write:
for data_line in data_to_save:
data_line_str = json.dumps(data_line)
f_write.write(f'{data_line_str}\n')
save_as_translations(file_path, data_to_save)
Saving into: data/train.json Saving into: data/s2s-train.json Saving into: data/valid.json Saving into: data/s2s-valid.json Saving into: data/test.json Saving into: data/s2s-test.json
!head data/train.json
{"label": 0, "text": "i didnt feel humiliated"} {"label": 0, "text": "i can go from feeling so hopeless to so damned hopeful just from being around someone who cares and is awake"} {"label": 3, "text": "im grabbing a minute to post i feel greedy wrong"} {"label": 2, "text": "i am ever feeling nostalgic about the fireplace i will know that it is still on the property"} {"label": 3, "text": "i am feeling grouchy"} {"label": 0, "text": "ive been feeling a little burdened lately wasnt sure why that was"} {"label": 5, "text": "ive been taking or milligrams or times recommended amount and ive fallen asleep a lot faster but i also feel like so funny"} {"label": 4, "text": "i feel as confused about life as a teenager or as jaded as a year old man"} {"label": 1, "text": "i have been with petronas for years i feel that petronas has performed well and made a huge profit"} {"label": 2, "text": "i feel romantic too"}
!head data/s2s-train.json
{"label": "sadness", "text": "i didnt feel humiliated"} {"label": "sadness", "text": "i can go from feeling so hopeless to so damned hopeful just from being around someone who cares and is awake"} {"label": "anger", "text": "im grabbing a minute to post i feel greedy wrong"} {"label": "love", "text": "i am ever feeling nostalgic about the fireplace i will know that it is still on the property"} {"label": "anger", "text": "i am feeling grouchy"} {"label": "sadness", "text": "ive been feeling a little burdened lately wasnt sure why that was"} {"label": "surprise", "text": "ive been taking or milligrams or times recommended amount and ive fallen asleep a lot faster but i also feel like so funny"} {"label": "fear", "text": "i feel as confused about life as a teenager or as jaded as a year old man"} {"label": "joy", "text": "i have been with petronas for years i feel that petronas has performed well and made a huge profit"} {"label": "love", "text": "i feel romantic too"}
# create tiny datasets for debugging purposes
for file_name in ["s2s-train", "s2s-valid", "s2s-test"]:
print(f"=== {file_name} ===")
all_text = Path(f"data/{file_name}.json").read_text().split('\n')
text = all_text[:250] + all_text[-250:]
Path(f"data/{file_name}-500.json").write_text("\n".join(text))
=== s2s-train === === s2s-valid === === s2s-test ===
!wc -l data/*
499 data/s2s-test-500.json 2000 data/s2s-test.json 499 data/s2s-train-500.json 16000 data/s2s-train.json 499 data/s2s-valid-500.json 2000 data/s2s-valid.json 2000 data/test.json 16000 data/train.json 2000 data/valid.json 41497 total
Zero Shot
from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
import json
import time
!nvidia-smi
Mon Feb 13 23:18:24 2023 +-----------------------------------------------------------------------------+ | NVIDIA-SMI 510.47.03 Driver Version: 510.47.03 CUDA Version: 11.6 | |-------------------------------+----------------------+----------------------+ | GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC | | Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. | | | | MIG M. | |===============================+======================+======================| | 0 Tesla T4 Off | 00000000:00:04.0 Off | 0 | | N/A 71C P0 31W / 70W | 7320MiB / 15360MiB | 0% Default | | | | N/A | +-------------------------------+----------------------+----------------------+ +-----------------------------------------------------------------------------+ | Processes: | | GPU GI CI PID Type Process name GPU Memory | | ID ID Usage | |=============================================================================| | 0 N/A N/A 5402 C 7317MiB | +-----------------------------------------------------------------------------+
if torch.cuda.is_available():
device = 0
else:
device = -1
def get_pipeline(pipeline_type: str, model_name: str, torch_dtype: torch.dtype="auto"):
class_type = AutoModelForSeq2SeqLM
model = class_type.from_pretrained(model_name, low_cpu_mem_usage=True, torch_dtype=torch.float32)
tokenizer = AutoTokenizer.from_pretrained(model_name)
return pipeline(pipeline_type, model=model, tokenizer=tokenizer, device=device)
lm_pipeline = get_pipeline('text2text-generation', 'google/flan-t5-large')
def generate_prompt(text):
labels = "possible labels: sadness, joy, love, anger, surprise, fear"
prompt = labels + '\n' + f'text: {text}' + '\n' + 'label: '
return prompt
def predict(text):
return lm_pipeline(generate_prompt(text), do_sample=False)[0]['generated_text']
with open('data/s2s-test.json') as f:
time_start = time.time()
total = 0
correct = 0
lines = f.readlines()
test_cases_amount = len(lines)
for line in lines:
item = json.loads(line)
text = item['text']
label = item['label']
total += 1
if total % 50 == 0:
print(f'{total}/{test_cases_amount}')
if predict(text) == label:
correct += 1
time_end = time.time()
print(f'Minutes elapsed: {(time_end - time_start) / 60}')
print(f'Accuracy: {correct/total}')
50/2000 100/2000 150/2000 200/2000 250/2000 300/2000 350/2000 400/2000 450/2000 500/2000 550/2000 600/2000 650/2000 700/2000 750/2000 800/2000 850/2000 900/2000 950/2000 1000/2000 1050/2000 1100/2000 1150/2000 1200/2000 1250/2000 1300/2000 1350/2000 1400/2000 1450/2000 1500/2000 1550/2000 1600/2000 1650/2000 1700/2000 1750/2000 1800/2000 1850/2000 1900/2000 1950/2000 2000/2000 Minutes elapsed: 3.088933833440145 Accuracy: 0.6505