{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"toc_visible": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
},
"accelerator": "GPU",
"gpuClass": "standard"
},
"cells": [
{
"cell_type": "markdown",
"source": [
"## connect to google drive (working on colab)"
],
"metadata": {
"id": "G0ujnpy2tuBE"
}
},
{
"cell_type": "code",
"source": [
"from google.colab import drive\n",
"drive.mount('/content/drive')"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Lwuh_S5pWY1j",
"outputId": "27838dab-7be0-4447-883a-95559887c7c8"
},
"execution_count": 1,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Mounted at /content/drive\n"
]
}
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "QNiUKMiqWLd0"
},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"source": [
"!mkdir moj7\n"
],
"metadata": {
"id": "vlnrhRaEWNJF"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"%cd drive"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "539oWG3pXOAX",
"outputId": "a9ae634d-d4a2-47dd-97d2-10c245c7c5d2"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"/content/drive\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"%cd MyDrive"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "wgmiKs4BiAiT",
"outputId": "fbbd0bc7-76bd-47bf-e38b-e051239e5ba7"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"/content/drive/MyDrive\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"%cd moj7"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "jNdkji_hiAlt",
"outputId": "962b875a-8d3f-433d-8d7b-dcd664ee1674"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"/content/drive/MyDrive/moj7\n"
]
}
]
},
{
"cell_type": "code",
"source": [],
"metadata": {
"id": "P249ENeSiAqn"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"!pwd"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "IOHV3Iz4WNLc",
"outputId": "f56a6ab7-73e6-4b03-824b-b5e749c8a82e"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"/content/drive/MyDrive/moj7\n"
]
}
]
},
{
"cell_type": "code",
"source": [],
"metadata": {
"id": "2pLkeHY5Z9oT"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## Preprocess"
],
"metadata": {
"id": "D7jhQfbttn9D"
}
},