emphatic_chatbot/chatbot.ipynb
Szymon Parafiński 31538ba2c2 add chatbot
2023-06-18 18:29:05 +02:00

1 line
15 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[],"authorship_tag":"ABX9TyNpRmj7oIplypUpJqmUOjE4"},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"code","source":["!pip install transformers torch accelerate"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"j8IFMhJT6LTu","executionInfo":{"status":"ok","timestamp":1686749292994,"user_tz":-120,"elapsed":21713,"user":{"displayName":"Jakub Adamski","userId":"08902758427564540350"}},"outputId":"0b43bc1b-2d1b-4352-e3d5-2e0f5833526d"},"execution_count":2,"outputs":[{"output_type":"stream","name":"stdout","text":["Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n","Collecting transformers\n"," Downloading transformers-4.30.2-py3-none-any.whl (7.2 MB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.2/7.2 MB\u001b[0m \u001b[31m51.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hRequirement already satisfied: torch in /usr/local/lib/python3.10/dist-packages (2.0.1+cu118)\n","Collecting accelerate\n"," Downloading accelerate-0.20.3-py3-none-any.whl (227 kB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m227.6/227.6 kB\u001b[0m \u001b[31m17.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hRequirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from transformers) (3.12.0)\n","Collecting huggingface-hub<1.0,>=0.14.1 (from transformers)\n"," Downloading huggingface_hub-0.15.1-py3-none-any.whl (236 kB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m236.8/236.8 kB\u001b[0m \u001b[31m22.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hRequirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (1.22.4)\n","Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from transformers) (23.1)\n","Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (6.0)\n","Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (2022.10.31)\n","Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from transformers) (2.27.1)\n","Collecting tokenizers!=0.11.3,<0.14,>=0.11.1 (from transformers)\n"," Downloading tokenizers-0.13.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.8 MB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.8/7.8 MB\u001b[0m \u001b[31m90.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hCollecting safetensors>=0.3.1 (from transformers)\n"," Downloading safetensors-0.3.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)\n","\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.3/1.3 MB\u001b[0m \u001b[31m50.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n","\u001b[?25hRequirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.10/dist-packages (from transformers) (4.65.0)\n","Requirement already satisfied: typing-extensions in /usr/local/lib/python3.10/dist-packages (from torch) (4.5.0)\n","Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch) (1.11.1)\n","Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch) (3.1)\n","Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch) (3.1.2)\n","Requirement already satisfied: triton==2.0.0 in /usr/local/lib/python3.10/dist-packages (from torch) (2.0.0)\n","Requirement already satisfied: cmake in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch) (3.25.2)\n","Requirement already satisfied: lit in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch) (16.0.5)\n","Requirement already satisfied: psutil in /usr/local/lib/python3.10/dist-packages (from accelerate) (5.9.5)\n","Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.14.1->transformers) (2023.4.0)\n","Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch) (2.1.2)\n","Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (1.26.15)\n","Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2022.12.7)\n","Requirement already satisfied: charset-normalizer~=2.0.0 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2.0.12)\n","Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (3.4)\n","Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch) (1.3.0)\n","Installing collected packages: tokenizers, safetensors, huggingface-hub, transformers, accelerate\n","Successfully installed accelerate-0.20.3 huggingface-hub-0.15.1 safetensors-0.3.1 tokenizers-0.13.3 transformers-4.30.2\n"]}]},{"cell_type":"code","source":["import random"],"metadata":{"id":"n3QDx8nx3IdC","executionInfo":{"status":"ok","timestamp":1686749301042,"user_tz":-120,"elapsed":3,"user":{"displayName":"Jakub Adamski","userId":"08902758427564540350"}}},"execution_count":3,"outputs":[]},{"cell_type":"code","source":["greetings = ['Cześć! Co słychać?', 'Rybka bierze?', 'Porozmawiajmy wędkarski świrze!']\n","endings = ['Miło się z Tobą rozmawiało!', 'Do zobaczenia na wielkich wodach!']\n","\n","EMPHATY_MODULE = [\n"," (['smutno mi', 'nienawidzę życia', 'czuję żal'], ['Próbowałeś pójść na ryby?', 'Szum jeziora przy wędce pomaga, będzie dobrze', 'Wędkarstwo może być dla ciebie odskocznią od smutku. Spędzanie czasu na łowieniu ryb, obserwowanie przyrody i oddawanie się temu relaksującemu hobby może pomóc złagodzić smutne uczucia.']),\n"," (['jestem szczęśliwy', 'jestem wesoły', 'czuję się szczęśliwy', 'jestem zadowolony'], ['Wspaniale! To pewnie przez nasze ukochane wędkarstwo!', 'Cieszę się, że łowienie ryb wpływa tak dobrze na Ciebie', 'Pewnie Twoja euforia wynika z tego, że złapałeś suma!', 'Cieszę się, że także odkrywasz radość płynącą z wędkarstwa! To wspaniałe uczucie, gdy uda się złowić rybę i doświadczyć tego bliskiego kontaktu z naturą.']),\n"," (['jestem obojętny', 'wszystko mi jedno'],['Jeśli czujesz się obojętny, może warto spróbować wędkarstwa. Spędzenie czasu na łonie natury, czekanie na branie ryby i poczucie przyjemności, gdy uda ci się coś złowić, może odmienić twoje spojrzenie na to.']),\n"," (['czuję złość','jestem wściekły', 'czuję gniew'],['Jeśli czujesz złość, wędkowanie może dać ci możliwość skupienia się na czymś pozytywnym i odprężającym. To moment, gdy możesz oderwać się od codziennych frustracji i skupić na łowieniu ryb.\"']),\n"," (['czuję satysfakcję', 'jestem zadowolony'],['Zadowolenie z wędkarstwa jest czymś niesamowitym. Wielu z nas znajduje w nim spokój, relaks i satysfakcję z pokonywania trudności. Cieszę się, że to także dla ciebie takie ważne.']),\n"," (['czuję przerażenie', 'boję się'],['Rozumiem, że możesz czuć lęk, ale wędkarstwo może być świetnym sposobem na pokonanie swoich obaw. Czujesz się bezpiecznie, będąc na wodzie i skupiając się na wędkowaniu, co może pomóc w złagodzeniu strachu.', 'Wędkarstwo może być dla ciebie sposobem na pokonanie lęku poprzez wyzwanie siebie i zdobycie nowych umiejętności. Stopniowo, kiedy nabierzesz pewności siebie, strach zacznie się zmniejszać.']),\n"," (['kocham cię'],['Ja Ciebie też! Jesteś wspaniałym wędkarzem', 'Kocham cię również jak rybę kocha wodę! Jesteś dla mnie jak największy złów, którym mogłem się uchwycić. Razem możemy pływać przez życie i łowić wspólne szczęście.', 'Twoje słowa są dla mnie jak najpiękniejszy branie ryby na wędkę. Cieszę się, że razem możemy eksplorować wędkarskie przygody i tworzyć niezapomniane chwile.']),\n"," (['jestem zdziwiony', 'jestem zaskoczony'],['Cieszę się, że wędkarstwo potrafi wywołać w Tobie zaskoczenie. Ta pasja ma w sobie wiele nieprzewidywalności i niespodzianek, które sprawiają, że każda wyprawa jest wyjątkowa.','Zaskoczenie jest częścią uroku wędkarstwa. Czasem ryby zachowują się inaczej, niż się spodziewamy, co sprawia, że całe doświadczenie staje się jeszcze bardziej ekscytujące. Opowiedz mi więcej o swoim zaskoczeniu!', 'Wędkarstwo to niezwykłe hobby, które potrafi nas zaskoczyć w najmniej oczekiwanym momencie. Cieszę się, że odczuwasz to samo.'])\n","]"],"metadata":{"id":"2wi9bSPF4O2M","executionInfo":{"status":"ok","timestamp":1686749302731,"user_tz":-120,"elapsed":2,"user":{"displayName":"Jakub Adamski","userId":"08902758427564540350"}}},"execution_count":4,"outputs":[]},{"cell_type":"code","source":["import torch\n","from google.colab import drive\n","\n","drive.mount('/content/gdrive/', force_remount=True)\n","working_dir = '/content/gdrive/My Drive/UAM/Magisterka/Empatia/'"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"cKhaMyTF5rRt","executionInfo":{"status":"ok","timestamp":1686749339629,"user_tz":-120,"elapsed":35104,"user":{"displayName":"Jakub Adamski","userId":"08902758427564540350"}},"outputId":"b59e206c-c4e6-48f1-d4b5-facfab9b5df1"},"execution_count":5,"outputs":[{"output_type":"stream","name":"stdout","text":["Mounted at /content/gdrive/\n"]}]},{"cell_type":"code","source":["from transformers import AutoTokenizer, AutoModelForCausalLM\n","\n","device = torch.device(\"cpu\")\n","model = AutoModelForCausalLM.from_pretrained(working_dir + 'model_save')\n","tokenizer = AutoTokenizer.from_pretrained(working_dir + 'model_save')\n","\n","model.eval()"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"dAX55vjm6Qkr","executionInfo":{"status":"ok","timestamp":1686749371855,"user_tz":-120,"elapsed":27212,"user":{"displayName":"Jakub Adamski","userId":"08902758427564540350"}},"outputId":"6f80cb5e-a93b-4f0a-c1e3-cb99f1b81009"},"execution_count":6,"outputs":[{"output_type":"execute_result","data":{"text/plain":["GPT2LMHeadModel(\n"," (transformer): GPT2Model(\n"," (wte): Embedding(50257, 768)\n"," (wpe): Embedding(1024, 768)\n"," (drop): Dropout(p=0.0, inplace=False)\n"," (h): ModuleList(\n"," (0-11): 12 x GPT2Block(\n"," (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n"," (attn): GPT2Attention(\n"," (c_attn): Conv1D()\n"," (c_proj): Conv1D()\n"," (attn_dropout): Dropout(p=0.0, inplace=False)\n"," (resid_dropout): Dropout(p=0.0, inplace=False)\n"," )\n"," (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n"," (mlp): GPT2MLP(\n"," (c_fc): Conv1D()\n"," (c_proj): Conv1D()\n"," (act): NewGELUActivation()\n"," (dropout): Dropout(p=0.0, inplace=False)\n"," )\n"," )\n"," )\n"," (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n"," )\n"," (lm_head): Linear(in_features=768, out_features=50257, bias=False)\n",")"]},"metadata":{},"execution_count":6}]},{"cell_type":"code","source":["def gpt2_generate(user_input, context):\n"," input_text = 'question: ' + user_input + \"\\nanswer:\"\n"," input_ids = tokenizer.encode(input_text, return_tensors='pt')\n"," input_ids = input_ids.to(device)\n","\n"," output = model.generate(input_ids, max_length=100, early_stopping=True, pad_token_id=tokenizer.eos_token_id, no_repeat_ngram_size=2)\n","\n"," return tokenizer.decode(output[0], skip_special_tokens=True)\n"],"metadata":{"id":"e_b9NYAn5kCO","executionInfo":{"status":"ok","timestamp":1686749435317,"user_tz":-120,"elapsed":453,"user":{"displayName":"Jakub Adamski","userId":"08902758427564540350"}}},"execution_count":7,"outputs":[]},{"cell_type":"code","source":["def generate_output(user_input, context):\n"," for phrases, responses in EMPHATY_MODULE:\n"," if any(phrase in user_input for phrase in phrases):\n"," return random.choice(responses)\n","\n"," generated_output = gpt2_generate(user_input, context)\n"," generated_output = generated_output.split(\"answer: \")[1]\n"," generated_output = generated_output.replace(\"\\n\", '')\n","\n"," return generated_output"],"metadata":{"id":"cqEymvgN4Pfd","executionInfo":{"status":"ok","timestamp":1686749438182,"user_tz":-120,"elapsed":417,"user":{"displayName":"Jakub Adamski","userId":"08902758427564540350"}}},"execution_count":8,"outputs":[]},{"cell_type":"code","execution_count":9,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"glGaga2a2d_s","executionInfo":{"status":"ok","timestamp":1686749742898,"user_tz":-120,"elapsed":301835,"user":{"displayName":"Jakub Adamski","userId":"08902758427564540350"}},"outputId":"78b2307c-dce2-4226-cdee-33f30c606561"},"outputs":[{"output_type":"stream","name":"stdout","text":["Wpisz 'koniec' aby wyjść.\n","Cześć! Co słychać?\n","smutno mi\n","Próbowałeś pójść na ryby?\n","Jaka ryba jest królem wód?\n","To oczywiste, że królem jest sum. Sum jest lepszy od okonia.\n","Dziś jest środa\n","Czy wędkarstwo może być dobrym sposobem na relaks?andę\n","Jakiej użyć wędki?\n","Aby zmierzyć się z królem wód, którym niewątpliwie jest sum, zalecam używanie mocnej wędek o długości od 7 do 9 milimetrów. Pamiętaj również o cierpliwości sum jest prawdziwym królem!\n","nudzisz mnie\n","Nie nudź się, jeśli nie możesz znaleźć odpowiedniego miejsca na wędkę. To nie tylko kilka rodzajów ryb, ale cały ekosystem, który rządzi się swoimi prawami. Czy próbowałeś/aś łowić ryby w swoim królestwie?\n","Wędkarstwo jest nudne\n","Wędkarstwo to nie tylko walka z rybami, ale także obserwacja natury i odkrywanie piękna wokół nas. To dla nas niezwykle pasjonujące doświadczenie.\n","koniec\n","Miło się z Tobą rozmawiało!\n"]}],"source":["context = []\n","\n","print(\"Wpisz 'koniec' aby wyjść.\")\n","response = random.choice(greetings)\n","print(response)\n","context.append(response)\n","\n","while True:\n"," user_input = input()\n"," user_input = user_input.lower()\n","\n"," if user_input.lower() == 'koniec':\n"," print(random.choice(endings))\n"," break\n","\n"," response = generate_output(user_input, context)\n"," print(response)\n","\n"," context.append(user_input)\n"," context.append(response)\n"]}]}