uczenie_glebokie_projekt/flan-t5.ipynb

2 lines
52 KiB
Plaintext
Raw Normal View History

2023-02-13 20:13:05 +01:00
{"cells":[{"attachments":{},"cell_type":"markdown","metadata":{},"source":["# FLAN-T5"]},{"attachments":{},"cell_type":"markdown","metadata":{},"source":["## Colab link better looking output:\n","https://colab.research.google.com/drive/1bVujvgH49tyY83eqZoWYcDBZ2JurmLTL?usp=sharing"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"cF27ATTsqUKG"},"outputs":[],"source":["!pip install transformers datasets torch sentencepiece"]},{"cell_type":"code","execution_count":426,"metadata":{"executionInfo":{"elapsed":730,"status":"ok","timestamp":1676231947361,"user":{"displayName":"Szymon Jadczak","userId":"16229243011212847322"},"user_tz":-60},"id":"8DLGqnFgqcmQ"},"outputs":[],"source":["from datasets import load_dataset\n","from transformers import AutoModelForSeq2SeqLM\n","from transformers import AutoTokenizer\n","import torch\n","from tqdm import tqdm\n","from sklearn.metrics import accuracy_score"]},{"cell_type":"code","execution_count":427,"metadata":{"executionInfo":{"elapsed":2,"status":"ok","timestamp":1676231949164,"user":{"displayName":"Szymon Jadczak","userId":"16229243011212847322"},"user_tz":-60},"id":"LY2b5R1Wqs7S"},"outputs":[],"source":["def load_and_process_dataset():\n"," dataset = load_dataset(\"sst2\")\n"," dataset.remove_columns('idx')\n"," del dataset['test']\n"," dataset['test'] = dataset['validation']\n"," del dataset['validation']\n"," split_dataset = dataset['train'].train_test_split(test_size=1600)\n"," dataset['train'] = split_dataset['train']\n"," dataset['validation'] = split_dataset['test']\n"," return dataset"]},{"cell_type":"code","execution_count":428,"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":329,"referenced_widgets":["a0e441327eee4112ac7b74781c77690e","70fac8ce031c46589955c693c245d627","f44e3cd240a844f9813bc544e68369c6","a9c34e3bd64f4365bf84cca088a67ed3","d3b6970711bd4a91911e36c6b93b0149","b9f228b973bd46c19f87a609f79b376f","80c430dd861144fd84427c50d7afb13c","0882f50062b6429cbcedec4c9538f477","1ba724a4b509400c921c9efc1edb6090","8d90794748b2466b8e9c66c1db803821","f02311e734734bd69180e5e89212f41a"]},"executionInfo":{"elapsed":5033,"status":"ok","timestamp":1676231956945,"user":{"displayName":"Szymon Jadczak","userId":"16229243011212847322"},"user_tz":-60},"id":"mJK9hcjKqtai","outputId":"6c30f689-268f-43cf-b625-4e10cb12ce0c"},"outputs":[{"name":"stderr","output_type":"stream","text":["WARNING:datasets.builder:Found cached dataset sst2 (/root/.cache/huggingface/datasets/sst2/default/2.0.0/9896208a8d85db057ac50c72282bcb8fe755accc671a57dd8059d4e130961ed5)\n"]},{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"a0e441327eee4112ac7b74781c77690e","version_major":2,"version_minor":0},"text/plain":[" 0%| | 0/3 [00:00<?, ?it/s]"]},"metadata":{},"output_type":"display_data"},{"data":{"text/plain":["DatasetDict({\n"," train: Dataset({\n"," features: ['idx', 'sentence', 'label'],\n"," num_rows: 65749\n"," })\n"," test: Dataset({\n"," features: ['idx', 'sentence', 'label'],\n"," num_rows: 872\n"," })\n"," validation: Dataset({\n"," features: ['idx', 'sentence', 'label'],\n"," num_rows: 1600\n"," })\n","})"]},"execution_count":428,"metadata":{},"output_type":"execute_result"}],"source":["dataset = load_and_process_dataset()\n","dataset"]},{"cell_type":"code","execution_count":11,"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":229,"referenced_widgets":["a0e442f4fa7c4badb0fa1a46840ba751","86fe78ea29bc446a8557cc291ff6300d","502d330613514ca29f1206533282601f","8d568373d5ce4af58f063dc4f30f0b71","465397f6e45a422eaddb86a96af4afd1","449f077725564ebdb8e0096f9aabec99","9f7b2c98a57240c98f1eca099011ae8c","26f887c435a846e39b049cc07ae9f738","bb02703a2688473a8309e9dd4117400a","7a76968e08894d4f886346615873ed8d","8d1a9c60b5334d97be88d9154bef822f","dec176fa12fa4379b0e02e08d0d744f9","339e942966964acba7db87e098e01bb9","186fc01aeb5246f7a3e9da6cc07e1aaf","b53c2d65531f4c5e8a1f0c837856a2e3","64fad61fc73e4fdcb978d8cfcee74b99","c26bd724e80445ccb977ddd9a9139a50","7a0d3d716dc