Out files with prediction results
This commit is contained in:
parent
892f21fc34
commit
63d362dc73
5272
dev-0/out.tsv
Normal file
5272
dev-0/out.tsv
Normal file
File diff suppressed because it is too large
Load Diff
35
main.py
35
main.py
@ -47,9 +47,6 @@ def read_data():
|
|||||||
|
|
||||||
x_labels, y_labels, x_train, y_train, x_dev, x_test = read_data()
|
x_labels, y_labels, x_train, y_train, x_dev, x_test = read_data()
|
||||||
|
|
||||||
print(len(y_train))
|
|
||||||
print(len(x_train))
|
|
||||||
|
|
||||||
x_train = x_train[x_labels[0]].str.lower()
|
x_train = x_train[x_labels[0]].str.lower()
|
||||||
x_dev = x_dev[x_labels[0]].str.lower()
|
x_dev = x_dev[x_labels[0]].str.lower()
|
||||||
x_test = x_test[x_labels[0]].str.lower()
|
x_test = x_test[x_labels[0]].str.lower()
|
||||||
@ -88,5 +85,35 @@ for epoch in range(5):
|
|||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
|
y_dev = []
|
||||||
|
y_test = []
|
||||||
|
|
||||||
print(Y_predictions)
|
nn_model.eval()
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for i in range(0, len(x_dev), BATCH_SIZE):
|
||||||
|
X = x_dev[i:i+BATCH_SIZE]
|
||||||
|
X = torch.tensor(X)
|
||||||
|
|
||||||
|
outputs = nn_model(X.float())
|
||||||
|
|
||||||
|
y = (outputs > 0.5)
|
||||||
|
y_dev.extend(y)
|
||||||
|
|
||||||
|
for i in range(0, len(x_test), BATCH_SIZE):
|
||||||
|
X = x_test[i:i+BATCH_SIZE]
|
||||||
|
X = torch.tensor(X)
|
||||||
|
|
||||||
|
outputs = nn_model(X.float())
|
||||||
|
|
||||||
|
y = (outputs > 0.5)
|
||||||
|
y_test.extend(y)
|
||||||
|
|
||||||
|
y_dev = np.asarray(y_dev, dtype=np.int32)
|
||||||
|
y_test = np.asarray(y_test, dtype=np.int32)
|
||||||
|
|
||||||
|
Y_dev = pd.DataFrame({'label': y_dev})
|
||||||
|
Y_test = pd.DataFrame({'label': y_test})
|
||||||
|
|
||||||
|
Y_dev.to_csv(r'dev-0/out.tsv', sep='\t', index=False, header=False)
|
||||||
|
Y_test.to_csv(r'test-A/out.tsv', sep='\t', index=False, header=False)
|
||||||
|
5152
test-A/out.tsv
Normal file
5152
test-A/out.tsv
Normal file
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user