uczenie/flanT5 (1).ipynb

1 line
14 KiB
Plaintext

{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[],"gpuType":"T4","authorship_tag":"ABX9TyPpPr27PrVxUEX9wAZNh5OO"},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"},"accelerator":"GPU"},"cells":[{"cell_type":"code","execution_count":1,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"4iEk32rC8Ju2","executionInfo":{"status":"ok","timestamp":1705260831679,"user_tz":-60,"elapsed":5009,"user":{"displayName":"Marcin Rostkowski","userId":"16749256502154511679"}},"outputId":"58e868a4-2f00-4768-c552-eea7081303c9"},"outputs":[{"output_type":"stream","name":"stdout","text":["Requirement already satisfied: transformers in /usr/local/lib/python3.10/dist-packages (4.35.2)\n","Requirement already satisfied: datasets in /usr/local/lib/python3.10/dist-packages (2.16.1)\n","Requirement already satisfied: torch in /usr/local/lib/python3.10/dist-packages (2.1.0+cu121)\n","Requirement already satisfied: sentencepiece in /usr/local/lib/python3.10/dist-packages (0.1.99)\n","Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from transformers) (3.13.1)\n","Requirement already satisfied: huggingface-hub<1.0,>=0.16.4 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.20.2)\n","Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (1.23.5)\n","Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from transformers) (23.2)\n","Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (6.0.1)\n","Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (2023.6.3)\n","Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from transformers) (2.31.0)\n","Requirement already satisfied: tokenizers<0.19,>=0.14 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.15.0)\n","Requirement already satisfied: safetensors>=0.3.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.4.1)\n","Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.10/dist-packages (from transformers) (4.66.1)\n","Requirement already satisfied: pyarrow>=8.0.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (10.0.1)\n","Requirement already satisfied: pyarrow-hotfix in /usr/local/lib/python3.10/dist-packages (from datasets) (0.6)\n","Requirement already satisfied: dill<0.3.8,>=0.3.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (0.3.7)\n","Requirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from datasets) (1.5.3)\n","Requirement already satisfied: xxhash in /usr/local/lib/python3.10/dist-packages (from datasets) (3.4.1)\n","Requirement already satisfied: multiprocess in /usr/local/lib/python3.10/dist-packages (from datasets) (0.70.15)\n","Requirement already satisfied: fsspec[http]<=2023.10.0,>=2023.1.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (2023.6.0)\n","Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from datasets) (3.9.1)\n","Requirement already satisfied: typing-extensions in /usr/local/lib/python3.10/dist-packages (from torch) (4.5.0)\n","Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch) (1.12)\n","Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch) (3.2.1)\n","Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch) (3.1.2)\n","Requirement already satisfied: triton==2.1.0 in /usr/local/lib/python3.10/dist-packages (from torch) (2.1.0)\n","Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (23.2.0)\n","Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (6.0.4)\n","Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.9.4)\n","Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.4.1)\n","Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.3.1)\n","Requirement already satisfied: async-timeout<5.0,>=4.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (4.0.3)\n","Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (3.3.2)\n","Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (3.6)\n","Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2.0.7)\n","Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2023.11.17)\n","Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch) (2.1.3)\n","Requirement already satisfied: python-dateutil>=2.8.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2.8.2)\n","Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2023.3.post1)\n","Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch) (1.3.0)\n","Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.8.1->pandas->datasets) (1.16.0)\n"]}],"source":["!pip install transformers datasets torch sentencepiece\n"]},{"cell_type":"code","source":["from datasets import load_dataset\n","from transformers import AutoModelForSeq2SeqLM\n","from transformers import AutoTokenizer\n","from transformers import pipeline\n","import torch\n","from tqdm import tqdm\n","from sklearn.metrics import accuracy_score"],"metadata":{"id":"Ly0PUAQz9WBy","executionInfo":{"status":"ok","timestamp":1705260845948,"user_tz":-60,"elapsed":14272,"user":{"displayName":"Marcin Rostkowski","userId":"16749256502154511679"}}},"execution_count":2,"outputs":[]},{"cell_type":"code","source":["def load_and_process_dataset():\n"," dataset = load_dataset(\"sst2\")\n"," dataset.remove_columns('idx')\n"," del dataset['test']\n"," dataset['test'] = dataset['validation']\n"," del dataset['validation']\n"," split_dataset = dataset['train'].train_test_split(test_size=1600)\n"," dataset['train'] = split_dataset['train']\n"," dataset['validation'] = split_dataset['test']\n"," return dataset"],"metadata":{"id":"KGF4hcZi9cMG","executionInfo":{"status":"ok","timestamp":1705260845948,"user_tz":-60,"elapsed":14,"user":{"displayName":"Marcin Rostkowski","userId":"16749256502154511679"}}},"execution_count":3,"outputs":[]},{"cell_type":"code","source":["dataset = load_and_process_dataset()\n","dataset"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"mLIbkreD9fWo","executionInfo":{"status":"ok","timestamp":1705260848231,"user_tz":-60,"elapsed":2296,"user":{"displayName":"Marcin Rostkowski","userId":"16749256502154511679"}},"outputId":"9f8b76e7-d6bc-4fe8-b429-a6b59bb2f4f7"},"execution_count":4,"outputs":[{"output_type":"stream","name":"stderr","text":["/usr/local/lib/python3.10/dist-packages/huggingface_hub/utils/_token.py:88: UserWarning: \n","The secret `HF_TOKEN` does not exist in your Colab secrets.\n","To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.\n","You will be able to reuse this secret in all of your notebooks.\n","Please note that authentication is recommended but still optional to access public models or datasets.\n"," warnings.warn(\n"]},{"output_type":"execute_result","data":{"text/plain":["DatasetDict({\n"," train: Dataset({\n"," features: ['idx', 'sentence', 'label'],\n"," num_rows: 65749\n"," })\n"," test: Dataset({\n"," features: ['idx', 'sentence', 'label'],\n"," num_rows: 872\n"," })\n"," validation: Dataset({\n"," features: ['idx', 'sentence', 'label'],\n"," num_rows: 1600\n"," })\n","})"]},"metadata":{},"execution_count":4}]},{"cell_type":"code","source":["tokenizer = AutoTokenizer.from_pretrained('google/flan-t5-base')"],"metadata":{"id":"RxKhtSSS9jnO","executionInfo":{"status":"ok","timestamp":1705260848630,"user_tz":-60,"elapsed":403,"user":{"displayName":"Marcin Rostkowski","userId":"16749256502154511679"}}},"execution_count":5,"outputs":[]},{"cell_type":"code","source":["def transform_dataset(dataset):\n"," new_dataset = []\n"," for row in dataset['test']:\n"," text = row['sentence'].replace(\"\\n\", \"\")\n"," new_row = {'sentence': text, 'label': \"negative\" if row['label'] == 0 else \"positive\"}\n"," new_dataset.append(new_row)\n"," return new_dataset\n","\n","new_dataset = transform_dataset(dataset)\n","dataset_copy = new_dataset.copy()\n","\n","model = AutoModelForSeq2SeqLM.from_pretrained('google/flan-t5-base')\n","device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n","model.to(device)\n","\n","sentiment_classifier = pipeline(\"sentiment-analysis\")\n","\n","def create_predictions(test_data):\n"," predictions = []\n"," expected = []\n","\n"," for row in tqdm(test_data):\n"," input_text = row['sentence']\n"," result = sentiment_classifier(input_text)\n"," label = result[0]['label'].lower()\n","\n"," # Zamiana etykiet na format używany w zbiorze danych\n"," if label == 'positive':\n"," label = 'positive'\n"," else:\n"," label = 'negative'\n","\n"," predictions.append(label)\n"," expected.append(row['label'])\n","\n"," return predictions, expected\n"],"metadata":{"id":"_W-Lg-zs9m3v","executionInfo":{"status":"ok","timestamp":1705260857868,"user_tz":-60,"elapsed":9240,"user":{"displayName":"Marcin Rostkowski","userId":"16749256502154511679"}},"colab":{"base_uri":"https://localhost:8080/"},"outputId":"e9c641e8-7115-4cde-8b85-d5b116d5fee1"},"execution_count":6,"outputs":[{"output_type":"stream","name":"stderr","text":["No model was supplied, defaulted to distilbert-base-uncased-finetuned-sst-2-english and revision af0f99b (https://huggingface.co/distilbert-base-uncased-finetuned-sst-2-english).\n","Using a pipeline without specifying a model name and revision in production is not recommended.\n"]}]},{"cell_type":"code","source":["pred = create_predictions(dataset_copy)\n","predictions = pred[0]\n","expected = pred[1]\n"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"MtPHUYaj99wF","executionInfo":{"status":"ok","timestamp":1705260926708,"user_tz":-60,"elapsed":68843,"user":{"displayName":"Marcin Rostkowski","userId":"16749256502154511679"}},"outputId":"8c3ba337-ac8e-47bc-a412-6b176a38cc73"},"execution_count":7,"outputs":[{"output_type":"stream","name":"stderr","text":["100%|██████████| 872/872 [01:08<00:00, 12.65it/s]\n"]}]},{"cell_type":"code","source":["accuracy = accuracy_score(expected, [p if p in [\"positive\", \"negative\"] else \"negative\" for p in predictions])\n","print(\"Accuracy:\", accuracy)\n"],"metadata":{"id":"Bf-jXafeBxif","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1705260926709,"user_tz":-60,"elapsed":16,"user":{"displayName":"Marcin Rostkowski","userId":"16749256502154511679"}},"outputId":"56dd356a-17c6-4997-8725-90e28bbb980e"},"execution_count":8,"outputs":[{"output_type":"stream","name":"stdout","text":["Accuracy: 0.9105504587155964\n"]}]},{"cell_type":"code","source":["example_sentences = [\n"," \"This movie was an amazing journey.\",\n"," \"I really did not like the new web design.\",\n"," \"The team did a great job with this project.\",\n"," \"I am not happy with the service.\",\n"," \"This is the best book I have ever read!\",\n","]\n","\n","for sentence in example_sentences:\n"," result = sentiment_classifier(sentence)\n"," print(f\"Sentence: '{sentence}'\\nSentiment: {result[0]['label']}, Score: {result[0]['score']:.2f}\\n\")\n"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"nJHp-Fmml6_t","executionInfo":{"status":"ok","timestamp":1705263137607,"user_tz":-60,"elapsed":628,"user":{"displayName":"Marcin Rostkowski","userId":"16749256502154511679"}},"outputId":"0a0ad0c8-dae9-4c70-9832-af978ded1777"},"execution_count":14,"outputs":[{"output_type":"stream","name":"stdout","text":["Sentence: 'This movie was an amazing journey.'\n","Sentiment: POSITIVE, Score: 1.00\n","\n","Sentence: 'I really did not like the new web design.'\n","Sentiment: NEGATIVE, Score: 1.00\n","\n","Sentence: 'The team did a great job with this project.'\n","Sentiment: POSITIVE, Score: 1.00\n","\n","Sentence: 'I am not happy with the service.'\n","Sentiment: NEGATIVE, Score: 1.00\n","\n","Sentence: 'This is the best book I have ever read!'\n","Sentiment: POSITIVE, Score: 1.00\n","\n"]}]},{"cell_type":"code","source":["example_sentences = [\n","\"The cat is sitting on the mat.\",\n","\"There are clouds in the sky.\",\n","\"The book is on the table.\",\n","\"A car is parked outside.\",\n","\"The door is closed.\",\n","]\n","\n","for sentence in example_sentences:\n"," result = sentiment_classifier(sentence)\n"," print(f\"Sentence: '{sentence}'\\nSentiment: {result[0]['label']}, Score: {result[0]['score']:.2f}\\n\")"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"hCo4Kh6xmbyV","executionInfo":{"status":"ok","timestamp":1705263130268,"user_tz":-60,"elapsed":672,"user":{"displayName":"Marcin Rostkowski","userId":"16749256502154511679"}},"outputId":"a5d6f5f8-0b3b-4d74-e8ab-7b3e1a929047"},"execution_count":13,"outputs":[{"output_type":"stream","name":"stdout","text":["Sentence: 'The cat is sitting on the mat.'\n","Sentiment: NEGATIVE, Score: 0.98\n","\n","Sentence: 'There are clouds in the sky.'\n","Sentiment: POSITIVE, Score: 1.00\n","\n","Sentence: 'The book is on the table.'\n","Sentiment: POSITIVE, Score: 0.99\n","\n","Sentence: 'A car is parked outside.'\n","Sentiment: POSITIVE, Score: 0.90\n","\n","Sentence: 'The door is closed.'\n","Sentiment: NEGATIVE, Score: 0.98\n","\n"]}]}]}