From 41c655024deecccc5f886e3969b737843f046b7e Mon Sep 17 00:00:00 2001 From: gedin Date: Wed, 24 May 2023 15:58:36 +0200 Subject: [PATCH] not user local installs on pip + eval metrics file creation --- train-eval/Dockerfile | 12 ++++++------ train-eval/eval.py | 12 +++++++++--- 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/train-eval/Dockerfile b/train-eval/Dockerfile index 5716b93..e8aa1c8 100644 --- a/train-eval/Dockerfile +++ b/train-eval/Dockerfile @@ -5,12 +5,12 @@ RUN apt install python3-pip -y RUN apt install unzip -y RUN apt install git -y -RUN pip install --user numpy -RUN pip install --user pandas -RUN pip install --user torch -RUN pip install --user keras -RUN pip install --user tensorflow -RUN pip install --user scikit-learn +RUN pip install numpy +RUN pip install pandas +RUN pip install torch +RUN pip install keras +RUN pip install tensorflow +RUN pip install scikit-learn # RUN echo "alias kaggle='~/.local/bin/kaggle'" >> ~/.bashrc diff --git a/train-eval/eval.py b/train-eval/eval.py index ff153bd..8368e49 100755 --- a/train-eval/eval.py +++ b/train-eval/eval.py @@ -38,10 +38,16 @@ model.load_state_dict(torch.load('model.pt')) x_test = torch.tensor(X.values, dtype=torch.float32) pred = model(x_test) pred = pred.detach().numpy() -print ("The accuracy is", accuracy_score(Y, np.argmax(pred, axis=1))) -print ("The precission score is ", precision_score(Y, np.argmax(pred, axis=1))) -print ("The recall score is ", recall_score(Y, np.argmax(pred, axis=1))) +acc = accuracy_score(Y, np.argmax(pred, axis=1)) +prec = precision_score(Y, np.argmax(pred, axis=1)) +recall = recall_score(Y, np.argmax(pred, axis=1)) +print ("The accuracy is", acc) +print ("The precission score is ", prec) +print ("The recall score is ", recall) +file = open('metrics.txt', 'w') +file.write(str(acc) + '\t' + str(prec) + '\t' + str(recall)) +file.close() np.savetxt('prediction.tsv', pred, delimiter='\t') \ No newline at end of file