update learning
All checks were successful
s444507-predict-s444356/pipeline/head This commit looks good
s444507-dvc/pipeline/head This commit looks good
s444507-evaluation/pipeline/head This commit looks good
444507-training/pipeline/head This commit looks good

This commit is contained in:
Adam Wojdyla 2022-06-05 14:44:38 +02:00
parent c7f12908b3
commit 115b7b8a08

View File

@ -11,7 +11,12 @@ import torch.nn.functional as F
import pandas as pd import pandas as pd
from sklearn import preprocessing from sklearn import preprocessing
import sys import sys
import os
path = '.'
files = os.listdir(".")
if not "Car_Prices_Poland_Kaggle.csv" in files:
path = "data"
class Model(nn.Module): class Model(nn.Module):
def __init__(self, input_dim): def __init__(self, input_dim):
@ -29,15 +34,15 @@ class Model(nn.Module):
def load_dataset_raw(): def load_dataset_raw():
""" Load data from .csv file. """ """ Load data from .csv file. """
cars = pd.read_csv('./Car_Prices_Poland_Kaggle.csv', usecols=[1, 4, 5, 6, 10], sep=',') cars = pd.read_csv(f'{path}/Car_Prices_Poland_Kaggle.csv', usecols=[1, 4, 5, 6, 10], sep=',')
return cars return cars
def load_dataset_files(): def load_dataset_files():
""" Load shuffled, splitted dev and train files from .csv files. """ """ Load shuffled, splitted dev and train files from .csv files. """
cars_dev = pd.read_csv('./Car_Prices_Poland_Kaggle_dev.csv', usecols=[1, 4, 5, 6, 10], sep=',', names= [str(i) for i in range(5)]) cars_dev = pd.read_csv(f'{path}/Car_Prices_Poland_Kaggle_dev.csv', usecols=[1, 4, 5, 6, 10], sep=',', names= [str(i) for i in range(5)])
cars_train = pd.read_csv('./Car_Prices_Poland_Kaggle_train.csv', usecols=[1, 4, 5, 6, 10], sep=',', names= [str(i) for i in range(5)]) cars_train = pd.read_csv(f'{path}/Car_Prices_Poland_Kaggle_train.csv', usecols=[1, 4, 5, 6, 10], sep=',', names= [str(i) for i in range(5)])
return cars_dev, cars_train return cars_dev, cars_train
@ -65,15 +70,6 @@ def prepare_labels_features(dataset):
return lab, feat return lab, feat
# def draw_plot(lbl):
# need to import matplotlib to work
# plt.hist(lbl, bins=[i for i in range(len(set(lbl)))], edgecolor="black")
# plt.xticks(np.arange(0, len(set(lbl)), 1))
# plt.show()
# Prepare dataset # Prepare dataset
print("Loading dataset...") print("Loading dataset...")
dev, train = load_dataset_files() dev, train = load_dataset_files()