mlflow
This commit is contained in:
parent
dce050a318
commit
5cc9953cc4
@ -1,5 +1,5 @@
|
||||
FROM ubuntu:22.04
|
||||
|
||||
RUN apt update && apt install -y vim make python3 python3-pip python-is-python3 gcc g++ golang wget unzip git
|
||||
RUN pip install pandas matplotlib scikit-learn tensorflow
|
||||
RUN pip install pandas matplotlib scikit-learn tensorflow mlflow
|
||||
CMD "bash"
|
||||
|
13
src/MLproject
Normal file
13
src/MLproject
Normal file
@ -0,0 +1,13 @@
|
||||
name: s452639-ium
|
||||
docker_env:
|
||||
image: s452639-image
|
||||
volumes: ['.:/ium/']
|
||||
|
||||
entry_points:
|
||||
train:
|
||||
parameters:
|
||||
epochs: {type: int, default: 2}
|
||||
command: "cd /ium; python tf_train.py {epochs}"
|
||||
entry_point: train
|
||||
test:
|
||||
command: "cd /ium; python tf_test.py"
|
@ -3,6 +3,9 @@ import pandas as pd
|
||||
import tensorflow as tf
|
||||
from sklearn.preprocessing import LabelEncoder
|
||||
|
||||
import mlflow
|
||||
mlflow.tensorflow.autolog()
|
||||
|
||||
pd.set_option('display.float_format', lambda x: '%.5f' % x)
|
||||
|
||||
le = LabelEncoder()
|
||||
@ -22,8 +25,7 @@ def load_data(path: str, le: LabelEncoder):
|
||||
|
||||
num_classes = len(le.classes_)
|
||||
|
||||
|
||||
def train(epochs: int):
|
||||
def main(epochs: int):
|
||||
global le
|
||||
|
||||
model = tf.keras.Sequential([
|
||||
@ -52,4 +54,4 @@ def train(epochs: int):
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
epochs = int('2' if len(sys.argv) != 2 else sys.argv[1])
|
||||
train(epochs)
|
||||
main(epochs)
|
||||
|
Loading…
Reference in New Issue
Block a user