feat: add sacred implementation

This commit is contained in:
Filip Patyk 2023-05-11 00:19:17 +02:00
parent 65a60c9dae
commit 34f3b4764d
3 changed files with 40 additions and 12 deletions

View File

@ -15,5 +15,6 @@ dependencies:
- pytorch-cuda=11.8
- transformers
- matplotlib
- pymongo
- pip:
- sacred==0.8.4

View File

@ -1,17 +1,13 @@
# import random
# from sklearn.model_selection import train_test_split
import argparse
import torch
from models import BertClassifier, utils
from datasets import NewsDataset
from train import train
from sacred import Experiment
from evaluate import evaluate
# SEED = 2137
# argument parser
parser = argparse.ArgumentParser(
prog="News classification",
description="Train or evaluate model",
@ -29,12 +25,29 @@ parser.add_argument("--learning_rate", "--lr", type=float, default=1e-6)
parser.add_argument("--num_epochs", "--epochs", "-e", type=int, default=3)
if __name__ == "__main__":
# sacred stuff
ex = Experiment("s424714")
ex.add_source_file("./src/train.py")
@ex.main
def main(_run):
args = parser.parse_args()
ex.open_resource(filename="./data/dataset/train.csv", mode="r")
ex.open_resource(filename="./data/dataset/test.csv", mode="r")
ex.open_resource(filename="./data/dataset/val.csv", mode="r")
INITIAL_LR = args.learning_rate
NUM_EPOCHS = args.num_epochs
BATCH_SIZE = args.batch
print(BATCH_SIZE)
@ex.config
def hyper_parameters():
initial_lr = INITIAL_LR # noqa: F841
num_epochs = NUM_EPOCHS # noqa: F841
batch_size = BATCH_SIZE # noqa: F841
print("INITIAL_LR: ", INITIAL_LR)
print("NUM_EPOCHS: ", NUM_EPOCHS)
@ -65,15 +78,20 @@ if __name__ == "__main__":
# trainig model
if args.train:
trained_model = train(
trained_model, metrics = train(
model=BertClassifier(),
train_data=test_data,
train_data=train_data,
val_data=val_data,
learning_rate=INITIAL_LR,
epochs=NUM_EPOCHS,
batch_size=BATCH_SIZE,
)
utils.save_model(model=trained_model, model_path=args.model_path)
ex.add_artifact(args.model_path)
_run.log_scalar("train_loss", metrics["train_loss"])
_run.log_scalar("val_loss", metrics["val_loss"])
_run.log_scalar("train_acc", metrics["train_acc"])
_run.log_scalar("val_acc", metrics["val_acc"])
# evaluating model
if args.test:
@ -89,3 +107,7 @@ if __name__ == "__main__":
build_id=int(args.build_id),
data=accuracy,
)
if __name__ == "__main__":
ex.run()

View File

@ -84,5 +84,10 @@ def train(
| Val Loss: {total_loss_val / len(val_data): .3f} \
| Val Accuracy: {total_acc_val / len(val_data): .3f}"
)
return model
metrics = {
"train_acc": total_acc_train / len(train_data),
"train_loss": total_loss_train / len(train_data),
"val_acc": total_acc_val / len(val_data),
"val_loss": total_loss_val / len(val_data),
}
return model, metrics