create saving model and results, add arguments parser for main scipt
This commit is contained in:
parent
fb24ad5bca
commit
4561c65980
1
.gitignore
vendored
1
.gitignore
vendored
@ -1,3 +1,4 @@
|
|||||||
.vscode
|
.vscode
|
||||||
__pycache__
|
__pycache__
|
||||||
data/
|
data/
|
||||||
|
results/
|
@ -15,7 +15,7 @@ def evaluate(
|
|||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
test_data: pd.DataFrame,
|
test_data: pd.DataFrame,
|
||||||
batch_size: int,
|
batch_size: int,
|
||||||
) -> None:
|
) -> list[int]:
|
||||||
test_dataset = Dataset(test_data)
|
test_dataset = Dataset(test_data)
|
||||||
|
|
||||||
test_dataloader = torch.utils.data.DataLoader(
|
test_dataloader = torch.utils.data.DataLoader(
|
||||||
@ -26,8 +26,11 @@ def evaluate(
|
|||||||
shuffle=True,
|
shuffle=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
model.to(DEVICE)
|
|
||||||
total_acc_test = 0
|
total_acc_test = 0
|
||||||
|
results = []
|
||||||
|
|
||||||
|
model.to(DEVICE)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for test_input, test_label in test_dataloader:
|
for test_input, test_label in test_dataloader:
|
||||||
@ -38,10 +41,8 @@ def evaluate(
|
|||||||
output = model(input_id, mask)
|
output = model(input_id, mask)
|
||||||
|
|
||||||
acc = (output.argmax(dim=1) == test_label).sum().item()
|
acc = (output.argmax(dim=1) == test_label).sum().item()
|
||||||
|
results.extend(output.argmax(dim=1).tolist())
|
||||||
total_acc_test += acc
|
total_acc_test += acc
|
||||||
|
|
||||||
print(f"Test Accuracy: {total_acc_test / len(test_data): .3f}")
|
print(f"Test Accuracy: {total_acc_test / len(test_data): .3f}")
|
||||||
|
return results
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
pass
|
|
||||||
|
35
src/main.py
35
src/main.py
@ -1,7 +1,8 @@
|
|||||||
import random
|
import random
|
||||||
from sklearn.model_selection import train_test_split
|
from sklearn.model_selection import train_test_split
|
||||||
|
import argparse
|
||||||
|
|
||||||
from models import BertClassifier
|
from models import BertClassifier, utils
|
||||||
from datasets import NewsDataset
|
from datasets import NewsDataset
|
||||||
from train import train
|
from train import train
|
||||||
from evaluate import evaluate
|
from evaluate import evaluate
|
||||||
@ -11,38 +12,60 @@ SEED = 2137
|
|||||||
# Hyperparameters
|
# Hyperparameters
|
||||||
|
|
||||||
INITIAL_LR = 1e-6
|
INITIAL_LR = 1e-6
|
||||||
NUM_EPOCHS = 5
|
NUM_EPOCHS = 3
|
||||||
BATCH_SIZE = 2
|
BATCH_SIZE = 2
|
||||||
|
|
||||||
|
|
||||||
|
# argument parser
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
prog="News classification",
|
||||||
|
description="Train or evaluate model",
|
||||||
|
)
|
||||||
|
parser.add_argument("--train", action="store_true", default=False)
|
||||||
|
parser.add_argument("--test", action="store_true", default=False)
|
||||||
|
parser.add_argument("--model_path", type=str, default="results/model.pt")
|
||||||
|
parser.add_argument("--results_path", type=str, default="results/results.csv")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
# loading & spliting data
|
# loading & spliting data
|
||||||
news_dataset = NewsDataset(data_dir_path="data", data_lenght=2000)
|
news_dataset = NewsDataset(data_dir_path="data", data_lenght=2000)
|
||||||
|
|
||||||
train_val_data, test_data = train_test_split(
|
train_val_data, test_data = train_test_split(
|
||||||
news_dataset.data,
|
news_dataset.data,
|
||||||
test_size=0.8,
|
test_size=0.2,
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
random_state=random.seed(SEED),
|
random_state=random.seed(SEED),
|
||||||
)
|
)
|
||||||
|
|
||||||
train_data, val_data = train_test_split(
|
train_data, val_data = train_test_split(
|
||||||
train_val_data,
|
train_val_data,
|
||||||
test_size=0.2,
|
test_size=0.2,
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
random_state=random.seed(SEED),
|
random_state=random.seed(SEED),
|
||||||
)
|
)
|
||||||
|
|
||||||
# trainig model
|
# trainig model
|
||||||
|
if args.train:
|
||||||
trained_model = train(
|
trained_model = train(
|
||||||
model=BertClassifier(),
|
model=BertClassifier(),
|
||||||
train_data=train_data,
|
train_data=test_data,
|
||||||
val_data=val_data,
|
val_data=val_data,
|
||||||
learning_rate=INITIAL_LR,
|
learning_rate=INITIAL_LR,
|
||||||
epochs=NUM_EPOCHS,
|
epochs=NUM_EPOCHS,
|
||||||
batch_size=BATCH_SIZE,
|
batch_size=BATCH_SIZE,
|
||||||
)
|
)
|
||||||
|
utils.save_model(model=trained_model, model_path=args.model_path)
|
||||||
|
|
||||||
# evaluating model
|
# evaluating model
|
||||||
evaluate(
|
if args.test:
|
||||||
model=trained_model,
|
model = utils.load_model(model=BertClassifier(), model_path=args.model_path) # loading model from model.pt file
|
||||||
|
results = evaluate(
|
||||||
|
model=model,
|
||||||
test_data=test_data,
|
test_data=test_data,
|
||||||
batch_size=BATCH_SIZE,
|
batch_size=BATCH_SIZE,
|
||||||
)
|
)
|
||||||
|
utils.save_results(labels=test_data["label"], results=results, file_path=args.results_path)
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
__all__ = ["BertClassifier"]
|
__all__ = ["BertClassifier", "utils"]
|
||||||
|
|
||||||
from .bert_model import BertClassifier
|
from .bert_model import BertClassifier
|
||||||
|
from .utils import utils
|
||||||
|
34
src/models/utils.py
Normal file
34
src/models/utils.py
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
|
||||||
|
class utils:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def save_model(model: nn.Module, model_path: str) -> None:
|
||||||
|
model_path = Path(model_path)
|
||||||
|
model_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
torch.save(model.state_dict(), model_path)
|
||||||
|
print(f"[INFO]\t Model saved at: {model_path}")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def load_model(model: nn.Module, model_path: str) -> nn.Module:
|
||||||
|
model_path = Path(model_path)
|
||||||
|
model_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
model.load_state_dict(torch.load(model_path))
|
||||||
|
return model
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def save_results(labels: list[int], results: list[int], file_path: str) -> None:
|
||||||
|
file_path = Path(file_path)
|
||||||
|
file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
df = pd.DataFrame({"labels": labels, "results": results})
|
||||||
|
df.to_csv(file_path, index=False)
|
Loading…
Reference in New Issue
Block a user