Add iris classification #2

Merged
s434624 merged 1 commits from iris-clasification into master 2019-05-05 17:36:08 +02:00
7 changed files with 229 additions and 173 deletions
Showing only changes of commit f36f544e18 - Show all commits

BIN
iris_model.h5 Normal file

Binary file not shown.

View File

@ -39,7 +39,6 @@ class Forklift {
}
setVelocity() {
debugger;
this.direction = this.sub(sections[this.currentTarget], this.positoin);
this.velocity = this.direction.setMag(this.speed);
}
@ -49,8 +48,7 @@ class Forklift {
if (
Math.abs(this.positoin.x - sections[this.currentTarget].x) <=
this.speed &&
Math.abs(this.positoin.y - sections[this.currentTarget].y) <=
this.speed
Math.abs(this.positoin.y - sections[this.currentTarget].y) <= this.speed
) {
this.positoin = sections[this.currentTarget];
this.nextTarget();

View File

@ -1,11 +1,10 @@
const serverUrl = 'http://localhost:8000';
let sections;
let roads;
let packageClaim;
let going = false;
let forklift;
let target;
// This runs once at start
function setup() {
createCanvas(600, 600).parent('canvas');
@ -14,8 +13,11 @@ function setup() {
createMagazineLayout();
select('#button').mousePressed(deliver);
target = select('#target');
select('#button').mousePressed(getIrisType);
sepalWidth = select('#sepalWidth');
sepalLength = select('#sepalLength');
petalWidth = select('#petalWidth');
petalLength = select('#petalLength');
// Create a forklift instance
forklift = new Forklift(sections[0].x, sections[0].y);
}
@ -63,14 +65,31 @@ function drawMagazine() {
}
}
function deliver() {
function getIrisType() {
let sw = select('#sepalWidth').value();
let sl = select('#sepalLength').value();
let pw = select('#petalWidth').value();
let pl = select('#petalLength').value();
let data = {
sepalWidth: sw,
sepalLength: sl,
petalWidth: pw,
petalLength: pl,
};
httpPost(serverUrl + '/classify', data, response => {
deliver(response);
});
}
function deliver(targetSection) {
let data = {
graph: magazineToGraph(),
start_node: forklift.currentSection,
dest_node: int(target.value()),
dest_node: int(targetSection),
};
console.log(data);
httpPost(
'http://localhost:8000/shortestPath',
serverUrl + '/shortestPath',
data,
response => {
path = response.split('').map(Number);

View File

@ -42,27 +42,21 @@
<div class="package">
<h3 style="margin-top: 0;">Package description</h1>
<label for="width">Sepal Width</label>
<input type="number" name="width">
<input type="number" id="sepalWidth">
<label for="topWidth">Sepal Length</label>
<input type="number" name="topWidth">
<input type="number" id="sepalLength">
<label for="botWidth">Petal Width</label>
<input type="number" name="botWidth">
<input type="number" id="petalWidth">
<label for="height">Petal Length</label>
<input type="number" name="height">
<label for="target">Target</label>
<input type="number" id="target" name="target">
<input type="number" id="petalLength">
<button id="button">Send Package</button>
</div>
<div id="canvas" style="margin: 10px;"></div>
<div class="legend">
<h3 style="margin-top: 0">Sections</h3>
<p>A - Cartons</p>
<p>B - Barrels</p>
<p>C - Plastic boxes</p>
<p>D - ______</p>
<p>E - ______</p>
<p>F - ______</p>
<p>G - ______</p>
<p>1 - Setosa</p>
<p>2 - Versicolor</p>
<p>3 - Viginica</p>
</div>
</div>
</body>

View File

@ -1,10 +1,11 @@
import math
import json
from django.shortcuts import render
from django.http import HttpResponse
from django.views.decorators.csrf import csrf_exempt
import json
import math
import tensorflow as tf
import numpy as np
# Create your views here.
@ -14,7 +15,21 @@ def index(request):
@csrf_exempt
def classify(request):
return HttpResponse(json.load(request))
loaded_request = json.load(request)
sw = loaded_request['sepalWidth']
sl = loaded_request['sepalLength']
pw = loaded_request['petalWidth']
pl = loaded_request['petalLength']
model = tf.keras.models.load_model('iris_model.h5')
output = model.predict(np.array([[sw, sl, pw, pl]]))
if output[0][0] > output[0][1] and output[0][0] > output[0][1]:
guess = 1
elif output[0][1] > output[0][0] and output[0][1] > output[0][2]:
guess = 2
else:
guess = 3
return HttpResponse(guess)
@csrf_exempt
@ -61,4 +76,6 @@ def shortestPath(request):
current = predecessor[current]
path[node] = p[::-1]
print(path)
return HttpResponse(path[dest_node][1:])

28
train_model.py Normal file
View File

@ -0,0 +1,28 @@
import numpy as np
import tensorflow as tf
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from tensorflow.keras import layers
from tensorflow.keras.utils import to_categorical
# Getting data
data_set = load_iris()
x = data_set['data']
y = to_categorical(data_set['target'])
train_x, test_x, train_y, test_y = train_test_split(x, y, test_size=0.2)
# Building the model
model = tf.keras.Sequential()
model.add(layers.Dense(8, activation='relu', input_dim=4))
model.add(layers.Dense(3, activation='sigmoid'))
model.compile(optimizer='adam', loss='categorical_crossentropy',
metrics=['accuracy'])
# Training the model
model.fit(train_x, train_y, validation_data=(test_x, test_y), epochs=2000)
model.save('iris_model.h5')