update learning
This commit is contained in:
parent
c7f12908b3
commit
115b7b8a08
@ -11,7 +11,12 @@ import torch.nn.functional as F
|
||||
import pandas as pd
|
||||
from sklearn import preprocessing
|
||||
import sys
|
||||
import os
|
||||
|
||||
path = '.'
|
||||
files = os.listdir(".")
|
||||
if not "Car_Prices_Poland_Kaggle.csv" in files:
|
||||
path = "data"
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, input_dim):
|
||||
@ -29,15 +34,15 @@ class Model(nn.Module):
|
||||
|
||||
def load_dataset_raw():
|
||||
""" Load data from .csv file. """
|
||||
cars = pd.read_csv('./Car_Prices_Poland_Kaggle.csv', usecols=[1, 4, 5, 6, 10], sep=',')
|
||||
cars = pd.read_csv(f'{path}/Car_Prices_Poland_Kaggle.csv', usecols=[1, 4, 5, 6, 10], sep=',')
|
||||
return cars
|
||||
|
||||
|
||||
def load_dataset_files():
|
||||
""" Load shuffled, splitted dev and train files from .csv files. """
|
||||
|
||||
cars_dev = pd.read_csv('./Car_Prices_Poland_Kaggle_dev.csv', usecols=[1, 4, 5, 6, 10], sep=',', names= [str(i) for i in range(5)])
|
||||
cars_train = pd.read_csv('./Car_Prices_Poland_Kaggle_train.csv', usecols=[1, 4, 5, 6, 10], sep=',', names= [str(i) for i in range(5)])
|
||||
cars_dev = pd.read_csv(f'{path}/Car_Prices_Poland_Kaggle_dev.csv', usecols=[1, 4, 5, 6, 10], sep=',', names= [str(i) for i in range(5)])
|
||||
cars_train = pd.read_csv(f'{path}/Car_Prices_Poland_Kaggle_train.csv', usecols=[1, 4, 5, 6, 10], sep=',', names= [str(i) for i in range(5)])
|
||||
|
||||
return cars_dev, cars_train
|
||||
|
||||
@ -65,15 +70,6 @@ def prepare_labels_features(dataset):
|
||||
return lab, feat
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
# def draw_plot(lbl):
|
||||
# need to import matplotlib to work
|
||||
# plt.hist(lbl, bins=[i for i in range(len(set(lbl)))], edgecolor="black")
|
||||
# plt.xticks(np.arange(0, len(set(lbl)), 1))
|
||||
# plt.show()
|
||||
|
||||
# Prepare dataset
|
||||
print("Loading dataset...")
|
||||
dev, train = load_dataset_files()
|
||||
|
Loading…
Reference in New Issue
Block a user