diff --git a/Konopka_QA.ipynb b/Konopka_QA.ipynb new file mode 100644 index 0000000..7836824 --- /dev/null +++ b/Konopka_QA.ipynb @@ -0,0 +1,2332 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "source": [ + "# Importy i sprawdzenie GPU" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%% md\n" + } + } + }, + { + "cell_type": "code", + "execution_count": 1, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Available GPU True\n" + ] + } + ], + "source": [ + "import datasets\n", + "import torch\n", + "import matplotlib.pyplot as plt\n", + "import pandas as pd\n", + "from transformers import AutoModelForQuestionAnswering, AutoTokenizer, pipeline\n", + "import json\n", + "import torch\n", + "\n", + "is_gpu_available = torch.cuda.is_available()\n", + "print(\"Available GPU\", is_gpu_available)" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + } + }, + { + "cell_type": "markdown", + "source": [ + "# The dataset" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%% md\n" + } + } + }, + { + "cell_type": "markdown", + "source": [ + "Check SubjQA dataset domains" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%% md\n" + } + } + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "collapsed": true, + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [ + { + "data": { + "text/plain": "['books', 'electronics', 'grocery', 'movies', 'restaurants', 'tripadvisor']" + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "datasets.get_dataset_config_names(\"subjqa\")" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "SubjQA is a question answering dataset that focuses on subjective questions and answers.\n", + "The dataset consists of roughly 10,000 questions over reviews from 6 different domains: books, movies, grocery,\n", + "electronics, TripAdvisor (i.e. hotels), and restaurants. \n", + "\n", + "FEATURES {'domain': Value(dtype='string', id=None), 'nn_mod': Value(dtype='string', id=None), 'nn_asp': Value(dtype='string', id=None), 'query_mod': Value(dtype='string', id=None), 'query_asp': Value(dtype='string', id=None), 'q_reviews_id': Value(dtype='string', id=None), 'question_subj_level': Value(dtype='int64', id=None), 'ques_subj_score': Value(dtype='float32', id=None), 'is_ques_subjective': Value(dtype='bool', id=None), 'review_id': Value(dtype='string', id=None), 'id': Value(dtype='string', id=None), 'title': Value(dtype='string', id=None), 'context': Value(dtype='string', id=None), 'question': Value(dtype='string', id=None), 'answers': Sequence(feature={'text': Value(dtype='string', id=None), 'answer_start': Value(dtype='int32', id=None), 'answer_subj_level': Value(dtype='int64', id=None), 'ans_subj_score': Value(dtype='float32', id=None), 'is_ans_subjective': Value(dtype='bool', id=None)}, length=-1, id=None)} \n", + "\n", + "SPLITS {'train': SplitInfo(name='train', num_bytes=1574953, num_examples=1165, dataset_name='subjqa'), 'test': SplitInfo(name='test', num_bytes=689440, num_examples=512, dataset_name='subjqa'), 'validation': SplitInfo(name='validation', num_bytes=312577, num_examples=230, dataset_name='subjqa')} \n", + "\n" + ] + } + ], + "source": [ + "selected_config = 'tripadvisor'\n", + "ds_builder = datasets.load_dataset_builder(\"subjqa\", selected_config)\n", + "print(ds_builder.info.description, \"\\n\")\n", + "print(\"FEATURES\", ds_builder.info.features, \"\\n\")\n", + "print(\"SPLITS\", ds_builder.info.splits, \"\\n\")" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + } + }, + { + "cell_type": "code", + "execution_count": 4, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Reusing dataset subjqa (/home/karo/.cache/huggingface/datasets/subjqa/tripadvisor/1.1.0/e5588f9298ff2d70686a00cc377e4bdccf4e32287459e3c6baf2dc5ab57fe7fd)\n" + ] + }, + { + "data": { + "text/plain": " 0%| | 0/3 [00:00\n\n\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n
idtitlequestionanswers.textanswers.answer_startcontext
25440d78196c65a3bec3913e2291cd9b771fYl2TN9c23ZGLUBSD9ks5UwIs it present ?[so it's a bit of a splurge meal][177]Byblos has a really beautiful decor, that I wo...
88457b688d836d9502a076fe6a50a9cde81usa_san francisco_hotel_adagioWhere can I locate the hotel staff?[][]Stayed at the Adagio for 4 nights end of June-...
\n" + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "qa_cols = [\"id\", \"title\", \"question\", \"answers.text\",\n", + " \"answers.answer_start\", \"context\"]\n", + "sample_df = dfs[\"train\"][qa_cols].sample(2, random_state=7)\n", + "sample_df" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + } + }, + { + "cell_type": "markdown", + "source": [ + "Sprawdzenie generowania odpowiedzi z kontekstu" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%% md\n" + } + } + }, + { + "cell_type": "code", + "execution_count": 13, + "outputs": [ + { + "data": { + "text/plain": "\"so it's a bit of a splurge meal\"" + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "start_idx = sample_df[\"answers.answer_start\"].iloc[0][0]\n", + "end_idx = start_idx + len(sample_df[\"answers.text\"].iloc[0][0])\n", + "sample_df[\"context\"].iloc[0][start_idx:end_idx]" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + } + }, + { + "cell_type": "markdown", + "source": [ + "Usuń obiekty które nie posiadają odpowiedzi" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%% md\n" + } + } + }, + { + "cell_type": "code", + "execution_count": 16, + "outputs": [ + { + "data": { + "text/plain": " 0%| | 0/2 [00:00 0)\n", + "subjqa['validation'] = subjqa['validation'].filter(lambda example: len(example[\"answers\"]['text']) > 0)\n", + "dfs = {split: dset.to_pandas() for split, dset in subjqa.flatten().items()}\n", + "subjqa" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + } + }, + { + "cell_type": "code", + "execution_count": 15, + "outputs": [ + { + "data": { + "text/plain": " domain nn_mod nn_asp query_mod query_asp \\\n0 tripadvisor excellent hotel friendly hotel \n1 tripadvisor neat hotel cozy hotel \n2 tripadvisor excellent value for money good value for money \n3 tripadvisor convenient place safe hotel \n4 tripadvisor helpful staff helpfull staff \n\n q_reviews_id question_subj_level ques_subj_score \\\n0 b96b8478f5202ac9534eaf75167016f7 1 0.0 \n1 ca174c824baba906d30eed207350e37e 1 0.0 \n2 5d41bfd0e2166e14eeae1c1be9085555 3 0.0 \n3 a5880d95aa1161cfaac48584ee58934d 1 0.6 \n4 902b8e2a5c10f1796abdc830c4a4acd2 2 0.0 \n\n is_ques_subjective review_id \\\n0 False tripadvisor_review_1509 \n1 False tripadvisor_review_6133 \n2 False tripadvisor_review_4872 \n3 True tripadvisor_review_4715 \n4 False tripadvisor_review_1093 \n\n id \\\n0 d1c352b70d1225245569a0a1acbf5e04 \n1 b6fbf58ca273ad8f9b47a9be6a36e707 \n2 ea95bccdd762284ad7040be8d016da4f \n3 413df1095d03e4f967d349f0490a2514 \n4 d7fdc86b464f2a797dca33b058b68078 \n\n title \\\n0 usa_san francisco_argonaut_hotel_a_kimpton_hotel \n1 usa_san francisco_best_western_tuscan_inn_fish... \n2 usa_san francisco_castle_inn \n3 usa_san francisco_castle_inn \n4 usa_san francisco_chancellor_hotel_on_union_sq... \n\n context \\\n0 Great setting at the end of the wharf (so your... \n1 My wife and I took two trips to San Fran in 20... \n2 Yep, I have to agree with all those folks who ... \n3 On first sight the Castle hotel is not great, ... \n4 Stayed at the Chancellor recently for 3 nights... \n\n question \\\n0 How was the hotel? \n1 How is the hotel? \n2 Is it value for money? \n3 Does the hotel offer good service? \n4 How do you rate the staff? \n\n answers.text answers.answer_start \\\n0 [excellent hotels] [527] \n1 [The hotel location was great] [129] \n2 [And very reasonably priced. Overall, excellen... [501] \n3 [On first sight the Castle hotel is not great] [0] \n4 [Staff very helpful] [86] \n\n answers.answer_subj_level answers.ans_subj_score answers.is_ans_subjective \n0 [1] [1.0] [True] \n1 [1] [0.75] [True] \n2 [3] [0.58] [True] \n3 [1] [0.5416667] [True] \n4 [2] [0.3] [False] ", + "text/html": "
\n\n\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n
domainnn_modnn_aspquery_modquery_aspq_reviews_idquestion_subj_levelques_subj_scoreis_ques_subjectivereview_ididtitlecontextquestionanswers.textanswers.answer_startanswers.answer_subj_levelanswers.ans_subj_scoreanswers.is_ans_subjective
0tripadvisorexcellenthotelfriendlyhotelb96b8478f5202ac9534eaf75167016f710.0Falsetripadvisor_review_1509d1c352b70d1225245569a0a1acbf5e04usa_san francisco_argonaut_hotel_a_kimpton_hotelGreat setting at the end of the wharf (so your...How was the hotel?[excellent hotels][527][1][1.0][True]
1tripadvisorneathotelcozyhotelca174c824baba906d30eed207350e37e10.0Falsetripadvisor_review_6133b6fbf58ca273ad8f9b47a9be6a36e707usa_san francisco_best_western_tuscan_inn_fish...My wife and I took two trips to San Fran in 20...How is the hotel?[The hotel location was great][129][1][0.75][True]
2tripadvisorexcellentvalue for moneygoodvalue for money5d41bfd0e2166e14eeae1c1be908555530.0Falsetripadvisor_review_4872ea95bccdd762284ad7040be8d016da4fusa_san francisco_castle_innYep, I have to agree with all those folks who ...Is it value for money?[And very reasonably priced. Overall, excellen...[501][3][0.58][True]
3tripadvisorconvenientplacesafehotela5880d95aa1161cfaac48584ee58934d10.6Truetripadvisor_review_4715413df1095d03e4f967d349f0490a2514usa_san francisco_castle_innOn first sight the Castle hotel is not great, ...Does the hotel offer good service?[On first sight the Castle hotel is not great][0][1][0.5416667][True]
4tripadvisorhelpfulstaffhelpfullstaff902b8e2a5c10f1796abdc830c4a4acd220.0Falsetripadvisor_review_1093d7fdc86b464f2a797dca33b058b68078usa_san francisco_chancellor_hotel_on_union_sq...Stayed at the Chancellor recently for 3 nights...How do you rate the staff?[Staff very helpful][86][2][0.3][False]
\n
" + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "dfs['validation'].head()" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + } + }, + { + "cell_type": "markdown", + "source": [ + "Sprawdzenie typów pytań" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%% md\n" + } + } + }, + { + "cell_type": "code", + "execution_count": 17, + "outputs": [ + { + "data": { + "text/plain": "
", + "image/png": "\n" + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": "
", + "image/png": "\n" + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "def show_question_types_counts(dfs_split):\n", + " counts = {}\n", + " question_types = [\"What\", \"How\", \"Is\", \"Does\", \"Do\", \"Was\", \"Where\", \"Why\"]\n", + "\n", + " for q in question_types:\n", + " try:\n", + " counts[q] = dfs[dfs_split][\"question\"].str.startswith(q).value_counts()[True]\n", + " except:\n", + " counts[q] = 0\n", + " pd.Series(counts).sort_values().plot.barh()\n", + " plt.title(\"Frequency of Question Types \" + dfs_split)\n", + " plt.show()\n", + "\n", + "\n", + "show_question_types_counts(\"train\")\n", + "show_question_types_counts(\"validation\")" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + } + }, + { + "cell_type": "markdown", + "source": [ + "Pytania i ilości odpowiedzi" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%% md\n" + } + } + }, + { + "cell_type": "code", + "execution_count": 18, + "outputs": [ + { + "data": { + "text/plain": "
", + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAhYAAAGzCAYAAABzfl4TAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/NK7nSAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAk7ElEQVR4nO3deVxU9eL/8fcAMYg4ICKSCopmuaGlprmglhQhpaV5c82stIVS89rXJcs2g251q9tiZl0tQy3ttrlk3tyvSy65ZVmWJpnKzWJRcxT4/P7ox9wmXAA/DKCv5+Mxj0dzzpk5n5kPwctzzoDDGGMEAABggV95DwAAAJw7CAsAAGANYQEAAKwhLAAAgDWEBQAAsIawAAAA1hAWAADAGsICAABYQ1gAAABrCAvAsmXLlsnhcGju3LnlPZRiOXjwoG666SbVqFFDDodDzz//fHkPCWewZ88eORwOTZ8+vbyHAhRBWKBSmj59uhwOh4KCgrRv374i67t27armzZuXw8gqn/vvv1+LFi3SuHHjNGPGDF177bVnfExWVpaCgoLkcDj01Vdf+WCUlc/MmTOJNJyXCAtUam63W2lpaeU9jEptyZIl6tmzp0aPHq2BAweqcePGZ3zMnDlz5HA4FBUVpfT0dB+MsvIpy7CoV6+efvvtNw0aNKhMnh84G4QFKrVLL71UU6dO1U8//VTeQ/G5I0eOWHmezMxMhYWFlegxb7/9trp3765+/fpp5syZVsZR0dh6f4vj2LFjKigoKPb2hUfr/P39y3BUQOkQFqjUxo8fr/z8/DMetTjdOWmHw6FHHnnEc/+RRx6Rw+HQN998o4EDByo0NFQ1a9bUQw89JGOMMjIy1LNnT7lcLkVFRenZZ5896T7z8/M1fvx4RUVFqWrVqurRo4cyMjKKbLdu3Tpde+21Cg0NVXBwsLp06aL//Oc/XtsUjmnHjh3q37+/qlevrk6dOp32NX///ffq06ePwsPDFRwcrCuuuELz58/3rC88nWSM0csvvyyHwyGHw3Ha55SkvXv3auXKlerbt6/69u2r3bt3a/Xq1UW2KzwdtWPHDl155ZUKDg5WnTp19Le//a3Iti+++KKaNWum4OBgVa9eXW3atPEEy9atW+VwOPTRRx95tt+4caMcDodatWrl9TxJSUlq166d17KFCxcqPj5eVatWVbVq1ZScnKwvv/zSa5tbb71VISEh+u6779S9e3dVq1ZNAwYMkCR9++236t27t6KiohQUFKS6deuqb9++ys7OPuV71LVrV82fP18//PCD532tX7++pP9dgzN79mxNmDBBderUUXBwsHJycvTLL79o9OjRiouLU0hIiFwul5KSkrRlyxav5z/Z13Pha9i3b59uuOEGhYSEqGbNmho9erTy8/NPOVbAtoDyHgBwNmJjY3XLLbdo6tSpGjt2rGrXrm3tuW+++WY1adJEaWlpmj9/vp544gmFh4drypQpuuqqq/TUU08pPT1do0eP1uWXX67OnTt7PX7SpElyOBwaM2aMMjMz9fzzzyshIUGbN29WlSpVJP1+GiIpKUmtW7fWxIkT5efnp2nTpumqq67SypUr1bZtW6/n7NOnjxo1aqQnn3xSxphTjv3gwYPq0KGDjh49quHDh6tGjRp688031aNHD82dO1c33nijOnfurBkzZmjQoEG6+uqrdcsttxTrfZk1a5aqVq2q6667TlWqVFHDhg2Vnp6uDh06FNn2119/1bXXXqtevXrpL3/5i+bOnasxY8YoLi5OSUlJkqSpU6dq+PDhuummmzRixAgdO3ZMW7du1bp169S/f381b95cYWFhWrFihXr06CFJWrlypfz8/LRlyxbl5OTI5XKpoKBAq1ev1rBhwzz7nzFjhgYPHqzExEQ99dRTOnr0qCZPnqxOnTrpiy++8Pywl6S8vDwlJiaqU6dOeuaZZxQcHKzjx48rMTFRbrdb9913n6KiorRv3z7NmzdPWVlZCg0NPel79OCDDyo7O1s//vijnnvuOUlSSEiI1zaPP/64AgMDNXr0aLndbgUGBmrHjh364IMP1KdPH8XGxurgwYOaMmWKunTpoh07dpzx6zs/P1+JiYlq166dnnnmGf373//Ws88+q4YNG+ruu+8+8+QCNhigEpo2bZqRZNavX2++++47ExAQYIYPH+5Z36VLF9OsWTPP/d27dxtJZtq0aUWeS5KZOHGi5/7EiRONJDNs2DDPsry8PFO3bl3jcDhMWlqaZ/mvv/5qqlSpYgYPHuxZtnTpUiPJ1KlTx+Tk5HiWv/vuu0aSeeGFF4wxxhQUFJhGjRqZxMREU1BQ4Nnu6NGjJjY21lx99dVFxtSvX79ivT8jR440kszKlSs9y3Jzc01sbKypX7++yc/P93r9KSkpxXpeY4yJi4szAwYM8NwfP368iYiIMCdOnPDarkuXLkaSeeuttzzL3G63iYqKMr179/Ys69mzp9dcnUxycrJp27at536vXr1Mr169jL+/v1m4cKExxphNmzYZSebDDz/0vN6wsDAzdOhQr+c6cOCACQ0N9Vo+ePBgI8mMHTvWa9svvvjCSDJz5sw57fhONeZ69eoVWV749dGgQQNz9OhRr3XHjh3zmhtjfv/adTqd5rHHHvNa9uev58LX8MftjDHmsssuM61bty7x+IHS4lQIKr0GDRpo0KBBeu2117R//35rz3vHHXd4/tvf319t2rSRMUa33367Z3lYWJguueQSff/990Uef8stt6hatWqe+zfddJMuvPBCLViwQJK0efNmffvtt+rfv78OHTqkn3/+WT///LOOHDmibt26acWKFUXOu991113FGvuCBQvUtm1br9MlISEhGjZsmPbs2aMdO3YU7034k61bt2rbtm3q16+fZ1m/fv30888/a9GiRUW2DwkJ0cCBAz33AwMD1bZtW6/3KywsTD/++KPWr19/yv3Gx8dr06ZNnuseVq1ape7du+vSSy/VypUrJf1+FMPhcHhe8+LFi5WVleUZX+HN399f7dq109KlS4vs58//qi88IrFo0SIdPXr0jO9PSQwePNhz5KqQ0+mUn9/v35bz8/N16NAhhYSE6JJLLtGmTZuK9bx//hqJj48/6dcnUFYIC5wTJkyYoLy8PKufEImJifG6HxoaqqCgIEVERBRZ/uuvvxZ5fKNGjbzuOxwOXXTRRdqzZ4+k38/dS7//gKlZs6bX7fXXX5fb7S5yHj82NrZYY//hhx90ySWXFFnepEkTz/rSePvtt1W1alU1aNBAu3bt0q5duxQUFKT69euf9NMhdevWLXLdRvXq1b3erzFjxigkJERt27ZVo0aNlJKSUuQak/j4eOXl5WnNmjXauXOnMjMzFR8fr86dO3uFRdOmTRUeHi7pf+/vVVddVeT9/fTTT5WZmem1j4CAANWtW9drWWxsrEaNGqXXX39dERERSkxM1Msvv3za6yuK62RzWVBQoOeee06NGjWS0+lURESEatasqa1btxZrn0FBQapZs6bXsj+/30BZ4xoLnBMaNGiggQMH6rXXXtPYsWOLrD/VRYmnu6jtZFfcn+oqfHOa6x1OpfBoxNNPP61LL730pNv8+bz8n/+F60vGGM2aNUtHjhxR06ZNi6zPzMzU4cOHvcZcnPerSZMm2rlzp+bNm6dPPvlE7733nl555RU9/PDDevTRRyVJbdq0UVBQkFasWKGYmBhFRkbq4osvVnx8vF555RW53W6tXLlSN954o+d5C9/fGTNmKCoqqsgYAgK8v/398WjBHz377LO69dZb9eGHH+rTTz/V8OHDlZqaqrVr1xYJkZI42Vw++eSTeuihh3Tbbbfp8ccfV3h4uPz8/DRy5MhifWqET4mgIiAscM6YMGGC3n77bT311FNF1lWvXl3S77/Y6Y9K+y/34ij8F3MhY4x27dqlFi1aSJIaNmwoSXK5XEpISLC673r16mnnzp1Fln/99dee9SW1fPly/fjjj3rsscc8Rz4K/frrrxo2bJg++OADr1MfxVW1alXdfPPNuvnmm3X8+HH16tVLkyZN0rhx4xQUFOQ5hbJy5UrFxMQoPj5e0u9HMtxut9LT03Xw4EGvC2gL39/IyMizfn/j4uIUFxenCRMmaPXq1erYsaNeffVVPfHEE6d8THE+YfNnc+fO1ZVXXqk33njDa3lWVlaRI2VARcWpEJwzGjZsqIEDB2rKlCk6cOCA1zqXy6WIiAitWLHCa/krr7xSZuN56623lJub67k/d+5c7d+/3/NpiNatW6thw4Z65plndPjw4SKP/+9//1vqfXfv3l2ff/651qxZ41l25MgRvfbaa6pfv/5JjzicSeFpkAceeEA33XST123o0KFq1KhRqX5Z1qFDh7zuBwYGqmnTpjLG6MSJE57l8fHxWrdunZYuXeoJi4iICDVp0sQTk4XLJSkxMVEul0tPPvmk1/MUKs77m5OTo7y8PK9lcXFx8vPzk9vtPu1jq1atWuJTJv7+/kWOfs2ZM+ekv10WqKg4YoFzyoMPPqgZM2Zo586datasmde6O+64Q2lpabrjjjvUpk0brVixQt98802ZjSU8PFydOnXSkCFDdPDgQT3//PO66KKLNHToUEmSn5+fXn/9dSUlJalZs2YaMmSI6tSpo3379mnp0qVyuVz6+OOPS7XvsWPHatasWUpKStLw4cMVHh6uN998U7t379Z777130kP+p+N2u/Xee+/p6quvVlBQ0Em36dGjh1544QVlZmYqMjKy2M99zTXXKCoqSh07dlStWrX01Vdf6aWXXlJycrLXxa/x8fGaNGmSMjIyvAKic+fOmjJliurXr+91asLlcmny5MkaNGiQWrVqpb59+6pmzZrau3ev5s+fr44dO+qll1467diWLFmie++9V3369NHFF1+svLw8zZgxQ/7+/urdu/dpH9u6dWu98847GjVqlC6//HKFhITo+uuvP+1jrrvuOj322GMaMmSIOnTooG3btik9PV0NGjQ47eOAioSwwDnloosu0sCBA/Xmm28WWffwww/rv//9r+bOnat3331XSUlJWrhwYYl+CJbE+PHjtXXrVqWmpio3N1fdunXTK6+8ouDgYM82Xbt21Zo1a/T444/rpZde0uHDhxUVFaV27drpzjvvLPW+a9WqpdWrV2vMmDF68cUXdezYMbVo0UIff/yxkpOTS/x88+fPV1ZW1ml/MF5//fV69tlnNXv2bA0fPrzYz33nnXcqPT1df//733X48GHVrVtXw4cP14QJE7y269Chg/z9/RUcHKyWLVt6lsfHx2vKlClesVGof//+ql27ttLS0vT000/L7XarTp06io+P15AhQ844tpYtWyoxMVEff/yx9u3b59n3woULdcUVV5z2sffcc482b96sadOm6bnnnlO9evXOGBbjx4/XkSNHNHPmTL3zzjtq1aqV5s+ff9LrhoCKymFKc9UZAADASXCNBQAAsIawAAAA1hAWAADAGsICAABYQ1gAAABrCAsAAGCNz3+PRUFBgX766SdVq1atVL/yFgAA+J4xRrm5uapdu/Zpf8mez8Pip59+UnR0tK93CwAALMjIyDjtH+DzeVgU/orejIwMuVwuX+8eAACUQk5OjqKjo71+1f7J+DwsCk9/uFwuwgIAgErmTJcxcPEmAACwhrAAAADWEBYAAMAawgIAAFhDWAAAAGsICwAAYA1hAQAArCEsAACANYQFAACwhrAAAADWEBYAAMAawgIAAFhDWAAAAGsICwAAYA1hAQAArCEsAACANYQFAACwhrAAAADWEBYAAMAawgIAAFgTUF47bj5xkfycweW1ewAAzjl70pLLewgcsQAAAPYQFgAAwBrCAgAAWENYAAAAawgLAABgDWEBAACsISwAAIA1hAUAALCGsAAAANYQFgAAwBrCAgAAWENYAAAAawgLAABgDWEBAACsISwAAIA1hAUAALCGsAAAANYQFgAAwJoShUVqaqouv/xyVatWTZGRkbrhhhu0c+fOshobAACoZEoUFsuXL1dKSorWrl2rxYsX68SJE7rmmmt05MiRshofAACoRAJKsvEnn3zidX/69OmKjIzUxo0b1blzZ6sDAwAAlU+JwuLPsrOzJUnh4eGn3Mbtdsvtdnvu5+TknM0uAQBABVbqizcLCgo0cuRIdezYUc2bNz/ldqmpqQoNDfXcoqOjS7tLAABQwZU6LFJSUrR9+3bNnj37tNuNGzdO2dnZnltGRkZpdwkAACq4Up0KuffeezVv3jytWLFCdevWPe22TqdTTqezVIMDAACVS4nCwhij++67T++//76WLVum2NjYshoXAACohEoUFikpKZo5c6Y+/PBDVatWTQcOHJAkhYaGqkqVKmUyQAAAUHmU6BqLyZMnKzs7W127dtWFF17oub3zzjtlNT4AAFCJlPhUCAAAwKnwt0IAAIA1hAUAALCGsAAAANYQFgAAwBrCAgAAWENYAAAAawgLAABgDWEBAACsISwAAIA1hAUAALCGsAAAANYQFgAAwBrCAgAAWENYAAAAawgLAABgDWEBAACsCSivHW9/NFEul6u8dg8AAMoARywAAIA1hAUAALCGsAAAANYQFgAAwBrCAgAAWENYAAAAawgLAABgDWEBAACsISwAAIA1hAUAALCGsAAAANYQFgAAwBrCAgAAWENYAAAAawgLAABgDWEBAACsISwAAIA1hAUAALCGsAAAANYQFgAAwBrCAgAAWENYAAAAawgLAABgDWEBAACsISwAAIA1hAUAALCGsAAAANYQFgAAwBrCAgAAWENYAAAAawgLAABgDWEBAACsISwAAIA1hAUAALCGsAAAANYQFgAAwBrCAgAAWENYAAAAawgLAABgDWEBAACsISwAAIA1hAUAALCGsAAAANYQFgAAwBrCAgAAWBNQXjtuPnGR/JzB5bV7AICP7ElLLu8hwIc4YgEAAKwhLAAAgDWEBQAAsIawAAAA1hAWAADAGsICAABYQ1gAAABrCAsAAGANYQEAAKwhLAAAgDWEBQAAsIawAAAA1hAWAADAGsICAABYQ1gAAABrCAsAAGANYQEAAKwhLAAAgDUlDosVK1bo+uuvV+3ateVwOPTBBx+UwbAAAEBlVOKwOHLkiFq2bKmXX365LMYDAAAqsYCSPiApKUlJSUllMRYAAFDJlTgsSsrtdsvtdnvu5+TklPUuAQBAOSnzizdTU1MVGhrquUVHR5f1LgEAQDkp87AYN26csrOzPbeMjIyy3iUAACgnZX4qxOl0yul0lvVuAABABcDvsQAAANaU+IjF4cOHtWvXLs/93bt3a/PmzQoPD1dMTIzVwQEAgMqlxGGxYcMGXXnllZ77o0aNkiQNHjxY06dPtzYwAABQ+ZQ4LLp27SpjTFmMBQAAVHJcYwEAAKwhLAAAgDWEBQAAsIawAAAA1hAWAADAGsICAABYQ1gAAABrCAsAAGANYQEAAKwhLAAAgDWEBQAAsIawAAAA1hAWAADAGsICAABYQ1gAAABrCAsAAGBNQHntePujiXK5XOW1ewAAUAY4YgEAAKwhLAAAgDWEBQAAsIawAAAA1hAWAADAGsICAABYQ1gAAABrCAsAAGANYQEAAKwhLAAAgDWEBQAAsIawAAAA1hAWAADAGsICAABYQ1gAAABrCAsAAGANYQEAAKwhLAAAgDWEBQAAsIawAAAA1hAWAADAGsICAABYQ1gAAABrCAsAAGANYQEAAKwhLAAAgDWEBQAAsIawAAAA1hAWAADAGsICAABYQ1gAAABrCAsAAGANYQEAAKwhLAAAgDWEBQAAsIawAAAA1hAWAADAGsICAABYQ1gAAABrCAsAAGANYQEAAKwhLAAAgDWEBQAAsIawAAAA1hAWAADAGsICAABYU25h0XziovLaNQAAKCMcsQAAANYQFgAAwBrCAgAAWENYAAAAawgLAABgDWEBAACsISwAAIA1hAUAALCGsAAAANYQFgAAwBrCAgAAWENYAAAAawgLAABgDWEBAACsISwAAIA1hAUAALCGsAAAANYQFgAAwJqzCou0tDQ5HA6NHDnS0nAAAEBlVuqwWL9+vaZMmaIWLVrYHA8AAKjEShUWhw8f1oABAzR16lRVr17d9pgAAEAlVaqwSElJUXJyshISEs64rdvtVk5OjtcNAACcmwJK+oDZs2dr06ZNWr9+fbG2T01N1aOPPlrigQEAgMqnREcsMjIyNGLECKWnpysoKKhYjxk3bpyys7M9t4yMjFINFAAAVHwlOmKxceNGZWZmqlWrVp5l+fn5WrFihV566SW53W75+/t7PcbpdMrpdNoZLQAAqNBKFBbdunXTtm3bvJYNGTJEjRs31pgxY4pEBQAAOL+UKCyqVaum5s2bey2rWrWqatSoUWQ5AAA4//CbNwEAgDUl/lTIny1btszCMAAAwLmAIxYAAMAawgIAAFhDWAAAAGsICwAAYA1hAQAArCEsAACANYQFAACwhrAAAADWEBYAAMAawgIAAFhDWAAAAGsICwAAYA1hAQAArCEsAACANYQFAACwhrAAAADWlFtYbH80sbx2DQAAyghHLAAAgDWEBQAAsIawAAAA1hAWAADAGsICAABYQ1gAAABrCAsAAGANYQEAAKwhLAAAgDWEBQAAsIawAAAA1hAWAADAGsICAABYQ1gAAABrCAsAAGANYQEAAKwhLAAAgDWEBQAAsIawAAAA1hAWAADAGsICAABYQ1gAAABrCAsAAGANYQEAAKwhLAAAgDWEBQAAsIawAAAA1hAWAADAGsICAABYQ1gAAABrCAsAAGANYQEAAKwhLAAAgDWEBQAAsIawAAAA1hAWAADAGsICAABYQ1gAAABrCAsAAGANYQEAAKwhLAAAgDWEBQAAsIawAAAA1hAWAADAGsICAABYQ1gAAABrCAsAAGANYQEAAKwhLAAAgDWEBQAAsIawAAAA1hAWAADAGsICAABYQ1gAAABrCAsAAGANYQEAAKwhLAAAgDWEBQAAsIawAAAA1hAWAADAGsICAABYQ1gAAABrCAsAAGANYQEAAKwpUVhMnjxZLVq0kMvlksvlUvv27bVw4cKyGhsAAKhkShQWdevWVVpamjZu3KgNGzboqquuUs+ePfXll1+W1fgAAEAl4jDGmLN5gvDwcD399NO6/fbbi7V9Tk6OQkNDlZ2dLZfLdTa7BgAAPlLcn98Bpd1Bfn6+5syZoyNHjqh9+/an3M7tdsvtdnsNDAAAnJtKfPHmtm3bFBISIqfTqbvuukvvv/++mjZtesrtU1NTFRoa6rlFR0ef1YABAEDFVeJTIcePH9fevXuVnZ2tuXPn6vXXX9fy5ctPGRcnO2IRHR3NqRAAACqR4p4KOetrLBISEtSwYUNNmTLF6sAAAEDFUdyf32f9eywKCgq8jkgAAIDzV4ku3hw3bpySkpIUExOj3NxczZw5U8uWLdOiRYvKanwAAKASKVFYZGZm6pZbbtH+/fsVGhqqFi1aaNGiRbr66qvLanwAAKASKVFYvPHGG2U1DgAAcA7gb4UAAABrCAsAAGANYQEAAKwhLAAAgDWEBQAAsIawAAAA1hAWAADAGsICAABYQ1gAAABrCAsAAGANYQEAAKwhLAAAgDWEBQAAsIawAAAA1hAWAADAGsICAABYQ1gAAABrCAsAAGANYQEAAKwhLAAAgDWEBQAAsIawAAAA1hAWAADAGsICAABYQ1gAAABrCAsAAGANYQEAAKwhLAAAgDWEBQAAsIawAAAA1hAWAADAGsICAABYQ1gAAABrCAsAAGANYQEAAKwhLAAAgDWEBQAAsIawAAAA1gT4eofGGElSTk6Or3cNAABKqfDnduHP8VPxeVgcOnRIkhQdHe3rXQMAgLOUm5ur0NDQU673eViEh4dLkvbu3XvagcH3cnJyFB0drYyMDLlcrvIeDv4/5qXiYm4qLubGPmOMcnNzVbt27dNu5/Ow8PP7/bKO0NBQJruCcrlczE0FxLxUXMxNxcXc2FWcAwJcvAkAAKwhLAAAgDU+Dwun06mJEyfK6XT6etc4A+amYmJeKi7mpuJibsqPw5zpcyMAAADFxKkQAABgDWEBAACsISwAAIA1hAUAALCGsAAAANb4NCxefvll1a9fX0FBQWrXrp0+//xzX+7+vJOamqrLL79c1apVU2RkpG644Qbt3LnTa5tjx44pJSVFNWrUUEhIiHr37q2DBw96bbN3714lJycrODhYkZGReuCBB5SXl+fLl3LOS0tLk8Ph0MiRIz3LmJvys2/fPg0cOFA1atRQlSpVFBcXpw0bNnjWG2P08MMP68ILL1SVKlWUkJCgb7/91us5fvnlFw0YMEAul0thYWG6/fbbdfjwYV+/lHNKfn6+HnroIcXGxqpKlSpq2LChHn/8ca8/isXcVADGR2bPnm0CAwPNP//5T/Pll1+aoUOHmrCwMHPw4EFfDeG8k5iYaKZNm2a2b99uNm/ebLp3725iYmLM4cOHPdvcddddJjo62nz22Wdmw4YN5oorrjAdOnTwrM/LyzPNmzc3CQkJ5osvvjALFiwwERERZty4ceXxks5Jn3/+ualfv75p0aKFGTFihGc5c1M+fvnlF1OvXj1z6623mnXr1pnvv//eLFq0yOzatcuzTVpamgkNDTUffPCB2bJli+nRo4eJjY01v/32m2eba6+91rRs2dKsXbvWrFy50lx00UWmX79+5fGSzhmTJk0yNWrUMPPmzTO7d+82c+bMMSEhIeaFF17wbMPclD+fhUXbtm1NSkqK535+fr6pXbu2SU1N9dUQznuZmZlGklm+fLkxxpisrCxzwQUXmDlz5ni2+eqrr4wks2bNGmOMMQsWLDB+fn7mwIEDnm0mT55sXC6Xcbvdvn0B56Dc3FzTqFEjs3jxYtOlSxdPWDA35WfMmDGmU6dOp1xfUFBgoqKizNNPP+1ZlpWVZZxOp5k1a5YxxpgdO3YYSWb9+vWebRYuXGgcDofZt29f2Q3+HJecnGxuu+02r2W9evUyAwYMMMYwNxWFT06FHD9+XBs3blRCQoJnmZ+fnxISErRmzRpfDAGSsrOzJf3vL8xu3LhRJ06c8JqXxo0bKyYmxjMva9asUVxcnGrVquXZJjExUTk5Ofryyy99OPpzU0pKipKTk73mQGJuytNHH32kNm3aqE+fPoqMjNRll12mqVOnetbv3r1bBw4c8Jqb0NBQtWvXzmtuwsLC1KZNG882CQkJ8vPz07p163z3Ys4xHTp00GeffaZvvvlGkrRlyxatWrVKSUlJkpibisInf930559/Vn5+vtc3QEmqVauWvv76a18M4bxXUFCgkSNHqmPHjmrevLkk6cCBAwoMDFRYWJjXtrVq1dKBAwc825xs3grXofRmz56tTZs2af369UXWMTfl5/vvv9fkyZM1atQojR8/XuvXr9fw4cMVGBiowYMHe97bk733f5ybyMhIr/UBAQEKDw9nbs7C2LFjlZOTo8aNG8vf31/5+fmaNGmSBgwYIEnMTQXh8z+bjvKRkpKi7du3a9WqVeU9FEjKyMjQiBEjtHjxYgUFBZX3cPAHBQUFatOmjZ588klJ0mWXXabt27fr1Vdf1eDBg8t5dOe3d999V+np6Zo5c6aaNWumzZs3a+TIkapduzZzU4H45FRIRESE/P39i1zRfvDgQUVFRfliCOe1e++9V/PmzdPSpUtVt25dz/KoqCgdP35cWVlZXtv/cV6ioqJOOm+F61A6GzduVGZmplq1aqWAgAAFBARo+fLl+sc//qGAgADVqlWLuSknF154oZo2beq1rEmTJtq7d6+k/723p/t+FhUVpczMTK/1eXl5+uWXX5ibs/DAAw9o7Nix6tu3r+Li4jRo0CDdf//9Sk1NlcTcVBQ+CYvAwEC1bt1an332mWdZQUGBPvvsM7Vv394XQzgvGWN077336v3339eSJUsUGxvrtb5169a64IILvOZl586d2rt3r2de2rdvr23btnn9j7h48WK5XK4i33xRfN26ddO2bdu0efNmz61NmzYaMGCA57+Zm/LRsWPHIh/L/uabb1SvXj1JUmxsrKKiorzmJicnR+vWrfOam6ysLG3cuNGzzZIlS1RQUKB27dr54FWcm44ePSo/P+8fW/7+/iooKJDE3FQYvrpKdPbs2cbpdJrp06ebHTt2mGHDhpmwsDCvK9ph1913321CQ0PNsmXLzP79+z23o0ePera56667TExMjFmyZInZsGGDad++vWnfvr1nfeFHGq+55hqzefNm88knn5iaNWvykcYy8MdPhRjD3JSXzz//3AQEBJhJkyaZb7/91qSnp5vg4GDz9ttve7ZJS0szYWFh5sMPPzRbt241PXv2POlHGi+77DKzbt06s2rVKtOoUSM+0niWBg8ebOrUqeP5uOm//vUvExERYf7v//7Psw1zU/58FhbGGPPiiy+amJgYExgYaNq2bWvWrl3ry92fdySd9DZt2jTPNr/99pu55557TPXq1U1wcLC58cYbzf79+72eZ8+ePSYpKclUqVLFREREmL/+9a/mxIkTPn41574/hwVzU34+/vhj07x5c+N0Ok3jxo3Na6+95rW+oKDAPPTQQ6ZWrVrG6XSabt26mZ07d3ptc+jQIdOvXz8TEhJiXC6XGTJkiMnNzfXlyzjn5OTkmBEjRpiYmBgTFBRkGjRoYB588EGvj1czN+XPYcwffmUZAADAWeBvhQAAAGsICwAAYA1hAQAArCEsAACANYQFAACwhrAAAADWEBYAAMAawgIAAFhDWAAAAGsICwAAYA1hAQAArPl/uYiapeA9OlwAAAAASUVORK5CYII=\n" + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": "
", + "image/png": "\n" + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "def show_answers_counts(dfs_split):\n", + " counts = dfs[dfs_split][\"answers.text\"] \\\n", + " .groupby(dfs[dfs_split][\"answers.text\"].str.len()) \\\n", + " .count().to_dict()\n", + "\n", + " pd.Series(counts).sort_values().plot.barh()\n", + " plt.title(\"Number of Answers \" + dfs_split)\n", + " plt.show()\n", + "\n", + "\n", + "show_answers_counts(\"train\")\n", + "show_answers_counts(\"validation\")\n" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + } + }, + { + "cell_type": "markdown", + "source": [ + "Przykładowe pytania" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%% md\n" + } + } + }, + { + "cell_type": "code", + "execution_count": 19, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "How was the service staff?\n", + "How is the stay?\n", + "How tastefully decorated was the room?\n", + "What is the customer service?\n", + "What do you think about dinner?\n", + "What is the most expensive price of food?\n", + "Is it service ?\n", + "Is it location ?\n", + "Is it a good place to stay?\n" + ] + } + ], + "source": [ + "for question_type in [\"How\", \"What\", \"Is\"]:\n", + " for question in (\n", + " dfs[\"train\"][dfs[\"train\"].question.str.startswith(question_type)]\n", + " .sample(n=3, random_state=42)['question']):\n", + " print(question)" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + } + }, + { + "cell_type": "markdown", + "source": [ + "# Preprocessing" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%% md\n" + } + } + }, + { + "cell_type": "markdown", + "source": [ + "## Tokenizacja" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%% md\n" + } + } + }, + { + "cell_type": "code", + "execution_count": 20, + "outputs": [ + { + "data": { + "text/plain": "True" + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model_ckpt = \"bert-base-cased\"\n", + "tokenizer = AutoTokenizer.from_pretrained(model_ckpt)\n", + "tokenizer.is_fast" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + } + }, + { + "cell_type": "code", + "execution_count": 21, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Token indices sequence length is longer than the specified maximum sequence length for this model (556 > 512). Running this sequence through the model will result in indexing errors\n" + ] + }, + { + "data": { + "text/plain": "
", + "image/png": "\n" + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "def compute_input_length(row):\n", + " inputs = tokenizer(row[\"question\"], row[\"context\"])\n", + " inputs_length = len(inputs[\"input_ids\"])\n", + " del inputs\n", + " return inputs_length\n", + "\n", + "\n", + "dfs[\"train\"][\"n_tokens\"] = dfs[\"train\"].apply(compute_input_length, axis=1)\n", + "\n", + "fig, ax = plt.subplots()\n", + "dfs[\"train\"][\"n_tokens\"].hist(bins=100, grid=False, ec=\"C0\", ax=ax)\n", + "plt.xlabel(\"Number of tokens in question-context pair\")\n", + "ax.axvline(x=384, ymin=0, ymax=1, linestyle=\"--\", color=\"C1\",\n", + " label=\"Maximum sequence length\")\n", + "plt.legend()\n", + "plt.ylabel(\"Count\")\n", + "plt.show()" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + } + }, + { + "cell_type": "markdown", + "source": [ + "Użycie stride w celu pozyskania okna z kontekstem i okrojenia wielkości tesktu, prezentacja działania" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%% md\n" + } + } + }, + { + "cell_type": "code", + "execution_count": 22, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[CLS] How many employees does the company have? [SEP] We had a great two night stay in Dec 2008. Hotel staff were very accomodating with our 18 month old - providing a cot, toy and stroller for him. The staff were friendly and upbeat. Our room was also upgraded. Our only dissappointment was the food at the downstairs restaurant - way too expensive and did not meet expectations. [SEP]\n", + "[CLS] Was the atmosphere of the tourist areas of san fransisco peaceful? [SEP] My wife and I stayed here on our honeymoon for 2 nights in early March and had a great stay. I had booked a room at another hotel on hotels. com on Jan. 6th though. Then 4 days before we were to start our honeymoon in San Fran., hotels. com send me an e - mail to say they had to change my reservation from a hotel downtown to a hotel at the [SEP]\n", + "[CLS] Was the atmosphere of the tourist areas of san fransisco peaceful? [SEP]. com on Jan. 6th though. Then 4 days before we were to start our honeymoon in San Fran., hotels. com send me an e - mail to say they had to change my reservation from a hotel downtown to a hotel at the airport. I rufused and was given a refund and no hotel. There was a huge 30, 000 person convention in town at the same [SEP]\n", + "[CLS] Was the atmosphere of the tourist areas of san fransisco peaceful? [SEP] - mail to say they had to change my reservation from a hotel downtown to a hotel at the airport. I rufused and was given a refund and no hotel. There was a huge 30, 000 person convention in town at the same time so all the hotels were full. I got lucky though and was able to book a bay view room at The Argonaut. This was the best [SEP]\n", + "[CLS] Was the atmosphere of the tourist areas of san fransisco peaceful? [SEP]nd and no hotel. There was a huge 30, 000 person convention in town at the same time so all the hotels were full. I got lucky though and was able to book a bay view room at The Argonaut. This was the best thing that could of happened to us! This place was amazing! We arrived at the hotel around 12 : 30pm and were able to check in right away [SEP]\n", + "[CLS] Was the atmosphere of the tourist areas of san fransisco peaceful? [SEP] and was able to book a bay view room at The Argonaut. This was the best thing that could of happened to us! This place was amazing! We arrived at the hotel around 12 : 30pm and were able to check in right away! The staff were all so nice and called us by name whenever they saw us. The room was amazing! We had a beautiful view of Alcatraz [SEP]\n", + "[CLS] Was the atmosphere of the tourist areas of san fransisco peaceful? [SEP]! We arrived at the hotel around 12 : 30pm and were able to check in right away! The staff were all so nice and called us by name whenever they saw us. The room was amazing! We had a beautiful view of Alcatraz Island and could even see the golden gate. Internet worked great and the TV was fine. It's not a big fancy TV, but we didn'[SEP]\n", + "[CLS] Was the atmosphere of the tourist areas of san fransisco peaceful? [SEP] whenever they saw us. The room was amazing! We had a beautiful view of Alcatraz Island and could even see the golden gate. Internet worked great and the TV was fine. It's not a big fancy TV, but we didn't fly all the way to San Fran. to watch TV anyway. What made our stay so great was the concierge. There was a wonderful [SEP]\n", + "[CLS] Was the atmosphere of the tourist areas of san fransisco peaceful? [SEP] and the TV was fine. It's not a big fancy TV, but we didn't fly all the way to San Fran. to watch TV anyway. What made our stay so great was the concierge. There was a wonderful young lady there that recommended 2 wonderful resturants for dinners and an amazing dim sum resturant for lunch. We enjoyed San Fran. way [SEP]\n", + "[CLS] Was the atmosphere of the tourist areas of san fransisco peaceful? [SEP] TV anyway. What made our stay so great was the concierge. There was a wonderful young lady there that recommended 2 wonderful resturants for dinners and an amazing dim sum resturant for lunch. We enjoyed San Fran. way more then we ever thought we would and we owe that to The Argonaut Hotel and there staff. We can't wait to come visit again. [SEP]\n", + "[CLS] How is the employee service on this hotel? [SEP] Spent one night at The Argonaut and wish that we could have stayed longer. We even missed the free wine hour but what the heck, the staff was so pleasant and not in a superficial,'I don't mean it'kind of way. Valet was on top of their game and was very helpful with directions or anything we might need. The roofm was super clean and I was looking for dirt in the areas [SEP]\n", + "[CLS] How is the employee service on this hotel? [SEP] a superficial,'I don't mean it'kind of way. Valet was on top of their game and was very helpful with directions or anything we might need. The roofm was super clean and I was looking for dirt in the areas that sometimes don't get the best housekeeping job. I couldn't find as much as a crumb! We were on the fourth floor in a two queen bedded room [SEP]\n", + "[CLS] How is the employee service on this hotel? [SEP] was super clean and I was looking for dirt in the areas that sometimes don't get the best housekeeping job. I couldn't find as much as a crumb! We were on the fourth floor in a two queen bedded room as we were traveling with our two children. We were upgraded to a Cannery / Alcatraz view because of our membership in Kimpton's In Touch program but it was really just [SEP]\n", + "[CLS] How is the employee service on this hotel? [SEP] were on the fourth floor in a two queen bedded room as we were traveling with our two children. We were upgraded to a Cannery / Alcatraz view because of our membership in Kimpton's In Touch program but it was really just a view of the Mexican restaurant below unless you really, really cranked your head. Still nice to receive. We also had a note welcoming us and a big bottle of water but didn [SEP]\n", + "[CLS] How is the employee service on this hotel? [SEP] Kimpton's In Touch program but it was really just a view of the Mexican restaurant below unless you really, really cranked your head. Still nice to receive. We also had a note welcoming us and a big bottle of water but didn't drink it because it didn't say it was complementary. Overall a great experience and I wouldn't hesitate to return whenever I am in the San Francisco area. Great decor [SEP]\n", + "[CLS] How is the employee service on this hotel? [SEP] a note welcoming us and a big bottle of water but didn't drink it because it didn't say it was complementary. Overall a great experience and I wouldn't hesitate to return whenever I am in the San Francisco area. Great decor. This was our first stay in a Kimtpon property but I am going to seek them out from now on. [SEP]\n", + "[CLS] How is attraction? [SEP] This is a great hotel! I loved the nautical decoration! The room was immaculate, very comfortable and we didn't have any issues with noise. The location is great, especially for a quick stay as you are close to the tourist attractions. There are countless kiosks in the area for different tours and attractions. We took the cable car tour around San Fran and over the Golden Gate Bridge. The ferry to Alcatraz is a 10 - [SEP]\n", + "[CLS] How is attraction? [SEP] you are close to the tourist attractions. There are countless kiosks in the area for different tours and attractions. We took the cable car tour around San Fran and over the Golden Gate Bridge. The ferry to Alcatraz is a 10 - 15 minute walk away from the hotel. There are loads of restaurants nearby. The restaurant attached to the hotel was great for breakfast. I would definitely recommend this hotel. [SEP]\n" + ] + } + ], + "source": [ + "context = subjqa[\"train\"][2:6][\"context\"]\n", + "question = subjqa[\"train\"][2:6][\"question\"]\n", + "\n", + "inputs = tokenizer(\n", + " question,\n", + " context,\n", + " max_length=100,\n", + " truncation=\"only_second\",\n", + " stride=50,\n", + " return_overflowing_tokens=True,\n", + " return_offsets_mapping=True,\n", + ")\n", + "for ids in inputs[\"input_ids\"]:\n", + " print(tokenizer.decode(ids))" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + } + }, + { + "cell_type": "code", + "execution_count": 23, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The 4 examples gave 18 features.\n", + "Here is where each comes from: [0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 3, 3].\n" + ] + } + ], + "source": [ + "print(f\"The 4 examples gave {len(inputs['input_ids'])} features.\")\n", + "print(f\"Here is where each comes from: {inputs['overflow_to_sample_mapping']}.\")" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + } + }, + { + "cell_type": "code", + "execution_count": 24, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[31, 0, 0, 0, 0, 0, 0, 0, 0, 61, 41, 0, 0, 0, 0, 0, 70, 27]\n", + "[31, 0, 0, 0, 0, 0, 0, 0, 0, 72, 45, 0, 0, 0, 0, 0, 70, 27]\n" + ] + } + ], + "source": [ + "answers = subjqa[\"train\"][2:6][\"answers\"]\n", + "start_positions = []\n", + "end_positions = []\n", + "\n", + "for i, offset in enumerate(inputs[\"offset_mapping\"]):\n", + " sample_idx = inputs[\"overflow_to_sample_mapping\"][i]\n", + " sequence_ids = inputs.sequence_ids(i)\n", + " answer = answers[sample_idx]\n", + "\n", + " # Find start and end of context\n", + " context_start = sequence_ids.index(1)\n", + " context_end = len(sequence_ids) - 2\n", + " start_char = 0\n", + " end_char = 0\n", + " if len(answer['answer_start']) > 0:\n", + " for idx, answer_start in enumerate(answer['answer_start']):\n", + " tmp_start_char = answer_start\n", + " tmp_end_char = answer_start + len(answer['text'][idx])\n", + " # Answer inside context\n", + " if offset[context_start][0] <= tmp_start_char and offset[context_end][1] >= tmp_end_char:\n", + " idx_c = context_start\n", + " while idx_c <= context_end and offset[idx_c][0] <= tmp_start_char:\n", + " idx_c += 1\n", + " start_char = idx_c - 1\n", + "\n", + " idx_c = context_end\n", + " while idx_c >= context_start and offset[idx_c][1] >= tmp_end_char:\n", + " idx_c -= 1\n", + " end_char = idx_c + 1\n", + " break\n", + "\n", + " start_positions.append(start_char)\n", + " end_positions.append(end_char)\n", + "\n", + "print(start_positions)\n", + "print(end_positions)" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + } + }, + { + "cell_type": "code", + "execution_count": 25, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Theoretical answer: 18, labels give: 18\n" + ] + } + ], + "source": [ + "idx = 0\n", + "sample_idx = inputs[\"overflow_to_sample_mapping\"][idx]\n", + "answer = answers[sample_idx][\"text\"][0]\n", + "\n", + "start = start_positions[idx]\n", + "end = end_positions[idx]\n", + "labeled_answer = tokenizer.decode(inputs[\"input_ids\"][idx][start: end + 1])\n", + "\n", + "print(f\"Theoretical answer: {answer}, labels give: {labeled_answer}\")\n", + "del inputs\n" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + } + }, + { + "cell_type": "markdown", + "source": [ + "Przygotowanie do stride" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%% md\n" + } + } + }, + { + "cell_type": "code", + "execution_count": 47, + "outputs": [], + "source": [ + "\n", + "def preprocess_train_data(examples, max_length=384, stride=128):\n", + " questions = [q.strip() for q in examples[\"question\"]]\n", + " print(tokenizer.__class__.__name__)\n", + " inputs = tokenizer(\n", + " questions,\n", + " examples[\"context\"],\n", + " max_length=max_length,\n", + " truncation=\"only_second\",\n", + " stride=stride,\n", + " return_overflowing_tokens=True,\n", + " return_offsets_mapping=True,\n", + " padding=\"max_length\",\n", + " )\n", + " answers = examples[\"answers\"]\n", + " start_positions = []\n", + " end_positions = []\n", + "\n", + " for i, offset in enumerate(inputs[\"offset_mapping\"]):\n", + " sample_idx = inputs[\"overflow_to_sample_mapping\"][i]\n", + " sequence_ids = inputs.sequence_ids(i)\n", + " answer = answers[sample_idx]\n", + "\n", + " # Find start and end of context\n", + " context_start = sequence_ids.index(1)\n", + " context_end = len(sequence_ids) - 2\n", + " start_char = 0\n", + " end_char = 0\n", + " if len(answer['answer_start']) > 0:\n", + " for idx, answer_start in enumerate(answer['answer_start']):\n", + " tmp_start_char = answer_start\n", + " tmp_end_char = answer_start + len(answer['text'][idx])\n", + " # Answer inside context\n", + " if offset[context_start][0] <= tmp_start_char and offset[context_end][1] >= tmp_end_char:\n", + " idx_c = context_start\n", + " while idx_c <= context_end and offset[idx_c][0] <= tmp_start_char:\n", + " idx_c += 1\n", + " start_char = idx_c - 1\n", + "\n", + " idx_c = context_end\n", + " while idx_c >= context_start and offset[idx_c][1] >= tmp_end_char:\n", + " idx_c -= 1\n", + " end_char = idx_c + 1\n", + " break\n", + "\n", + " start_positions.append(start_char)\n", + " end_positions.append(end_char)\n", + " inputs[\"start_positions\"] = start_positions\n", + " inputs[\"end_positions\"] = end_positions\n", + " return inputs" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + } + }, + { + "cell_type": "code", + "execution_count": 27, + "outputs": [ + { + "data": { + "text/plain": " 0%| | 0/2 [00:00 max_answer_length\n", + " if (\n", + " end_index < start_index\n", + " or end_index - start_index + 1 > max_answer_length\n", + " ):\n", + " continue\n", + "\n", + " answer = {\n", + " \"text\": context[offsets[start_index][0]: offsets[end_index][1]],\n", + " \"logit_score\": start_logit[start_index] + end_logit[end_index],\n", + " }\n", + " answers.append(answer)\n", + "\n", + " # Select the answer with the best score\n", + " if len(answers) > 0:\n", + " best_answer = max(answers, key=lambda x: x[\"logit_score\"])\n", + " predicted_answers.append(\n", + " {\"id\": example_id, \"prediction_text\": best_answer[\"text\"]}\n", + " )\n", + " else:\n", + " predicted_answers.append({\"id\": example_id, \"prediction_text\": \"\"})\n", + "\n", + " theoretical_answers = [\n", + " {\"id\": ex[\"id\"], \"answers\": {\n", + " 'text': ex[\"answers\"]['text']\n", + " if len(ex[\"answers\"]['text']) != 0\n", + " else [\"\"],\n", + " 'answer_start': ex[\"answers\"][\"answer_start\"]\n", + " if len(ex[\"answers\"][\"answer_start\"]) != 0\n", + " else [0]\n", + " }} for ex in examples\n", + " ]\n", + " metrics = metric.compute(predictions=predicted_answers, references=theoretical_answers)\n", + " for i in range(3):\n", + " print(\"QUESTION:\\t\", examples[i]['question'])\n", + " print(\"PREDICTED:\", predicted_answers[i]['prediction_text'])\n", + " print(\"ACTUAL:\", theoretical_answers[i]['answers']['text'])\n", + " print(metrics)\n", + "\n", + " return predicted_answers, theoretical_answers, metrics" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + } + }, + { + "cell_type": "code", + "execution_count": 50, + "outputs": [], + "source": [ + "import numpy as np\n", + "\n", + "def predict_from_trained():\n", + " eval_set = subjqa[\"validation\"].map(\n", + " preprocess_validation_examples,\n", + " batched=True,\n", + " remove_columns=subjqa[\"validation\"].column_names,\n", + " )\n", + " eval_set_for_model = eval_set.remove_columns([\"example_id\", \"offset_mapping\"])\n", + " eval_set_for_model.set_format(\"torch\")\n", + " batch_size = 8\n", + " all_start_logits = []\n", + " all_end_logits = []\n", + " for i in range(0, int(eval_set_for_model.num_rows / batch_size.__ceil__())+1):\n", + " batch = {k: eval_set_for_model[k][batch_size*i:batch_size*(i+1)].to(device) for k in eval_set_for_model.column_names}\n", + " with torch.no_grad():\n", + " outputs = trained_model(**batch)\n", + " all_start_logits.append(outputs.start_logits.cpu().numpy())\n", + " all_end_logits.append(outputs.end_logits.cpu().numpy())\n", + " start_logits = np.concatenate(all_start_logits, axis=0)\n", + " end_logits = np.concatenate(all_end_logits, axis=0)\n", + " _=compute_metrics(start_logits, end_logits, eval_set, subjqa[\"validation\"])\n" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + } + }, + { + "cell_type": "markdown", + "source": [ + "### distilbert-base-cased-distilled-squad trained" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%% md\n" + } + } + }, + { + "cell_type": "code", + "execution_count": 51, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Loading cached processed dataset at /home/karo/.cache/huggingface/datasets/subjqa/tripadvisor/1.1.0/e5588f9298ff2d70686a00cc377e4bdccf4e32287459e3c6baf2dc5ab57fe7fd/cache-bfb7935e995d8aee.arrow\n" + ] + }, + { + "data": { + "text/plain": " 0%| | 0/265 [00:00", + "text/html": "\n
\n \n \n [ 2/762 : < :, Epoch 0.00/3]\n
\n \n \n \n \n \n \n \n \n \n
StepTraining Loss

" + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Saving model checkpoint to data/bert-finetuned-subjqa/checkpoint-254\n", + "Configuration saved in data/bert-finetuned-subjqa/checkpoint-254/config.json\n", + "Model weights saved in data/bert-finetuned-subjqa/checkpoint-254/pytorch_model.bin\n", + "tokenizer config file saved in data/bert-finetuned-subjqa/checkpoint-254/tokenizer_config.json\n", + "Special tokens file saved in data/bert-finetuned-subjqa/checkpoint-254/special_tokens_map.json\n", + "Saving model checkpoint to data/bert-finetuned-subjqa/checkpoint-508\n", + "Configuration saved in data/bert-finetuned-subjqa/checkpoint-508/config.json\n", + "Model weights saved in data/bert-finetuned-subjqa/checkpoint-508/pytorch_model.bin\n", + "tokenizer config file saved in data/bert-finetuned-subjqa/checkpoint-508/tokenizer_config.json\n", + "Special tokens file saved in data/bert-finetuned-subjqa/checkpoint-508/special_tokens_map.json\n", + "Saving model checkpoint to data/bert-finetuned-subjqa/checkpoint-762\n", + "Configuration saved in data/bert-finetuned-subjqa/checkpoint-762/config.json\n", + "Model weights saved in data/bert-finetuned-subjqa/checkpoint-762/pytorch_model.bin\n", + "tokenizer config file saved in data/bert-finetuned-subjqa/checkpoint-762/tokenizer_config.json\n", + "Special tokens file saved in data/bert-finetuned-subjqa/checkpoint-762/special_tokens_map.json\n", + "\n", + "\n", + "Training completed. Do not forget to share your model on huggingface.co/models =)\n", + "\n", + "\n" + ] + }, + { + "data": { + "text/plain": "TrainOutput(global_step=762, training_loss=0.7420042443463183, metrics={'train_runtime': 336.3364, 'train_samples_per_second': 18.107, 'train_steps_per_second': 2.266, 'total_flos': 1193472936391680.0, 'train_loss': 0.7420042443463183, 'epoch': 3.0})" + }, + "execution_count": 53, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from transformers import TrainingArguments\n", + "from transformers import Trainer\n", + "\n", + "\n", + "args = TrainingArguments(\n", + " output_dir=\"data/bert-finetuned-subjqa\",\n", + " overwrite_output_dir=True,\n", + " evaluation_strategy=\"no\",\n", + " save_strategy=\"epoch\",\n", + " learning_rate=2e-5,\n", + " num_train_epochs=3,\n", + " weight_decay=0.01,\n", + " fp16=True,\n", + ")\n", + "trainer = Trainer(\n", + " model=model,\n", + " args=args,\n", + " train_dataset=train_dataset,\n", + " eval_dataset=validation_dataset,\n", + " tokenizer=tokenizer,\n", + ")\n", + "trainer.train()" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + } + }, + { + "cell_type": "code", + "execution_count": 54, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "The following columns in the test set don't have a corresponding argument in `BertForQuestionAnswering.forward` and have been ignored: example_id, offset_mapping. If example_id, offset_mapping are not expected by `BertForQuestionAnswering.forward`, you can safely ignore this message.\n", + "***** Running Prediction *****\n", + " Num examples = 327\n", + " Batch size = 8\n" + ] + }, + { + "data": { + "text/plain": "", + "text/html": "\n

\n \n \n [ 1/41 : < :]\n
\n " + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": " 0%| | 0/265 [00:00", + "text/html": "\n
\n \n \n [ 2/750 : < :, Epoch 0.00/3]\n
\n \n \n \n \n \n \n \n \n \n
StepTraining Loss

" + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Saving model checkpoint to data/roberta-finetuned-subjqa/checkpoint-250\n", + "Configuration saved in data/roberta-finetuned-subjqa/checkpoint-250/config.json\n", + "Model weights saved in data/roberta-finetuned-subjqa/checkpoint-250/pytorch_model.bin\n", + "tokenizer config file saved in data/roberta-finetuned-subjqa/checkpoint-250/tokenizer_config.json\n", + "Special tokens file saved in data/roberta-finetuned-subjqa/checkpoint-250/special_tokens_map.json\n", + "Saving model checkpoint to data/roberta-finetuned-subjqa/checkpoint-500\n", + "Configuration saved in data/roberta-finetuned-subjqa/checkpoint-500/config.json\n", + "Model weights saved in data/roberta-finetuned-subjqa/checkpoint-500/pytorch_model.bin\n", + "tokenizer config file saved in data/roberta-finetuned-subjqa/checkpoint-500/tokenizer_config.json\n", + "Special tokens file saved in data/roberta-finetuned-subjqa/checkpoint-500/special_tokens_map.json\n", + "Saving model checkpoint to data/roberta-finetuned-subjqa/checkpoint-750\n", + "Configuration saved in data/roberta-finetuned-subjqa/checkpoint-750/config.json\n", + "Model weights saved in data/roberta-finetuned-subjqa/checkpoint-750/pytorch_model.bin\n", + "tokenizer config file saved in data/roberta-finetuned-subjqa/checkpoint-750/tokenizer_config.json\n", + "Special tokens file saved in data/roberta-finetuned-subjqa/checkpoint-750/special_tokens_map.json\n", + "\n", + "\n", + "Training completed. Do not forget to share your model on huggingface.co/models =)\n", + "\n", + "\n" + ] + }, + { + "data": { + "text/plain": "TrainOutput(global_step=750, training_loss=0.4630465749104818, metrics={'train_runtime': 337.0904, 'train_samples_per_second': 17.755, 'train_steps_per_second': 2.225, 'total_flos': 1172895816798720.0, 'train_loss': 0.4630465749104818, 'epoch': 3.0})" + }, + "execution_count": 57, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "args = TrainingArguments(\n", + " output_dir=\"data/roberta-finetuned-subjqa\",\n", + " overwrite_output_dir=True,\n", + " evaluation_strategy=\"no\",\n", + " save_strategy=\"epoch\",\n", + " learning_rate=2e-5,\n", + " num_train_epochs=3,\n", + " weight_decay=0.01,\n", + " fp16=True,\n", + ")\n", + "trainer = Trainer(\n", + " model=model,\n", + " args=args,\n", + " train_dataset=roberta_train_dataset,\n", + " eval_dataset=roberta_validation_dataset,\n", + " tokenizer=tokenizer,\n", + ")\n", + "trainer.train()" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + } + }, + { + "cell_type": "code", + "execution_count": 58, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "The following columns in the test set don't have a corresponding argument in `RobertaForQuestionAnswering.forward` and have been ignored: example_id, offset_mapping. If example_id, offset_mapping are not expected by `RobertaForQuestionAnswering.forward`, you can safely ignore this message.\n", + "***** Running Prediction *****\n", + " Num examples = 321\n", + " Batch size = 8\n" + ] + }, + { + "data": { + "text/plain": "", + "text/html": "\n

\n \n \n [ 1/41 : < :]\n
\n " + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": " 0%| | 0/265 [00:00