added parameters and mlflow to dockerfile
This commit is contained in:
parent
b665e697d5
commit
abab284e2e
@ -8,5 +8,6 @@ RUN pip install scikit-learn
|
|||||||
RUN pip install datasets
|
RUN pip install datasets
|
||||||
RUN pip install torch
|
RUN pip install torch
|
||||||
RUN pip install seaborn
|
RUN pip install seaborn
|
||||||
|
RUN pip install mlflow
|
||||||
COPY ./zad1.py ./
|
COPY ./zad1.py ./
|
||||||
CMD ["python3", 'zad1.py']
|
CMD ["python3", 'zad1.py']
|
13
zad1.py
13
zad1.py
@ -6,6 +6,13 @@ import numpy as np
|
|||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description='IUM script')
|
||||||
|
parser.add_argument('--num_epochs', type=int, default=10, help='Number of epochs')
|
||||||
|
parser.add_argument('--lr', type=float, default=0.001, help='Learning rate')
|
||||||
|
parser.add_argument('--alpha', type=float, default=0.001, help='Learning rate')
|
||||||
|
args = parser.parse_args()
|
||||||
logging.basicConfig(level=logging.WARN)
|
logging.basicConfig(level=logging.WARN)
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -98,9 +105,9 @@ model = TabularModel(input_dim, hidden_dim, output_dim)
|
|||||||
criterion = nn.CrossEntropyLoss()
|
criterion = nn.CrossEntropyLoss()
|
||||||
optimizer = torch.optim.Adam(model.parameters())
|
optimizer = torch.optim.Adam(model.parameters())
|
||||||
|
|
||||||
num_epochs = 10
|
num_epochs = args.num_epochs
|
||||||
lr = 0.01
|
lr = args.lr
|
||||||
alpha = 0.01
|
alpha = args.alpha
|
||||||
model = TabularModel(input_dim=len(wine_train.columns)-1, hidden_dim=hidden_dim, output_dim=output_dim)
|
model = TabularModel(input_dim=len(wine_train.columns)-1, hidden_dim=hidden_dim, output_dim=output_dim)
|
||||||
criterion = nn.CrossEntropyLoss()
|
criterion = nn.CrossEntropyLoss()
|
||||||
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=alpha)
|
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=alpha)
|
||||||
|
Loading…
Reference in New Issue
Block a user