This commit is contained in:
ksanu 2019-12-04 00:20:52 +01:00
parent 76d1ec7f77
commit cd18524ccb
6 changed files with 32466 additions and 32465 deletions

View File

@ -2,6 +2,7 @@
<project version="4">
<component name="ChangeListManager">
<list default="true" id="d25a65da-2ba0-4272-a0a5-c59cbecb6088" name="Default Changelist" comment="">
<change beforePath="$PROJECT_DIR$/.idea/workspace.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/workspace.xml" afterDir="false" />
<change beforePath="$PROJECT_DIR$/dev-0/out.tsv" beforeDir="false" afterPath="$PROJECT_DIR$/dev-0/out.tsv" afterDir="false" />
<change beforePath="$PROJECT_DIR$/dev-0/out_float.tsv" beforeDir="false" afterPath="$PROJECT_DIR$/dev-0/out_float.tsv" afterDir="false" />
<change beforePath="$PROJECT_DIR$/s.py" beforeDir="false" afterPath="$PROJECT_DIR$/s.py" afterDir="false" />

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

18
s.py
View File

@ -177,7 +177,7 @@ testA_x = torch.cat([testA_x_temp2, testA_x_words_onehot], 1)
dataset_train = TrainDataset(x, y)
trainloader=DataLoader(dataset=dataset_train, batch_size=minibatch_size, shuffle=True)
def train_loop(i = 50):
def train_loop(i = 20):
for i in range(i):
for xb, yb_expected in trainloader:
optimizer.zero_grad()
@ -197,13 +197,13 @@ def train_loop(i = 50):
dev_y_pred_float_df = pandas.DataFrame(dev_y_pred_float_tensor.detach().numpy())
auc_score = roc_auc_score(dev_y_test, dev_y_pred_float_df)
print("auc:\t", auc_score, "\tloss:\t", loss.item())
if ((auc_score > 0.70)):
if ((auc_score > 0.9)):
break
loss.backward()
optimizer.step()
if ((auc_score > 0.70)):
if ((auc_score > 0.9)):
break
#print(loss)
@ -223,10 +223,10 @@ for i in range(0,11026):
file2.write(str(dev_y[i].data.item()) + "\n")
dev_y_pred_float.append(dev_y[i].data.item())
var = dev_y[i].data.item()
if var < 0.5:
file.write("0" + "\n")
else:
if var > 0.999:
file.write("1" + "\n")
else:
file.write("0" + "\n")
file.close()
file2.close()
@ -246,9 +246,9 @@ file2=open("test-A/out_float.tsv","w")
for i in range(0,11061):
file2.write(str(testA_y[i].data.item()) + "\n")
if testA_y[i].data.item() < 0.5:
file.write("0" + "\n")
else:
if testA_y[i].data.item() > 0.999:
file.write("1" + "\n")
else:
file.write("0" + "\n")
file.close()
file2.close()

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff