stop tracking healthcare-dataset-stroke-data.csv
This commit is contained in:
parent
d6e037c2b4
commit
6cb4c72e4d
3
.dvc/.gitignore
vendored
Normal file
3
.dvc/.gitignore
vendored
Normal file
@ -0,0 +1,3 @@
|
||||
/config.local
|
||||
/tmp
|
||||
/cache
|
0
.dvc/config
Normal file
0
.dvc/config
Normal file
107
.dvc/plots/confusion.json
Normal file
107
.dvc/plots/confusion.json
Normal file
@ -0,0 +1,107 @@
|
||||
{
|
||||
"$schema": "https://vega.github.io/schema/vega-lite/v4.json",
|
||||
"data": {
|
||||
"values": "<DVC_METRIC_DATA>"
|
||||
},
|
||||
"title": "<DVC_METRIC_TITLE>",
|
||||
"facet": {
|
||||
"field": "rev",
|
||||
"type": "nominal"
|
||||
},
|
||||
"spec": {
|
||||
"transform": [
|
||||
{
|
||||
"aggregate": [
|
||||
{
|
||||
"op": "count",
|
||||
"as": "xy_count"
|
||||
}
|
||||
],
|
||||
"groupby": [
|
||||
"<DVC_METRIC_Y>",
|
||||
"<DVC_METRIC_X>"
|
||||
]
|
||||
},
|
||||
{
|
||||
"impute": "xy_count",
|
||||
"groupby": [
|
||||
"rev",
|
||||
"<DVC_METRIC_Y>"
|
||||
],
|
||||
"key": "<DVC_METRIC_X>",
|
||||
"value": 0
|
||||
},
|
||||
{
|
||||
"impute": "xy_count",
|
||||
"groupby": [
|
||||
"rev",
|
||||
"<DVC_METRIC_X>"
|
||||
],
|
||||
"key": "<DVC_METRIC_Y>",
|
||||
"value": 0
|
||||
},
|
||||
{
|
||||
"joinaggregate": [
|
||||
{
|
||||
"op": "max",
|
||||
"field": "xy_count",
|
||||
"as": "max_count"
|
||||
}
|
||||
],
|
||||
"groupby": []
|
||||
},
|
||||
{
|
||||
"calculate": "datum.xy_count / datum.max_count",
|
||||
"as": "percent_of_max"
|
||||
}
|
||||
],
|
||||
"encoding": {
|
||||
"x": {
|
||||
"field": "<DVC_METRIC_X>",
|
||||
"type": "nominal",
|
||||
"sort": "ascending",
|
||||
"title": "<DVC_METRIC_X_LABEL>"
|
||||
},
|
||||
"y": {
|
||||
"field": "<DVC_METRIC_Y>",
|
||||
"type": "nominal",
|
||||
"sort": "ascending",
|
||||
"title": "<DVC_METRIC_Y_LABEL>"
|
||||
}
|
||||
},
|
||||
"layer": [
|
||||
{
|
||||
"mark": "rect",
|
||||
"width": 300,
|
||||
"height": 300,
|
||||
"encoding": {
|
||||
"color": {
|
||||
"field": "xy_count",
|
||||
"type": "quantitative",
|
||||
"title": "",
|
||||
"scale": {
|
||||
"domainMin": 0,
|
||||
"nice": true
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"mark": "text",
|
||||
"encoding": {
|
||||
"text": {
|
||||
"field": "xy_count",
|
||||
"type": "quantitative"
|
||||
},
|
||||
"color": {
|
||||
"condition": {
|
||||
"test": "datum.percent_of_max > 0.5",
|
||||
"value": "white"
|
||||
},
|
||||
"value": "black"
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
112
.dvc/plots/confusion_normalized.json
Normal file
112
.dvc/plots/confusion_normalized.json
Normal file
@ -0,0 +1,112 @@
|
||||
{
|
||||
"$schema": "https://vega.github.io/schema/vega-lite/v4.json",
|
||||
"data": {
|
||||
"values": "<DVC_METRIC_DATA>"
|
||||
},
|
||||
"title": "<DVC_METRIC_TITLE>",
|
||||
"facet": {
|
||||
"field": "rev",
|
||||
"type": "nominal"
|
||||
},
|
||||
"spec": {
|
||||
"transform": [
|
||||
{
|
||||
"aggregate": [
|
||||
{
|
||||
"op": "count",
|
||||
"as": "xy_count"
|
||||
}
|
||||
],
|
||||
"groupby": [
|
||||
"<DVC_METRIC_Y>",
|
||||
"<DVC_METRIC_X>"
|
||||
]
|
||||
},
|
||||
{
|
||||
"impute": "xy_count",
|
||||
"groupby": [
|
||||
"rev",
|
||||
"<DVC_METRIC_Y>"
|
||||
],
|
||||
"key": "<DVC_METRIC_X>",
|
||||
"value": 0
|
||||
},
|
||||
{
|
||||
"impute": "xy_count",
|
||||
"groupby": [
|
||||
"rev",
|
||||
"<DVC_METRIC_X>"
|
||||
],
|
||||
"key": "<DVC_METRIC_Y>",
|
||||
"value": 0
|
||||
},
|
||||
{
|
||||
"joinaggregate": [
|
||||
{
|
||||
"op": "sum",
|
||||
"field": "xy_count",
|
||||
"as": "sum_y"
|
||||
}
|
||||
],
|
||||
"groupby": [
|
||||
"<DVC_METRIC_Y>"
|
||||
]
|
||||
},
|
||||
{
|
||||
"calculate": "datum.xy_count / datum.sum_y",
|
||||
"as": "percent_of_y"
|
||||
}
|
||||
],
|
||||
"encoding": {
|
||||
"x": {
|
||||
"field": "<DVC_METRIC_X>",
|
||||
"type": "nominal",
|
||||
"sort": "ascending",
|
||||
"title": "<DVC_METRIC_X_LABEL>"
|
||||
},
|
||||
"y": {
|
||||
"field": "<DVC_METRIC_Y>",
|
||||
"type": "nominal",
|
||||
"sort": "ascending",
|
||||
"title": "<DVC_METRIC_Y_LABEL>"
|
||||
}
|
||||
},
|
||||
"layer": [
|
||||
{
|
||||
"mark": "rect",
|
||||
"width": 300,
|
||||
"height": 300,
|
||||
"encoding": {
|
||||
"color": {
|
||||
"field": "percent_of_y",
|
||||
"type": "quantitative",
|
||||
"title": "",
|
||||
"scale": {
|
||||
"domain": [
|
||||
0,
|
||||
1
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"mark": "text",
|
||||
"encoding": {
|
||||
"text": {
|
||||
"field": "percent_of_y",
|
||||
"type": "quantitative",
|
||||
"format": ".2f"
|
||||
},
|
||||
"color": {
|
||||
"condition": {
|
||||
"test": "datum.percent_of_y > 0.5",
|
||||
"value": "white"
|
||||
},
|
||||
"value": "black"
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
31
.dvc/plots/default.json
Normal file
31
.dvc/plots/default.json
Normal file
@ -0,0 +1,31 @@
|
||||
{
|
||||
"$schema": "https://vega.github.io/schema/vega-lite/v4.json",
|
||||
"data": {
|
||||
"values": "<DVC_METRIC_DATA>"
|
||||
},
|
||||
"title": "<DVC_METRIC_TITLE>",
|
||||
"width": 300,
|
||||
"height": 300,
|
||||
"mark": {
|
||||
"type": "line"
|
||||
},
|
||||
"encoding": {
|
||||
"x": {
|
||||
"field": "<DVC_METRIC_X>",
|
||||
"type": "quantitative",
|
||||
"title": "<DVC_METRIC_X_LABEL>"
|
||||
},
|
||||
"y": {
|
||||
"field": "<DVC_METRIC_Y>",
|
||||
"type": "quantitative",
|
||||
"title": "<DVC_METRIC_Y_LABEL>",
|
||||
"scale": {
|
||||
"zero": false
|
||||
}
|
||||
},
|
||||
"color": {
|
||||
"field": "rev",
|
||||
"type": "nominal"
|
||||
}
|
||||
}
|
||||
}
|
116
.dvc/plots/linear.json
Normal file
116
.dvc/plots/linear.json
Normal file
@ -0,0 +1,116 @@
|
||||
{
|
||||
"$schema": "https://vega.github.io/schema/vega-lite/v4.json",
|
||||
"data": {
|
||||
"values": "<DVC_METRIC_DATA>"
|
||||
},
|
||||
"title": "<DVC_METRIC_TITLE>",
|
||||
"width": 300,
|
||||
"height": 300,
|
||||
"layer": [
|
||||
{
|
||||
"encoding": {
|
||||
"x": {
|
||||
"field": "<DVC_METRIC_X>",
|
||||
"type": "quantitative",
|
||||
"title": "<DVC_METRIC_X_LABEL>"
|
||||
},
|
||||
"y": {
|
||||
"field": "<DVC_METRIC_Y>",
|
||||
"type": "quantitative",
|
||||
"title": "<DVC_METRIC_Y_LABEL>",
|
||||
"scale": {
|
||||
"zero": false
|
||||
}
|
||||
},
|
||||
"color": {
|
||||
"field": "rev",
|
||||
"type": "nominal"
|
||||
}
|
||||
},
|
||||
"layer": [
|
||||
{
|
||||
"mark": "line"
|
||||
},
|
||||
{
|
||||
"selection": {
|
||||
"label": {
|
||||
"type": "single",
|
||||
"nearest": true,
|
||||
"on": "mouseover",
|
||||
"encodings": [
|
||||
"x"
|
||||
],
|
||||
"empty": "none",
|
||||
"clear": "mouseout"
|
||||
}
|
||||
},
|
||||
"mark": "point",
|
||||
"encoding": {
|
||||
"opacity": {
|
||||
"condition": {
|
||||
"selection": "label",
|
||||
"value": 1
|
||||
},
|
||||
"value": 0
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"transform": [
|
||||
{
|
||||
"filter": {
|
||||
"selection": "label"
|
||||
}
|
||||
}
|
||||
],
|
||||
"layer": [
|
||||
{
|
||||
"mark": {
|
||||
"type": "rule",
|
||||
"color": "gray"
|
||||
},
|
||||
"encoding": {
|
||||
"x": {
|
||||
"field": "<DVC_METRIC_X>",
|
||||
"type": "quantitative"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"encoding": {
|
||||
"text": {
|
||||
"type": "quantitative",
|
||||
"field": "<DVC_METRIC_Y>"
|
||||
},
|
||||
"x": {
|
||||
"field": "<DVC_METRIC_X>",
|
||||
"type": "quantitative"
|
||||
},
|
||||
"y": {
|
||||
"field": "<DVC_METRIC_Y>",
|
||||
"type": "quantitative"
|
||||
}
|
||||
},
|
||||
"layer": [
|
||||
{
|
||||
"mark": {
|
||||
"type": "text",
|
||||
"align": "left",
|
||||
"dx": 5,
|
||||
"dy": -5
|
||||
},
|
||||
"encoding": {
|
||||
"color": {
|
||||
"type": "nominal",
|
||||
"field": "rev"
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
104
.dvc/plots/scatter.json
Normal file
104
.dvc/plots/scatter.json
Normal file
@ -0,0 +1,104 @@
|
||||
{
|
||||
"$schema": "https://vega.github.io/schema/vega-lite/v4.json",
|
||||
"data": {
|
||||
"values": "<DVC_METRIC_DATA>"
|
||||
},
|
||||
"title": "<DVC_METRIC_TITLE>",
|
||||
"width": 300,
|
||||
"height": 300,
|
||||
"layer": [
|
||||
{
|
||||
"encoding": {
|
||||
"x": {
|
||||
"field": "<DVC_METRIC_X>",
|
||||
"type": "quantitative",
|
||||
"title": "<DVC_METRIC_X_LABEL>"
|
||||
},
|
||||
"y": {
|
||||
"field": "<DVC_METRIC_Y>",
|
||||
"type": "quantitative",
|
||||
"title": "<DVC_METRIC_Y_LABEL>",
|
||||
"scale": {
|
||||
"zero": false
|
||||
}
|
||||
},
|
||||
"color": {
|
||||
"field": "rev",
|
||||
"type": "nominal"
|
||||
}
|
||||
},
|
||||
"layer": [
|
||||
{
|
||||
"mark": "point"
|
||||
},
|
||||
{
|
||||
"selection": {
|
||||
"label": {
|
||||
"type": "single",
|
||||
"nearest": true,
|
||||
"on": "mouseover",
|
||||
"encodings": [
|
||||
"x"
|
||||
],
|
||||
"empty": "none",
|
||||
"clear": "mouseout"
|
||||
}
|
||||
},
|
||||
"mark": "point",
|
||||
"encoding": {
|
||||
"opacity": {
|
||||
"condition": {
|
||||
"selection": "label",
|
||||
"value": 1
|
||||
},
|
||||
"value": 0
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"transform": [
|
||||
{
|
||||
"filter": {
|
||||
"selection": "label"
|
||||
}
|
||||
}
|
||||
],
|
||||
"layer": [
|
||||
{
|
||||
"encoding": {
|
||||
"text": {
|
||||
"type": "quantitative",
|
||||
"field": "<DVC_METRIC_Y>"
|
||||
},
|
||||
"x": {
|
||||
"field": "<DVC_METRIC_X>",
|
||||
"type": "quantitative"
|
||||
},
|
||||
"y": {
|
||||
"field": "<DVC_METRIC_Y>",
|
||||
"type": "quantitative"
|
||||
}
|
||||
},
|
||||
"layer": [
|
||||
{
|
||||
"mark": {
|
||||
"type": "text",
|
||||
"align": "left",
|
||||
"dx": 5,
|
||||
"dy": -5
|
||||
},
|
||||
"encoding": {
|
||||
"color": {
|
||||
"type": "nominal",
|
||||
"field": "rev"
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
39
.dvc/plots/smooth.json
Normal file
39
.dvc/plots/smooth.json
Normal file
@ -0,0 +1,39 @@
|
||||
{
|
||||
"$schema": "https://vega.github.io/schema/vega-lite/v4.json",
|
||||
"data": {
|
||||
"values": "<DVC_METRIC_DATA>"
|
||||
},
|
||||
"title": "<DVC_METRIC_TITLE>",
|
||||
"mark": {
|
||||
"type": "line"
|
||||
},
|
||||
"encoding": {
|
||||
"x": {
|
||||
"field": "<DVC_METRIC_X>",
|
||||
"type": "quantitative",
|
||||
"title": "<DVC_METRIC_X_LABEL>"
|
||||
},
|
||||
"y": {
|
||||
"field": "<DVC_METRIC_Y>",
|
||||
"type": "quantitative",
|
||||
"title": "<DVC_METRIC_Y_LABEL>",
|
||||
"scale": {
|
||||
"zero": false
|
||||
}
|
||||
},
|
||||
"color": {
|
||||
"field": "rev",
|
||||
"type": "nominal"
|
||||
}
|
||||
},
|
||||
"transform": [
|
||||
{
|
||||
"loess": "<DVC_METRIC_Y>",
|
||||
"on": "<DVC_METRIC_X>",
|
||||
"groupby": [
|
||||
"rev"
|
||||
],
|
||||
"bandwidth": 0.3
|
||||
}
|
||||
]
|
||||
}
|
3
.dvcignore
Normal file
3
.dvcignore
Normal file
@ -0,0 +1,3 @@
|
||||
# Add patterns of files dvc should ignore, which could improve
|
||||
# the performance. Learn more at
|
||||
# https://dvc.org/doc/user-guide/dvcignore
|
@ -3,9 +3,12 @@ FROM ubuntu:latest
|
||||
RUN apt-get update && apt-get install -y python3-pip && pip3 install setuptools && pip3 install numpy && pip3 install pandas && pip3 install wget && pip3 install scikit-learn && pip3 install matplotlib && rm -rf /var/lib/apt/lists/*
|
||||
RUN pip3 install torch torchvision torchaudio
|
||||
RUN pip3 install sacred && pip3 install GitPython && pip3 install pymongo
|
||||
RUN pip3 install dvc
|
||||
RUN pip3 install 'dvc[ssh]' paramiko
|
||||
WORKDIR /app
|
||||
|
||||
COPY ./create.py ./
|
||||
COPY ./stats.py ./
|
||||
COPY ./stroke-pytorch.py ./
|
||||
COPY ./stroke-pytorch-eval.py ./
|
||||
COPY ./train-dvc.py ./
|
||||
|
37
Jenkinsfile-dvc
Normal file
37
Jenkinsfile-dvc
Normal file
@ -0,0 +1,37 @@
|
||||
pipeline {
|
||||
agent {
|
||||
dockerfile true
|
||||
}
|
||||
parameters{
|
||||
buildSelector(
|
||||
defaultSelector: lastSuccessful(),
|
||||
description: 'Which build to use for copying artifacts',
|
||||
name: 'WHICH_BUILD'
|
||||
)
|
||||
}
|
||||
stages {
|
||||
stage('dvc') {
|
||||
steps {
|
||||
withCredentials([sshUserPrivateKey(credentialsId: '48ac7004-216e-4260-abba-1fe5db753e18', keyFileVariable: 'IUM_SFTP_KEY')]) {
|
||||
copyArtifacts fingerprintArtifacts: true, projectName: 's434766-create-dataset', selector: buildParameter('WHICH_BUILD')
|
||||
sh 'dvc remote add -f -d ium_ssh_remote ssh://ium-sftp@tzietkiewicz.vm.wmi.amu.edu.pl/ium-sftp'
|
||||
sh 'dvc remote modify --local ium_ssh_remote keyfile $IUM_SFTP_KEY'
|
||||
sh "dvc pull -f"
|
||||
sh "dvc reproduce"
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
post {
|
||||
success {
|
||||
mail body: 'DVC success', subject: 's434766', to: '26ab8f35.uam.onmicrosoft.com@emea.teams.ms'
|
||||
archiveArtifacts 'accuracy.txt'
|
||||
|
||||
}
|
||||
|
||||
failure {
|
||||
mail body: 'DVC failure', subject: 's434766', to: '26ab8f35.uam.onmicrosoft.com@emea.teams.ms'
|
||||
}
|
||||
}
|
||||
}
|
18
dvc.yaml
Normal file
18
dvc.yaml
Normal file
@ -0,0 +1,18 @@
|
||||
stages:
|
||||
download_and_split:
|
||||
cmd: python3 split_10.py
|
||||
deps:
|
||||
- healthcare-dataset-stroke-data.csv
|
||||
- create.py
|
||||
outs:
|
||||
- data_train.csv
|
||||
- data_test.csv
|
||||
- data_val.csv
|
||||
train_model:
|
||||
cmd: python3 train-dvc.py
|
||||
deps:
|
||||
- data_train.csv
|
||||
- data_test.csv
|
||||
- data_val.csv
|
||||
outs:
|
||||
- Y_pred.txt
|
File diff suppressed because it is too large
Load Diff
32
lab2.ipynb
32
lab2.ipynb
@ -10,13 +10,12 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.5-final"
|
||||
"version": "3.8.5"
|
||||
},
|
||||
"orig_nbformat": 2,
|
||||
"kernelspec": {
|
||||
"name": "python3",
|
||||
"display_name": "Python 3",
|
||||
"language": "python"
|
||||
"name": "python385jvsc74a57bd02cef13873963874fd5439bd04a135498d1dd9725d9d90f40de0b76178a8e03b1",
|
||||
"display_name": "Python 3.8.5 64-bit ('base': conda)"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
@ -24,7 +23,7 @@
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 31,
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
@ -84,7 +83,7 @@
|
||||
" print(data.describe(include='all'))\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"downloadCSV()\n",
|
||||
"# downloadCSV()\n",
|
||||
"data = dropNaN()\n",
|
||||
"data = NormalizeData(data)\n",
|
||||
"\n",
|
||||
@ -95,6 +94,27 @@
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"output_type": "execute_result",
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"array(['private', 'self_employed', 'govt_job', 'children', 'never_worked'],\n",
|
||||
" dtype=object)"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"execution_count": 6
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"pd.unique(data['work_type'])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 27,
|
||||
|
15
mlruns/0/119a232b4f5b4792afb6fbda18d95262/meta.yaml
Normal file
15
mlruns/0/119a232b4f5b4792afb6fbda18d95262/meta.yaml
Normal file
@ -0,0 +1,15 @@
|
||||
artifact_uri: file:///home/przemek/ium_434766/mlruns/0/119a232b4f5b4792afb6fbda18d95262/artifacts
|
||||
end_time: 1622122401430
|
||||
entry_point_name: ''
|
||||
experiment_id: '0'
|
||||
lifecycle_stage: active
|
||||
name: ''
|
||||
run_id: 119a232b4f5b4792afb6fbda18d95262
|
||||
run_uuid: 119a232b4f5b4792afb6fbda18d95262
|
||||
source_name: ''
|
||||
source_type: 4
|
||||
source_version: ''
|
||||
start_time: 1622122401247
|
||||
status: 3
|
||||
tags: []
|
||||
user_id: owcap
|
1
mlruns/0/119a232b4f5b4792afb6fbda18d95262/metrics/rmse
Normal file
1
mlruns/0/119a232b4f5b4792afb6fbda18d95262/metrics/rmse
Normal file
@ -0,0 +1 @@
|
||||
1622122401410 0.12816519 0
|
@ -0,0 +1 @@
|
||||
0.46289125084877014
|
@ -0,0 +1 @@
|
||||
16
|
1
mlruns/0/119a232b4f5b4792afb6fbda18d95262/params/epochs
Normal file
1
mlruns/0/119a232b4f5b4792afb6fbda18d95262/params/epochs
Normal file
@ -0,0 +1 @@
|
||||
5
|
@ -0,0 +1 @@
|
||||
d4912c0bdcc4ecba96dfd2b643b5e816d51c6bda
|
@ -0,0 +1 @@
|
||||
.\lab8-mlflow.py
|
@ -0,0 +1 @@
|
||||
LOCAL
|
@ -0,0 +1 @@
|
||||
owcap
|
75
train-dvc.py
Normal file
75
train-dvc.py
Normal file
@ -0,0 +1,75 @@
|
||||
import torch
|
||||
import sys
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from torch.autograd import Variable
|
||||
import torchvision.transforms as transforms
|
||||
from sklearn.model_selection import train_test_split
|
||||
from sklearn.preprocessing import MinMaxScaler
|
||||
from sklearn.metrics import accuracy_score
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from sacred import Experiment
|
||||
from sacred.observers import FileStorageObserver
|
||||
np.set_printoptions(suppress=False)
|
||||
|
||||
|
||||
class LogisticRegressionModel(nn.Module):
|
||||
def __init__(self, input_dim, output_dim):
|
||||
super(LogisticRegressionModel, self).__init__()
|
||||
self.linear = nn.Linear(input_dim, output_dim)
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
def forward(self, x):
|
||||
out = self.linear(x)
|
||||
return self.sigmoid(out)
|
||||
|
||||
|
||||
data_train = pd.read_csv("data_train.csv")
|
||||
data_test = pd.read_csv("data_test.csv")
|
||||
data_val = pd.read_csv("data_val.csv")
|
||||
FEATURES = ['age','hypertension','heart_disease','ever_married', 'avg_glucose_level', 'bmi']
|
||||
|
||||
x_train = data_train[FEATURES].astype(np.float32)
|
||||
y_train = data_train['stroke'].astype(np.float32)
|
||||
|
||||
x_test = data_test[FEATURES].astype(np.float32)
|
||||
y_test = data_test['stroke'].astype(np.float32)
|
||||
|
||||
fTrain = torch.from_numpy(x_train.values)
|
||||
tTrain = torch.from_numpy(y_train.values.reshape(2945,1))
|
||||
|
||||
fTest= torch.from_numpy(x_test.values)
|
||||
tTest = torch.from_numpy(y_test.values)
|
||||
|
||||
batch_size = int(sys.argv[1]) if len(sys.argv) > 1 else 16
|
||||
num_epochs = int(sys.argv[2]) if len(sys.argv) > 2 else 5
|
||||
learning_rate = 0.001
|
||||
input_dim = 6
|
||||
output_dim = 1
|
||||
|
||||
model = LogisticRegressionModel(input_dim, output_dim)
|
||||
|
||||
criterion = torch.nn.BCELoss(reduction='mean')
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr = learning_rate)
|
||||
|
||||
for epoch in range(num_epochs):
|
||||
# print ("Epoch #",epoch)
|
||||
model.train()
|
||||
optimizer.zero_grad()
|
||||
# Forward pass
|
||||
y_pred = model(fTrain)
|
||||
# Compute Loss
|
||||
loss = criterion(y_pred, tTrain)
|
||||
# print(loss.item())
|
||||
# Backward pass
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
y_pred = model(fTest)
|
||||
# print("predicted Y value: ", y_pred.data)
|
||||
|
||||
|
||||
txt_file = open("Y_pred.txt", "w")
|
||||
n = txt_file.write(f"Y_pred: { y_pred.data}")
|
||||
txt_file.close()
|
||||
# torch.save(model.state_dict(), 'stroke.pth')
|
||||
|
Loading…
Reference in New Issue
Block a user