feat: add sacred implementation
This commit is contained in:
parent
65a60c9dae
commit
34f3b4764d
|
@ -15,5 +15,6 @@ dependencies:
|
||||||
- pytorch-cuda=11.8
|
- pytorch-cuda=11.8
|
||||||
- transformers
|
- transformers
|
||||||
- matplotlib
|
- matplotlib
|
||||||
|
- pymongo
|
||||||
|
- pip:
|
||||||
|
- sacred==0.8.4
|
||||||
|
|
38
src/main.py
38
src/main.py
|
@ -1,17 +1,13 @@
|
||||||
# import random
|
|
||||||
|
|
||||||
# from sklearn.model_selection import train_test_split
|
|
||||||
import argparse
|
import argparse
|
||||||
import torch
|
import torch
|
||||||
from models import BertClassifier, utils
|
from models import BertClassifier, utils
|
||||||
from datasets import NewsDataset
|
from datasets import NewsDataset
|
||||||
from train import train
|
from train import train
|
||||||
|
from sacred import Experiment
|
||||||
from evaluate import evaluate
|
from evaluate import evaluate
|
||||||
|
|
||||||
# SEED = 2137
|
|
||||||
|
|
||||||
# argument parser
|
# argument parser
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
prog="News classification",
|
prog="News classification",
|
||||||
description="Train or evaluate model",
|
description="Train or evaluate model",
|
||||||
|
@ -29,12 +25,29 @@ parser.add_argument("--learning_rate", "--lr", type=float, default=1e-6)
|
||||||
parser.add_argument("--num_epochs", "--epochs", "-e", type=int, default=3)
|
parser.add_argument("--num_epochs", "--epochs", "-e", type=int, default=3)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
# sacred stuff
|
||||||
|
ex = Experiment("s424714")
|
||||||
|
ex.add_source_file("./src/train.py")
|
||||||
|
|
||||||
|
|
||||||
|
@ex.main
|
||||||
|
def main(_run):
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
ex.open_resource(filename="./data/dataset/train.csv", mode="r")
|
||||||
|
ex.open_resource(filename="./data/dataset/test.csv", mode="r")
|
||||||
|
ex.open_resource(filename="./data/dataset/val.csv", mode="r")
|
||||||
|
|
||||||
INITIAL_LR = args.learning_rate
|
INITIAL_LR = args.learning_rate
|
||||||
NUM_EPOCHS = args.num_epochs
|
NUM_EPOCHS = args.num_epochs
|
||||||
BATCH_SIZE = args.batch
|
BATCH_SIZE = args.batch
|
||||||
|
print(BATCH_SIZE)
|
||||||
|
|
||||||
|
@ex.config
|
||||||
|
def hyper_parameters():
|
||||||
|
initial_lr = INITIAL_LR # noqa: F841
|
||||||
|
num_epochs = NUM_EPOCHS # noqa: F841
|
||||||
|
batch_size = BATCH_SIZE # noqa: F841
|
||||||
|
|
||||||
print("INITIAL_LR: ", INITIAL_LR)
|
print("INITIAL_LR: ", INITIAL_LR)
|
||||||
print("NUM_EPOCHS: ", NUM_EPOCHS)
|
print("NUM_EPOCHS: ", NUM_EPOCHS)
|
||||||
|
@ -65,15 +78,20 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
# trainig model
|
# trainig model
|
||||||
if args.train:
|
if args.train:
|
||||||
trained_model = train(
|
trained_model, metrics = train(
|
||||||
model=BertClassifier(),
|
model=BertClassifier(),
|
||||||
train_data=test_data,
|
train_data=train_data,
|
||||||
val_data=val_data,
|
val_data=val_data,
|
||||||
learning_rate=INITIAL_LR,
|
learning_rate=INITIAL_LR,
|
||||||
epochs=NUM_EPOCHS,
|
epochs=NUM_EPOCHS,
|
||||||
batch_size=BATCH_SIZE,
|
batch_size=BATCH_SIZE,
|
||||||
)
|
)
|
||||||
utils.save_model(model=trained_model, model_path=args.model_path)
|
utils.save_model(model=trained_model, model_path=args.model_path)
|
||||||
|
ex.add_artifact(args.model_path)
|
||||||
|
_run.log_scalar("train_loss", metrics["train_loss"])
|
||||||
|
_run.log_scalar("val_loss", metrics["val_loss"])
|
||||||
|
_run.log_scalar("train_acc", metrics["train_acc"])
|
||||||
|
_run.log_scalar("val_acc", metrics["val_acc"])
|
||||||
|
|
||||||
# evaluating model
|
# evaluating model
|
||||||
if args.test:
|
if args.test:
|
||||||
|
@ -89,3 +107,7 @@ if __name__ == "__main__":
|
||||||
build_id=int(args.build_id),
|
build_id=int(args.build_id),
|
||||||
data=accuracy,
|
data=accuracy,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
ex.run()
|
||||||
|
|
|
@ -84,5 +84,10 @@ def train(
|
||||||
| Val Loss: {total_loss_val / len(val_data): .3f} \
|
| Val Loss: {total_loss_val / len(val_data): .3f} \
|
||||||
| Val Accuracy: {total_acc_val / len(val_data): .3f}"
|
| Val Accuracy: {total_acc_val / len(val_data): .3f}"
|
||||||
)
|
)
|
||||||
|
metrics = {
|
||||||
return model
|
"train_acc": total_acc_train / len(train_data),
|
||||||
|
"train_loss": total_loss_train / len(train_data),
|
||||||
|
"val_acc": total_acc_val / len(val_data),
|
||||||
|
"val_loss": total_loss_val / len(val_data),
|
||||||
|
}
|
||||||
|
return model, metrics
|
||||||
|
|
Loading…
Reference in New Issue