diff --git a/.gitignore b/.gitignore index a98485a..fc96cd3 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ .vscode __pycache__ -data/ \ No newline at end of file +data/ +results/ \ No newline at end of file diff --git a/src/evaluate.py b/src/evaluate.py index 0877170..052a7ce 100644 --- a/src/evaluate.py +++ b/src/evaluate.py @@ -15,7 +15,7 @@ def evaluate( model: nn.Module, test_data: pd.DataFrame, batch_size: int, -) -> None: +) -> list[int]: test_dataset = Dataset(test_data) test_dataloader = torch.utils.data.DataLoader( @@ -26,8 +26,11 @@ def evaluate( shuffle=True, ) - model.to(DEVICE) total_acc_test = 0 + results = [] + + model.to(DEVICE) + model.eval() with torch.no_grad(): for test_input, test_label in test_dataloader: @@ -38,10 +41,8 @@ def evaluate( output = model(input_id, mask) acc = (output.argmax(dim=1) == test_label).sum().item() + results.extend(output.argmax(dim=1).tolist()) total_acc_test += acc print(f"Test Accuracy: {total_acc_test / len(test_data): .3f}") - - -if __name__ == "__main__": - pass + return results diff --git a/src/main.py b/src/main.py index 39ceda9..b48d32c 100644 --- a/src/main.py +++ b/src/main.py @@ -1,7 +1,8 @@ import random from sklearn.model_selection import train_test_split +import argparse -from models import BertClassifier +from models import BertClassifier, utils from datasets import NewsDataset from train import train from evaluate import evaluate @@ -11,38 +12,60 @@ SEED = 2137 # Hyperparameters INITIAL_LR = 1e-6 -NUM_EPOCHS = 5 +NUM_EPOCHS = 3 BATCH_SIZE = 2 + +# argument parser + +parser = argparse.ArgumentParser( + prog="News classification", + description="Train or evaluate model", +) +parser.add_argument("--train", action="store_true", default=False) +parser.add_argument("--test", action="store_true", default=False) +parser.add_argument("--model_path", type=str, default="results/model.pt") +parser.add_argument("--results_path", type=str, default="results/results.csv") + + if __name__ == "__main__": + args = parser.parse_args() + # loading & spliting data news_dataset = NewsDataset(data_dir_path="data", data_lenght=2000) train_val_data, test_data = train_test_split( news_dataset.data, - test_size=0.8, + test_size=0.2, shuffle=True, random_state=random.seed(SEED), ) + train_data, val_data = train_test_split( train_val_data, test_size=0.2, shuffle=True, random_state=random.seed(SEED), ) + # trainig model - trained_model = train( - model=BertClassifier(), - train_data=train_data, - val_data=val_data, - learning_rate=INITIAL_LR, - epochs=NUM_EPOCHS, - batch_size=BATCH_SIZE, - ) + if args.train: + trained_model = train( + model=BertClassifier(), + train_data=test_data, + val_data=val_data, + learning_rate=INITIAL_LR, + epochs=NUM_EPOCHS, + batch_size=BATCH_SIZE, + ) + utils.save_model(model=trained_model, model_path=args.model_path) # evaluating model - evaluate( - model=trained_model, - test_data=test_data, - batch_size=BATCH_SIZE, - ) + if args.test: + model = utils.load_model(model=BertClassifier(), model_path=args.model_path) # loading model from model.pt file + results = evaluate( + model=model, + test_data=test_data, + batch_size=BATCH_SIZE, + ) + utils.save_results(labels=test_data["label"], results=results, file_path=args.results_path) diff --git a/src/models/__init__.py b/src/models/__init__.py index fed70df..d255641 100644 --- a/src/models/__init__.py +++ b/src/models/__init__.py @@ -1,3 +1,4 @@ -__all__ = ["BertClassifier"] +__all__ = ["BertClassifier", "utils"] from .bert_model import BertClassifier +from .utils import utils diff --git a/src/models/utils.py b/src/models/utils.py new file mode 100644 index 0000000..ca5465f --- /dev/null +++ b/src/models/utils.py @@ -0,0 +1,34 @@ +from pathlib import Path + +import torch +import torch.nn as nn +import pandas as pd + + +class utils: + def __init__(self) -> None: + pass + + @staticmethod + def save_model(model: nn.Module, model_path: str) -> None: + model_path = Path(model_path) + model_path.parent.mkdir(parents=True, exist_ok=True) + + torch.save(model.state_dict(), model_path) + print(f"[INFO]\t Model saved at: {model_path}") + + @staticmethod + def load_model(model: nn.Module, model_path: str) -> nn.Module: + model_path = Path(model_path) + model_path.parent.mkdir(parents=True, exist_ok=True) + + model.load_state_dict(torch.load(model_path)) + return model + + @staticmethod + def save_results(labels: list[int], results: list[int], file_path: str) -> None: + file_path = Path(file_path) + file_path.parent.mkdir(parents=True, exist_ok=True) + + df = pd.DataFrame({"labels": labels, "results": results}) + df.to_csv(file_path, index=False)