{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "d42ddd87",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from torch import nn\n",
"\n",
"torch.cuda.empty_cache()"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "37fa7d97",
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"import regex as re\n",
"import csv\n",
"\n",
"def clean_text(text):\n",
" text = text.lower().replace('-\\\\\\\\n', '').replace('\\\\\\\\n', ' ')\n",
" text = re.sub(r'\\p{P}', '', text)\n",
" text = text.replace(\"'t\", \" not\").replace(\"'s\", \" is\").replace(\"'ll\", \" will\").replace(\"'m\", \" am\").replace(\"'ve\", \" have\")\n",
"\n",
" return text"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "41e2f529",
"metadata": {},
"outputs": [],
"source": [
"train_data = pd.read_csv('train/in.tsv.xz', sep='\\t', error_bad_lines=False, warn_bad_lines=False, header=None, quoting=csv.QUOTE_NONE)\n",
"train_labels = pd.read_csv('train/expected.tsv', sep='\\t', error_bad_lines=False, warn_bad_lines=False, header=None, quoting=csv.QUOTE_NONE)\n",
"\n",
"train_data = train_data[[6, 7]]\n",
"train_data = pd.concat([train_data, train_labels], axis=1)\n",
"\n",
"train_data['text'] = train_data[6] + train_data[0] + train_data[7]\n",
"train_data = train_data[['text']]\n",
"\n",
"with open('processed_train.txt', 'w', encoding='utf-8') as file:\n",
" for _, row in train_data.iterrows():\n",
" text = clean_text(str(row['text']))\n",
" file.write(text + '\\n')"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "dc73124c",
"metadata": {},
"outputs": [],
"source": [
"vocab_size = 40000\n",
"embed_size = 300\n",
"hidden_size = 128\n",
"\n",
"class SimpleTrigramNeuralLanguageModel(nn.Module):\n",
" def __init__(self, vocabulary_size, embedding_size, hidden_size):\n",
" super(SimpleTrigramNeuralLanguageModel, self).__init__()\n",
" self.embedding = nn.Embedding(vocabulary_size * 2, embedding_size)\n",
" self.linear1 = nn.Linear(embedding_size, hidden_size)\n",
" self.linear2 = nn.Linear(hidden_size, vocabulary_size * 2)\n",
"\n",
" def forward(self, x):\n",
" x = self.embedding(x)\n",
" x = self.linear1(x)\n",
" x = self.linear2(x)\n",
" x = torch.softmax(x, dim=1)\n",
" return x"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "569b4c88",
"metadata": {},
"outputs": [],
"source": [
"import regex as re\n",
"from itertools import islice, chain\n",
"from torchtext.vocab import build_vocab_from_iterator\n",
"from torch.utils.data import IterableDataset\n",
"\n",
"def get_words_from_line(line):\n",
" line = line.rstrip()\n",
" yield ''\n",
" for m in re.finditer(r'[\\p{L}0-9\\*]+|\\p{P}+', line):\n",
" yield m.group(0).lower()\n",
" yield ''\n",
"\n",
"def get_word_lines_from_file(file_name):\n",
" with open(file_name, 'r', encoding='utf-8') as fh:\n",
" for line in fh:\n",
" yield get_words_from_line(line)\n",
" \n",
"def look_ahead_iterator(gen):\n",
" prev_1 = None\n",
" prev_2 = None\n",
" for item in gen:\n",
" if prev_1 and prev_2:\n",
" yield (prev_2 + prev_1, item)\n",
" prev_2 = prev_1\n",
" prev_1 = item"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "f95cb913",
"metadata": {},
"outputs": [],
"source": [
"class Trigrams(IterableDataset):\n",
" def __init__(self, text_file, vocabulary_size):\n",
" self.vocab = build_vocab_from_iterator(\n",
" get_word_lines_from_file(text_file),\n",
" max_tokens = vocabulary_size,\n",
" specials = ['']\n",
" )\n",
" self.vocab.set_default_index(self.vocab[''])\n",
" self.vocabulary_size = vocabulary_size\n",
" self.text_file = text_file\n",
"\n",
" def __iter__(self):\n",
" return look_ahead_iterator((self.vocab[t] for t in chain.from_iterable(get_word_lines_from_file(self.text_file))))"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "7a51f2b1",
"metadata": {},
"outputs": [],
"source": [
"from torch.utils.data import DataLoader\n",
"\n",
"device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
"train_dataset = Trigrams('processed_train.txt', vocab_size)\n",
"model = SimpleTrigramNeuralLanguageModel(vocab_size, embed_size, hidden_size).to(device)\n",
"data = DataLoader(train_dataset, batch_size=800)\n",
"optimizer = torch.optim.Adam(model.parameters())\n",
"criterion = torch.nn.NLLLoss()"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "474194ae",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0 tensor(11.3293, device='cuda:0', grad_fn=)\n",
"100 tensor(8.9417, device='cuda:0', grad_fn=)\n",
"200 tensor(7.0454, device='cuda:0', grad_fn=)\n",
"300 tensor(6.8511, device='cuda:0', grad_fn=)\n",
"400 tensor(6.8680, device='cuda:0', grad_fn=)\n",
"500 tensor(6.8153, device='cuda:0', grad_fn=)\n",
"600 tensor(6.5640, device='cuda:0', grad_fn=)\n",
"700 tensor(6.8175, device='cuda:0', grad_fn=)\n",
"800 tensor(6.6864, device='cuda:0', grad_fn=)\n",
"900 tensor(6.7530, device='cuda:0', grad_fn=)\n",
"1000 tensor(6.5542, device='cuda:0', grad_fn=)\n",
"1100 tensor(6.5068, device='cuda:0', grad_fn=)\n",
"1200 tensor(6.7081, device='cuda:0', grad_fn=)\n",
"1300 tensor(6.2363, device='cuda:0', grad_fn=)\n",
"1400 tensor(6.5277, device='cuda:0', grad_fn=)\n",
"1500 tensor(6.5607, device='cuda:0', grad_fn=)\n",
"1600 tensor(6.5931, device='cuda:0', grad_fn=)\n",
"1700 tensor(6.5355, device='cuda:0', grad_fn=)\n",
"1800 tensor(6.7281, device='cuda:0', grad_fn=)\n",
"1900 tensor(6.4659, device='cuda:0', grad_fn=)\n",
"2000 tensor(6.2887, device='cuda:0', grad_fn=)\n",
"2100 tensor(6.2616, device='cuda:0', grad_fn=)\n",
"2200 tensor(6.3290, device='cuda:0', grad_fn=)\n",
"2300 tensor(6.6389, device='cuda:0', grad_fn=)\n",
"2400 tensor(6.6202, device='cuda:0', grad_fn=)\n",
"2500 tensor(6.3433, device='cuda:0', grad_fn=)\n",
"2600 tensor(6.2726, device='cuda:0', grad_fn=)\n",
"2700 tensor(6.5647, device='cuda:0', grad_fn=)\n",
"2800 tensor(6.7472, device='cuda:0', grad_fn=)\n",
"2900 tensor(6.5692, device='cuda:0', grad_fn=)\n",
"3000 tensor(6.0704, device='cuda:0', grad_fn=)\n",
"3100 tensor(6.3795, device='cuda:0', grad_fn=)\n",
"3200 tensor(6.3263, device='cuda:0', grad_fn=)\n",
"3300 tensor(6.5520, device='cuda:0', grad_fn=)\n",
"3400 tensor(6.3271, device='cuda:0', grad_fn=)\n",
"3500 tensor(6.2009, device='cuda:0', grad_fn=)\n",
"3600 tensor(6.5486, device='cuda:0', grad_fn=)\n",
"3700 tensor(6.2033, device='cuda:0', grad_fn=)\n",
"3800 tensor(6.3768, device='cuda:0', grad_fn=)\n",
"3900 tensor(6.7510, device='cuda:0', grad_fn=)\n",
"4000 tensor(6.3879, device='cuda:0', grad_fn=)\n",
"4100 tensor(6.3350, device='cuda:0', grad_fn=)\n",
"4200 tensor(6.8703, device='cuda:0', grad_fn=)\n",
"4300 tensor(6.3114, device='cuda:0', grad_fn=)\n",
"4400 tensor(6.3841, device='cuda:0', grad_fn=)\n",
"4500 tensor(6.2134, device='cuda:0', grad_fn=)\n",
"4600 tensor(6.2360, device='cuda:0', grad_fn=)\n",
"4700 tensor(6.4428, device='cuda:0', grad_fn=)\n",
"4800 tensor(6.2655, device='cuda:0', grad_fn=)\n",
"4900 tensor(6.5545, device='cuda:0', grad_fn=)\n",
"5000 tensor(6.7002, device='cuda:0', grad_fn=)\n",
"5100 tensor(6.2191, device='cuda:0', grad_fn=)\n",
"5200 tensor(6.3981, device='cuda:0', grad_fn=)\n",
"5300 tensor(6.5035, device='cuda:0', grad_fn=)\n",
"5400 tensor(6.2316, device='cuda:0', grad_fn=)\n",
"5500 tensor(6.4646, device='cuda:0', grad_fn=)\n",
"5600 tensor(6.3733, device='cuda:0', grad_fn=)\n",
"5700 tensor(6.4972, device='cuda:0', grad_fn=)\n",
"5800 tensor(6.1650, device='cuda:0', grad_fn=)\n",
"5900 tensor(6.2509, device='cuda:0', grad_fn=)\n",
"6000 tensor(6.4030, device='cuda:0', grad_fn=)\n",
"6100 tensor(6.8080, device='cuda:0', grad_fn=)\n",
"6200 tensor(6.5556, device='cuda:0', grad_fn=)\n",
"6300 tensor(6.5532, device='cuda:0', grad_fn=)\n",
"6400 tensor(6.2327, device='cuda:0', grad_fn=)\n",
"6500 tensor(6.4358, device='cuda:0', grad_fn=)\n",
"6600 tensor(6.3786, device='cuda:0', grad_fn=)\n",
"6700 tensor(6.6644, device='cuda:0', grad_fn=)\n",
"6800 tensor(6.0746, device='cuda:0', grad_fn=)\n",
"6900 tensor(6.4358, device='cuda:0', grad_fn=)\n",
"7000 tensor(6.9150, device='cuda:0', grad_fn=)\n",
"7100 tensor(6.6115, device='cuda:0', grad_fn=)\n",
"7200 tensor(6.3954, device='cuda:0', grad_fn=)\n",
"7300 tensor(6.4474, device='cuda:0', grad_fn=)\n",
"7400 tensor(6.6758, device='cuda:0', grad_fn=)\n",
"7500 tensor(6.3773, device='cuda:0', grad_fn=)\n",
"7600 tensor(6.0583, device='cuda:0', grad_fn=)\n",
"7700 tensor(6.3850, device='cuda:0', grad_fn=)\n",
"7800 tensor(6.4212, device='cuda:0', grad_fn=)\n",
"7900 tensor(6.4790, device='cuda:0', grad_fn=)\n",
"8000 tensor(6.1858, device='cuda:0', grad_fn=)\n",
"8100 tensor(6.1886, device='cuda:0', grad_fn=)\n",
"8200 tensor(6.5135, device='cuda:0', grad_fn=)\n",
"8300 tensor(6.3304, device='cuda:0', grad_fn=)\n",
"8400 tensor(6.5295, device='cuda:0', grad_fn=)\n",
"8500 tensor(6.2931, device='cuda:0', grad_fn=)\n",
"8600 tensor(6.2511, device='cuda:0', grad_fn=)\n",
"8700 tensor(6.2957, device='cuda:0', grad_fn=)\n",
"8800 tensor(6.3172, device='cuda:0', grad_fn=)\n",
"8900 tensor(6.2837, device='cuda:0', grad_fn=)\n",
"9000 tensor(6.3057, device='cuda:0', grad_fn=)\n",
"9100 tensor(6.5710, device='cuda:0', grad_fn=)\n",
"9200 tensor(6.6593, device='cuda:0', grad_fn=)\n",
"9300 tensor(6.2960, device='cuda:0', grad_fn=)\n",
"9400 tensor(6.6207, device='cuda:0', grad_fn=)\n",
"9500 tensor(6.4218, device='cuda:0', grad_fn=)\n",
"9600 tensor(6.2484, device='cuda:0', grad_fn=)\n",
"9700 tensor(6.1428, device='cuda:0', grad_fn=)\n",
"9800 tensor(6.4388, device='cuda:0', grad_fn=)\n",
"9900 tensor(6.2794, device='cuda:0', grad_fn=)\n",
"10000 tensor(6.1755, device='cuda:0', grad_fn=)\n",
"10100 tensor(6.5736, device='cuda:0', grad_fn=)\n",
"10200 tensor(6.4235, device='cuda:0', grad_fn=)\n",
"10300 tensor(6.4275, device='cuda:0', grad_fn=)\n",
"10400 tensor(6.5050, device='cuda:0', grad_fn=)\n",
"10500 tensor(6.4074, device='cuda:0', grad_fn=)\n",
"10600 tensor(6.0418, device='cuda:0', grad_fn=)\n",
"10700 tensor(6.3675, device='cuda:0', grad_fn=)\n",
"10800 tensor(6.4171, device='cuda:0', grad_fn=)\n",
"10900 tensor(6.5078, device='cuda:0', grad_fn=)\n",
"11000 tensor(6.2692, device='cuda:0', grad_fn=)\n",
"11100 tensor(6.3667, device='cuda:0', grad_fn=)\n",
"11200 tensor(6.3770, device='cuda:0', grad_fn=)\n",
"11300 tensor(6.4283, device='cuda:0', grad_fn=)\n",
"11400 tensor(6.4568, device='cuda:0', grad_fn=)\n",
"11500 tensor(6.3557, device='cuda:0', grad_fn=)\n",
"11600 tensor(6.4649, device='cuda:0', grad_fn=)\n",
"11700 tensor(6.5798, device='cuda:0', grad_fn=)\n",
"11800 tensor(6.4245, device='cuda:0', grad_fn=)\n",
"11900 tensor(6.4913, device='cuda:0', grad_fn=)\n",
"12000 tensor(6.3519, device='cuda:0', grad_fn=)\n",
"12100 tensor(6.4345, device='cuda:0', grad_fn=)\n",
"12200 tensor(6.5832, device='cuda:0', grad_fn=)\n",
"12300 tensor(6.4204, device='cuda:0', grad_fn=)\n",
"12400 tensor(6.2925, device='cuda:0', grad_fn=)\n",
"12500 tensor(6.4187, device='cuda:0', grad_fn=)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"12600 tensor(6.5779, device='cuda:0', grad_fn=)\n",
"12700 tensor(6.1300, device='cuda:0', grad_fn=)\n",
"12800 tensor(6.3179, device='cuda:0', grad_fn=)\n",
"12900 tensor(6.5471, device='cuda:0', grad_fn=)\n",
"13000 tensor(6.2621, device='cuda:0', grad_fn=)\n",
"13100 tensor(6.4863, device='cuda:0', grad_fn=)\n",
"13200 tensor(6.4671, device='cuda:0', grad_fn=)\n",
"13300 tensor(6.5966, device='cuda:0', grad_fn=)\n",
"13400 tensor(6.3855, device='cuda:0', grad_fn=)\n",
"13500 tensor(6.4136, device='cuda:0', grad_fn=)\n",
"13600 tensor(6.4274, device='cuda:0', grad_fn=)\n",
"13700 tensor(6.3050, device='cuda:0', grad_fn=)\n",
"13800 tensor(6.4028, device='cuda:0', grad_fn=)\n",
"13900 tensor(6.1994, device='cuda:0', grad_fn=)\n",
"14000 tensor(6.2238, device='cuda:0', grad_fn=)\n",
"14100 tensor(6.2973, device='cuda:0', grad_fn=)\n",
"14200 tensor(6.3696, device='cuda:0', grad_fn=)\n",
"14300 tensor(6.4446, device='cuda:0', grad_fn=)\n",
"14400 tensor(6.6806, device='cuda:0', grad_fn=)\n",
"14500 tensor(6.5539, device='cuda:0', grad_fn=)\n",
"14600 tensor(6.4135, device='cuda:0', grad_fn=)\n",
"14700 tensor(6.4098, device='cuda:0', grad_fn=)\n",
"14800 tensor(6.2572, device='cuda:0', grad_fn=)\n",
"14900 tensor(6.2828, device='cuda:0', grad_fn=)\n",
"15000 tensor(6.6121, device='cuda:0', grad_fn=)\n",
"15100 tensor(6.4960, device='cuda:0', grad_fn=)\n",
"15200 tensor(6.2099, device='cuda:0', grad_fn=)\n",
"15300 tensor(6.4276, device='cuda:0', grad_fn=)\n",
"15400 tensor(5.9707, device='cuda:0', grad_fn=)\n",
"15500 tensor(6.2765, device='cuda:0', grad_fn=)\n",
"15600 tensor(6.3095, device='cuda:0', grad_fn=)\n",
"15700 tensor(6.3933, device='cuda:0', grad_fn=)\n",
"15800 tensor(6.2718, device='cuda:0', grad_fn=)\n",
"15900 tensor(6.5708, device='cuda:0', grad_fn=)\n",
"16000 tensor(6.1227, device='cuda:0', grad_fn=)\n",
"16100 tensor(6.4434, device='cuda:0', grad_fn=)\n",
"16200 tensor(6.6841, device='cuda:0', grad_fn=)\n",
"16300 tensor(6.0971, device='cuda:0', grad_fn=)\n",
"16400 tensor(6.4550, device='cuda:0', grad_fn=)\n",
"16500 tensor(6.2755, device='cuda:0', grad_fn=)\n",
"16600 tensor(6.4492, device='cuda:0', grad_fn=)\n",
"16700 tensor(6.4977, device='cuda:0', grad_fn=)\n",
"16800 tensor(6.3766, device='cuda:0', grad_fn=)\n",
"16900 tensor(6.1726, device='cuda:0', grad_fn=)\n",
"17000 tensor(6.4672, device='cuda:0', grad_fn=)\n",
"17100 tensor(6.1932, device='cuda:0', grad_fn=)\n",
"17200 tensor(6.3820, device='cuda:0', grad_fn=)\n",
"17300 tensor(6.3394, device='cuda:0', grad_fn=)\n",
"17400 tensor(6.5227, device='cuda:0', grad_fn=)\n",
"17500 tensor(6.6092, device='cuda:0', grad_fn=)\n",
"17600 tensor(6.1775, device='cuda:0', grad_fn=)\n",
"17700 tensor(6.4336, device='cuda:0', grad_fn=)\n",
"17800 tensor(6.2012, device='cuda:0', grad_fn=)\n",
"17900 tensor(6.5930, device='cuda:0', grad_fn=)\n",
"18000 tensor(6.5210, device='cuda:0', grad_fn=)\n",
"18100 tensor(6.3719, device='cuda:0', grad_fn=)\n",
"18200 tensor(6.1121, device='cuda:0', grad_fn=)\n",
"18300 tensor(6.3552, device='cuda:0', grad_fn=)\n",
"18400 tensor(6.4725, device='cuda:0', grad_fn=)\n",
"18500 tensor(6.3435, device='cuda:0', grad_fn=)\n",
"18600 tensor(6.3549, device='cuda:0', grad_fn=)\n",
"18700 tensor(6.4716, device='cuda:0', grad_fn=)\n",
"18800 tensor(6.3291, device='cuda:0', grad_fn=)\n",
"18900 tensor(6.3823, device='cuda:0', grad_fn=)\n",
"19000 tensor(6.2017, device='cuda:0', grad_fn=)\n",
"19100 tensor(6.2470, device='cuda:0', grad_fn=)\n",
"19200 tensor(6.3263, device='cuda:0', grad_fn=)\n",
"19300 tensor(6.5956, device='cuda:0', grad_fn=)\n",
"19400 tensor(6.3802, device='cuda:0', grad_fn=)\n",
"19500 tensor(6.3646, device='cuda:0', grad_fn=)\n",
"19600 tensor(6.1903, device='cuda:0', grad_fn=)\n",
"19700 tensor(6.7986, device='cuda:0', grad_fn=)\n",
"19800 tensor(6.4438, device='cuda:0', grad_fn=)\n",
"19900 tensor(6.4476, device='cuda:0', grad_fn=)\n",
"20000 tensor(6.2691, device='cuda:0', grad_fn=)\n",
"20100 tensor(6.6191, device='cuda:0', grad_fn=)\n",
"20200 tensor(6.5294, device='cuda:0', grad_fn=)\n",
"20300 tensor(6.2749, device='cuda:0', grad_fn=)\n",
"20400 tensor(6.5561, device='cuda:0', grad_fn=)\n",
"20500 tensor(6.3675, device='cuda:0', grad_fn=)\n",
"20600 tensor(6.2805, device='cuda:0', grad_fn=)\n",
"20700 tensor(6.4063, device='cuda:0', grad_fn=)\n",
"20800 tensor(6.2243, device='cuda:0', grad_fn=)\n",
"20900 tensor(6.0176, device='cuda:0', grad_fn=)\n",
"21000 tensor(6.1914, device='cuda:0', grad_fn=)\n",
"21100 tensor(6.4219, device='cuda:0', grad_fn=)\n",
"21200 tensor(6.6379, device='cuda:0', grad_fn=)\n",
"21300 tensor(6.4248, device='cuda:0', grad_fn=)\n",
"21400 tensor(6.5332, device='cuda:0', grad_fn=)\n",
"21500 tensor(6.5993, device='cuda:0', grad_fn=)\n",
"21600 tensor(6.5038, device='cuda:0', grad_fn=)\n",
"21700 tensor(6.5882, device='cuda:0', grad_fn=)\n",
"21800 tensor(6.4390, device='cuda:0', grad_fn=)\n",
"21900 tensor(6.3383, device='cuda:0', grad_fn=)\n",
"22000 tensor(6.3932, device='cuda:0', grad_fn=)\n",
"22100 tensor(6.3587, device='cuda:0', grad_fn=)\n",
"22200 tensor(6.4001, device='cuda:0', grad_fn=)\n",
"22300 tensor(6.1865, device='cuda:0', grad_fn=)\n",
"22400 tensor(6.2366, device='cuda:0', grad_fn=)\n",
"22500 tensor(7.0326, device='cuda:0', grad_fn=)\n",
"22600 tensor(6.3798, device='cuda:0', grad_fn=)\n",
"22700 tensor(6.5353, device='cuda:0', grad_fn=)\n",
"22800 tensor(6.7912, device='cuda:0', grad_fn=)\n",
"22900 tensor(6.3939, device='cuda:0', grad_fn=)\n",
"23000 tensor(6.2855, device='cuda:0', grad_fn=)\n",
"23100 tensor(6.0151, device='cuda:0', grad_fn=)\n",
"23200 tensor(6.2457, device='cuda:0', grad_fn=)\n",
"23300 tensor(6.3422, device='cuda:0', grad_fn=)\n",
"23400 tensor(6.3322, device='cuda:0', grad_fn=)\n",
"23500 tensor(6.0716, device='cuda:0', grad_fn=)\n",
"23600 tensor(6.5486, device='cuda:0', grad_fn=)\n",
"23700 tensor(6.5902, device='cuda:0', grad_fn=)\n",
"23800 tensor(6.4079, device='cuda:0', grad_fn=)\n",
"23900 tensor(6.5497, device='cuda:0', grad_fn=)\n",
"24000 tensor(6.4957, device='cuda:0', grad_fn=)\n",
"24100 tensor(6.3668, device='cuda:0', grad_fn=)\n",
"24200 tensor(6.7314, device='cuda:0', grad_fn=)\n",
"24300 tensor(6.5585, device='cuda:0', grad_fn=)\n",
"24400 tensor(6.4228, device='cuda:0', grad_fn=)\n",
"24500 tensor(6.2029, device='cuda:0', grad_fn=)\n",
"24600 tensor(6.2034, device='cuda:0', grad_fn=)\n",
"24700 tensor(6.6652, device='cuda:0', grad_fn=)\n",
"24800 tensor(6.2777, device='cuda:0', grad_fn=)\n",
"24900 tensor(6.2962, device='cuda:0', grad_fn=)\n",
"25000 tensor(6.3366, device='cuda:0', grad_fn=)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"25100 tensor(6.5767, device='cuda:0', grad_fn=)\n",
"25200 tensor(6.4680, device='cuda:0', grad_fn=