test4
This commit is contained in:
parent
76d1ec7f77
commit
cd18524ccb
@ -2,6 +2,7 @@
|
||||
<project version="4">
|
||||
<component name="ChangeListManager">
|
||||
<list default="true" id="d25a65da-2ba0-4272-a0a5-c59cbecb6088" name="Default Changelist" comment="">
|
||||
<change beforePath="$PROJECT_DIR$/.idea/workspace.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/workspace.xml" afterDir="false" />
|
||||
<change beforePath="$PROJECT_DIR$/dev-0/out.tsv" beforeDir="false" afterPath="$PROJECT_DIR$/dev-0/out.tsv" afterDir="false" />
|
||||
<change beforePath="$PROJECT_DIR$/dev-0/out_float.tsv" beforeDir="false" afterPath="$PROJECT_DIR$/dev-0/out_float.tsv" afterDir="false" />
|
||||
<change beforePath="$PROJECT_DIR$/s.py" beforeDir="false" afterPath="$PROJECT_DIR$/s.py" afterDir="false" />
|
||||
|
10592
dev-0/out.tsv
10592
dev-0/out.tsv
File diff suppressed because it is too large
Load Diff
21856
dev-0/out_float.tsv
21856
dev-0/out_float.tsv
File diff suppressed because it is too large
Load Diff
18
s.py
18
s.py
@ -177,7 +177,7 @@ testA_x = torch.cat([testA_x_temp2, testA_x_words_onehot], 1)
|
||||
dataset_train = TrainDataset(x, y)
|
||||
trainloader=DataLoader(dataset=dataset_train, batch_size=minibatch_size, shuffle=True)
|
||||
|
||||
def train_loop(i = 50):
|
||||
def train_loop(i = 20):
|
||||
for i in range(i):
|
||||
for xb, yb_expected in trainloader:
|
||||
optimizer.zero_grad()
|
||||
@ -197,13 +197,13 @@ def train_loop(i = 50):
|
||||
dev_y_pred_float_df = pandas.DataFrame(dev_y_pred_float_tensor.detach().numpy())
|
||||
auc_score = roc_auc_score(dev_y_test, dev_y_pred_float_df)
|
||||
print("auc:\t", auc_score, "\tloss:\t", loss.item())
|
||||
if ((auc_score > 0.70)):
|
||||
if ((auc_score > 0.9)):
|
||||
break
|
||||
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
if ((auc_score > 0.70)):
|
||||
if ((auc_score > 0.9)):
|
||||
break
|
||||
#print(loss)
|
||||
|
||||
@ -223,10 +223,10 @@ for i in range(0,11026):
|
||||
file2.write(str(dev_y[i].data.item()) + "\n")
|
||||
dev_y_pred_float.append(dev_y[i].data.item())
|
||||
var = dev_y[i].data.item()
|
||||
if var < 0.5:
|
||||
file.write("0" + "\n")
|
||||
else:
|
||||
if var > 0.999:
|
||||
file.write("1" + "\n")
|
||||
else:
|
||||
file.write("0" + "\n")
|
||||
file.close()
|
||||
file2.close()
|
||||
|
||||
@ -246,9 +246,9 @@ file2=open("test-A/out_float.tsv","w")
|
||||
|
||||
for i in range(0,11061):
|
||||
file2.write(str(testA_y[i].data.item()) + "\n")
|
||||
if testA_y[i].data.item() < 0.5:
|
||||
file.write("0" + "\n")
|
||||
else:
|
||||
if testA_y[i].data.item() > 0.999:
|
||||
file.write("1" + "\n")
|
||||
else:
|
||||
file.write("0" + "\n")
|
||||
file.close()
|
||||
file2.close()
|
10418
test-A/out.tsv
10418
test-A/out.tsv
File diff suppressed because it is too large
Load Diff
22046
test-A/out_float.tsv
22046
test-A/out_float.tsv
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user