mlflow
This commit is contained in:
parent
ba715dfce1
commit
41cfcb3b1a
8
MLProject
Normal file
8
MLProject
Normal file
@ -0,0 +1,8 @@
|
||||
name: Skin cancer classification
|
||||
|
||||
conda_env: conda.yaml
|
||||
entry_points:
|
||||
train:
|
||||
command: "python train.py"
|
||||
test:
|
||||
command: "python test.py"
|
12
conda.yaml
Normal file
12
conda.yaml
Normal file
@ -0,0 +1,12 @@
|
||||
name: mlflow-env
|
||||
|
||||
channels:
|
||||
- defaults
|
||||
|
||||
dependencies:
|
||||
- python=3.11
|
||||
- pip
|
||||
- pip:
|
||||
- mlflow
|
||||
- torch
|
||||
- torchvision
|
6
test.py
6
test.py
@ -1,5 +1,6 @@
|
||||
import csv
|
||||
import torch
|
||||
import mlflow
|
||||
|
||||
from torchvision import transforms, datasets
|
||||
from torch.utils.data import DataLoader
|
||||
@ -24,6 +25,9 @@ if __name__ == '__main__':
|
||||
BATCH_SIZE = 32
|
||||
MODEL_PATH = 'model.pth'
|
||||
|
||||
mlflow.set_experiment('Skin Cancer Detection - Test set')
|
||||
mlflow.start_run()
|
||||
|
||||
names = {0: 'Benign', 1: 'Malignant'}
|
||||
predictions = []
|
||||
|
||||
@ -47,6 +51,8 @@ if __name__ == '__main__':
|
||||
|
||||
test_accuracy = test_correct / test_total
|
||||
|
||||
mlflow.log_metric('test_accuracy', test_accuracy)
|
||||
|
||||
predictions = [names[pred] for pred in predictions]
|
||||
|
||||
print(f'Accuracy test {test_accuracy:.2%}')
|
||||
|
15
train.py
15
train.py
@ -1,4 +1,5 @@
|
||||
import torch
|
||||
import mlflow
|
||||
import torch.nn.functional as F
|
||||
|
||||
from torch import nn
|
||||
@ -132,6 +133,9 @@ def train(EPOCHS: int, BATCH_SIZE: int,
|
||||
train_accuracy = train_correct / train_total
|
||||
val_accuracy = val_correct / val_total
|
||||
|
||||
mlflow.log_metric("train_loss", avg_loss, step=epoch)
|
||||
mlflow.log_metric("val_loss", avg_vloss, step=epoch)
|
||||
|
||||
print(f'Loss train {avg_loss:.3f}, loss valid {avg_vloss:.3f}')
|
||||
print(f'Accuracy train {train_accuracy:.2%}, accuracy valid {val_accuracy:.2%}')
|
||||
|
||||
@ -152,6 +156,14 @@ if __name__ == "__main__":
|
||||
EPOCHS = 3
|
||||
LEARNING_RATE = 0.001
|
||||
|
||||
mlflow.set_experiment("Skin cancer classification, custom CNN model")
|
||||
mlflow.start_run()
|
||||
|
||||
mlflow.log_param("IMG_SIZE", IMG_SIZE)
|
||||
mlflow.log_param("BATCH_SIZE", BATCH_SIZE)
|
||||
mlflow.log_param("EPOCHS", EPOCHS)
|
||||
mlflow.log_param("LEARNING_RATE", LEARNING_RATE)
|
||||
|
||||
train_loader, val_loader, train_size = get_data(IMG_SIZE, BATCH_SIZE, device)
|
||||
|
||||
model = Model()
|
||||
@ -163,4 +175,5 @@ if __name__ == "__main__":
|
||||
train(EPOCHS, BATCH_SIZE, model, train_loader,
|
||||
val_loader, criterion, optimizer, train_size)
|
||||
|
||||
torch.save(model.state_dict(), "model.pth")
|
||||
torch.save(model.state_dict(), "model.pth")
|
||||
mlflow.log_artifact("model.pth")
|
Loading…
Reference in New Issue
Block a user