update learning
This commit is contained in:
parent
c7f12908b3
commit
115b7b8a08
@ -11,7 +11,12 @@ import torch.nn.functional as F
|
|||||||
import pandas as pd
|
import pandas as pd
|
||||||
from sklearn import preprocessing
|
from sklearn import preprocessing
|
||||||
import sys
|
import sys
|
||||||
|
import os
|
||||||
|
|
||||||
|
path = '.'
|
||||||
|
files = os.listdir(".")
|
||||||
|
if not "Car_Prices_Poland_Kaggle.csv" in files:
|
||||||
|
path = "data"
|
||||||
|
|
||||||
class Model(nn.Module):
|
class Model(nn.Module):
|
||||||
def __init__(self, input_dim):
|
def __init__(self, input_dim):
|
||||||
@ -29,15 +34,15 @@ class Model(nn.Module):
|
|||||||
|
|
||||||
def load_dataset_raw():
|
def load_dataset_raw():
|
||||||
""" Load data from .csv file. """
|
""" Load data from .csv file. """
|
||||||
cars = pd.read_csv('./Car_Prices_Poland_Kaggle.csv', usecols=[1, 4, 5, 6, 10], sep=',')
|
cars = pd.read_csv(f'{path}/Car_Prices_Poland_Kaggle.csv', usecols=[1, 4, 5, 6, 10], sep=',')
|
||||||
return cars
|
return cars
|
||||||
|
|
||||||
|
|
||||||
def load_dataset_files():
|
def load_dataset_files():
|
||||||
""" Load shuffled, splitted dev and train files from .csv files. """
|
""" Load shuffled, splitted dev and train files from .csv files. """
|
||||||
|
|
||||||
cars_dev = pd.read_csv('./Car_Prices_Poland_Kaggle_dev.csv', usecols=[1, 4, 5, 6, 10], sep=',', names= [str(i) for i in range(5)])
|
cars_dev = pd.read_csv(f'{path}/Car_Prices_Poland_Kaggle_dev.csv', usecols=[1, 4, 5, 6, 10], sep=',', names= [str(i) for i in range(5)])
|
||||||
cars_train = pd.read_csv('./Car_Prices_Poland_Kaggle_train.csv', usecols=[1, 4, 5, 6, 10], sep=',', names= [str(i) for i in range(5)])
|
cars_train = pd.read_csv(f'{path}/Car_Prices_Poland_Kaggle_train.csv', usecols=[1, 4, 5, 6, 10], sep=',', names= [str(i) for i in range(5)])
|
||||||
|
|
||||||
return cars_dev, cars_train
|
return cars_dev, cars_train
|
||||||
|
|
||||||
@ -65,15 +70,6 @@ def prepare_labels_features(dataset):
|
|||||||
return lab, feat
|
return lab, feat
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# def draw_plot(lbl):
|
|
||||||
# need to import matplotlib to work
|
|
||||||
# plt.hist(lbl, bins=[i for i in range(len(set(lbl)))], edgecolor="black")
|
|
||||||
# plt.xticks(np.arange(0, len(set(lbl)), 1))
|
|
||||||
# plt.show()
|
|
||||||
|
|
||||||
# Prepare dataset
|
# Prepare dataset
|
||||||
print("Loading dataset...")
|
print("Loading dataset...")
|
||||||
dev, train = load_dataset_files()
|
dev, train = load_dataset_files()
|
||||||
|
Loading…
Reference in New Issue
Block a user