Homework sacred
This commit is contained in:
parent
b553894099
commit
2a439a88b2
3
.gitignore
vendored
3
.gitignore
vendored
@ -13,4 +13,5 @@ venv
|
|||||||
model_resutls.txt
|
model_resutls.txt
|
||||||
model
|
model
|
||||||
metrics.txt
|
metrics.txt
|
||||||
metrics.png
|
metrics.png
|
||||||
|
my_runs
|
2
Jenkinsfile
vendored
2
Jenkinsfile
vendored
@ -21,7 +21,7 @@ pipeline {
|
|||||||
withEnv(["EPOCH=${params.EPOCH}"]) {
|
withEnv(["EPOCH=${params.EPOCH}"]) {
|
||||||
copyArtifacts filter: '*', projectName: 's444463-create-dataset'
|
copyArtifacts filter: '*', projectName: 's444463-create-dataset'
|
||||||
sh 'python3 ./deepl.py $EPOCH'
|
sh 'python3 ./deepl.py $EPOCH'
|
||||||
archiveArtifacts artifacts: "model"
|
archiveArtifacts artifacts: "model, my_runs"
|
||||||
build job: "s444463-evaluation/master"
|
build job: "s444463-evaluation/master"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
24
deepl.py
24
deepl.py
@ -10,7 +10,17 @@ from torch import nn
|
|||||||
from torch import optim
|
from torch import optim
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import sys
|
import sys
|
||||||
|
from sacred import Experiment
|
||||||
|
from sacred.observers import FileStorageObserver
|
||||||
|
|
||||||
|
ex = Experiment()
|
||||||
|
|
||||||
|
ex.observers.append(FileStorageObserver('my_runs'))
|
||||||
|
vectorizer = TfidfVectorizer()
|
||||||
|
|
||||||
|
@ex.config
|
||||||
|
def my_config():
|
||||||
|
epochs = 10
|
||||||
|
|
||||||
|
|
||||||
def convert_text_to_model_form(text):
|
def convert_text_to_model_form(text):
|
||||||
@ -18,12 +28,12 @@ def convert_text_to_model_form(text):
|
|||||||
b = torch.tensor(scipy.sparse.csr_matrix.todense(a)).float()
|
b = torch.tensor(scipy.sparse.csr_matrix.todense(a)).float()
|
||||||
return b
|
return b
|
||||||
|
|
||||||
|
@ex.automain
|
||||||
if __name__ == "__main__":
|
def my_main(epochs, _run):
|
||||||
print(sys.argv[1])
|
# print(sys.argv[1])
|
||||||
print(type(sys.argv[1]))
|
# print(type(sys.argv[1]))
|
||||||
print(sys.argv[1])
|
# print(sys.argv[1])
|
||||||
epochs = int(sys.argv[1])
|
# epochs = int(sys.argv[1])
|
||||||
# epochs=10
|
# epochs=10
|
||||||
|
|
||||||
# kaggle.api.authenticate()
|
# kaggle.api.authenticate()
|
||||||
@ -59,7 +69,6 @@ if __name__ == "__main__":
|
|||||||
y_dev = np.array(y_dev)
|
y_dev = np.array(y_dev)
|
||||||
y_test = np.array(y_test)
|
y_test = np.array(y_test)
|
||||||
|
|
||||||
vectorizer = TfidfVectorizer()
|
|
||||||
|
|
||||||
company_profile = vectorizer.fit_transform(company_profile)
|
company_profile = vectorizer.fit_transform(company_profile)
|
||||||
x_train = vectorizer.transform(x_train)
|
x_train = vectorizer.transform(x_train)
|
||||||
@ -172,6 +181,7 @@ if __name__ == "__main__":
|
|||||||
f.close()
|
f.close()
|
||||||
|
|
||||||
torch.save(model, 'model')
|
torch.save(model, 'model')
|
||||||
|
ex.add_artifact("model")
|
||||||
|
|
||||||
|
|
||||||
# plt.figure(figsize=(12, 5))
|
# plt.figure(figsize=(12, 5))
|
||||||
|
Loading…
Reference in New Issue
Block a user