Update notebooks

This commit is contained in:
Ryszard Staruch 2024-05-17 17:12:00 +02:00
parent 4abe74e453
commit 97d4e400fe
2 changed files with 21 additions and 6 deletions

View File

@ -12,7 +12,7 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"#### Zadanie 1 (150 punktów)\n", "#### Zadanie 1 (170 punktów)\n",
"\n", "\n",
"Na podstawie zbioru danych https://huggingface.co/datasets/mteb/tweet_sentiment_extraction stwórz model bazujący na dwukierunkowej sieci neuronowej LSTM (proszę skorzystać z gotowego modułu LSTM w bibliotece torch) do klasyfikacji sentymentu tekstów w postaci tweetów. Można skorzystać z gotowych embeddingów lub wytrenować własne - względem uznania. Metody filtrowania tekstów (często zawierają wiele różnych znaków/symboli, które mogą mieć znaczenie) również należą do Państwa zadania. \n", "Na podstawie zbioru danych https://huggingface.co/datasets/mteb/tweet_sentiment_extraction stwórz model bazujący na dwukierunkowej sieci neuronowej LSTM (proszę skorzystać z gotowego modułu LSTM w bibliotece torch) do klasyfikacji sentymentu tekstów w postaci tweetów. Można skorzystać z gotowych embeddingów lub wytrenować własne - względem uznania. Metody filtrowania tekstów (często zawierają wiele różnych znaków/symboli, które mogą mieć znaczenie) również należą do Państwa zadania. \n",
"\n", "\n",

View File

@ -139,11 +139,11 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 39, "execution_count": 42,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"lstm_layer = torch.nn.LSTM(5, 5, 2, batch_first=True, bidirectional=True)\n", "lstm_layer = torch.nn.LSTM(5, 5, 30, batch_first=True, bidirectional=True)\n",
"\n", "\n",
"embedded_inputs = embedding(padded_input)\n", "embedded_inputs = embedding(padded_input)\n",
"x = torch.nn.utils.rnn.pack_padded_sequence(embedded_inputs, lengths, batch_first=True, enforce_sorted=False)\n", "x = torch.nn.utils.rnn.pack_padded_sequence(embedded_inputs, lengths, batch_first=True, enforce_sorted=False)\n",
@ -151,25 +151,32 @@
"output, _ = torch.nn.utils.rnn.pad_packed_sequence(output, batch_first=True)" "output, _ = torch.nn.utils.rnn.pad_packed_sequence(output, batch_first=True)"
] ]
}, },
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Zmienna hidden zawiera wszystkie ukryte stany na przestrzeni wszystkich warstw, natomiast zmienna output zawiera jedynie stany w ostatniej warstwie"
]
},
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"#### Wartościami, które należy wykorzystać do klasyfikacji to (jedna z dwóch opcji):\n", "#### Wartościami, które należy wykorzystać do klasyfikacji to (jedna z dwóch opcji):\n",
"* konkatenacja ostatniego i przedostatniego elementu z warstwy hidden (sieć jest dwukierunkowa, więc chcemy się dostać do stanów z ostatniej warstwy jednego oraz drugiego kierunku)\n", "* konkatenacja ostatniego i przedostatniego elementu ze zmiennej hidden (sieć jest dwukierunkowa, więc chcemy się dostać do stanów z ostatniej warstwy jednego oraz drugiego kierunku)\n",
"* pierwszy element dla każdego przykładu ze zmiennej out (tam jest automatycznie skonkatenowany output dla obu kierunków, dlatego mamy na końcu rozmiar 10)" "* pierwszy element dla każdego przykładu ze zmiennej out (tam jest automatycznie skonkatenowany output dla obu kierunków, dlatego mamy na końcu rozmiar 10)"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 40, "execution_count": 43,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"torch.Size([6, 3, 5])\n", "torch.Size([60, 3, 5])\n",
"torch.Size([3, 7, 10])\n" "torch.Size([3, 7, 10])\n"
] ]
} }
@ -178,6 +185,14 @@
"print(hidden.shape)\n", "print(hidden.shape)\n",
"print(output.shape)" "print(output.shape)"
] ]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"torch.Size([6, 3, 5])\n",
"torch.Size([3, 7, 10])"
]
} }
], ],
"metadata": { "metadata": {