
{
"cell_type": "code",
"source": [
"import re"
],
"metadata": {
"id": "_IPWOt2BZ_-q"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"train_file ='train/in.tsv.xz'\n",
"test_file = 'test-A/in.tsv.xz'\n",
"out_file = 'test-A/out.tsv'\n",
"\n",
"def preprocess(line):\n",
" line = replace_endline(line)\n",
" line = get_rid_of_header(line)\n",
" return line\n",
"\n",
"def get_rid_of_header(line):\n",
" line = line.split('\\t')[6:]\n",
" return \"\".join(line)\n",
" \n",
"def replace_endline(line):\n",
" line = re.sub(\"\\\\n|\\\\+\", \" \", line)\n",
" return line"
],
"metadata": {
"id": "qDnZdPblWNNr"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"from itertools import islice\n",
"import regex as re\n",
"import sys\n",
"from torchtext.vocab import build_vocab_from_iterator\n",
"import lzma\n",
"import pickle\n",
"\n",
"\n",
"\n",
"def get_words_from_line(line):\n",
" line = line.rstrip()\n",
" yield ''\n",
" line = preprocess(line)\n",
" for t in line.split(' '):\n",
" yield t\n",
" yield ''\n",
"\n",
"\n",
"def get_word_lines_from_file(file_name):\n",
" n = 0\n",
" with lzma.open(file_name, 'r') as fh:\n",
" for line in fh:\n",
" n+=1\n",
" if n%1000==0:\n",
" print(n)\n",
" yield get_words_from_line(line.decode('utf-8'))\n",
"#vocab_size = 20000\n",
"vocab_size = 20000\n",
"\n",
"vocab = build_vocab_from_iterator(\n",
" get_word_lines_from_file(train_file),\n",
" max_tokens = vocab_size,\n",
" specials = [''])\n",
"\n"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "RVN0lKVZfwMe",
"outputId": "305b03e4-f626-4560-a371-41bc5a0ea9c7"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"1000\n",
"2000\n",
"3000\n",
"4000\n",
"5000\n",
"6000\n",
"7000\n",
"8000\n",
"9000\n",
"10000\n",
"11000\n",
"12000\n",
"13000\n",
"14000\n",
"15000\n",
"16000\n",
"17000\n",
"18000\n",
"19000\n",
"20000\n",
"21000\n",
"22000\n",
"23000\n",
"24000\n",
"25000\n",
"26000\n",
"27000\n",
"28000\n",
"29000\n",
"30000\n",
"31000\n",
"32000\n",
"33000\n",
"34000\n",
"35000\n",
"36000\n",
"37000\n",
"38000\n",
"39000\n",
"40000\n",
"41000\n",
"42000\n",
"43000\n",
"44000\n",
"45000\n",
"46000\n",
"47000\n",
"48000\n",
"49000\n",
"50000\n",
"51000\n",
"52000\n",
"53000\n",
"54000\n",
"55000\n",
"56000\n",
"57000\n",
"58000\n",
"59000\n",
"60000\n",
"61000\n",
"62000\n",
"63000\n",
"64000\n",
"65000\n",
"66000\n",
"67000\n",
"68000\n",
"69000\n",
"70000\n",
"71000\n",
"72000\n",
"73000\n",
"74000\n",
"75000\n",
"76000\n",
"77000\n",
"78000\n",
"79000\n",
"80000\n",
"81000\n",
"82000\n",
"83000\n",
"84000\n",
"85000\n",
"86000\n",
"87000\n",
"88000\n",
"89000\n",
"90000\n",
"91000\n",
"92000\n",
"93000\n",
"94000\n",
"95000\n",
"96000\n",
"97000\n",
"98000\n",
"99000\n",
"100000\n",
"101000\n",
"102000\n",
"103000\n",
"104000\n",
"105000\n",
"106000\n",
"107000\n",
"108000\n",
"109000\n",
"110000\n",
"111000\n",
"112000\n",
"113000\n",
"114000\n",
"115000\n",
"116000\n",
"117000\n",
"118000\n",
"119000\n",
"120000\n",
"121000\n",
"122000\n",
"123000\n",
"124000\n",
"125000\n",
"126000\n",
"127000\n",
"128000\n",
"129000\n",
"130000\n",
"131000\n",
"132000\n",
"133000\n",
"134000\n",
"135000\n",
"136000\n",
"137000\n",
"138000\n",
"139000\n",
"140000\n",
"141000\n",
"142000\n",
"143000\n",
"144000\n",
"145000\n",
"146000\n",
"147000\n",
"148000\n",
"149000\n",
"150000\n",
"151000\n",
"152000\n",
"153000\n",
"154000\n",
"155000\n",
"156000\n",
"157000\n",
"158000\n",
"159000\n",
"160000\n",
"161000\n",
"162000\n",
"163000\n",
"164000\n",
"165000\n",
"166000\n",
"167000\n",
"168000\n",
"169000\n",
"170000\n",
"171000\n",
"172000\n",
"173000\n",
"174000\n",
"175000\n",
"176000\n",
"177000\n",
"178000\n",
"179000\n",
"180000\n",
"181000\n",
"182000\n",
"183000\n",
"184000\n",
"185000\n",
"186000\n",
"187000\n",
"188000\n",
"189000\n",
"190000\n",
"191000\n",
"192000\n",
"193000\n",
"194000\n",
"195000\n",
"196000\n",
"197000\n",
"198000\n",
"199000\n",
"200000\n",
"201000\n",
"202000\n",
"203000\n",
"204000\n",
"205000\n",
"206000\n",
"207000\n",
"208000\n",
"209000\n",
"210000\n",
"211000\n",
"212000\n",
"213000\n",
"214000\n",
"215000\n",
"216000\n",
"217000\n",
"218000\n",
"219000\n",
"220000\n",
"221000\n",
"222000\n",
"223000\n",
"224000\n",
"225000\n",
"226000\n",
"227000\n",
"228000\n",
"229000\n",
"230000\n",
"231000\n",
"232000\n",
"233000\n",
"234000\n",
"235000\n",
"236000\n",
"237000\n",
"238000\n",
"239000\n",
"240000\n",
"241000\n",
"242000\n",
"243000\n",
"244000\n",
"245000\n",
"246000\n",
"247000\n",
"248000\n",
"249000\n",
"250000\n",
"251000\n",
"252000\n",
"253000\n",
"254000\n",
"255000\n",
"256000\n",
"257000\n",
"258000\n",
"259000\n",
"260000\n",
"261000\n",
"262000\n",
"263000\n",
"264000\n",
"265000\n",
"266000\n",
"267000\n",
"268000\n",
"269000\n",
"270000\n",
"271000\n",
"272000\n",
"273000\n",
"274000\n",
"275000\n",
"276000\n",
"277000\n",
"278000\n",
"279000\n",
"280000\n",
"281000\n",
"282000\n",
"283000\n",
"284000\n",
"285000\n",
"286000\n",
"287000\n",
"288000\n",
"289000\n",
"290000\n",
"291000\n",
"292000\n",
"293000\n",
"294000\n",
"295000\n",
"296000\n",
"297000\n",
"298000\n",
"299000\n",
"300000\n",
"301000\n",
"302000\n",
"303000\n",
"304000\n",
"305000\n",
"306000\n",
"307000\n",
"308000\n",
"309000\n",
"310000\n",
"311000\n",
"312000\n",
"313000\n",
"314000\n",
"315000\n",
"316000\n",
"317000\n",
"318000\n",
"319000\n",
"320000\n",
"321000\n",
"322000\n",
"323000\n",
"324000\n",
"325000\n",
"326000\n",
"327000\n",
"328000\n",
"329000\n",
"330000\n",
"331000\n",
"332000\n",
"333000\n",
"334000\n",
"335000\n",
"336000\n",
"337000\n",
"338000\n",
"339000\n",
"340000\n",
"341000\n",
"342000\n",
"343000\n",
"344000\n",
"345000\n",
"346000\n",
"347000\n",
"348000\n",
"349000\n",
"350000\n",
"351000\n",
"352000\n",
"353000\n",
"354000\n",
"355000\n",
"356000\n",
"357000\n",
"358000\n",
"359000\n",
"360000\n",
"361000\n",
"362000\n",
"363000\n",
"364000\n",
"365000\n",
"366000\n",
"367000\n",
"368000\n",
"369000\n",
"370000\n",
"371000\n",
"372000\n",
"373000\n",
"374000\n",
"375000\n",
"376000\n",
"377000\n",
"378000\n",
"379000\n",
"380000\n",
"381000\n",
"382000\n",
"383000\n",
"384000\n",
"385000\n",
"386000\n",
"387000\n",
"388000\n",
"389000\n",
"390000\n",
"391000\n",
"392000\n",
"393000\n",
"394000\n",
"395000\n",
"396000\n",
"397000\n",
"398000\n",
"399000\n",
"400000\n",
"401000\n",
"402000\n",
"403000\n",
"404000\n",
"405000\n",
"406000\n",
"407000\n",
"408000\n",
"409000\n",
"410000\n",
"411000\n",
"412000\n",
"413000\n",
"414000\n",
"415000\n",
"416000\n",
"417000\n",
"418000\n",
"419000\n",
"420000\n",
"421000\n",
"422000\n",
"423000\n",
"424000\n",
"425000\n",
"426000\n",
"427000\n",
"428000\n",
"429000\n",
"430000\n",
"431000\n",
"432000\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"vocab['no']"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "b9vMOTlZxISl",
"outputId": "a3a71b17-3fb2-4794-ae5d-a43d9af49b69"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"50"
]
},
"metadata": {},
"execution_count": 19
}
]
},