paranormal-or-skeptic/run.ipynb

461 lines
12 KiB
Plaintext

{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import gensim\n",
"import torch\n",
"import pandas as pd\n",
"from gensim.models import Word2Vec\n",
"from gensim import downloader\n",
"from sklearn.feature_extraction.text import TfidfVectorizer"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"BATCH_SIZE = 10\n",
"EPOCHS = 100\n",
"FEAUTERES = 200\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"class NeuralNetworkModel(torch.nn.Module):\n",
" \n",
" def __init__(self):\n",
" super(NeuralNetworkModel, self).__init__()\n",
" self.fc1 = torch.nn.Linear(FEAUTERES,500)\n",
" self.fc2 = torch.nn.Linear(500,1)\n",
"\n",
" def forward(self, x):\n",
" x = self.fc1(x)\n",
" x = torch.relu(x)\n",
" x = self.fc2(x)\n",
" x = torch.sigmoid(x)\n",
" return x"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"word2vec = downloader.load(\"glove-twitter-200\")"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"def readData(fileName): \n",
" with open(f'{fileName}/in.tsv', 'r', encoding='utf8') as f:\n",
" X = np.array([x.strip().lower() for x in f.readlines()])\n",
" with open(f'{fileName}/expected.tsv', 'r', encoding='utf8') as f:\n",
" y = np.array([int(x.strip()) for x in f.readlines()])\n",
" return X,y"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"X_file,y_file = readData('dev-0')"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"x_train_w2v = [np.mean([word2vec[word.lower()] for word in doc.split() if word.lower() in word2vec]\n",
" or [np.zeros(FEAUTERES)], axis=0) for doc in X_file]"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"def train_model(X_file,y_file):\n",
" model = NeuralNetworkModel()\n",
"\n",
" criterion = torch.nn.BCELoss()\n",
" optimizer = torch.optim.ASGD(model.parameters(), lr=0.05)\n",
" for epoch in range(EPOCHS):\n",
" print(epoch)\n",
" loss_score = 0\n",
" acc_score = 0\n",
" items_total = 0\n",
" for i in range(0, y_file.shape[0], BATCH_SIZE):\n",
" x = X_file[i:i+BATCH_SIZE]\n",
" x = torch.tensor(np.array(x).astype(np.float32))\n",
" y = y_file[i:i+BATCH_SIZE]\n",
" y = torch.tensor(y.astype(np.float32)).reshape(-1, 1)\n",
" y_pred = model(x)\n",
" acc_score += torch.sum((y_pred > 0.5) == y).item()\n",
" items_total += y.shape[0]\n",
"\n",
" optimizer.zero_grad()\n",
" loss = criterion(y_pred, y)\n",
" loss.backward()\n",
" optimizer.step()\n",
"\n",
" loss_score += loss.item() * y.shape[0]\n",
" \n",
" print((loss_score / items_total), (acc_score / items_total))\n",
" return model"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"def predict(model,x_file):\n",
" y_dev = []\n",
" with torch.no_grad():\n",
" for i in range(0, len(x_file), BATCH_SIZE):\n",
" x = x_file[i:i+BATCH_SIZE]\n",
" x = torch.tensor(np.array(x).astype(np.float32))\n",
" outputs = model(x)\n",
" y = (outputs > 0.5)\n",
" y_dev.extend(y)\n",
" return y_dev\n",
" "
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"def wrtieToFile(fileName,y_file):\n",
" y_out = []\n",
" for y in y_file:\n",
" y_out.append(int(str(y[0]).split('(')[1].split(')')[0]=='True'))\n",
" with open(f'{fileName}/out.tsv','w',encoding='utf8') as f:\n",
" for y in y_out:\n",
" f.write(f'{y}\\n')"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0\n",
"0.6414709375416563 0.6464339908952959\n",
"1\n",
"0.6118579905971953 0.6589529590288316\n",
"2\n",
"0.5930351529140393 0.677731411229135\n",
"3\n",
"0.5807589731138194 0.6936646433990895\n",
"4\n",
"0.5711128521026628 0.7031487101669196\n",
"5\n",
"0.5637358135638451 0.7065629742033384\n",
"6\n",
"0.5573145605239321 0.710546282245827\n",
"7\n",
"0.5521481898931252 0.715288315629742\n",
"8\n",
"0.5475104518053836 0.7181335356600911\n",
"9\n",
"0.5430893454028008 0.7202200303490136\n",
"10\n",
"0.5395108298066443 0.7236342943854325\n",
"11\n",
"0.5361589408495177 0.7257207890743551\n",
"12\n",
"0.53314527610885 0.7270485584218513\n",
"13\n",
"0.5298747769267226 0.7297040971168437\n",
"14\n",
"0.5269876997833096 0.7319802731411229\n",
"15\n",
"0.5245049590914763 0.7336874051593323\n",
"16\n",
"0.5220209190930057 0.7363429438543247\n",
"17\n",
"0.5203242429527871 0.7365326251896813\n",
"18\n",
"0.5182899421417297 0.737670713201821\n",
"19\n",
"0.5155506848000069 0.7401365705614568\n",
"20\n",
"0.5131794015095429 0.7403262518968133\n",
"21\n",
"0.5113656374375719 0.7412746585735963\n",
"22\n",
"0.5092821710139558 0.7420333839150227\n",
"23\n",
"0.5067137854063547 0.7441198786039454\n",
"24\n",
"0.5047900934558085 0.745257966616085\n",
"25\n",
"0.5025694217866397 0.7488619119878603\n",
"26\n",
"0.5007175219885451 0.7486722306525038\n",
"27\n",
"0.4981631609315847 0.747154779969651\n",
"28\n",
"0.4961598192105615 0.7498103186646434\n",
"29\n",
"0.49438970515077685 0.7507587253414264\n",
"30\n",
"0.49240998727621366 0.7507587253414264\n",
"31\n",
"0.4907134136267018 0.7520864946889226\n",
"32\n",
"0.48826086573438415 0.7541729893778453\n",
"33\n",
"0.4871711270185541 0.7560698027314112\n",
"34\n",
"0.48422483688330614 0.756638846737481\n",
"35\n",
"0.48217912709371546 0.7604324734446131\n",
"36\n",
"0.48009182657475535 0.761380880121396\n",
"37\n",
"0.4778907883217013 0.7632776934749621\n",
"38\n",
"0.47551582766660067 0.7621396054628224\n",
"39\n",
"0.47324845619635353 0.7646054628224582\n",
"40\n",
"0.47138607904755925 0.7653641881638846\n",
"41\n",
"0.4684638544374424 0.7659332321699545\n",
"42\n",
"0.4662012148575012 0.7685887708649469\n",
"43\n",
"0.46414706633568986 0.7693474962063733\n",
"44\n",
"0.4620490156040613 0.7702959028831563\n",
"45\n",
"0.46027336999977486 0.7706752655538694\n",
"46\n",
"0.4574687189093264 0.7746585735963581\n",
"47\n",
"0.45456105805311653 0.7748482549317147\n",
"48\n",
"0.45308226045385025 0.7769347496206374\n",
"49\n",
"0.44969080617490237 0.7792109256449166\n",
"50\n",
"0.4477136310823092 0.77902124430956\n",
"51\n",
"0.44523295281067693 0.7841426403641881\n",
"52\n",
"0.44300158465442235 0.7835735963581184\n",
"53\n",
"0.44147631555388656 0.7852807283763278\n",
"54\n",
"0.43824701448718767 0.78850531107739\n",
"55\n",
"0.437326367692923 0.7936267071320182\n",
"56\n",
"0.43404240863558824 0.7936267071320182\n",
"57\n",
"0.43146262304328825 0.7959028831562974\n",
"58\n",
"0.429094969041996 0.7938163884673748\n",
"59\n",
"0.42631421763059857 0.7977996965098634\n",
"60\n",
"0.4239590879280985 0.798937784522003\n",
"61\n",
"0.4216488930983229 0.8014036418816388\n",
"62\n",
"0.41922062316595693 0.8033004552352049\n",
"63\n",
"0.417561381201688 0.8053869499241275\n",
"64\n",
"0.4144452941633637 0.8051972685887708\n",
"65\n",
"0.41305530049212064 0.8080424886191199\n",
"66\n",
"0.410686616688311 0.8072837632776935\n",
"67\n",
"0.4076426998430889 0.8114567526555387\n",
"68\n",
"0.4061218895193342 0.811267071320182\n",
"69\n",
"0.4029337710281198 0.8139226100151745\n",
"70\n",
"0.40099998707395496 0.8143019726858877\n",
"71\n",
"0.39854915830701004 0.8133535660091047\n",
"72\n",
"0.39473064304845285 0.8201820940819423\n",
"73\n",
"0.3931978788616896 0.8198027314112292\n",
"74\n",
"0.3905544553760422 0.8218892261001517\n",
"75\n",
"0.3894510168316513 0.8211305007587253\n",
"76\n",
"0.38586248252229916 0.8247344461305007\n",
"77\n",
"0.3851398667786977 0.8256828528072838\n",
"78\n",
"0.38457902678046857 0.8247344461305007\n",
"79\n",
"0.3803209278461197 0.8272003034901366\n",
"80\n",
"0.37845283393127693 0.8287177541729894\n",
"81\n",
"0.37618811287505943 0.8294764795144158\n",
"82\n",
"0.37400476449368486 0.8323216995447648\n",
"83\n",
"0.3726042910890261 0.8332701062215478\n",
"84\n",
"0.36963997851373215 0.8338391502276176\n",
"85\n",
"0.3680792153446917 0.8363050075872535\n",
"86\n",
"0.36542417398160704 0.8361153262518968\n",
"87\n",
"0.36405448698366627 0.8376327769347496\n",
"88\n",
"0.3595154614517061 0.8423748103186647\n",
"89\n",
"0.35860147739566967 0.8419954476479514\n",
"90\n",
"0.3578952589836848 0.8404779969650986\n",
"91\n",
"0.35602253879814755 0.8414264036418816\n",
"92\n",
"0.3523210818386087 0.8446509863429439\n",
"93\n",
"0.34952340000598764 0.8480652503793626\n",
"94\n",
"0.3513405356550524 0.845030349013657\n",
"95\n",
"0.349314306160274 0.8493930197268589\n",
"96\n",
"0.34516190266595626 0.8492033383915023\n",
"97\n",
"0.34279035948137776 0.8524279210925645\n",
"98\n",
"0.34358633996576793 0.8518588770864947\n",
"99\n",
"0.3396215371445545 0.8528072837632777\n"
]
}
],
"source": [
"model = train_model(x_train_w2v,y_file)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"y_dev=predict(model,x_train_w2v)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"wrtieToFile(\"dev-0\",y_dev)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"with open(f'test-A/in.tsv', 'r', encoding='utf8') as f:\n",
" X = np.array([x.strip().lower() for x in f.readlines()])"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"x_train_w2v = [np.mean([word2vec[word.lower()] for word in doc.split() if word.lower() in word2vec]\n",
" or [np.zeros(FEAUTERES)], axis=0) for doc in X]"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"y_dev=predict(model,x_train_w2v)"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
"wrtieToFile(\"test-A\",y_dev)"
]
}
],
"metadata": {
"interpreter": {
"hash": "f08154012ddadd8e950e6e9e035c7a7b32c136e7647e9b7c77e02eb723a8bedb"
},
"kernelspec": {
"display_name": "Python 3.9.7 ('base')",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.7"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}