mlflow
This commit is contained in:
parent
dce050a318
commit
5cc9953cc4
@ -1,5 +1,5 @@
|
|||||||
FROM ubuntu:22.04
|
FROM ubuntu:22.04
|
||||||
|
|
||||||
RUN apt update && apt install -y vim make python3 python3-pip python-is-python3 gcc g++ golang wget unzip git
|
RUN apt update && apt install -y vim make python3 python3-pip python-is-python3 gcc g++ golang wget unzip git
|
||||||
RUN pip install pandas matplotlib scikit-learn tensorflow
|
RUN pip install pandas matplotlib scikit-learn tensorflow mlflow
|
||||||
CMD "bash"
|
CMD "bash"
|
||||||
|
13
src/MLproject
Normal file
13
src/MLproject
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
name: s452639-ium
|
||||||
|
docker_env:
|
||||||
|
image: s452639-image
|
||||||
|
volumes: ['.:/ium/']
|
||||||
|
|
||||||
|
entry_points:
|
||||||
|
train:
|
||||||
|
parameters:
|
||||||
|
epochs: {type: int, default: 2}
|
||||||
|
command: "cd /ium; python tf_train.py {epochs}"
|
||||||
|
entry_point: train
|
||||||
|
test:
|
||||||
|
command: "cd /ium; python tf_test.py"
|
@ -3,6 +3,9 @@ import pandas as pd
|
|||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
from sklearn.preprocessing import LabelEncoder
|
from sklearn.preprocessing import LabelEncoder
|
||||||
|
|
||||||
|
import mlflow
|
||||||
|
mlflow.tensorflow.autolog()
|
||||||
|
|
||||||
pd.set_option('display.float_format', lambda x: '%.5f' % x)
|
pd.set_option('display.float_format', lambda x: '%.5f' % x)
|
||||||
|
|
||||||
le = LabelEncoder()
|
le = LabelEncoder()
|
||||||
@ -22,8 +25,7 @@ def load_data(path: str, le: LabelEncoder):
|
|||||||
|
|
||||||
num_classes = len(le.classes_)
|
num_classes = len(le.classes_)
|
||||||
|
|
||||||
|
def main(epochs: int):
|
||||||
def train(epochs: int):
|
|
||||||
global le
|
global le
|
||||||
|
|
||||||
model = tf.keras.Sequential([
|
model = tf.keras.Sequential([
|
||||||
@ -52,4 +54,4 @@ def train(epochs: int):
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import sys
|
import sys
|
||||||
epochs = int('2' if len(sys.argv) != 2 else sys.argv[1])
|
epochs = int('2' if len(sys.argv) != 2 else sys.argv[1])
|
||||||
train(epochs)
|
main(epochs)
|
||||||
|
Loading…
Reference in New Issue
Block a user