
{
"cell_type": "code",
"source": [
"with open('filename.pickle', 'wb') as handle:\n",
" pickle.dump(vocab, handle, protocol=pickle.HIGHEST_PROTOCOL)"
],
"metadata": {
"id": "6R9l6tuPxB_B"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## Create NN"
],
"metadata": {
"id": "Be25rS6Uvl4V"
}
},
{
"cell_type": "code",
"source": [
"from torch import nn\n",
"import torch\n",
"import pickle\n",
"# embed_size = 150\n",
"embed_size = 150\n",
"\n",
"class Bigram(nn.Module):\n",
" def __init__(self, vocabulary_size, embedding_size):\n",
" super(Bigram, self).__init__()\n",
" self.model = nn.Sequential(\n",
" nn.Embedding(vocabulary_size, embedding_size),\n",
" nn.Linear(embedding_size, vocabulary_size),\n",
" nn.Softmax()\n",
" )\n",
" def forward(self, x):\n",
" return self.model(x)\n",
"\n",
"model = Bigram(vocab_size, embed_size)\n",
"\n",
"vocab.set_default_index(vocab[''])\n",
"res = torch.tensor(vocab.forward(['order']))\n",
"print(res)\n"
],
"metadata": {
"id": "dGTOmcHwWNSi",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "88c30492-7a9a-4b96-9119-ecdf5865bb51"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"tensor([215])\n"
]
}
]
},