create saving model and results, add arguments parser for main scipt

This commit is contained in:
Filip Patyk 2023-05-06 21:31:51 +02:00
parent fb24ad5bca
commit 4561c65980
5 changed files with 84 additions and 24 deletions

1
.gitignore vendored
View File

@ -1,3 +1,4 @@
.vscode
__pycache__
data/
results/

View File

@ -15,7 +15,7 @@ def evaluate(
model: nn.Module,
test_data: pd.DataFrame,
batch_size: int,
) -> None:
) -> list[int]:
test_dataset = Dataset(test_data)
test_dataloader = torch.utils.data.DataLoader(
@ -26,8 +26,11 @@ def evaluate(
shuffle=True,
)
model.to(DEVICE)
total_acc_test = 0
results = []
model.to(DEVICE)
model.eval()
with torch.no_grad():
for test_input, test_label in test_dataloader:
@ -38,10 +41,8 @@ def evaluate(
output = model(input_id, mask)
acc = (output.argmax(dim=1) == test_label).sum().item()
results.extend(output.argmax(dim=1).tolist())
total_acc_test += acc
print(f"Test Accuracy: {total_acc_test / len(test_data): .3f}")
if __name__ == "__main__":
pass
return results

View File

@ -1,7 +1,8 @@
import random
from sklearn.model_selection import train_test_split
import argparse
from models import BertClassifier
from models import BertClassifier, utils
from datasets import NewsDataset
from train import train
from evaluate import evaluate
@ -11,38 +12,60 @@ SEED = 2137
# Hyperparameters
INITIAL_LR = 1e-6
NUM_EPOCHS = 5
NUM_EPOCHS = 3
BATCH_SIZE = 2
# argument parser
parser = argparse.ArgumentParser(
prog="News classification",
description="Train or evaluate model",
)
parser.add_argument("--train", action="store_true", default=False)
parser.add_argument("--test", action="store_true", default=False)
parser.add_argument("--model_path", type=str, default="results/model.pt")
parser.add_argument("--results_path", type=str, default="results/results.csv")
if __name__ == "__main__":
args = parser.parse_args()
# loading & spliting data
news_dataset = NewsDataset(data_dir_path="data", data_lenght=2000)
train_val_data, test_data = train_test_split(
news_dataset.data,
test_size=0.8,
test_size=0.2,
shuffle=True,
random_state=random.seed(SEED),
)
train_data, val_data = train_test_split(
train_val_data,
test_size=0.2,
shuffle=True,
random_state=random.seed(SEED),
)
# trainig model
trained_model = train(
model=BertClassifier(),
train_data=train_data,
val_data=val_data,
learning_rate=INITIAL_LR,
epochs=NUM_EPOCHS,
batch_size=BATCH_SIZE,
)
if args.train:
trained_model = train(
model=BertClassifier(),
train_data=test_data,
val_data=val_data,
learning_rate=INITIAL_LR,
epochs=NUM_EPOCHS,
batch_size=BATCH_SIZE,
)
utils.save_model(model=trained_model, model_path=args.model_path)
# evaluating model
evaluate(
model=trained_model,
test_data=test_data,
batch_size=BATCH_SIZE,
)
if args.test:
model = utils.load_model(model=BertClassifier(), model_path=args.model_path) # loading model from model.pt file
results = evaluate(
model=model,
test_data=test_data,
batch_size=BATCH_SIZE,
)
utils.save_results(labels=test_data["label"], results=results, file_path=args.results_path)

View File

@ -1,3 +1,4 @@
__all__ = ["BertClassifier"]
__all__ = ["BertClassifier", "utils"]
from .bert_model import BertClassifier
from .utils import utils

34
src/models/utils.py Normal file
View File

@ -0,0 +1,34 @@
from pathlib import Path
import torch
import torch.nn as nn
import pandas as pd
class utils:
def __init__(self) -> None:
pass
@staticmethod
def save_model(model: nn.Module, model_path: str) -> None:
model_path = Path(model_path)
model_path.parent.mkdir(parents=True, exist_ok=True)
torch.save(model.state_dict(), model_path)
print(f"[INFO]\t Model saved at: {model_path}")
@staticmethod
def load_model(model: nn.Module, model_path: str) -> nn.Module:
model_path = Path(model_path)
model_path.parent.mkdir(parents=True, exist_ok=True)
model.load_state_dict(torch.load(model_path))
return model
@staticmethod
def save_results(labels: list[int], results: list[int], file_path: str) -> None:
file_path = Path(file_path)
file_path.parent.mkdir(parents=True, exist_ok=True)
df = pd.DataFrame({"labels": labels, "results": results})
df.to_csv(file_path, index=False)