create saving model and results, add arguments parser for main scipt
This commit is contained in:
parent
fb24ad5bca
commit
4561c65980
3
.gitignore
vendored
3
.gitignore
vendored
@ -1,3 +1,4 @@
|
||||
.vscode
|
||||
__pycache__
|
||||
data/
|
||||
data/
|
||||
results/
|
@ -15,7 +15,7 @@ def evaluate(
|
||||
model: nn.Module,
|
||||
test_data: pd.DataFrame,
|
||||
batch_size: int,
|
||||
) -> None:
|
||||
) -> list[int]:
|
||||
test_dataset = Dataset(test_data)
|
||||
|
||||
test_dataloader = torch.utils.data.DataLoader(
|
||||
@ -26,8 +26,11 @@ def evaluate(
|
||||
shuffle=True,
|
||||
)
|
||||
|
||||
model.to(DEVICE)
|
||||
total_acc_test = 0
|
||||
results = []
|
||||
|
||||
model.to(DEVICE)
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
for test_input, test_label in test_dataloader:
|
||||
@ -38,10 +41,8 @@ def evaluate(
|
||||
output = model(input_id, mask)
|
||||
|
||||
acc = (output.argmax(dim=1) == test_label).sum().item()
|
||||
results.extend(output.argmax(dim=1).tolist())
|
||||
total_acc_test += acc
|
||||
|
||||
print(f"Test Accuracy: {total_acc_test / len(test_data): .3f}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pass
|
||||
return results
|
||||
|
55
src/main.py
55
src/main.py
@ -1,7 +1,8 @@
|
||||
import random
|
||||
from sklearn.model_selection import train_test_split
|
||||
import argparse
|
||||
|
||||
from models import BertClassifier
|
||||
from models import BertClassifier, utils
|
||||
from datasets import NewsDataset
|
||||
from train import train
|
||||
from evaluate import evaluate
|
||||
@ -11,38 +12,60 @@ SEED = 2137
|
||||
# Hyperparameters
|
||||
|
||||
INITIAL_LR = 1e-6
|
||||
NUM_EPOCHS = 5
|
||||
NUM_EPOCHS = 3
|
||||
BATCH_SIZE = 2
|
||||
|
||||
|
||||
# argument parser
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="News classification",
|
||||
description="Train or evaluate model",
|
||||
)
|
||||
parser.add_argument("--train", action="store_true", default=False)
|
||||
parser.add_argument("--test", action="store_true", default=False)
|
||||
parser.add_argument("--model_path", type=str, default="results/model.pt")
|
||||
parser.add_argument("--results_path", type=str, default="results/results.csv")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
|
||||
# loading & spliting data
|
||||
news_dataset = NewsDataset(data_dir_path="data", data_lenght=2000)
|
||||
|
||||
train_val_data, test_data = train_test_split(
|
||||
news_dataset.data,
|
||||
test_size=0.8,
|
||||
test_size=0.2,
|
||||
shuffle=True,
|
||||
random_state=random.seed(SEED),
|
||||
)
|
||||
|
||||
train_data, val_data = train_test_split(
|
||||
train_val_data,
|
||||
test_size=0.2,
|
||||
shuffle=True,
|
||||
random_state=random.seed(SEED),
|
||||
)
|
||||
|
||||
# trainig model
|
||||
trained_model = train(
|
||||
model=BertClassifier(),
|
||||
train_data=train_data,
|
||||
val_data=val_data,
|
||||
learning_rate=INITIAL_LR,
|
||||
epochs=NUM_EPOCHS,
|
||||
batch_size=BATCH_SIZE,
|
||||
)
|
||||
if args.train:
|
||||
trained_model = train(
|
||||
model=BertClassifier(),
|
||||
train_data=test_data,
|
||||
val_data=val_data,
|
||||
learning_rate=INITIAL_LR,
|
||||
epochs=NUM_EPOCHS,
|
||||
batch_size=BATCH_SIZE,
|
||||
)
|
||||
utils.save_model(model=trained_model, model_path=args.model_path)
|
||||
|
||||
# evaluating model
|
||||
evaluate(
|
||||
model=trained_model,
|
||||
test_data=test_data,
|
||||
batch_size=BATCH_SIZE,
|
||||
)
|
||||
if args.test:
|
||||
model = utils.load_model(model=BertClassifier(), model_path=args.model_path) # loading model from model.pt file
|
||||
results = evaluate(
|
||||
model=model,
|
||||
test_data=test_data,
|
||||
batch_size=BATCH_SIZE,
|
||||
)
|
||||
utils.save_results(labels=test_data["label"], results=results, file_path=args.results_path)
|
||||
|
@ -1,3 +1,4 @@
|
||||
__all__ = ["BertClassifier"]
|
||||
__all__ = ["BertClassifier", "utils"]
|
||||
|
||||
from .bert_model import BertClassifier
|
||||
from .utils import utils
|
||||
|
34
src/models/utils.py
Normal file
34
src/models/utils.py
Normal file
@ -0,0 +1,34 @@
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import pandas as pd
|
||||
|
||||
|
||||
class utils:
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def save_model(model: nn.Module, model_path: str) -> None:
|
||||
model_path = Path(model_path)
|
||||
model_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
torch.save(model.state_dict(), model_path)
|
||||
print(f"[INFO]\t Model saved at: {model_path}")
|
||||
|
||||
@staticmethod
|
||||
def load_model(model: nn.Module, model_path: str) -> nn.Module:
|
||||
model_path = Path(model_path)
|
||||
model_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
model.load_state_dict(torch.load(model_path))
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def save_results(labels: list[int], results: list[int], file_path: str) -> None:
|
||||
file_path = Path(file_path)
|
||||
file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
df = pd.DataFrame({"labels": labels, "results": results})
|
||||
df.to_csv(file_path, index=False)
|
Loading…
Reference in New Issue
Block a user