This commit is contained in:
Filip Patyk 2023-05-12 19:55:19 +02:00
parent 261c5ca6e2
commit b679e7d9c6
3 changed files with 25 additions and 34 deletions

1
.gitignore vendored
View File

@ -3,3 +3,4 @@ __pycache__
data/
results/
sacred/
mlruns/

View File

@ -16,5 +16,7 @@ dependencies:
- transformers
- matplotlib
- pymongo
- mlflow
- pip
- pip:
- sacred==0.8.4

View File

@ -1,13 +1,11 @@
import argparse
import shutil
import mlflow
import torch
from sacred.observers import FileStorageObserver, MongoObserver
from datasets import NewsDataset
from evaluate import evaluate
from models import BertClassifier, utils
from sacred import Experiment
from train import train
# argument parser
@ -28,40 +26,27 @@ parser.add_argument("--learning_rate", "--lr", type=float, default=1e-6)
parser.add_argument("--num_epochs", "--epochs", "-e", type=int, default=3)
# sacred stuff
ex = Experiment(name="s424714", interactive=True)
SACRED_DIR_PATH = "./sacred"
if not torch.cuda.is_available():
ex.observers.append(MongoObserver(url="mongodb://admin:IUM_2021@172.17.0.1:27017", db_name="sacred"))
# ex.observers.append(MongoObserver(url="mongodb://admin:IUM_2021@172.17.0.1:27017", db_name="sacred"))
ex.observers.append(FileStorageObserver(SACRED_DIR_PATH))
ex.add_source_file("./src/train.py")
# mlflow stuff
mlflow.set_tracking_uri("http://localhost:5000")
mlflow.set_experiment("s424714")
@ex.main
def main(_run):
def main():
args = parser.parse_args()
ex.open_resource(filename="./data/dataset/train.csv", mode="r")
ex.open_resource(filename="./data/dataset/test.csv", mode="r")
ex.open_resource(filename="./data/dataset/val.csv", mode="r")
INITIAL_LR = args.learning_rate
NUM_EPOCHS = args.num_epochs
BATCH_SIZE = args.batch
print(BATCH_SIZE)
@ex.config
def hyper_parameters():
initial_lr = INITIAL_LR # noqa: F841
num_epochs = NUM_EPOCHS # noqa: F841
batch_size = BATCH_SIZE # noqa: F841
print("INITIAL_LR: ", INITIAL_LR)
print("NUM_EPOCHS: ", NUM_EPOCHS)
print("BATCH_SIZE: ", BATCH_SIZE)
print("CUDA: ", torch.cuda.is_available())
print("CUDA: ", cuda := torch.cuda.is_available())
mlflow.log_param("INITIAL_LR", INITIAL_LR)
mlflow.log_param("NUM_EPOCHS", NUM_EPOCHS)
mlflow.log_param("BATCH_SIZE", BATCH_SIZE)
mlflow.log_param("CUDA", cuda)
# raise
# loading & spliting data
@ -96,11 +81,12 @@ def main(_run):
batch_size=BATCH_SIZE,
)
utils.save_model(model=trained_model, model_path=args.model_path)
ex.add_artifact(args.model_path)
_run.log_scalar("train_loss", metrics["train_loss"])
_run.log_scalar("val_loss", metrics["val_loss"])
_run.log_scalar("train_acc", metrics["train_acc"])
_run.log_scalar("val_acc", metrics["val_acc"])
# mlflow saving metrics
mlflow.log_metric("train_acc", metrics["train_acc"])
mlflow.log_metric("train_loss", metrics["train_loss"])
mlflow.log_metric("val_acc", metrics["val_acc"])
mlflow.log_metric("val_loss", metrics["val_loss"])
# evaluating model
if args.test:
@ -119,5 +105,7 @@ def main(_run):
if __name__ == "__main__":
ex.run()
shutil.make_archive(base_name="./results/sacred-artifacts", format="zip", root_dir=SACRED_DIR_PATH)
with mlflow.start_run() as run:
print("MLflow run experiment_id: {0}".format(run.info.experiment_id))
print("MLflow run artifact_uri: {0}".format(run.info.artifact_uri))
main()