
{
"cell_type": "code",
"source": [],
"metadata": {
"id": "mWZ_jw-hxNXk"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"from torch.utils.data import IterableDataset\n",
"import itertools\n",
"\n",
"def look_ahead_iterator(gen):\n",
" prev = None\n",
" for item in gen:\n",
" if prev is not None:\n",
" yield (prev, item)\n",
" prev = item\n",
"\n",
"class Bigrams(IterableDataset):\n",
" def __init__(self, text_file, vocabulary_size):\n",
" self.vocab = build_vocab_from_iterator(\n",
" get_word_lines_from_file(text_file),\n",
" max_tokens = vocabulary_size,\n",
" specials = [''])\n",
" self.vocab.set_default_index(self.vocab[''])\n",
" self.vocabulary_size = vocabulary_size\n",
" self.text_file = text_file\n",
"\n",
" def __iter__(self):\n",
" return look_ahead_iterator(\n",
" (self.vocab[t] for t in itertools.chain.from_iterable(get_word_lines_from_file(self.text_file))))\n",
"\n",
"\n",
"train_dataset = Bigrams(train_file, vocab_size)"
],
"metadata": {
"id": "5CSigeomWNVT",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "4fb3b1ff-f91b-4799-fc5f-17bae10b94ec"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"1000\n",
"2000\n",
"3000\n",
"4000\n",
"5000\n",
"6000\n",
"7000\n",
"8000\n",
"9000\n",
"10000\n",
"11000\n",
"12000\n",
"13000\n",
"14000\n",
"15000\n",
"16000\n",
"17000\n",
"18000\n",
"19000\n",
"20000\n",
"21000\n",
"22000\n",
"23000\n",
"24000\n",
"25000\n",
"26000\n",
"27000\n",
"28000\n",
"29000\n",
"30000\n",
"31000\n",
"32000\n",
"33000\n",
"34000\n",
"35000\n",
"36000\n",
"37000\n",
"38000\n",
"39000\n",
"40000\n",
"41000\n",
"42000\n",
"43000\n",
"44000\n",
"45000\n",
"46000\n",
"47000\n",
"48000\n",
"49000\n",
"50000\n",
"51000\n",
"52000\n",
"53000\n",
"54000\n",
"55000\n",
"56000\n",
"57000\n",
"58000\n",
"59000\n",
"60000\n",
"61000\n",
"62000\n",
"63000\n",
"64000\n",
"65000\n",
"66000\n",
"67000\n",
"68000\n",
"69000\n",
"70000\n",
"71000\n",
"72000\n",
"73000\n",
"74000\n",
"75000\n",
"76000\n",
"77000\n",
"78000\n",
"79000\n",
"80000\n",
"81000\n",
"82000\n",
"83000\n",
"84000\n",
"85000\n",
"86000\n",
"87000\n",
"88000\n",
"89000\n",
"90000\n",
"91000\n",
"92000\n",
"93000\n",
"94000\n",
"95000\n",
"96000\n",
"97000\n",
"98000\n",
"99000\n",
"100000\n",
"101000\n",
"102000\n",
"103000\n",
"104000\n",
"105000\n",
"106000\n",
"107000\n",
"108000\n",
"109000\n",
"110000\n",
"111000\n",
"112000\n",
"113000\n",
"114000\n",
"115000\n",
"116000\n",
"117000\n",
"118000\n",
"119000\n",
"120000\n",
"121000\n",
"122000\n",
"123000\n",
"124000\n",
"125000\n",
"126000\n",
"127000\n",
"128000\n",
"129000\n",
"130000\n",
"131000\n",
"132000\n",
"133000\n",
"134000\n",
"135000\n",
"136000\n",
"137000\n",
"138000\n",
"139000\n",
"140000\n",
"141000\n",
"142000\n",
"143000\n",
"144000\n",
"145000\n",
"146000\n",
"147000\n",
"148000\n",
"149000\n",
"150000\n",
"151000\n",
"152000\n",
"153000\n",
"154000\n",
"155000\n",
"156000\n",
"157000\n",
"158000\n",
"159000\n",
"160000\n",
"161000\n",
"162000\n",
"163000\n",
"164000\n",
"165000\n",
"166000\n",
"167000\n",
"168000\n",
"169000\n",
"170000\n",
"171000\n",
"172000\n",
"173000\n",
"174000\n",
"175000\n",
"176000\n",
"177000\n",
"178000\n",
"179000\n",
"180000\n",
"181000\n",
"182000\n",
"183000\n",
"184000\n",
"185000\n",
"186000\n",
"187000\n",
"188000\n",
"189000\n",
"190000\n",
"191000\n",
"192000\n",
"193000\n",
"194000\n",
"195000\n",
"196000\n",
"197000\n",
"198000\n",
"199000\n",
"200000\n",
"201000\n",
"202000\n",
"203000\n",
"204000\n",
"205000\n",
"206000\n",
"207000\n",
"208000\n",
"209000\n",
"210000\n",
"211000\n",
"212000\n",
"213000\n",
"214000\n",
"215000\n",
"216000\n",
"217000\n",
"218000\n",
"219000\n",
"220000\n",
"221000\n",
"222000\n",
"223000\n",
"224000\n",
"225000\n",
"226000\n",
"227000\n",
"228000\n",
"229000\n",
"230000\n",
"231000\n",
"232000\n",
"233000\n",
"234000\n",
"235000\n",
"236000\n",
"237000\n",
"238000\n",
"239000\n",
"240000\n",
"241000\n",
"242000\n",
"243000\n",
"244000\n",
"245000\n",
"246000\n",
"247000\n",
"248000\n",
"249000\n",
"250000\n",
"251000\n",
"252000\n",
"253000\n",
"254000\n",
"255000\n",
"256000\n",
"257000\n",
"258000\n",
"259000\n",
"260000\n",
"261000\n",
"262000\n",
"263000\n",
"264000\n",
"265000\n",
"266000\n",
"267000\n",
"268000\n",
"269000\n",
"270000\n",
"271000\n",
"272000\n",
"273000\n",
"274000\n",
"275000\n",
"276000\n",
"277000\n",
"278000\n",
"279000\n",
"280000\n",
"281000\n",
"282000\n",
"283000\n",
"284000\n",
"285000\n",
"286000\n",
"287000\n",
"288000\n",
"289000\n",
"290000\n",
"291000\n",
"292000\n",
"293000\n",
"294000\n",
"295000\n",
"296000\n",
"297000\n",
"298000\n",
"299000\n",
"300000\n",
"301000\n",
"302000\n",
"303000\n",
"304000\n",
"305000\n",
"306000\n",
"307000\n",
"308000\n",
"309000\n",
"310000\n",
"311000\n",
"312000\n",
"313000\n",
"314000\n",
"315000\n",
"316000\n",
"317000\n",
"318000\n",
"319000\n",
"320000\n",
"321000\n",
"322000\n",
"323000\n",
"324000\n",
"325000\n",
"326000\n",
"327000\n",
"328000\n",
"329000\n",
"330000\n",
"331000\n",
"332000\n",
"333000\n",
"334000\n",
"335000\n",
"336000\n",
"337000\n",
"338000\n",
"339000\n",
"340000\n",
"341000\n",
"342000\n",
"343000\n",
"344000\n",
"345000\n",
"346000\n",
"347000\n",
"348000\n",
"349000\n",
"350000\n",
"351000\n",
"352000\n",
"353000\n",
"354000\n",
"355000\n",
"356000\n",
"357000\n",
"358000\n",
"359000\n",
"360000\n",
"361000\n",
"362000\n",
"363000\n",
"364000\n",
"365000\n",
"366000\n",
"367000\n",
"368000\n",
"369000\n",
"370000\n",
"371000\n",
"372000\n",
"373000\n",
"374000\n",
"375000\n",
"376000\n",
"377000\n",
"378000\n",
"379000\n",
"380000\n",
"381000\n",
"382000\n",
"383000\n",
"384000\n",
"385000\n",
"386000\n",
"387000\n",
"388000\n",
"389000\n",
"390000\n",
"391000\n",
"392000\n",
"393000\n",
"394000\n",
"395000\n",
"396000\n",
"397000\n",
"398000\n",
"399000\n",
"400000\n",
"401000\n",
"402000\n",
"403000\n",
"404000\n",
"405000\n",
"406000\n",
"407000\n",
"408000\n",
"409000\n",
"410000\n",
"411000\n",
"412000\n",
"413000\n",
"414000\n",
"415000\n",
"416000\n",
"417000\n",
"418000\n",
"419000\n",
"420000\n",
"421000\n",
"422000\n",
"423000\n",
"424000\n",
"425000\n",
"426000\n",
"427000\n",
"428000\n",
"429000\n",
"430000\n",
"431000\n",
"432000\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"from torch.utils.data import DataLoader\n",
"\n",
"next(iter(DataLoader(train_dataset, batch_size=5)))"
],
"metadata": {
"id": "oYAZ772rWNX7",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "da3a811a-36cf-4d34-82d8-b1ce2447c232"
},
"execution_count": null,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"[tensor([ 23, 191, 5791, 1, 112]),\n",
" tensor([ 191, 5791, 1, 112, 159])]"
]
},
"metadata": {},
"execution_count": 23
}
]
},
{
"cell_type": "markdown",
"source": [
"## Train"
],
"metadata": {
"id": "1H_dI372vrNh"
}
},