Homework sacred

This commit is contained in:
Mikołaj Pokrywka 2022-05-07 14:23:09 +02:00
parent b553894099
commit 2a439a88b2
3 changed files with 20 additions and 9 deletions

3
.gitignore vendored
View File

@ -13,4 +13,5 @@ venv
model_resutls.txt
model
metrics.txt
metrics.png
metrics.png
my_runs

2
Jenkinsfile vendored
View File

@ -21,7 +21,7 @@ pipeline {
withEnv(["EPOCH=${params.EPOCH}"]) {
copyArtifacts filter: '*', projectName: 's444463-create-dataset'
sh 'python3 ./deepl.py $EPOCH'
archiveArtifacts artifacts: "model"
archiveArtifacts artifacts: "model, my_runs"
build job: "s444463-evaluation/master"
}
}

View File

@ -10,7 +10,17 @@ from torch import nn
from torch import optim
import matplotlib.pyplot as plt
import sys
from sacred import Experiment
from sacred.observers import FileStorageObserver
ex = Experiment()
ex.observers.append(FileStorageObserver('my_runs'))
vectorizer = TfidfVectorizer()
@ex.config
def my_config():
epochs = 10
def convert_text_to_model_form(text):
@ -18,12 +28,12 @@ def convert_text_to_model_form(text):
b = torch.tensor(scipy.sparse.csr_matrix.todense(a)).float()
return b
if __name__ == "__main__":
print(sys.argv[1])
print(type(sys.argv[1]))
print(sys.argv[1])
epochs = int(sys.argv[1])
@ex.automain
def my_main(epochs, _run):
# print(sys.argv[1])
# print(type(sys.argv[1]))
# print(sys.argv[1])
# epochs = int(sys.argv[1])
# epochs=10
# kaggle.api.authenticate()
@ -59,7 +69,6 @@ if __name__ == "__main__":
y_dev = np.array(y_dev)
y_test = np.array(y_test)
vectorizer = TfidfVectorizer()
company_profile = vectorizer.fit_transform(company_profile)
x_train = vectorizer.transform(x_train)
@ -172,6 +181,7 @@ if __name__ == "__main__":
f.close()
torch.save(model, 'model')
ex.add_artifact("model")
# plt.figure(figsize=(12, 5))