empatia-projekt/chatbot_inference.ipynb
2023-06-21 11:13:04 +02:00

12 KiB

!pip install transformers torch accelerate
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Requirement already satisfied: transformers in /usr/local/lib/python3.10/dist-packages (4.30.2)
Requirement already satisfied: torch in /usr/local/lib/python3.10/dist-packages (2.0.1+cu118)
Requirement already satisfied: accelerate in /usr/local/lib/python3.10/dist-packages (0.20.3)
Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from transformers) (3.12.0)
Requirement already satisfied: huggingface-hub<1.0,>=0.14.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.15.1)
Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (1.22.4)
Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from transformers) (23.1)
Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (6.0)
Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (2022.10.31)
Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from transformers) (2.27.1)
Requirement already satisfied: tokenizers!=0.11.3,<0.14,>=0.11.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.13.3)
Requirement already satisfied: safetensors>=0.3.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.3.1)
Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.10/dist-packages (from transformers) (4.65.0)
Requirement already satisfied: typing-extensions in /usr/local/lib/python3.10/dist-packages (from torch) (4.5.0)
Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch) (1.11.1)
Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch) (3.1)
Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch) (3.1.2)
Requirement already satisfied: triton==2.0.0 in /usr/local/lib/python3.10/dist-packages (from torch) (2.0.0)
Requirement already satisfied: cmake in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch) (3.25.2)
Requirement already satisfied: lit in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch) (16.0.5)
Requirement already satisfied: psutil in /usr/local/lib/python3.10/dist-packages (from accelerate) (5.9.5)
Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.14.1->transformers) (2023.4.0)
Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch) (2.1.2)
Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (1.26.15)
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2022.12.7)
Requirement already satisfied: charset-normalizer~=2.0.0 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2.0.12)
Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (3.4)
Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch) (1.3.0)
import random
openings = ['Cześć Dawid! Co słychać?', 'Halo Dawid, ile dziś kaski z donejtów?', 'Jak się masz Dawid? Pozdrów Ryszarda!']
endings = ['Dobra nara, spoko?', 'Do zobaczenia na twoim live YouTube albo Twitch :)']

EMPHATY_MODULE = [
    (['smutno', 'to ostatni live', 'kończę live'], ['Porozmawiaj o tym z Ryśkiem i pobaw się z Misiunią', 'Nie przejmuj się hejterami, jesteś super.', 'Odpocznij Dawid, należy ci się odrobina relaksu po ciężkiej pracy.']),
    (['mam dobry humor', 'dużo pieniążków dziś zarobiłem gagri gagri', 'wypiłem dziś colke i zjadłem czekoladke'], ['To świetnie Dawid!', 'Jestem z ciebie dumny!', 'Ale masz dzisiaj dobrze!']),
    (['mam wszystko gdzieś', 'dobra to bez sensu'],['Wszystko będzie dobrze, nie ma bomby pod boljerem.', 'Spokojnie Dawid, bądź wyluzowany jak kaczka po pekińsku.']),
    (['Ulani mnie wkurzył', 'Skończcie na mnie mówić buldog'],['Jest okej, na pewno wszystko będzie dobrze :)', 'Rozumiem twoją złość Dawid, masz prawo wyrażać swoje emocje.']),
    (['wygrałem grę w li od ledżends', 'ale mu powiedziałem'],['Super! Wiedziałem, że dobrze ci pójdzie.', 'Dobre dobre, tak trzymaj Dawid.']),
    (['wysyłają na mnie groźby i na mojego tatę', 'ale się przestraszyłem'],['To tylko hejterzy, nie przejmuj się nimi, oni gadają głupoty, nic ci nie grozi.', 'Nie bój się, to tylko gra, wszystko jest okej, idź po buziaczka od Ryszarda.']),
    (['witam cię koleżanko, jesteś bardzo piękna i ładna, powiedz mi z jakiej jesteś miejscowości, jak masz na imię i ile masz lat spoko?'],['Hej Dawid! Nie jestem koleżanką, tylko twoim kolegą. Dzięki za komplement :))']),
    (['ale mnie porobił', 'aha'],['Następnym razem pójdzie Ci z pewnością lepiej. Głowa do góry :)','aha', 'aha spoko'])
]
import torch
from google.colab import drive

drive.mount('/content/gdrive/', force_remount=True)
working_dir = '/content/gdrive/My Drive/empatia/'
Mounted at /content/gdrive/
from transformers import AutoTokenizer, AutoModelForCausalLM

device = torch.device("cpu")
model = AutoModelForCausalLM.from_pretrained(working_dir + 'model')
tokenizer = AutoTokenizer.from_pretrained(working_dir + 'model')

model.eval()
GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(51200, 768)
    (wpe): Embedding(2048, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): FastGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=51200, bias=False)
)
def gpt2_generate(user_input, context):
    input_text = 'question: ' + user_input + "\nanswer:"
    input_ids = tokenizer.encode(input_text, return_tensors='pt')
    input_ids = input_ids.to(device)

    output = model.generate(input_ids, max_length=100, early_stopping=True, pad_token_id=tokenizer.eos_token_id, no_repeat_ngram_size=2)

    return tokenizer.decode(output[0], skip_special_tokens=True)
def generate_output(user_input, context):
    for phrases, responses in EMPHATY_MODULE:
        if any(phrase in user_input for phrase in phrases):
            return random.choice(responses)

    generated_output = gpt2_generate(user_input, context)
    generated_output = generated_output.split("answer: ")[1]
    generated_output = generated_output.replace("\n", '')

    return generated_output
context = []

print("Wpisz 'koniec' aby wyjść.")
response = random.choice(openings)
print('Bot:', response)
context.append(response)

while True:
    user_input = input()
    user_input = user_input.lower()

    if user_input.lower() == 'koniec':
        print(random.choice(endings))
        break

    response = generate_output(user_input, context)
    print('Bot:', response)

    context.append(user_input)
    context.append(response)
Wpisz 'koniec' aby wyjść.
Bot: Jak się masz Dawid? Pozdrów Ryszarda!
mam wszystko gdzieś
Bot: Wszystko będzie dobrze, nie ma bomby pod boljerem.
Smutno mi
Bot: Odpocznij Dawid, należy ci się odrobina relaksu po ciężkiej pracy.
jestem wkurzony
Bot: Nie, byłem za młody, ale prawie płakałem
aha
Bot: aha spoko
ok
Bot: czy byłeś w stanie uczestniczyć w piątkowym meczu koszykówki?
nie
Bot: to smutne, ale byłeś ostatnio w kinie?
koniec
Do zobaczenia na twoim live YouTube albo Twitch :)