scared first attempt
This commit is contained in:
parent
3f0f274eb3
commit
dc82d390f1
1
.gitignore
vendored
Normal file
1
.gitignore
vendored
Normal file
@ -0,0 +1 @@
|
|||||||
|
sacred_runs
|
1
evaluation/Jenkinsfile
vendored
1
evaluation/Jenkinsfile
vendored
@ -27,6 +27,7 @@ pipeline {
|
|||||||
steps {
|
steps {
|
||||||
sh 'python ./evaluation/eval.py'
|
sh 'python ./evaluation/eval.py'
|
||||||
sh 'python ./evaluation/plot.py'
|
sh 'python ./evaluation/plot.py'
|
||||||
|
sh 'python ./evaluation/scared-fileobserver.py'
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
stage('archiveArtifacts') {
|
stage('archiveArtifacts') {
|
||||||
|
92
evaluation/scared-fileobserver.py
Normal file
92
evaluation/scared-fileobserver.py
Normal file
@ -0,0 +1,92 @@
|
|||||||
|
import datetime
|
||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
from torch.autograd import Variable
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from sacred import Experiment
|
||||||
|
from sacred.observers import FileStorageObserver
|
||||||
|
import csv
|
||||||
|
|
||||||
|
ex = Experiment("434700-file", interactive=False, save_git_info=False)
|
||||||
|
ex.observers.append(FileStorageObserver('sacred_runs/my_runs'))
|
||||||
|
|
||||||
|
|
||||||
|
@ex.config
|
||||||
|
def my_config():
|
||||||
|
epochs = 10
|
||||||
|
batch_size = 16
|
||||||
|
|
||||||
|
|
||||||
|
@ex.capture
|
||||||
|
def prepare_model(epochs, batch_size, _run):
|
||||||
|
INPUT_DIM = 1
|
||||||
|
OUTPUT_DIM = 1
|
||||||
|
LEARNING_RATE = 0.01
|
||||||
|
EPOCHS = epochs
|
||||||
|
|
||||||
|
dataset = pd.read_csv('datasets/train_set.csv')
|
||||||
|
|
||||||
|
x_values = [datetime.datetime.strptime(
|
||||||
|
item, "%Y-%m-%d").month for item in dataset['date'].values]
|
||||||
|
x_train = np.array(x_values, dtype=np.float32)
|
||||||
|
x_train = x_train.reshape(-1, 1)
|
||||||
|
|
||||||
|
y_values = [min(dataset['result_1'].values[i]/dataset['result_2'].values[i], dataset['result_2'].values[i] /
|
||||||
|
dataset['result_1'].values[i]) for i in range(len(dataset['result_1'].values))]
|
||||||
|
y_train = np.array(y_values, dtype=np.float32)
|
||||||
|
y_train = y_train.reshape(-1, 1)
|
||||||
|
|
||||||
|
class LinearRegression(torch.nn.Module):
|
||||||
|
def __init__(self, inputSize, outputSize):
|
||||||
|
super(LinearRegression, self).__init__()
|
||||||
|
self.linear = torch.nn.Linear(inputSize, outputSize)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
out = self.linear(x)
|
||||||
|
return out
|
||||||
|
|
||||||
|
model = LinearRegression(INPUT_DIM, OUTPUT_DIM)
|
||||||
|
|
||||||
|
criterion = torch.nn.MSELoss()
|
||||||
|
optimizer = torch.optim.SGD(model.parameters(), lr=LEARNING_RATE)
|
||||||
|
|
||||||
|
for epoch in range(EPOCHS):
|
||||||
|
inputs = Variable(torch.from_numpy(x_train))
|
||||||
|
labels = Variable(torch.from_numpy(y_train))
|
||||||
|
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
outputs = model(inputs)
|
||||||
|
|
||||||
|
loss = criterion(outputs, labels)
|
||||||
|
print(loss)
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
print('epoch {}, loss {}'.format(epoch, loss.item()))
|
||||||
|
|
||||||
|
torch.save(model.state_dict(), 'model-experiment.pt')
|
||||||
|
|
||||||
|
with torch.no_grad(): # we don't need gradients in the testing phase
|
||||||
|
predicted = model(Variable(torch.from_numpy(x_train))).data.numpy()
|
||||||
|
|
||||||
|
with open('model_experiment_results.csv', mode='w') as filee:
|
||||||
|
writer = csv.writer(filee, delimiter=',', quotechar='"',
|
||||||
|
quoting=csv.QUOTE_MINIMAL)
|
||||||
|
|
||||||
|
writer.writerow(['x', 'y', 'predicted_y'])
|
||||||
|
|
||||||
|
for i in range(len(x_train)):
|
||||||
|
writer.writerow([x_train[i][0], y_train[i][0], predicted[i][0]])
|
||||||
|
|
||||||
|
|
||||||
|
@ex.automain
|
||||||
|
def my_main(epochs, batch_size):
|
||||||
|
print(prepare_model())
|
||||||
|
|
||||||
|
|
||||||
|
ex.run()
|
||||||
|
ex.add_artifact('model-experiment.pt')
|
||||||
|
ex.add_artifact('model_experiment_results.csv')
|
@ -4,3 +4,4 @@ torch
|
|||||||
numpy
|
numpy
|
||||||
matplotlib
|
matplotlib
|
||||||
seaborn
|
seaborn
|
||||||
|
sacred
|
Loading…
Reference in New Issue
Block a user