sacred
This commit is contained in:
parent
3c4da67f2d
commit
a428cad996
@ -2,16 +2,10 @@ import sys
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from sacred.observers import FileStorageObserver, MongoObserver
|
||||
from sklearn.preprocessing import LabelEncoder
|
||||
import pandas as pd
|
||||
|
||||
|
||||
# Parametry z konsoli
|
||||
try:
|
||||
epochs = int(sys.argv[1])
|
||||
except:
|
||||
print('No epoch number passed. Defaulting to 100')
|
||||
epochs = 100
|
||||
from sacred import Experiment
|
||||
|
||||
|
||||
# Model
|
||||
@ -29,6 +23,32 @@ class Model(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
# Sacred
|
||||
ex = Experiment()
|
||||
ex.observers.append(FileStorageObserver('my_runs'))
|
||||
# Parametry treningu -> my_runs/X/config.json
|
||||
# Plik z modelem jako artefakt -> my_runs/X/model.pkl
|
||||
# Kod źródłowy -> my_runs/_sources/biblioteki_ml_XXXXXXXXXXX.py
|
||||
# Wyniki (ostateczny loss) -> my_runs/X/metrics.json
|
||||
ex.observers.append(MongoObserver(url='mongodb://mongo_user:mongo_password_IUM_2021@localhost:27017',
|
||||
db_name='sacred'))
|
||||
|
||||
|
||||
@ex.config
|
||||
def my_config():
|
||||
epochs = 100
|
||||
|
||||
|
||||
@ex.automain
|
||||
def train_main(epochs, _run):
|
||||
# Parametry z konsoli
|
||||
# try:
|
||||
# epochs = int(sys.argv[1])
|
||||
# except:
|
||||
# print('No epoch number passed. Defaulting to 100')
|
||||
# epochs = 100
|
||||
|
||||
|
||||
# Ładowanie danych
|
||||
train_set = pd.read_csv('d_train.csv', encoding='latin-1')
|
||||
train_set = train_set[['Rating', 'Branch', 'Reviewer_Location']]
|
||||
@ -83,6 +103,7 @@ for i in range(epochs):
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
_run.log_scalar("training.final_loss", losses[-1].item()) # Ostateczny loss
|
||||
|
||||
|
||||
# Testy
|
||||
@ -100,3 +121,7 @@ print(f"{df['Correct'].sum() / len(df)} percent of predictions correct")
|
||||
# Zapis do pliku
|
||||
df.to_csv('neural_network_prediction_results.csv', index=False)
|
||||
torch.save(model, "model.pkl")
|
||||
|
||||
# Zapis Sacred
|
||||
ex.add_artifact("model.pkl")
|
||||
ex.add_artifact("neural_network_prediction_results.csv")
|
||||
|
@ -24,13 +24,13 @@ pipeline {
|
||||
stage('Train model') {
|
||||
steps {
|
||||
withEnv(["EPOCH=${params.EPOCH}"]) {
|
||||
sh 'python biblioteki_ml.py $EPOCH'
|
||||
sh 'python biblioteki_ml.py with "epochs=$EPOCH"'
|
||||
}
|
||||
}
|
||||
}
|
||||
stage('Archive model') {
|
||||
stage('Archive artifacts') {
|
||||
steps {
|
||||
archiveArtifacts artifacts: 'model.pkl, neural_network_prediction_results.csv'
|
||||
archiveArtifacts artifacts: 'model.pkl, neural_network_prediction_results.csv, my_run'
|
||||
}
|
||||
}
|
||||
stage ('Model - evaluation') {
|
||||
|
Loading…
Reference in New Issue
Block a user