feat: add sacred implementation

This commit is contained in:
Filip Patyk 2023-05-11 00:19:17 +02:00
parent 65a60c9dae
commit 34f3b4764d
3 changed files with 40 additions and 12 deletions

View File

@ -15,5 +15,6 @@ dependencies:
- pytorch-cuda=11.8 - pytorch-cuda=11.8
- transformers - transformers
- matplotlib - matplotlib
- pymongo
- pip:
- sacred==0.8.4

View File

@ -1,17 +1,13 @@
# import random
# from sklearn.model_selection import train_test_split
import argparse import argparse
import torch import torch
from models import BertClassifier, utils from models import BertClassifier, utils
from datasets import NewsDataset from datasets import NewsDataset
from train import train from train import train
from sacred import Experiment
from evaluate import evaluate from evaluate import evaluate
# SEED = 2137
# argument parser # argument parser
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
prog="News classification", prog="News classification",
description="Train or evaluate model", description="Train or evaluate model",
@ -29,12 +25,29 @@ parser.add_argument("--learning_rate", "--lr", type=float, default=1e-6)
parser.add_argument("--num_epochs", "--epochs", "-e", type=int, default=3) parser.add_argument("--num_epochs", "--epochs", "-e", type=int, default=3)
if __name__ == "__main__": # sacred stuff
ex = Experiment("s424714")
ex.add_source_file("./src/train.py")
@ex.main
def main(_run):
args = parser.parse_args() args = parser.parse_args()
ex.open_resource(filename="./data/dataset/train.csv", mode="r")
ex.open_resource(filename="./data/dataset/test.csv", mode="r")
ex.open_resource(filename="./data/dataset/val.csv", mode="r")
INITIAL_LR = args.learning_rate INITIAL_LR = args.learning_rate
NUM_EPOCHS = args.num_epochs NUM_EPOCHS = args.num_epochs
BATCH_SIZE = args.batch BATCH_SIZE = args.batch
print(BATCH_SIZE)
@ex.config
def hyper_parameters():
initial_lr = INITIAL_LR # noqa: F841
num_epochs = NUM_EPOCHS # noqa: F841
batch_size = BATCH_SIZE # noqa: F841
print("INITIAL_LR: ", INITIAL_LR) print("INITIAL_LR: ", INITIAL_LR)
print("NUM_EPOCHS: ", NUM_EPOCHS) print("NUM_EPOCHS: ", NUM_EPOCHS)
@ -65,15 +78,20 @@ if __name__ == "__main__":
# trainig model # trainig model
if args.train: if args.train:
trained_model = train( trained_model, metrics = train(
model=BertClassifier(), model=BertClassifier(),
train_data=test_data, train_data=train_data,
val_data=val_data, val_data=val_data,
learning_rate=INITIAL_LR, learning_rate=INITIAL_LR,
epochs=NUM_EPOCHS, epochs=NUM_EPOCHS,
batch_size=BATCH_SIZE, batch_size=BATCH_SIZE,
) )
utils.save_model(model=trained_model, model_path=args.model_path) utils.save_model(model=trained_model, model_path=args.model_path)
ex.add_artifact(args.model_path)
_run.log_scalar("train_loss", metrics["train_loss"])
_run.log_scalar("val_loss", metrics["val_loss"])
_run.log_scalar("train_acc", metrics["train_acc"])
_run.log_scalar("val_acc", metrics["val_acc"])
# evaluating model # evaluating model
if args.test: if args.test:
@ -89,3 +107,7 @@ if __name__ == "__main__":
build_id=int(args.build_id), build_id=int(args.build_id),
data=accuracy, data=accuracy,
) )
if __name__ == "__main__":
ex.run()

View File

@ -84,5 +84,10 @@ def train(
| Val Loss: {total_loss_val / len(val_data): .3f} \ | Val Loss: {total_loss_val / len(val_data): .3f} \
| Val Accuracy: {total_acc_val / len(val_data): .3f}" | Val Accuracy: {total_acc_val / len(val_data): .3f}"
) )
metrics = {
return model "train_acc": total_acc_train / len(train_data),
"train_loss": total_loss_train / len(train_data),
"val_acc": total_acc_val / len(val_data),
"val_loss": total_loss_val / len(val_data),
}
return model, metrics