{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import gensim\n", "import torch\n", "import pandas as pd\n", "from gensim.models import Word2Vec\n", "from gensim import downloader\n", "from sklearn.feature_extraction.text import TfidfVectorizer" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "BATCH_SIZE = 10\n", "EPOCHS = 100\n", "FEAUTERES = 200\n" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "class NeuralNetworkModel(torch.nn.Module):\n", " \n", " def __init__(self):\n", " super(NeuralNetworkModel, self).__init__()\n", " self.fc1 = torch.nn.Linear(FEAUTERES,500)\n", " self.fc2 = torch.nn.Linear(500,1)\n", "\n", " def forward(self, x):\n", " x = self.fc1(x)\n", " x = torch.relu(x)\n", " x = self.fc2(x)\n", " x = torch.sigmoid(x)\n", " return x" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "word2vec = downloader.load(\"glove-twitter-200\")" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "def readData(fileName): \n", " with open(f'{fileName}/in.tsv', 'r', encoding='utf8') as f:\n", " X = np.array([x.strip().lower() for x in f.readlines()])\n", " with open(f'{fileName}/expected.tsv', 'r', encoding='utf8') as f:\n", " y = np.array([int(x.strip()) for x in f.readlines()])\n", " return X,y" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "X_file,y_file = readData('dev-0')" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "x_train_w2v = [np.mean([word2vec[word.lower()] for word in doc.split() if word.lower() in word2vec]\n", " or [np.zeros(FEAUTERES)], axis=0) for doc in X_file]" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "def train_model(X_file,y_file):\n", " model = NeuralNetworkModel()\n", "\n", " criterion = torch.nn.BCELoss()\n", " optimizer = torch.optim.ASGD(model.parameters(), lr=0.05)\n", " for epoch in range(EPOCHS):\n", " print(epoch)\n", " loss_score = 0\n", " acc_score = 0\n", " items_total = 0\n", " for i in range(0, y_file.shape[0], BATCH_SIZE):\n", " x = X_file[i:i+BATCH_SIZE]\n", " x = torch.tensor(np.array(x).astype(np.float32))\n", " y = y_file[i:i+BATCH_SIZE]\n", " y = torch.tensor(y.astype(np.float32)).reshape(-1, 1)\n", " y_pred = model(x)\n", " acc_score += torch.sum((y_pred > 0.5) == y).item()\n", " items_total += y.shape[0]\n", "\n", " optimizer.zero_grad()\n", " loss = criterion(y_pred, y)\n", " loss.backward()\n", " optimizer.step()\n", "\n", " loss_score += loss.item() * y.shape[0]\n", " \n", " print((loss_score / items_total), (acc_score / items_total))\n", " return model" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "def predict(model,x_file):\n", " y_dev = []\n", " with torch.no_grad():\n", " for i in range(0, len(x_file), BATCH_SIZE):\n", " x = x_file[i:i+BATCH_SIZE]\n", " x = torch.tensor(np.array(x).astype(np.float32))\n", " outputs = model(x)\n", " y = (outputs > 0.5)\n", " y_dev.extend(y)\n", " return y_dev\n", " " ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "def wrtieToFile(fileName,y_file):\n", " y_out = []\n", " for y in y_file:\n", " y_out.append(int(str(y[0]).split('(')[1].split(')')[0]=='True'))\n", " with open(f'{fileName}/out.tsv','w',encoding='utf8') as f:\n", " for y in y_out:\n", " f.write(f'{y}\\n')" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0\n", "0.6414709375416563 0.6464339908952959\n", "1\n", "0.6118579905971953 0.6589529590288316\n", "2\n", "0.5930351529140393 0.677731411229135\n", "3\n", "0.5807589731138194 0.6936646433990895\n", "4\n", "0.5711128521026628 0.7031487101669196\n", "5\n", "0.5637358135638451 0.7065629742033384\n", "6\n", "0.5573145605239321 0.710546282245827\n", "7\n", "0.5521481898931252 0.715288315629742\n", "8\n", "0.5475104518053836 0.7181335356600911\n", "9\n", "0.5430893454028008 0.7202200303490136\n", "10\n", "0.5395108298066443 0.7236342943854325\n", "11\n", "0.5361589408495177 0.7257207890743551\n", "12\n", "0.53314527610885 0.7270485584218513\n", "13\n", "0.5298747769267226 0.7297040971168437\n", "14\n", "0.5269876997833096 0.7319802731411229\n", "15\n", "0.5245049590914763 0.7336874051593323\n", "16\n", "0.5220209190930057 0.7363429438543247\n", "17\n", "0.5203242429527871 0.7365326251896813\n", "18\n", "0.5182899421417297 0.737670713201821\n", "19\n", "0.5155506848000069 0.7401365705614568\n", "20\n", "0.5131794015095429 0.7403262518968133\n", "21\n", "0.5113656374375719 0.7412746585735963\n", "22\n", "0.5092821710139558 0.7420333839150227\n", "23\n", "0.5067137854063547 0.7441198786039454\n", "24\n", "0.5047900934558085 0.745257966616085\n", "25\n", "0.5025694217866397 0.7488619119878603\n", "26\n", "0.5007175219885451 0.7486722306525038\n", "27\n", "0.4981631609315847 0.747154779969651\n", "28\n", "0.4961598192105615 0.7498103186646434\n", "29\n", "0.49438970515077685 0.7507587253414264\n", "30\n", "0.49240998727621366 0.7507587253414264\n", "31\n", "0.4907134136267018 0.7520864946889226\n", "32\n", "0.48826086573438415 0.7541729893778453\n", "33\n", "0.4871711270185541 0.7560698027314112\n", "34\n", "0.48422483688330614 0.756638846737481\n", "35\n", "0.48217912709371546 0.7604324734446131\n", "36\n", "0.48009182657475535 0.761380880121396\n", "37\n", "0.4778907883217013 0.7632776934749621\n", "38\n", "0.47551582766660067 0.7621396054628224\n", "39\n", "0.47324845619635353 0.7646054628224582\n", "40\n", "0.47138607904755925 0.7653641881638846\n", "41\n", "0.4684638544374424 0.7659332321699545\n", "42\n", "0.4662012148575012 0.7685887708649469\n", "43\n", "0.46414706633568986 0.7693474962063733\n", "44\n", "0.4620490156040613 0.7702959028831563\n", "45\n", "0.46027336999977486 0.7706752655538694\n", "46\n", "0.4574687189093264 0.7746585735963581\n", "47\n", "0.45456105805311653 0.7748482549317147\n", "48\n", "0.45308226045385025 0.7769347496206374\n", "49\n", "0.44969080617490237 0.7792109256449166\n", "50\n", "0.4477136310823092 0.77902124430956\n", "51\n", "0.44523295281067693 0.7841426403641881\n", "52\n", "0.44300158465442235 0.7835735963581184\n", "53\n", "0.44147631555388656 0.7852807283763278\n", "54\n", "0.43824701448718767 0.78850531107739\n", "55\n", "0.437326367692923 0.7936267071320182\n", "56\n", "0.43404240863558824 0.7936267071320182\n", "57\n", "0.43146262304328825 0.7959028831562974\n", "58\n", "0.429094969041996 0.7938163884673748\n", "59\n", "0.42631421763059857 0.7977996965098634\n", "60\n", "0.4239590879280985 0.798937784522003\n", "61\n", "0.4216488930983229 0.8014036418816388\n", "62\n", "0.41922062316595693 0.8033004552352049\n", "63\n", "0.417561381201688 0.8053869499241275\n", "64\n", "0.4144452941633637 0.8051972685887708\n", "65\n", "0.41305530049212064 0.8080424886191199\n", "66\n", "0.410686616688311 0.8072837632776935\n", "67\n", "0.4076426998430889 0.8114567526555387\n", "68\n", "0.4061218895193342 0.811267071320182\n", "69\n", "0.4029337710281198 0.8139226100151745\n", "70\n", "0.40099998707395496 0.8143019726858877\n", "71\n", "0.39854915830701004 0.8133535660091047\n", "72\n", "0.39473064304845285 0.8201820940819423\n", "73\n", "0.3931978788616896 0.8198027314112292\n", "74\n", "0.3905544553760422 0.8218892261001517\n", "75\n", "0.3894510168316513 0.8211305007587253\n", "76\n", "0.38586248252229916 0.8247344461305007\n", "77\n", "0.3851398667786977 0.8256828528072838\n", "78\n", "0.38457902678046857 0.8247344461305007\n", "79\n", "0.3803209278461197 0.8272003034901366\n", "80\n", "0.37845283393127693 0.8287177541729894\n", "81\n", "0.37618811287505943 0.8294764795144158\n", "82\n", "0.37400476449368486 0.8323216995447648\n", "83\n", "0.3726042910890261 0.8332701062215478\n", "84\n", "0.36963997851373215 0.8338391502276176\n", "85\n", "0.3680792153446917 0.8363050075872535\n", "86\n", "0.36542417398160704 0.8361153262518968\n", "87\n", "0.36405448698366627 0.8376327769347496\n", "88\n", "0.3595154614517061 0.8423748103186647\n", "89\n", "0.35860147739566967 0.8419954476479514\n", "90\n", "0.3578952589836848 0.8404779969650986\n", "91\n", "0.35602253879814755 0.8414264036418816\n", "92\n", "0.3523210818386087 0.8446509863429439\n", "93\n", "0.34952340000598764 0.8480652503793626\n", "94\n", "0.3513405356550524 0.845030349013657\n", "95\n", "0.349314306160274 0.8493930197268589\n", "96\n", "0.34516190266595626 0.8492033383915023\n", "97\n", "0.34279035948137776 0.8524279210925645\n", "98\n", "0.34358633996576793 0.8518588770864947\n", "99\n", "0.3396215371445545 0.8528072837632777\n" ] } ], "source": [ "model = train_model(x_train_w2v,y_file)" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "y_dev=predict(model,x_train_w2v)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "wrtieToFile(\"dev-0\",y_dev)" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "with open(f'test-A/in.tsv', 'r', encoding='utf8') as f:\n", " X = np.array([x.strip().lower() for x in f.readlines()])" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "x_train_w2v = [np.mean([word2vec[word.lower()] for word in doc.split() if word.lower() in word2vec]\n", " or [np.zeros(FEAUTERES)], axis=0) for doc in X]" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "y_dev=predict(model,x_train_w2v)" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "wrtieToFile(\"test-A\",y_dev)" ] } ], "metadata": { "interpreter": { "hash": "f08154012ddadd8e950e6e9e035c7a7b32c136e7647e9b7c77e02eb723a8bedb" }, "kernelspec": { "display_name": "Python 3.9.7 ('base')", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.7" }, "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 2 }