feat: add sacred implementation
This commit is contained in:
parent
65a60c9dae
commit
34f3b4764d
|
@ -15,5 +15,6 @@ dependencies:
|
|||
- pytorch-cuda=11.8
|
||||
- transformers
|
||||
- matplotlib
|
||||
|
||||
|
||||
- pymongo
|
||||
- pip:
|
||||
- sacred==0.8.4
|
||||
|
|
38
src/main.py
38
src/main.py
|
@ -1,17 +1,13 @@
|
|||
# import random
|
||||
|
||||
# from sklearn.model_selection import train_test_split
|
||||
import argparse
|
||||
import torch
|
||||
from models import BertClassifier, utils
|
||||
from datasets import NewsDataset
|
||||
from train import train
|
||||
from sacred import Experiment
|
||||
from evaluate import evaluate
|
||||
|
||||
# SEED = 2137
|
||||
|
||||
# argument parser
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="News classification",
|
||||
description="Train or evaluate model",
|
||||
|
@ -29,12 +25,29 @@ parser.add_argument("--learning_rate", "--lr", type=float, default=1e-6)
|
|||
parser.add_argument("--num_epochs", "--epochs", "-e", type=int, default=3)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# sacred stuff
|
||||
ex = Experiment("s424714")
|
||||
ex.add_source_file("./src/train.py")
|
||||
|
||||
|
||||
@ex.main
|
||||
def main(_run):
|
||||
args = parser.parse_args()
|
||||
|
||||
ex.open_resource(filename="./data/dataset/train.csv", mode="r")
|
||||
ex.open_resource(filename="./data/dataset/test.csv", mode="r")
|
||||
ex.open_resource(filename="./data/dataset/val.csv", mode="r")
|
||||
|
||||
INITIAL_LR = args.learning_rate
|
||||
NUM_EPOCHS = args.num_epochs
|
||||
BATCH_SIZE = args.batch
|
||||
print(BATCH_SIZE)
|
||||
|
||||
@ex.config
|
||||
def hyper_parameters():
|
||||
initial_lr = INITIAL_LR # noqa: F841
|
||||
num_epochs = NUM_EPOCHS # noqa: F841
|
||||
batch_size = BATCH_SIZE # noqa: F841
|
||||
|
||||
print("INITIAL_LR: ", INITIAL_LR)
|
||||
print("NUM_EPOCHS: ", NUM_EPOCHS)
|
||||
|
@ -65,15 +78,20 @@ if __name__ == "__main__":
|
|||
|
||||
# trainig model
|
||||
if args.train:
|
||||
trained_model = train(
|
||||
trained_model, metrics = train(
|
||||
model=BertClassifier(),
|
||||
train_data=test_data,
|
||||
train_data=train_data,
|
||||
val_data=val_data,
|
||||
learning_rate=INITIAL_LR,
|
||||
epochs=NUM_EPOCHS,
|
||||
batch_size=BATCH_SIZE,
|
||||
)
|
||||
utils.save_model(model=trained_model, model_path=args.model_path)
|
||||
ex.add_artifact(args.model_path)
|
||||
_run.log_scalar("train_loss", metrics["train_loss"])
|
||||
_run.log_scalar("val_loss", metrics["val_loss"])
|
||||
_run.log_scalar("train_acc", metrics["train_acc"])
|
||||
_run.log_scalar("val_acc", metrics["val_acc"])
|
||||
|
||||
# evaluating model
|
||||
if args.test:
|
||||
|
@ -89,3 +107,7 @@ if __name__ == "__main__":
|
|||
build_id=int(args.build_id),
|
||||
data=accuracy,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
ex.run()
|
||||
|
|
|
@ -84,5 +84,10 @@ def train(
|
|||
| Val Loss: {total_loss_val / len(val_data): .3f} \
|
||||
| Val Accuracy: {total_acc_val / len(val_data): .3f}"
|
||||
)
|
||||
|
||||
return model
|
||||
metrics = {
|
||||
"train_acc": total_acc_train / len(train_data),
|
||||
"train_loss": total_loss_train / len(train_data),
|
||||
"val_acc": total_acc_val / len(val_data),
|
||||
"val_loss": total_loss_val / len(val_data),
|
||||
}
|
||||
return model, metrics
|
||||
|
|
Loading…
Reference in New Issue