
{
"cell_type": "code",
"source": [
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')"
],
"metadata": {
"id": "N2u4Qmadgtdn"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"\n",
"model = Bigram(vocab_size, embed_size).to(device)\n",
"data = DataLoader(train_dataset, batch_size=1000)\n",
"optimizer = torch.optim.Adam(model.parameters())\n",
"criterion = torch.nn.NLLLoss()\n",
"## epochs=2\n",
"for i in range(2):\n",
" print('epoch: =', i)\n",
" model.train()\n",
" step = 0\n",
" for x, y in data:\n",
" x = x.to(device)\n",
" y = y.to(device)\n",
" optimizer.zero_grad()\n",
" ypredicted = model(x)\n",
" loss = criterion(torch.log(ypredicted), y)\n",
" if step % 100 == 0:\n",
" print(step, loss)\n",
" step += 1\n",
" loss.backward()\n",
" optimizer.step()\n",
" torch.save(model.state_dict(), 'model.bin') \n"
],
"metadata": {
"id": "OGk2tjbvWNag",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "868503a1-2849-40ff-e703-886fba094927"
},
"execution_count": null,
"outputs": [
{
"metadata": {
"tags": null
},
"name": "stdout",
"output_type": "stream",
"text": [
"epoch: = 0\n"
]
},