en-ner-conll-2003/run.py

115 lines
2.8 KiB
Python
Raw Normal View History

2022-06-13 21:44:52 +02:00
import numpy as np
import gensim
import re
import torch
import pandas as pd
from gensim.models import Word2Vec
from gensim import downloader
from sklearn.feature_extraction.text import TfidfVectorizer
from torchtext.vocab import vocab
from collections import Counter, OrderedDict
BATCH_SIZE = 64
EPOCHS = 50
FEATURES = 200
class NeuralNetworkModel(torch.nn.Module):
def __init__(self, ):
self.emb = torch.nn.Embedding(24000, 200)
self.fc1 = torch.nn.Linear(2400, 20)
def forward(self, x):
x = self.emb(x)
x = x.reshape(2400)
x = self.fc1(x)
return x
def train_model(X_train, y_train):
model = NeuralNetworkModel()
criterion = torch.nn.BCELoss()
optimizer = torch.optim.ASGD(model.parameters(), lr=0.05)
for epoch in range(EPOCHS):
print(epoch)
loss_score = 0
acc_score = 0
items_total = 0
for i in range(0, y_train.shape[0], BATCH_SIZE):
x = X_train[i:i+BATCH_SIZE]
x = torch.tensor(np.array(x).astype(np.float32))
y = y_train[i:i+BATCH_SIZE]
y = torch.tensor(y.astype(np.float32)).reshape(-1, 1)
y_pred = model(x)
acc_score += torch.sum((y_pred > 0.5) == y).item()
items_total += y.shape[0]
optimizer.zero_grad()
loss = criterion(y_pred, y)
loss.backward()
optimizer.step()
loss_score += loss.item() * y.shape[0]
print((loss_score / items_total), (acc_score / items_total))
return model
def predict(model, x_test):
y_dev = []
with torch.no_grad():
for i in range(0, len(x_test), BATCH_SIZE):
x = x_test[i:i+BATCH_SIZE]
x = torch.tensor(np.array(x).astype(np.float32))
outputs = model(x)
y = (outputs > 0.5)
y_dev.extend(y)
return y_dev
def load_data(path):
x, y = [], []
with open(path, 'r', encoding='utf8') as f:
for l in f:
d = l.strip().split('\t')
if len(d[0].split()) == len(d[1].split()):
y.append(d[0].split())
x.append(d[1].split())
return x, y
def write_res(data, path):
with open(path, 'w') as f:
for line in data:
f.write(f'{line}\n')
print(f"Data written {path}/out.tsv")
def build_vocab(dataset):
counter = Counter()
for document in dataset:
counter.update(document)
v = vocab(counter, specials=['<unk>', '<pad>', '<bos>', '<eos>'])
v.set_default_index(0)
return v
def main():
print(torch.cuda.is_available())
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
x_train, y_train = load_data('train/train.tsv')
vocab = build_vocab(x_train)
if __name__ == '__main__':
main()