mlflow
This commit is contained in:
parent
261c5ca6e2
commit
b679e7d9c6
1
.gitignore
vendored
1
.gitignore
vendored
@ -3,3 +3,4 @@ __pycache__
|
||||
data/
|
||||
results/
|
||||
sacred/
|
||||
mlruns/
|
@ -16,5 +16,7 @@ dependencies:
|
||||
- transformers
|
||||
- matplotlib
|
||||
- pymongo
|
||||
- mlflow
|
||||
- pip
|
||||
- pip:
|
||||
- sacred==0.8.4
|
||||
|
54
src/main.py
54
src/main.py
@ -1,13 +1,11 @@
|
||||
import argparse
|
||||
import shutil
|
||||
|
||||
import mlflow
|
||||
import torch
|
||||
from sacred.observers import FileStorageObserver, MongoObserver
|
||||
|
||||
from datasets import NewsDataset
|
||||
from evaluate import evaluate
|
||||
from models import BertClassifier, utils
|
||||
from sacred import Experiment
|
||||
from train import train
|
||||
|
||||
# argument parser
|
||||
@ -28,40 +26,27 @@ parser.add_argument("--learning_rate", "--lr", type=float, default=1e-6)
|
||||
parser.add_argument("--num_epochs", "--epochs", "-e", type=int, default=3)
|
||||
|
||||
|
||||
# sacred stuff
|
||||
ex = Experiment(name="s424714", interactive=True)
|
||||
SACRED_DIR_PATH = "./sacred"
|
||||
if not torch.cuda.is_available():
|
||||
ex.observers.append(MongoObserver(url="mongodb://admin:IUM_2021@172.17.0.1:27017", db_name="sacred"))
|
||||
# ex.observers.append(MongoObserver(url="mongodb://admin:IUM_2021@172.17.0.1:27017", db_name="sacred"))
|
||||
ex.observers.append(FileStorageObserver(SACRED_DIR_PATH))
|
||||
|
||||
ex.add_source_file("./src/train.py")
|
||||
# mlflow stuff
|
||||
mlflow.set_tracking_uri("http://localhost:5000")
|
||||
mlflow.set_experiment("s424714")
|
||||
|
||||
|
||||
@ex.main
|
||||
def main(_run):
|
||||
def main():
|
||||
args = parser.parse_args()
|
||||
|
||||
ex.open_resource(filename="./data/dataset/train.csv", mode="r")
|
||||
ex.open_resource(filename="./data/dataset/test.csv", mode="r")
|
||||
ex.open_resource(filename="./data/dataset/val.csv", mode="r")
|
||||
|
||||
INITIAL_LR = args.learning_rate
|
||||
NUM_EPOCHS = args.num_epochs
|
||||
BATCH_SIZE = args.batch
|
||||
print(BATCH_SIZE)
|
||||
|
||||
@ex.config
|
||||
def hyper_parameters():
|
||||
initial_lr = INITIAL_LR # noqa: F841
|
||||
num_epochs = NUM_EPOCHS # noqa: F841
|
||||
batch_size = BATCH_SIZE # noqa: F841
|
||||
|
||||
print("INITIAL_LR: ", INITIAL_LR)
|
||||
print("NUM_EPOCHS: ", NUM_EPOCHS)
|
||||
print("BATCH_SIZE: ", BATCH_SIZE)
|
||||
print("CUDA: ", torch.cuda.is_available())
|
||||
print("CUDA: ", cuda := torch.cuda.is_available())
|
||||
|
||||
mlflow.log_param("INITIAL_LR", INITIAL_LR)
|
||||
mlflow.log_param("NUM_EPOCHS", NUM_EPOCHS)
|
||||
mlflow.log_param("BATCH_SIZE", BATCH_SIZE)
|
||||
mlflow.log_param("CUDA", cuda)
|
||||
|
||||
# raise
|
||||
# loading & spliting data
|
||||
@ -96,11 +81,12 @@ def main(_run):
|
||||
batch_size=BATCH_SIZE,
|
||||
)
|
||||
utils.save_model(model=trained_model, model_path=args.model_path)
|
||||
ex.add_artifact(args.model_path)
|
||||
_run.log_scalar("train_loss", metrics["train_loss"])
|
||||
_run.log_scalar("val_loss", metrics["val_loss"])
|
||||
_run.log_scalar("train_acc", metrics["train_acc"])
|
||||
_run.log_scalar("val_acc", metrics["val_acc"])
|
||||
|
||||
# mlflow saving metrics
|
||||
mlflow.log_metric("train_acc", metrics["train_acc"])
|
||||
mlflow.log_metric("train_loss", metrics["train_loss"])
|
||||
mlflow.log_metric("val_acc", metrics["val_acc"])
|
||||
mlflow.log_metric("val_loss", metrics["val_loss"])
|
||||
|
||||
# evaluating model
|
||||
if args.test:
|
||||
@ -119,5 +105,7 @@ def main(_run):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
ex.run()
|
||||
shutil.make_archive(base_name="./results/sacred-artifacts", format="zip", root_dir=SACRED_DIR_PATH)
|
||||
with mlflow.start_run() as run:
|
||||
print("MLflow run experiment_id: {0}".format(run.info.experiment_id))
|
||||
print("MLflow run artifact_uri: {0}".format(run.info.artifact_uri))
|
||||
main()
|
||||
|
Loading…
Reference in New Issue
Block a user