This commit is contained in:
filnow 2024-05-14 20:31:53 +02:00
parent ba715dfce1
commit 41cfcb3b1a
4 changed files with 40 additions and 1 deletions

8
MLProject Normal file
View File

@ -0,0 +1,8 @@
name: Skin cancer classification
conda_env: conda.yaml
entry_points:
train:
command: "python train.py"
test:
command: "python test.py"

12
conda.yaml Normal file
View File

@ -0,0 +1,12 @@
name: mlflow-env
channels:
- defaults
dependencies:
- python=3.11
- pip
- pip:
- mlflow
- torch
- torchvision

View File

@ -1,5 +1,6 @@
import csv
import torch
import mlflow
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
@ -24,6 +25,9 @@ if __name__ == '__main__':
BATCH_SIZE = 32
MODEL_PATH = 'model.pth'
mlflow.set_experiment('Skin Cancer Detection - Test set')
mlflow.start_run()
names = {0: 'Benign', 1: 'Malignant'}
predictions = []
@ -47,6 +51,8 @@ if __name__ == '__main__':
test_accuracy = test_correct / test_total
mlflow.log_metric('test_accuracy', test_accuracy)
predictions = [names[pred] for pred in predictions]
print(f'Accuracy test {test_accuracy:.2%}')

View File

@ -1,4 +1,5 @@
import torch
import mlflow
import torch.nn.functional as F
from torch import nn
@ -132,6 +133,9 @@ def train(EPOCHS: int, BATCH_SIZE: int,
train_accuracy = train_correct / train_total
val_accuracy = val_correct / val_total
mlflow.log_metric("train_loss", avg_loss, step=epoch)
mlflow.log_metric("val_loss", avg_vloss, step=epoch)
print(f'Loss train {avg_loss:.3f}, loss valid {avg_vloss:.3f}')
print(f'Accuracy train {train_accuracy:.2%}, accuracy valid {val_accuracy:.2%}')
@ -152,6 +156,14 @@ if __name__ == "__main__":
EPOCHS = 3
LEARNING_RATE = 0.001
mlflow.set_experiment("Skin cancer classification, custom CNN model")
mlflow.start_run()
mlflow.log_param("IMG_SIZE", IMG_SIZE)
mlflow.log_param("BATCH_SIZE", BATCH_SIZE)
mlflow.log_param("EPOCHS", EPOCHS)
mlflow.log_param("LEARNING_RATE", LEARNING_RATE)
train_loader, val_loader, train_size = get_data(IMG_SIZE, BATCH_SIZE, device)
model = Model()
@ -164,3 +176,4 @@ if __name__ == "__main__":
val_loader, criterion, optimizer, train_size)
torch.save(model.state_dict(), "model.pth")
mlflow.log_artifact("model.pth")