61 lines
2.2 KiB
Python
61 lines
2.2 KiB
Python
#!/usr/bin/python
|
|
import os.path
|
|
import sys
|
|
|
|
import pandas as pd
|
|
from kaggle import api
|
|
from pandas import read_csv
|
|
from sklearn.model_selection import train_test_split
|
|
|
|
|
|
def download_and_save_dataset(data_path, dataset_name):
|
|
if not os.path.exists(os.path.join(data_path, dataset_name)):
|
|
api.authenticate()
|
|
api.dataset_download_files('shivamb/real-or-fake-fake-jobposting-prediction', path=data_path,
|
|
unzip=True)
|
|
os.rename(os.path.join(data_path, 'fake_job_postings.csv'), os.path.join(data_path, dataset_name))
|
|
|
|
|
|
def preprocess_dataset(data):
|
|
# drop columns with many nulls
|
|
return data.drop(['job_id', 'department', 'salary_range', 'benefits'], axis=1)
|
|
|
|
|
|
def split_dataset(data_path, dataset_name):
|
|
data = read_csv(os.path.join(data_path, dataset_name))
|
|
data = preprocess_dataset(data)
|
|
train_ratio, validation_ratio, test_ratio = 0.6, 0.2, 0.2
|
|
data_x, data_y = data.iloc[:, :-1], data.iloc[:, -1:]
|
|
|
|
x_train, x_test, y_train, y_test = train_test_split(data_x, data_y, test_size=1 - train_ratio,
|
|
random_state=123, stratify=data['fraudulent'])
|
|
|
|
x_val, x_test, y_val, y_test = train_test_split(x_test, y_test,
|
|
test_size=test_ratio / (test_ratio + validation_ratio),
|
|
random_state=123)
|
|
|
|
return x_train, x_val, x_test, y_train, y_val, y_test
|
|
|
|
|
|
def save_dataset(data_path, data, name):
|
|
data.to_csv(os.path.join(data_path, name))
|
|
|
|
|
|
def main():
|
|
data_path, dataset_name = sys.argv[1], sys.argv[2]
|
|
abs_data_path = os.path.abspath(data_path)
|
|
download_and_save_dataset(abs_data_path, dataset_name)
|
|
|
|
x_train, x_val, x_test, y_train, y_val, y_test = split_dataset(abs_data_path, dataset_name)
|
|
|
|
train_data = pd.concat([x_train, y_train], axis=1)
|
|
test_data = pd.concat([x_test, y_test], axis=1)
|
|
dev_data = pd.concat([x_val, y_val], axis=1)
|
|
|
|
for data, name in ((train_data, 'train_data.csv'), (test_data, 'test_data.csv'), (dev_data, 'dev_data.csv')):
|
|
save_dataset(abs_data_path, data, name)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|