
{
"metadata": {
"tags": null
},
"name": "stderr",
"output_type": "stream",
"text": [
"/usr/local/lib/python3.10/dist-packages/torch/nn/modules/container.py:217: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.\n",
" input = module(input)\n"
]
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"0 tensor(10.3037, device='cuda:0', grad_fn=)\n",
"100 tensor(8.7506, device='cuda:0', grad_fn=)\n",
"200 tensor(7.8141, device='cuda:0', grad_fn=)\n",
"1000\n",
"300 tensor(7.4218, device='cuda:0', grad_fn=)\n",
"400 tensor(7.1627, device='cuda:0', grad_fn=)\n",
"500 tensor(6.7964, device='cuda:0', grad_fn=)\n",
"2000\n",
"600 tensor(6.4704, device='cuda:0', grad_fn=)\n",
"700 tensor(6.3798, device='cuda:0', grad_fn=)\n",
"800 tensor(6.2849, device='cuda:0', grad_fn=)\n",
"3000\n",
"900 tensor(6.3975, device='cuda:0', grad_fn=)\n",
"1000 tensor(6.0096, device='cuda:0', grad_fn=)\n",
"1100 tensor(5.7434, device='cuda:0', grad_fn=)\n",
"4000\n",
"1200 tensor(5.9602, device='cuda:0', grad_fn=)\n",
"1300 tensor(6.1623, device='cuda:0', grad_fn=)\n",
"1400 tensor(6.1647, device='cuda:0', grad_fn=)\n",
"5000\n",
"1500 tensor(6.1010, device='cuda:0', grad_fn=)\n",
"1600 tensor(6.0634, device='cuda:0', grad_fn=)\n",
"6000\n",
"1700 tensor(5.9149, device='cuda:0', grad_fn=)\n",
"1800 tensor(5.7918, device='cuda:0', grad_fn=)\n",
"1900 tensor(5.6739, device='cuda:0', grad_fn=)\n",
"7000\n",
"2000 tensor(5.5298, device='cuda:0', grad_fn=)\n",
"2100 tensor(5.8011, device='cuda:0', grad_fn=)\n",
"2200 tensor(5.4338, device='cuda:0', grad_fn=)\n",
"8000\n",
"2300 tensor(5.7522, device='cuda:0', grad_fn=)\n",
"2400 tensor(5.0313, device='cuda:0', grad_fn=)\n",
"2500 tensor(5.7116, device='cuda:0', grad_fn=)\n",
"9000\n",
"2600 tensor(5.2706, device='cuda:0', grad_fn=)\n",
"2700 tensor(5.6324, device='cuda:0', grad_fn=)\n",
"2800 tensor(5.0710, device='cuda:0', grad_fn=)\n",
"10000\n",
"2900 tensor(5.5921, device='cuda:0', grad_fn=)\n",
"3000 tensor(5.4808, device='cuda:0', grad_fn=)\n",
"11000\n",
"3100 tensor(5.3611, device='cuda:0', grad_fn=)\n",
"3200 tensor(5.6228, device='cuda:0', grad_fn=)\n",
"3300 tensor(5.4286, device='cuda:0', grad_fn=)\n",
"12000\n",
"3400 tensor(5.3550, device='cuda:0', grad_fn=)\n",
"3500 tensor(5.4032, device='cuda:0', grad_fn=)\n",
"3600 tensor(5.1070, device='cuda:0', grad_fn=)\n",
"13000\n",
"3700 tensor(5.4506, device='cuda:0', grad_fn=)\n",
"3800 tensor(5.4622, device='cuda:0', grad_fn=)\n",
"3900 tensor(5.4984, device='cuda:0', grad_fn=)\n",
"14000\n",
"4000 tensor(5.1740, device='cuda:0', grad_fn=)\n",
"4100 tensor(5.6064, device='cuda:0', grad_fn=)\n",
"4200 tensor(5.0705, device='cuda:0', grad_fn=)\n",
"15000\n",
"4300 tensor(5.5181, device='cuda:0', grad_fn=)\n",
"4400 tensor(5.2919, device='cuda:0', grad_fn=)\n",
"16000\n",
"4500 tensor(5.5021, device='cuda:0', grad_fn=)\n",
"4600 tensor(5.5308, device='cuda:0', grad_fn=)\n",
"4700 tensor(5.4699, device='cuda:0', grad_fn=)\n",
"17000\n",
"4800 tensor(5.2686, device='cuda:0', grad_fn=)\n",
"4900 tensor(5.4776, device='cuda:0', grad_fn=)\n",
"5000 tensor(5.5061, device='cuda:0', grad_fn=)\n",
"18000\n",
"5100 tensor(5.3180, device='cuda:0', grad_fn=)\n",
"5200 tensor(5.5524, device='cuda:0', grad_fn=)\n",
"5300 tensor(5.3481, device='cuda:0', grad_fn=)\n",
"19000\n",
"5400 tensor(5.2153, device='cuda:0', grad_fn=)\n",
"5500 tensor(5.4478, device='cuda:0', grad_fn=)\n",
"20000\n",
"5600 tensor(5.3441, device='cuda:0', grad_fn=)\n",
"5700 tensor(5.3958, device='cuda:0', grad_fn=)\n",
"5800 tensor(5.8945, device='cuda:0', grad_fn=)\n",
"21000\n",
"5900 tensor(5.5684, device='cuda:0', grad_fn=)\n",
"6000 tensor(5.5715, device='cuda:0', grad_fn=)\n",
"6100 tensor(5.2367, device='cuda:0', grad_fn=)\n",
"22000\n",
"6200 tensor(5.6976, device='cuda:0', grad_fn=)\n",
"6300 tensor(5.5367, device='cuda:0', grad_fn=)\n",
"6400 tensor(5.3024, device='cuda:0', grad_fn=)\n",
"23000\n",
"6500 tensor(5.3010, device='cuda:0', grad_fn=)\n",
"6600 tensor(6.0962, device='cuda:0', grad_fn=)\n",
"6700 tensor(5.0961, device='cuda:0', grad_fn=)\n",
"24000\n",
"6800 tensor(5.1091, device='cuda:0', grad_fn=)\n",
"6900 tensor(5.4123, device='cuda:0', grad_fn=)\n",
"25000\n",
"7000 tensor(5.3128, device='cuda:0', grad_fn=)\n",
"7100 tensor(5.3416, device='cuda:0', grad_fn=)\n",
"7200 tensor(5.4973, device='cuda:0', grad_fn=)\n",
"26000\n",
"7300 tensor(5.4418, device='cuda:0', grad_fn=)\n",
"7400 tensor(5.2171, device='cuda:0', grad_fn=)\n",
"7500 tensor(5.6509, device='cuda:0', grad_fn=)\n",
"27000\n",
"7600 tensor(5.0550, device='cuda:0', grad_fn=)\n",
"7700 tensor(5.4937, device='cuda:0', grad_fn=)\n",
"7800 tensor(5.9218, device='cuda:0', grad_fn=)\n",
"28000\n",
"7900 tensor(5.2853, device='cuda:0', grad_fn=)\n",
"8000 tensor(5.3146, device='cuda:0', grad_fn=)\n",
"8100 tensor(4.8552, device='cuda:0', grad_fn=)\n",
"29000\n",
"8200 tensor(5.3389, device='cuda:0', grad_fn=)\n",
"8300 tensor(5.2421, device='cuda:0', grad_fn=)\n",
"30000\n",
"8400 tensor(5.2460, device='cuda:0', grad_fn=)\n",
"8500 tensor(5.0331, device='cuda:0', grad_fn=)\n",
"8600 tensor(5.0050, device='cuda:0', grad_fn=)\n",
"31000\n",
"8700 tensor(5.3844, device='cuda:0', grad_fn=)\n",
"8800 tensor(5.4491, device='cuda:0', grad_fn=)\n",
"8900 tensor(5.6790, device='cuda:0', grad_fn=)\n",
"32000\n",
"9000 tensor(5.1118, device='cuda:0', grad_fn=)\n",
"9100 tensor(5.3567, device='cuda:0', grad_fn=)\n",
"9200 tensor(5.4141, device='cuda:0', grad_fn=)\n",
"33000\n",
"9300 tensor(5.3085, device='cuda:0', grad_fn=)\n",
"9400 tensor(5.2808, device='cuda:0', grad_fn=)\n",
"34000\n",
"9500 tensor(5.0931, device='cuda:0', grad_fn=)\n",
"9600 tensor(5.1090, device='cuda:0', grad_fn=)\n",
"9700 tensor(5.2519, device='cuda:0', grad_fn=)\n",
"35000\n",
"9800 tensor(5.3852, device='cuda:0', grad_fn=)\n",
"9900 tensor(5.0943, device='cuda:0', grad_fn=)\n",
"10000 tensor(5.4690, device='cuda:0', grad_fn=)\n",
"36000\n",
"10100 tensor(5.4348, device='cuda:0', grad_fn=)\n",
"10200 tensor(5.3262, device='cuda:0', grad_fn=)\n",
"10300 tensor(5.4878, device='cuda:0', grad_fn=)\n",
"37000\n",
"10400 tensor(5.2384, device='cuda:0', grad_fn=)\n",
"10500 tensor(5.2151, device='cuda:0', grad_fn=)\n",
"10600 tensor(4.8722, device='cuda:0', grad_fn=)\n",
"38000\n",
"10700 tensor(5.4325, device='cuda:0', grad_fn=)\n",
"10800 tensor(4.8699, device='cuda:0', grad_fn=)\n",
"39000\n",
"10900 tensor(5.3448, device='cuda:0', grad_fn=)\n",
"11000 tensor(5.1358, device='cuda:0', grad_fn=)\n",
"11100 tensor(5.0432, device='cuda:0', grad_fn=)\n",
"40000\n",
"11200 tensor(5.4062, device='cuda:0', grad_fn=)\n",
"11300 tensor(5.4040, device='cuda:0', grad_fn=)\n",
"11400 tensor(5.5312, device='cuda:0', grad_fn=)\n",
"41000\n",
"11500 tensor(5.4374, device='cuda:0', grad_fn=)\n",
"11600 tensor(5.0998, device='cuda:0', grad_fn=)\n",
"11700 tensor(5.4217, device='cuda:0', grad_fn=)\n",
"42000\n",
"11800 tensor(5.5747, device='cuda:0', grad_fn=)\n",
"11900 tensor(5.0467, device='cuda:0', grad_fn=)\n",
"12000 tensor(5.4270, device='cuda:0', grad_fn=)\n",
"43000\n",
"12100 tensor(5.2043, device='cuda:0', grad_fn=)\n",
"12200 tensor(5.2369, device='cuda:0', grad_fn=)\n",
"44000\n",
"12300 tensor(5.4465, device='cuda:0', grad_fn=)\n",
"12400 tensor(4.9839, device='cuda:0', grad_fn=)\n",
"12500 tensor(5.3214, device='cuda:0', grad_fn=)\n",
"45000\n",
"12600 tensor(5.1928, device='cuda:0', grad_fn=)\n",
"12700 tensor(4.9646, device='cuda:0', grad_fn=)\n",
"12800 tensor(5.3325, device='cuda:0', grad_fn=)\n",
"46000\n",
"12900 tensor(5.4429, device='cuda:0', grad_fn=)\n",
"13000 tensor(5.0652, device='cuda:0', grad_fn=)\n",
"13100 tensor(5.3126, device='cuda:0', grad_fn=)\n",
"47000\n",
"13200 tensor(5.4124, device='cuda:0', grad_fn=)\n",
"13300 tensor(5.5385, device='cuda:0', grad_fn=)\n",
"13400 tensor(5.0986, device='cuda:0', grad_fn=)\n",
"48000\n",
"13500 tensor(5.2693, device='cuda:0', grad_fn=)\n",
"13600 tensor(5.2136, device='cuda:0', grad_fn=)\n",
"49000\n",
"13700 tensor(5.5169, device='cuda:0', grad_fn=)\n",
"13800 tensor(5.1840, device='cuda:0', grad_fn=)\n",
"13900 tensor(5.2700, device='cuda:0', grad_fn=)\n",
"50000\n",
"14000 tensor(5.2077, device='cuda:0', grad_fn=)\n",
"14100 tensor(5.3791, device='cuda:0', grad_fn=)\n",
"14200 tensor(5.4008, device='cuda:0', grad_fn=)\n",
"51000\n",
"14300 tensor(5.3506, device='cuda:0', grad_fn=)\n",
"14400 tensor(4.7662, device='cuda:0', grad_fn=)\n",
"14500 tensor(4.9474, device='cuda:0', grad_fn=)\n",
"52000\n",
"14600 tensor(5.0245, device='cuda:0', grad_fn=)\n",
"14700 tensor(5.3977, device='cuda:0', grad_fn=)\n",
"14800 tensor(4.9653, device='cuda:0', grad_fn=)\n",
"53000\n",
"14900 tensor(4.8947, device='cuda:0', grad_fn=)\n",
"15000 tensor(5.3548, device='cuda:0', grad_fn=)\n",
"54000\n",
"15100 tensor(4.7244, device='cuda:0', grad_fn=)\n",
"15200 tensor(4.9752, device='cuda:0', grad_fn=)\n",
"15300 tensor(5.3929, device='cuda:0', grad_fn=)\n",
"55000\n",
"15400 tensor(5.3096, device='cuda:0', grad_fn=)\n",
"15500 tensor(5.1247, device='cuda:0', grad_fn=)\n",
"15600 tensor(5.2753, device='cuda:0', grad_fn=)\n",
"56000\n",
"15700 tensor(5.2373, device='cuda:0', grad_fn=)\n",
"15800 tensor(4.9997, device='cuda:0', grad_fn=)\n",
"15900 tensor(5.1718, device='cuda:0', grad_fn=)\n",
"57000\n",
"16000 tensor(5.5952, device='cuda:0', grad_fn=)\n",
"16100 tensor(5.3699, device='cuda:0', grad_fn=)\n",
"16200 tensor(5.0923, device='cuda:0', grad_fn=)\n",
"58000\n",
"16300 tensor(4.9985, device='cuda:0', grad_fn=)\n",
"16400 tensor(5.3076, device='cuda:0', grad_fn=)\n",
"59000\n",
"16500 tensor(5.1994, device='cuda:0', grad_fn=)\n",
"16600 tensor(5.3672, device='cuda:0', grad_fn=)\n",
"16700 tensor(5.2054, device='cuda:0', grad_fn=)\n",
"60000\n",
"16800 tensor(5.3379, device='cuda:0', grad_fn=)\n",
"16900 tensor(5.2785, device='cuda:0', grad_fn=)\n",
"17000 tensor(5.2590, device='cuda:0', grad_fn=)\n",
"61000\n",
"17100 tensor(5.3564, device='cuda:0', grad_fn=)\n",
"17200 tensor(5.3598, device='cuda:0', grad_fn=)\n",
"17300 tensor(4.7786, device='cuda:0', grad_fn=)\n",
"62000\n",
"17400 tensor(5.2639, device='cuda:0', grad_fn=)\n",
"17500 tensor(5.2037, device='cuda:0', grad_fn=)\n",
"17600 tensor(5.1158, device='cuda:0', grad_fn=)\n",
"63000\n",
"17700 tensor(4.9831, device='cuda:0', grad_fn=)\n",
"17800 tensor(4.8950, device='cuda:0', grad_fn=)\n",
"64000\n",
"17900 tensor(5.0928, device='cuda:0', grad_fn=)\n",
"18000 tensor(5.3423, device='cuda:0', grad_fn=)\n",
"18100 tensor(5.1760, device='cuda:0', grad_fn=)\n",
"65000\n",
"18200 tensor(5.2021, device='cuda:0', grad_fn=)\n",
"18300 tensor(5.1306, device='cuda:0', grad_fn=)\n",
"18400 tensor(5.1199, device='cuda:0', grad_fn=)\n",
"66000\n",
"18500 tensor(5.2082, device='cuda:0', grad_fn=)\n",
"18600 tensor(5.3290, device='cuda:0', grad_fn=)\n",
"18700 tensor(5.2257, device='cuda:0', grad_fn=)\n",
"67000\n",
"18800 tensor(4.9107, device='cuda:0', grad_fn=)\n",
"18900 tensor(5.3400, device='cuda:0', grad_fn=)\n",
"68000\n",
"19000 tensor(5.1366, device='cuda:0', grad_fn=)\n",
"19100 tensor(5.1199, device='cuda:0', grad_fn=)\n",
"19200 tensor(5.2202, device='cuda:0', grad_fn=)\n",
"69000\n",
"19300 tensor(5.2236, device='cuda:0', grad_fn=)\n",
"19400 tensor(5.2953, device='cuda:0', grad_fn=)\n",
"19500 tensor(5.1308, device='cuda:0', grad_fn=)\n",
"70000\n",
"19600 tensor(5.3578, device='cuda:0', grad_fn=)\n",
"19700 tensor(5.1600, device='cuda:0', grad_fn=)\n",
"19800 tensor(4.6220, device='cuda:0', grad_fn=)\n",
"71000\n",
"19900 tensor(5.3731, device='cuda:0', grad_fn=)\n",
"20000 tensor(4.9936, device='cuda:0', grad_fn=)\n",
"20100 tensor(5.0817, device='cuda:0', grad_fn=)\n",
"72000\n",
"20200 tensor(5.1613, device='cuda:0', grad_fn=)\n",
"20300 tensor(5.3877, device='cuda:0', grad_fn=)\n",
"73000\n",
"20400 tensor(5.4114, device='cuda:0', grad_fn=)\n",
"20500 tensor(5.2609, device='cuda:0', grad_fn=)\n",
"20600 tensor(5.1378, device='cuda:0', grad_fn=)\n",
"74000\n",
"20700 tensor(5.0799, device='cuda:0', grad_fn=)\n",
"20800 tensor(5.3615, device='cuda:0', grad_fn=)\n",
"20900 tensor(5.3365, device='cuda:0', grad_fn=)\n",
"75000\n",
"21000 tensor(4.9244, device='cuda:0', grad_fn=)\n",
"21100 tensor(5.5084, device='cuda:0', grad_fn=)\n",
"21200 tensor(4.8769, device='cuda:0', grad_fn=)\n",
"76000\n",
"21300 tensor(5.3414, device='cuda:0', grad_fn=)\n",
"21400 tensor(5.0648, device='cuda:0', grad_fn=)\n",
"21500 tensor(5.0594, device='cuda:0', grad_fn=)\n",
"77000\n",
"21600 tensor(5.2537, device='cuda:0', grad_fn=)\n",
"21700 tensor(5.1834, device='cuda:0', grad_fn=)\n",
"21800 tensor(4.8151, device='cuda:0', grad_fn=)\n",
"78000\n",
"21900 tensor(5.3335, device='cuda:0', grad_fn=)\n",
"22000 tensor(4.9580, device='cuda:0', grad_fn=)\n",
"79000\n",
"22100 tensor(5.2262, device='cuda:0', grad_fn=)\n",
"22200 tensor(5.1946, device='cuda:0', grad_fn=)\n",
"22300 tensor(5.2404, device='cuda:0', grad_fn=)\n",
"80000\n",
"22400 tensor(4.9491, device='cuda:0', grad_fn=)\n",
"22500 tensor(4.6901, device='cuda:0', grad_fn=)\n",
"22600 tensor(5.1937, device='cuda:0', grad_fn=)\n",
"81000\n",
"22700 tensor(4.9937, device='cuda:0', grad_fn=)\n",
"22800 tensor(5.1401, device='cuda:0', grad_fn=)\n",
"22900 tensor(5.0599, device='cuda:0', grad_fn=)\n",
"82000\n",
"23000 tensor(5.4315, device='cuda:0', grad_fn=)\n",
"23100 tensor(5.1854, device='cuda:0', grad_fn=)\n",
"83000\n",
"23200 tensor(5.1033, device='cuda:0', grad_fn=)\n",
"23300 tensor(5.2352, device='cuda:0', grad_fn=)\n",
"23400 tensor(5.2004, device='cuda:0', grad_fn=)\n",
"84000\n",
"23500 tensor(5.0866, device='cuda:0', grad_fn=)\n",
"23600 tensor(5.2372, device='cuda:0', grad_fn=)\n",
"23700 tensor(5.4711, device='cuda:0', grad_fn=)\n",
"85000\n",
"23800 tensor(5.4030, device='cuda:0', grad_fn=)\n",
"23900 tensor(5.3589, device='cuda:0', grad_fn=)\n",
"24000 tensor(5.1646, device='cuda:0', grad_fn=)\n",
"86000\n",
"24100 tensor(5.4865, device='cuda:0', grad_fn=)\n",
"24200 tensor(5.3663, device='cuda:0', grad_fn=)\n",
"24300 tensor(5.1760, device='cuda:0', grad_fn=)\n",
"87000\n",
"24400 tensor(5.2950, device='cuda:0', grad_fn=)\n",
"24500 tensor(5.0376, device='cuda:0', grad_fn=)\n",
"88000\n",
"24600 tensor(5.1229, device='cuda:0', grad_fn=)\n",
"24700 tensor(5.3261, device='cuda:0', grad_fn=)\n",
"24800 tensor(5.3953, device='cuda:0', grad_fn=)\n",
"89000\n",
"24900 tensor(5.2734, device='cuda:0', grad_fn=)\n",
"25000 tensor(5.5544, device='cuda:0', grad_fn=