Uczenie_Glebokie/Projekt/project.ipynb

408 lines
11 KiB
Plaintext
Raw Normal View History

2024-06-10 13:07:21 +02:00
{
"cells": [
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"id": "sMvlO4r-2-dQ"
},
"outputs": [],
"source": [
"import numpy as np\n",
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n",
"import tensorflow.keras.utils as ku\n",
"from wordcloud import WordCloud\n",
"from tensorflow.keras.preprocessing.sequence import pad_sequences\n",
"from tensorflow.keras.layers import Embedding, LSTM, Dense, Dropout, Bidirectional\n",
"from tensorflow.keras.preprocessing.text import Tokenizer\n",
"from tensorflow.keras.models import Sequential\n",
"from tensorflow.keras.optimizers import Adam\n",
"from tensorflow.keras import regularizers\n",
"from keras.models import load_model"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"id": "Ib8MIaQ33Kqk"
},
"outputs": [],
"source": [
"data_pan_tadeusz = open('pan-tadeusz.txt', encoding=\"utf8\").read()\n",
"data_SI = open('SI_data.txt', encoding=\"utf8\").read()"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"id": "wquLVQVj5Tdx"
},
"outputs": [],
"source": [
"def create_corpus(data):\n",
" corpus = data.lower().split(\"\\n\")\n",
" corpus = [element.strip() for element in corpus if element]\n",
" return corpus"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"id": "ZFiZmIeX8Ifi"
},
"outputs": [],
"source": [
"corpus_pan_tadeusz = create_corpus(data_pan_tadeusz)[:4000]\n",
"corpus_SI = create_corpus(data_SI)\n",
"corpus = corpus_pan_tadeusz + corpus_SI"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"id": "2zw0S_vw8Ksf"
},
"outputs": [],
"source": [
"tokenizer = Tokenizer()\n",
"tokenizer.fit_on_texts(corpus)\n",
"total_words = len(tokenizer.word_index)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"id": "VHXURJSO7fBk"
},
"outputs": [],
"source": [
"def create_input_sequences(corpus):\n",
" input_sequences = []\n",
" for line in corpus:\n",
" token_list = tokenizer.texts_to_sequences([line])[0]\n",
"\n",
" for i in range(1, len(token_list)):\n",
" n_gram_sequence = token_list[:i+1]\n",
" input_sequences.append(n_gram_sequence)\n",
" return input_sequences"
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {
"id": "Jl3rcom57ptg"
},
"outputs": [],
"source": [
"input_sequences_pan_tadeusz = create_input_sequences(corpus_pan_tadeusz)\n",
"input_sequences_SI = create_input_sequences(corpus_SI)\n",
"input_sequences = create_input_sequences(corpus)"
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {
"id": "5_ah83de7yfc"
},
"outputs": [],
"source": [
"max_sequence_len = max([len(x) for x in input_sequences])"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"id": "xyz8co7B8SJa"
},
"outputs": [],
"source": [
"def create_predictors_label(input_sequences, max_sequence_len):\n",
" input_sequences = np.array(pad_sequences(input_sequences,\n",
" maxlen=max_sequence_len,\n",
" padding='pre'))\n",
" predictors, label = input_sequences[:, :-1], input_sequences[:, -1]\n",
" label = ku.to_categorical(label, num_classes=total_words+1)\n",
" return predictors, label"
]
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {
"id": "pn8D_IT97BWy"
},
"outputs": [],
"source": [
"predictors_pan_tadeusz, label_pan_tadeusz = create_predictors_label(input_sequences_pan_tadeusz, max_sequence_len)\n",
"predictors_SI, label_SI = create_predictors_label(input_sequences_SI, max_sequence_len)\n",
"predictors, label = create_predictors_label(input_sequences, max_sequence_len)"
]
},
{
"cell_type": "code",
"execution_count": 49,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "j6gmo0fd8Tvq",
"outputId": "a17d4649-9916-42f6-f7dd-75dbeb0dcbd2"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: \"sequential_2\"\n",
"_________________________________________________________________\n",
" Layer (type) Output Shape Param # \n",
"=================================================================\n",
" embedding_2 (Embedding) (None, 75, 100) 1072800 \n",
" \n",
" bidirectional_2 (Bidirecti (None, 75, 300) 301200 \n",
" onal) \n",
" \n",
" dropout_2 (Dropout) (None, 75, 300) 0 \n",
" \n",
" lstm_5 (LSTM) (None, 100) 160400 \n",
" \n",
" dense_4 (Dense) (None, 10727) 1083427 \n",
" \n",
" dense_5 (Dense) (None, 10728) 115089984 \n",
" \n",
"=================================================================\n",
"Total params: 117707811 (449.02 MB)\n",
"Trainable params: 117707811 (449.02 MB)\n",
"Non-trainable params: 0 (0.00 Byte)\n",
"_________________________________________________________________\n",
"None\n"
]
}
],
"source": [
"# model = Sequential()\n",
"# model.add(Embedding(total_words+1, 100,\n",
"# input_length=max_sequence_len-1))\n",
"# model.add(Bidirectional(LSTM(150, return_sequences=True)))\n",
"# model.add(Dropout(0.2))\n",
"# model.add(LSTM(100))\n",
"# model.add(Dense(total_words+1/2, activation='relu',\n",
"# kernel_regularizer=regularizers.l2(0.01)))\n",
"# model.add(Dense(total_words+1, activation='softmax'))\n",
"# model.compile(loss='categorical_crossentropy',\n",
"# optimizer='adam', metrics=['accuracy'])\n",
"# print(model.summary())"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"WARNING:tensorflow:From C:\\Users\\Pawel\\anaconda3\\Lib\\site-packages\\keras\\src\\backend.py:1398: The name tf.executing_eagerly_outside_functions is deprecated. Please use tf.compat.v1.executing_eagerly_outside_functions instead.\n",
"\n"
]
}
],
"source": [
"model = load_model('my_model.h5')"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "UdZmXNVS8aJk",
"outputId": "3664d91a-a866-4bee-d6fc-4c320d68f118"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"744/744 [==============================] - 1501s 2s/step - loss: 1.4722 - accuracy: 0.7626\n"
]
}
],
"source": [
"history = model.fit(predictors_pan_tadeusz, label_pan_tadeusz, epochs=1, verbose=1)"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "ykyvfDET-PdY",
"outputId": "71835132-bd74-4feb-c272-815fa05f8661"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/3\n",
"55/55 [==============================] - 98s 2s/step - loss: 4.6245 - accuracy: 0.2131\n",
"Epoch 2/3\n",
"55/55 [==============================] - 97s 2s/step - loss: 3.9096 - accuracy: 0.2921\n",
"Epoch 3/3\n",
"55/55 [==============================] - 111s 2s/step - loss: 3.4379 - accuracy: 0.3603\n"
]
}
],
"source": [
"history = model.fit(predictors_SI, label_SI, epochs=3, verbose=1)"
]
},
{
"cell_type": "code",
"execution_count": 38,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"799/799 [==============================] - 1105s 1s/step - loss: 1.7071 - accuracy: 0.7451\n"
]
}
],
"source": [
"history = model.fit(predictors, label, epochs=1, verbose=1)"
]
},
{
"cell_type": "code",
"execution_count": 49,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "HYnWu0yWRA0l",
"outputId": "604a19b7-028f-4ac9-a562-67a35965f53d"
},
"outputs": [],
"source": [
"model.save('my_model.h5')"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"id": "bWWKKkKk8d3i"
},
"outputs": [],
"source": [
"def predict(text, next_words=25):\n",
" for _ in range(next_words):\n",
" token_list = tokenizer.texts_to_sequences([text])[0]\n",
" token_list = pad_sequences(\n",
" [token_list], maxlen=max_sequence_len-1,\n",
" padding='pre')\n",
" predicted = np.argmax(model.predict(token_list,\n",
" verbose=0), axis=-1)\n",
" output_word = \"\"\n",
" for word, index in tokenizer.word_index.items():\n",
" if index == predicted:\n",
" output_word = word\n",
" break\n",
"\n",
" text += \" \" + output_word\n",
" return text"
]
},
{
"cell_type": "code",
"execution_count": 56,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 53
},
"id": "bMcMgTh3-EkL",
"outputId": "1423e627-4e33-4c41-af41-3a88a53a3b38"
},
"outputs": [
{
"data": {
"text/plain": [
"'CNN «wielmożni nieruchomi głowę lecz weźmiem na świat ich umiała się wtłoczyć na końcu które w w chleba gałeczki sieci neuronowych i zdolność do generowania'"
]
},
"execution_count": 56,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"predict(\"CNN\", 24)"
]
},
{
"cell_type": "code",
"execution_count": 55,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'GANy i w dawnej surowości prawidłach wychował zakazy żołnierszczyzny na sklepieniu sieci neuronowych w w przetwarzaniu języka naturalnego'"
]
},
"execution_count": 55,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"predict(\"GANy\", 17)"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"gpuType": "T4",
"provenance": []
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
}
},
"nbformat": 4,
"nbformat_minor": 1
}