create saving model and results, add arguments parser for main scipt

This commit is contained in:
Filip Patyk 2023-05-06 21:31:51 +02:00
parent fb24ad5bca
commit 4561c65980
5 changed files with 84 additions and 24 deletions

3
.gitignore vendored
View File

@ -1,3 +1,4 @@
.vscode .vscode
__pycache__ __pycache__
data/ data/
results/

View File

@ -15,7 +15,7 @@ def evaluate(
model: nn.Module, model: nn.Module,
test_data: pd.DataFrame, test_data: pd.DataFrame,
batch_size: int, batch_size: int,
) -> None: ) -> list[int]:
test_dataset = Dataset(test_data) test_dataset = Dataset(test_data)
test_dataloader = torch.utils.data.DataLoader( test_dataloader = torch.utils.data.DataLoader(
@ -26,8 +26,11 @@ def evaluate(
shuffle=True, shuffle=True,
) )
model.to(DEVICE)
total_acc_test = 0 total_acc_test = 0
results = []
model.to(DEVICE)
model.eval()
with torch.no_grad(): with torch.no_grad():
for test_input, test_label in test_dataloader: for test_input, test_label in test_dataloader:
@ -38,10 +41,8 @@ def evaluate(
output = model(input_id, mask) output = model(input_id, mask)
acc = (output.argmax(dim=1) == test_label).sum().item() acc = (output.argmax(dim=1) == test_label).sum().item()
results.extend(output.argmax(dim=1).tolist())
total_acc_test += acc total_acc_test += acc
print(f"Test Accuracy: {total_acc_test / len(test_data): .3f}") print(f"Test Accuracy: {total_acc_test / len(test_data): .3f}")
return results
if __name__ == "__main__":
pass

View File

@ -1,7 +1,8 @@
import random import random
from sklearn.model_selection import train_test_split from sklearn.model_selection import train_test_split
import argparse
from models import BertClassifier from models import BertClassifier, utils
from datasets import NewsDataset from datasets import NewsDataset
from train import train from train import train
from evaluate import evaluate from evaluate import evaluate
@ -11,38 +12,60 @@ SEED = 2137
# Hyperparameters # Hyperparameters
INITIAL_LR = 1e-6 INITIAL_LR = 1e-6
NUM_EPOCHS = 5 NUM_EPOCHS = 3
BATCH_SIZE = 2 BATCH_SIZE = 2
# argument parser
parser = argparse.ArgumentParser(
prog="News classification",
description="Train or evaluate model",
)
parser.add_argument("--train", action="store_true", default=False)
parser.add_argument("--test", action="store_true", default=False)
parser.add_argument("--model_path", type=str, default="results/model.pt")
parser.add_argument("--results_path", type=str, default="results/results.csv")
if __name__ == "__main__": if __name__ == "__main__":
args = parser.parse_args()
# loading & spliting data # loading & spliting data
news_dataset = NewsDataset(data_dir_path="data", data_lenght=2000) news_dataset = NewsDataset(data_dir_path="data", data_lenght=2000)
train_val_data, test_data = train_test_split( train_val_data, test_data = train_test_split(
news_dataset.data, news_dataset.data,
test_size=0.8, test_size=0.2,
shuffle=True, shuffle=True,
random_state=random.seed(SEED), random_state=random.seed(SEED),
) )
train_data, val_data = train_test_split( train_data, val_data = train_test_split(
train_val_data, train_val_data,
test_size=0.2, test_size=0.2,
shuffle=True, shuffle=True,
random_state=random.seed(SEED), random_state=random.seed(SEED),
) )
# trainig model # trainig model
trained_model = train( if args.train:
model=BertClassifier(), trained_model = train(
train_data=train_data, model=BertClassifier(),
val_data=val_data, train_data=test_data,
learning_rate=INITIAL_LR, val_data=val_data,
epochs=NUM_EPOCHS, learning_rate=INITIAL_LR,
batch_size=BATCH_SIZE, epochs=NUM_EPOCHS,
) batch_size=BATCH_SIZE,
)
utils.save_model(model=trained_model, model_path=args.model_path)
# evaluating model # evaluating model
evaluate( if args.test:
model=trained_model, model = utils.load_model(model=BertClassifier(), model_path=args.model_path) # loading model from model.pt file
test_data=test_data, results = evaluate(
batch_size=BATCH_SIZE, model=model,
) test_data=test_data,
batch_size=BATCH_SIZE,
)
utils.save_results(labels=test_data["label"], results=results, file_path=args.results_path)

View File

@ -1,3 +1,4 @@
__all__ = ["BertClassifier"] __all__ = ["BertClassifier", "utils"]
from .bert_model import BertClassifier from .bert_model import BertClassifier
from .utils import utils

34
src/models/utils.py Normal file
View File

@ -0,0 +1,34 @@
from pathlib import Path
import torch
import torch.nn as nn
import pandas as pd
class utils:
def __init__(self) -> None:
pass
@staticmethod
def save_model(model: nn.Module, model_path: str) -> None:
model_path = Path(model_path)
model_path.parent.mkdir(parents=True, exist_ok=True)
torch.save(model.state_dict(), model_path)
print(f"[INFO]\t Model saved at: {model_path}")
@staticmethod
def load_model(model: nn.Module, model_path: str) -> nn.Module:
model_path = Path(model_path)
model_path.parent.mkdir(parents=True, exist_ok=True)
model.load_state_dict(torch.load(model_path))
return model
@staticmethod
def save_results(labels: list[int], results: list[int], file_path: str) -> None:
file_path = Path(file_path)
file_path.parent.mkdir(parents=True, exist_ok=True)
df = pd.DataFrame({"labels": labels, "results": results})
df.to_csv(file_path, index=False)