empatia-projekt/chatbot_inference.ipynb
2023-06-21 11:13:04 +02:00

1 line
12 KiB
Plaintext

{"cells":[{"cell_type":"code","execution_count":1,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"j8IFMhJT6LTu","outputId":"c530d055-4e03-49b2-db16-6e508a9164fd","executionInfo":{"status":"ok","timestamp":1687338323582,"user_tz":-120,"elapsed":6268,"user":{"displayName":"Michał Ulaniuk","userId":"07769450445479269606"}}},"outputs":[{"output_type":"stream","name":"stdout","text":["Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n","Requirement already satisfied: transformers in /usr/local/lib/python3.10/dist-packages (4.30.2)\n","Requirement already satisfied: torch in /usr/local/lib/python3.10/dist-packages (2.0.1+cu118)\n","Requirement already satisfied: accelerate in /usr/local/lib/python3.10/dist-packages (0.20.3)\n","Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from transformers) (3.12.0)\n","Requirement already satisfied: huggingface-hub<1.0,>=0.14.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.15.1)\n","Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (1.22.4)\n","Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from transformers) (23.1)\n","Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (6.0)\n","Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (2022.10.31)\n","Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from transformers) (2.27.1)\n","Requirement already satisfied: tokenizers!=0.11.3,<0.14,>=0.11.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.13.3)\n","Requirement already satisfied: safetensors>=0.3.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.3.1)\n","Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.10/dist-packages (from transformers) (4.65.0)\n","Requirement already satisfied: typing-extensions in /usr/local/lib/python3.10/dist-packages (from torch) (4.5.0)\n","Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch) (1.11.1)\n","Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch) (3.1)\n","Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch) (3.1.2)\n","Requirement already satisfied: triton==2.0.0 in /usr/local/lib/python3.10/dist-packages (from torch) (2.0.0)\n","Requirement already satisfied: cmake in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch) (3.25.2)\n","Requirement already satisfied: lit in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch) (16.0.5)\n","Requirement already satisfied: psutil in /usr/local/lib/python3.10/dist-packages (from accelerate) (5.9.5)\n","Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.14.1->transformers) (2023.4.0)\n","Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch) (2.1.2)\n","Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (1.26.15)\n","Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2022.12.7)\n","Requirement already satisfied: charset-normalizer~=2.0.0 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2.0.12)\n","Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (3.4)\n","Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch) (1.3.0)\n"]}],"source":["!pip install transformers torch accelerate"]},{"cell_type":"code","execution_count":2,"metadata":{"id":"n3QDx8nx3IdC","executionInfo":{"status":"ok","timestamp":1687338323583,"user_tz":-120,"elapsed":15,"user":{"displayName":"Michał Ulaniuk","userId":"07769450445479269606"}}},"outputs":[],"source":["import random"]},{"cell_type":"code","execution_count":3,"metadata":{"id":"2wi9bSPF4O2M","executionInfo":{"status":"ok","timestamp":1687338323584,"user_tz":-120,"elapsed":12,"user":{"displayName":"Michał Ulaniuk","userId":"07769450445479269606"}}},"outputs":[],"source":["openings = ['Cześć Dawid! Co słychać?', 'Halo Dawid, ile dziś kaski z donejtów?', 'Jak się masz Dawid? Pozdrów Ryszarda!']\n","endings = ['Dobra nara, spoko?', 'Do zobaczenia na twoim live YouTube albo Twitch :)']\n","\n","EMPHATY_MODULE = [\n"," (['smutno', 'to ostatni live', 'kończę live'], ['Porozmawiaj o tym z Ryśkiem i pobaw się z Misiunią', 'Nie przejmuj się hejterami, jesteś super.', 'Odpocznij Dawid, należy ci się odrobina relaksu po ciężkiej pracy.']),\n"," (['mam dobry humor', 'dużo pieniążków dziś zarobiłem gagri gagri', 'wypiłem dziś colke i zjadłem czekoladke'], ['To świetnie Dawid!', 'Jestem z ciebie dumny!', 'Ale masz dzisiaj dobrze!']),\n"," (['mam wszystko gdzieś', 'dobra to bez sensu'],['Wszystko będzie dobrze, nie ma bomby pod boljerem.', 'Spokojnie Dawid, bądź wyluzowany jak kaczka po pekińsku.']),\n"," (['Ulani mnie wkurzył', 'Skończcie na mnie mówić buldog'],['Jest okej, na pewno wszystko będzie dobrze :)', 'Rozumiem twoją złość Dawid, masz prawo wyrażać swoje emocje.']),\n"," (['wygrałem grę w li od ledżends', 'ale mu powiedziałem'],['Super! Wiedziałem, że dobrze ci pójdzie.', 'Dobre dobre, tak trzymaj Dawid.']),\n"," (['wysyłają na mnie groźby i na mojego tatę', 'ale się przestraszyłem'],['To tylko hejterzy, nie przejmuj się nimi, oni gadają głupoty, nic ci nie grozi.', 'Nie bój się, to tylko gra, wszystko jest okej, idź po buziaczka od Ryszarda.']),\n"," (['witam cię koleżanko, jesteś bardzo piękna i ładna, powiedz mi z jakiej jesteś miejscowości, jak masz na imię i ile masz lat spoko?'],['Hej Dawid! Nie jestem koleżanką, tylko twoim kolegą. Dzięki za komplement :))']),\n"," (['ale mnie porobił', 'aha'],['Następnym razem pójdzie Ci z pewnością lepiej. Głowa do góry :)','aha', 'aha spoko'])\n","]"]},{"cell_type":"code","execution_count":4,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"cKhaMyTF5rRt","outputId":"6457628b-5e7a-450e-f962-387863e262c8","executionInfo":{"status":"ok","timestamp":1687338331435,"user_tz":-120,"elapsed":7862,"user":{"displayName":"Michał Ulaniuk","userId":"07769450445479269606"}}},"outputs":[{"output_type":"stream","name":"stdout","text":["Mounted at /content/gdrive/\n"]}],"source":["import torch\n","from google.colab import drive\n","\n","drive.mount('/content/gdrive/', force_remount=True)\n","working_dir = '/content/gdrive/My Drive/empatia/'"]},{"cell_type":"code","execution_count":5,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"dAX55vjm6Qkr","outputId":"aa9208e4-fd62-4149-f92c-06537873659b","executionInfo":{"status":"ok","timestamp":1687338353912,"user_tz":-120,"elapsed":22487,"user":{"displayName":"Michał Ulaniuk","userId":"07769450445479269606"}}},"outputs":[{"output_type":"execute_result","data":{"text/plain":["GPT2LMHeadModel(\n"," (transformer): GPT2Model(\n"," (wte): Embedding(51200, 768)\n"," (wpe): Embedding(2048, 768)\n"," (drop): Dropout(p=0.1, inplace=False)\n"," (h): ModuleList(\n"," (0-11): 12 x GPT2Block(\n"," (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n"," (attn): GPT2Attention(\n"," (c_attn): Conv1D()\n"," (c_proj): Conv1D()\n"," (attn_dropout): Dropout(p=0.1, inplace=False)\n"," (resid_dropout): Dropout(p=0.1, inplace=False)\n"," )\n"," (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n"," (mlp): GPT2MLP(\n"," (c_fc): Conv1D()\n"," (c_proj): Conv1D()\n"," (act): FastGELUActivation()\n"," (dropout): Dropout(p=0.1, inplace=False)\n"," )\n"," )\n"," )\n"," (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n"," )\n"," (lm_head): Linear(in_features=768, out_features=51200, bias=False)\n",")"]},"metadata":{},"execution_count":5}],"source":["from transformers import AutoTokenizer, AutoModelForCausalLM\n","\n","device = torch.device(\"cpu\")\n","model = AutoModelForCausalLM.from_pretrained(working_dir + 'model')\n","tokenizer = AutoTokenizer.from_pretrained(working_dir + 'model')\n","\n","model.eval()"]},{"cell_type":"code","execution_count":6,"metadata":{"id":"e_b9NYAn5kCO","executionInfo":{"status":"ok","timestamp":1687338353913,"user_tz":-120,"elapsed":16,"user":{"displayName":"Michał Ulaniuk","userId":"07769450445479269606"}}},"outputs":[],"source":["def gpt2_generate(user_input, context):\n"," input_text = 'question: ' + user_input + \"\\nanswer:\"\n"," input_ids = tokenizer.encode(input_text, return_tensors='pt')\n"," input_ids = input_ids.to(device)\n","\n"," output = model.generate(input_ids, max_length=100, early_stopping=True, pad_token_id=tokenizer.eos_token_id, no_repeat_ngram_size=2)\n","\n"," return tokenizer.decode(output[0], skip_special_tokens=True)\n"]},{"cell_type":"code","execution_count":7,"metadata":{"id":"cqEymvgN4Pfd","executionInfo":{"status":"ok","timestamp":1687338353914,"user_tz":-120,"elapsed":15,"user":{"displayName":"Michał Ulaniuk","userId":"07769450445479269606"}}},"outputs":[],"source":["def generate_output(user_input, context):\n"," for phrases, responses in EMPHATY_MODULE:\n"," if any(phrase in user_input for phrase in phrases):\n"," return random.choice(responses)\n","\n"," generated_output = gpt2_generate(user_input, context)\n"," generated_output = generated_output.split(\"answer: \")[1]\n"," generated_output = generated_output.replace(\"\\n\", '')\n","\n"," return generated_output"]},{"cell_type":"code","execution_count":14,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"glGaga2a2d_s","outputId":"dd90d004-3a6d-45c5-e291-f256479c3a78","executionInfo":{"status":"ok","timestamp":1687338577633,"user_tz":-120,"elapsed":43935,"user":{"displayName":"Michał Ulaniuk","userId":"07769450445479269606"}}},"outputs":[{"output_type":"stream","name":"stdout","text":["Wpisz 'koniec' aby wyjść.\n","Bot: Jak się masz Dawid? Pozdrów Ryszarda!\n","mam wszystko gdzieś\n","Bot: Wszystko będzie dobrze, nie ma bomby pod boljerem.\n","Smutno mi\n","Bot: Odpocznij Dawid, należy ci się odrobina relaksu po ciężkiej pracy.\n","jestem wkurzony\n","Bot: Nie, byłem za młody, ale prawie płakałem\n","aha\n","Bot: aha spoko\n","ok\n","Bot: czy byłeś w stanie uczestniczyć w piątkowym meczu koszykówki?\n","nie\n","Bot: to smutne, ale byłeś ostatnio w kinie?\n","koniec\n","Do zobaczenia na twoim live YouTube albo Twitch :)\n"]}],"source":["context = []\n","\n","print(\"Wpisz 'koniec' aby wyjść.\")\n","response = random.choice(openings)\n","print('Bot:', response)\n","context.append(response)\n","\n","while True:\n"," user_input = input()\n"," user_input = user_input.lower()\n","\n"," if user_input.lower() == 'koniec':\n"," print(random.choice(endings))\n"," break\n","\n"," response = generate_output(user_input, context)\n"," print('Bot:', response)\n","\n"," context.append(user_input)\n"," context.append(response)\n"]}],"metadata":{"colab":{"provenance":[]},"kernelspec":{"display_name":"Python 3","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.8.16"}},"nbformat":4,"nbformat_minor":0}