#!/usr/bin/python
import os.path
import sys

import pandas as pd
import numpy as np
from kaggle import api
from pandas import read_csv
from sklearn.model_selection import train_test_split
from nltk.tokenize import RegexpTokenizer


def download_and_save_dataset(data_path, dataset_name):
    if not os.path.exists(os.path.join(data_path, dataset_name)):
        api.authenticate()
        api.dataset_download_files('shivamb/real-or-fake-fake-jobposting-prediction', path=data_path,
                                   unzip=True)
        os.rename(os.path.join(data_path, 'fake_job_postings.csv'), os.path.join(data_path, dataset_name))


def preprocess_dataset(data):
    data = data.replace(np.nan, '', regex=True)

    data['description'] = data['description'].str.replace(r"\W+", " ", regex=True)
    data['description'] = data['description'].str.replace(r"url_\w+", " ", regex=True)
    data['description'] = data['description'].str.replace(r"\s+", " ", regex=True)

    data['text'] = data[['title', 'department', 'company_profile', 'description', 'requirements', 'benefits']].apply(
        lambda x: ' '.join(x), axis=1)
    data['text'] = data['text'].str.lower()

    tokenizer = RegexpTokenizer(r'\w+')
    data['tokens'] = data['text'].apply(tokenizer.tokenize)

    return data.drop(['job_id', 'department', 'company_profile', 'description', 'requirements', 'benefits', 'text'],
                     axis=1)


def split_dataset(data_path, dataset_name):
    data = read_csv(os.path.join(data_path, dataset_name))
    data = preprocess_dataset(data)
    train_ratio, validation_ratio, test_ratio = 0.6, 0.2, 0.2
    data_x, data_y = data.iloc[:, :-1], data.iloc[:, -1:]

    x_train, x_test, y_train, y_test = train_test_split(data_x, data_y, test_size=1 - train_ratio,
                                                        random_state=123, stratify=data['fraudulent'])

    x_val, x_test, y_val, y_test = train_test_split(x_test, y_test,
                                                    test_size=test_ratio / (test_ratio + validation_ratio),
                                                    random_state=123)

    return x_train, x_val, x_test, y_train, y_val, y_test


def save_dataset(data_path, data, name):
    data.to_csv(os.path.join(data_path, name), index=False)


def main():
    data_path, dataset_name = sys.argv[1], sys.argv[2]
    abs_data_path = os.path.abspath(data_path)
    download_and_save_dataset(abs_data_path, dataset_name)

    x_train, x_val, x_test, y_train, y_val, y_test = split_dataset(abs_data_path, dataset_name)

    train_data = pd.concat([x_train, y_train], axis=1)
    test_data = pd.concat([x_test, y_test], axis=1)
    dev_data = pd.concat([x_val, y_val], axis=1)

    for data, name in ((train_data, 'train_data.csv'), (test_data, 'test_data.csv'), (dev_data, 'dev_data.csv')):
        save_dataset(abs_data_path, data, name)


if __name__ == '__main__':
    main()