sol
This commit is contained in:
commit
a5d20e8107
3
.idea/.gitignore
vendored
Normal file
3
.idea/.gitignore
vendored
Normal file
@ -0,0 +1,3 @@
|
||||
# Default ignored files
|
||||
/shelf/
|
||||
/workspace.xml
|
14
.idea/inspectionProfiles/Project_Default.xml
Normal file
14
.idea/inspectionProfiles/Project_Default.xml
Normal file
@ -0,0 +1,14 @@
|
||||
<component name="InspectionProjectProfileManager">
|
||||
<profile version="1.0">
|
||||
<option name="myName" value="Project Default" />
|
||||
<inspection_tool class="PyPep8NamingInspection" enabled="true" level="WEAK WARNING" enabled_by_default="true">
|
||||
<option name="ignoredErrors">
|
||||
<list>
|
||||
<option value="N813" />
|
||||
<option value="N802" />
|
||||
<option value="N806" />
|
||||
</list>
|
||||
</option>
|
||||
</inspection_tool>
|
||||
</profile>
|
||||
</component>
|
6
.idea/inspectionProfiles/profiles_settings.xml
Normal file
6
.idea/inspectionProfiles/profiles_settings.xml
Normal file
@ -0,0 +1,6 @@
|
||||
<component name="InspectionProjectProfileManager">
|
||||
<settings>
|
||||
<option name="USE_PROJECT_PROFILE" value="false" />
|
||||
<version value="1.0" />
|
||||
</settings>
|
||||
</component>
|
4
.idea/misc.xml
Normal file
4
.idea/misc.xml
Normal file
@ -0,0 +1,4 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.8" project-jdk-type="Python SDK" />
|
||||
</project>
|
8
.idea/modules.xml
Normal file
8
.idea/modules.xml
Normal file
@ -0,0 +1,8 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="ProjectModuleManager">
|
||||
<modules>
|
||||
<module fileurl="file://$PROJECT_DIR$/.idea/petite-difference-challenge2.iml" filepath="$PROJECT_DIR$/.idea/petite-difference-challenge2.iml" />
|
||||
</modules>
|
||||
</component>
|
||||
</project>
|
8
.idea/petite-difference-challenge2.iml
Normal file
8
.idea/petite-difference-challenge2.iml
Normal file
@ -0,0 +1,8 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<module type="PYTHON_MODULE" version="4">
|
||||
<component name="NewModuleRootManager">
|
||||
<content url="file://$MODULE_DIR$" />
|
||||
<orderEntry type="inheritedJdk" />
|
||||
<orderEntry type="sourceFolder" forTests="false" />
|
||||
</component>
|
||||
</module>
|
6
.idea/vcs.xml
Normal file
6
.idea/vcs.xml
Normal file
@ -0,0 +1,6 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="VcsDirectoryMappings">
|
||||
<mapping directory="$PROJECT_DIR$" vcs="Git" />
|
||||
</component>
|
||||
</project>
|
1
config.txt
Normal file
1
config.txt
Normal file
@ -0,0 +1 @@
|
||||
--metric Likelihood --metric Accuracy --metric {Likelihood:N<Likelihood>,Accuracy:N<Accuracy>}P<2>{f<in[2]:for-humans>N<+H>,f<in[3]:contaminated>N<+C>,f<in[3]:not-contaminated>N<-C>} --precision 5
|
5272
dev-0/expected.tsv
Normal file
5272
dev-0/expected.tsv
Normal file
File diff suppressed because it is too large
Load Diff
5272
dev-0/in.tsv
Normal file
5272
dev-0/in.tsv
Normal file
File diff suppressed because one or more lines are too long
BIN
dev-0/in.tsv.xz
Normal file
BIN
dev-0/in.tsv.xz
Normal file
Binary file not shown.
137314
dev-0/out.tsv
Normal file
137314
dev-0/out.tsv
Normal file
File diff suppressed because it is too large
Load Diff
7
gonito.yaml
Normal file
7
gonito.yaml
Normal file
@ -0,0 +1,7 @@
|
||||
description: Simple Solution for Logistic Regresion
|
||||
tags:
|
||||
- logistic-regression
|
||||
- pytorch-nn
|
||||
param-files:
|
||||
- "*.yaml"
|
||||
- config/*.yaml
|
98
main.py
Normal file
98
main.py
Normal file
@ -0,0 +1,98 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
import torch.optim as optim
|
||||
from string import punctuation
|
||||
from unidecode import unidecode
|
||||
import pickle
|
||||
|
||||
PREDICT_DIR = "test-A"
|
||||
TRAINED_MODEL_FILE = 'model.bin'
|
||||
INDEX_FILE = "index.pkl"
|
||||
train_lines = open('train/in.tsv', 'r', encoding='utf-8').readlines()
|
||||
result_lines = open('train/expected.tsv', 'r', encoding='utf-8').readlines()
|
||||
input_file = open(PREDICT_DIR+'/in.tsv', 'r')
|
||||
out_file = open(PREDICT_DIR+'/out.tsv', 'w')
|
||||
LABEL_TO_INDEX = {0: 0, 1: 1}
|
||||
|
||||
|
||||
class Classifier(nn.Module):
|
||||
def __init__(self, num_labels, vocab_size):
|
||||
super().__init__()
|
||||
self.linear = nn.Linear(vocab_size, num_labels)
|
||||
|
||||
def forward(self, bow_vec):
|
||||
return F.log_softmax(self.linear(bow_vec), dim=1)
|
||||
|
||||
|
||||
data = []
|
||||
temp = []
|
||||
i = 0
|
||||
lines = len(train_lines)
|
||||
for i in range(lines):
|
||||
temp.append(unidecode(train_lines[i].lower()))
|
||||
for punct in punctuation:
|
||||
if punct in temp[i]:
|
||||
temp[i].replace(punct, " ")
|
||||
only_alpha = list(filter(lambda w: w.isalpha(), temp[i].split()))
|
||||
data.append((only_alpha, int(result_lines[i])))
|
||||
|
||||
word_indexes = {}
|
||||
for sent, _ in data:
|
||||
for word in sent:
|
||||
if word not in word_indexes:
|
||||
word_indexes[word] = len(word_indexes)
|
||||
|
||||
VOCAB_SIZE = len(word_indexes)
|
||||
NUM_LABELS = 2
|
||||
|
||||
|
||||
model = Classifier(NUM_LABELS, VOCAB_SIZE)
|
||||
model.to(torch.device('cuda:0'))
|
||||
loss_function = nn.NLLLoss()
|
||||
optimizer = optim.SGD(model.parameters(), lr=0.1)
|
||||
|
||||
for epoch in range(3):
|
||||
i = 1
|
||||
data_len = len(data)
|
||||
for instance, label in data:
|
||||
model.zero_grad()
|
||||
vec = torch.zeros(len(word_indexes))
|
||||
for word in instance:
|
||||
if word in word_indexes:
|
||||
vec[word_indexes[word]] += 1
|
||||
target = torch.LongTensor([LABEL_TO_INDEX[label]])
|
||||
log_probs = model
|
||||
loss = loss_function(log_probs, target)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
i = i + 1
|
||||
if i % 10000 == 0:
|
||||
print(str(epoch) + "/" + str(3) + " " + str(i/data_len))
|
||||
|
||||
torch.save(model.state_dict(), TRAINED_MODEL_FILE)
|
||||
f = open(INDEX_FILE, "wb")
|
||||
pickle.dump(word_indexes,f)
|
||||
f.close()
|
||||
|
||||
state_dict = torch.load(TRAINED_MODEL_FILE)
|
||||
linear_weight = state_dict['linear.weight']
|
||||
linear_bias = state_dict['linear.bias']
|
||||
model = Classifier(len(linear_bias), len(linear_weight[0]))
|
||||
model.load_state_dict(torch.load(TRAINED_MODEL_FILE))
|
||||
model.eval()
|
||||
|
||||
with open(INDEX_FILE, 'rb') as file:
|
||||
word_indexes = pickle.load(file)
|
||||
|
||||
with torch.no_grad():
|
||||
for line in input_file.readlines():
|
||||
log_probs = model
|
||||
if log_probs[0][0] > log_probs[0][1]:
|
||||
out_file.write("0\n")
|
||||
else:
|
||||
out_file.write("1\n")
|
||||
|
||||
input_file.close()
|
||||
out_file.close()
|
||||
|
5152
test-A/in.tsv
Normal file
5152
test-A/in.tsv
Normal file
File diff suppressed because one or more lines are too long
BIN
test-A/in.tsv.xz
Normal file
BIN
test-A/in.tsv.xz
Normal file
Binary file not shown.
0
test-A/out.tsv
Normal file
0
test-A/out.tsv
Normal file
|
289579
train/expected.tsv
Normal file
289579
train/expected.tsv
Normal file
File diff suppressed because it is too large
Load Diff
289579
train/in.tsv
Normal file
289579
train/in.tsv
Normal file
File diff suppressed because one or more lines are too long
BIN
train/in.tsv.xz
Normal file
BIN
train/in.tsv.xz
Normal file
Binary file not shown.
Loading…
Reference in New Issue
Block a user