Add iris classification

This commit is contained in:
Jacob 2019-05-05 17:34:26 +02:00
parent 9d19eeea22
commit f36f544e18
7 changed files with 229 additions and 173 deletions

BIN
iris_model.h5 Normal file

Binary file not shown.

View File

@ -1,68 +1,66 @@
class Forklift { class Forklift {
constructor(x, y) { constructor(x, y) {
this.positoin = createVector(x, y); this.positoin = createVector(x, y);
this.speed = 5; this.speed = 5;
this.path = []; this.path = [];
this.currentTarget = ''; this.currentTarget = '';
this.targetId = 0; this.targetId = 0;
this.velocity = createVector(0, 0); this.velocity = createVector(0, 0);
this.direction = createVector(0, 0); this.direction = createVector(0, 0);
this.end = false; this.end = false;
this.currentSection = 0; this.currentSection = 0;
} }
draw() { draw() {
fill(225, 255, 0); fill(225, 255, 0);
ellipse(this.positoin.x, this.positoin.y, 20); ellipse(this.positoin.x, this.positoin.y, 20);
} }
setPath(path) { setPath(path) {
this.end = false; this.end = false;
this.path = path; this.path = path;
this.currentTarget = this.path[this.targetId]; this.currentTarget = this.path[this.targetId];
this.setVelocity(); this.setVelocity();
} }
nextTarget() { nextTarget() {
this.targetId += 1; this.targetId += 1;
if (this.targetId < this.path.length) { if (this.targetId < this.path.length) {
this.currentTarget = this.path[this.targetId]; this.currentTarget = this.path[this.targetId];
} else { } else {
this.end = true; this.end = true;
this.targetId = 0; this.targetId = 0;
}
} }
}
targetReached() { targetReached() {
this.currentSection = this.currentTarget; this.currentSection = this.currentTarget;
return this.end; return this.end;
} }
setVelocity() { setVelocity() {
debugger; this.direction = this.sub(sections[this.currentTarget], this.positoin);
this.direction = this.sub(sections[this.currentTarget], this.positoin); this.velocity = this.direction.setMag(this.speed);
this.velocity = this.direction.setMag(this.speed); }
}
move() { move() {
this.positoin = this.add(this.positoin, this.velocity); this.positoin = this.add(this.positoin, this.velocity);
if ( if (
Math.abs(this.positoin.x - sections[this.currentTarget].x) <= Math.abs(this.positoin.x - sections[this.currentTarget].x) <=
this.speed && this.speed &&
Math.abs(this.positoin.y - sections[this.currentTarget].y) <= Math.abs(this.positoin.y - sections[this.currentTarget].y) <= this.speed
this.speed ) {
) { this.positoin = sections[this.currentTarget];
this.positoin = sections[this.currentTarget]; this.nextTarget();
this.nextTarget(); this.setVelocity();
this.setVelocity();
}
} }
}
sub(target, pos) { sub(target, pos) {
return createVector(target.x - pos.x, target.y - pos.y); return createVector(target.x - pos.x, target.y - pos.y);
} }
add(target, pos) { add(target, pos) {
return createVector(target.x + pos.x, target.y + pos.y); return createVector(target.x + pos.x, target.y + pos.y);
} }
} }

View File

@ -1,130 +1,149 @@
const serverUrl = 'http://localhost:8000';
let sections; let sections;
let roads; let roads;
let packageClaim; let packageClaim;
let going = false; let going = false;
let forklift; let forklift;
let target;
// This runs once at start // This runs once at start
function setup() { function setup() {
createCanvas(600, 600).parent('canvas'); createCanvas(600, 600).parent('canvas');
frameRate(30); frameRate(30);
createMagazineLayout(); createMagazineLayout();
select('#button').mousePressed(deliver); select('#button').mousePressed(getIrisType);
target = select('#target'); sepalWidth = select('#sepalWidth');
// Create a forklift instance sepalLength = select('#sepalLength');
forklift = new Forklift(sections[0].x, sections[0].y); petalWidth = select('#petalWidth');
petalLength = select('#petalLength');
// Create a forklift instance
forklift = new Forklift(sections[0].x, sections[0].y);
} }
// This runs every frame // This runs every frame
function draw() { function draw() {
background(64); background(64);
drawMagazine(); drawMagazine();
forklift.draw(); forklift.draw();
if (going) { if (going) {
if (forklift.targetReached()) { if (forklift.targetReached()) {
going = false; going = false;
} else { } else {
forklift.move(); forklift.move();
}
} }
}
} }
function drawMagazine() { function drawMagazine() {
noFill(); noFill();
stroke(220); stroke(220);
strokeWeight(10); strokeWeight(10);
// Draw all the roads in the magazine // Draw all the roads in the magazine
for (let road of roads) { for (let road of roads) {
line( line(
sections[road[0]].x, sections[road[0]].x,
sections[road[0]].y, sections[road[0]].y,
sections[road[1]].x, sections[road[1]].x,
sections[road[1]].y, sections[road[1]].y,
); );
} }
noStroke(); noStroke();
textAlign(CENTER, CENTER); textAlign(CENTER, CENTER);
// Draw all sections in the magazine // Draw all sections in the magazine
for (let section of Object.keys(sections)) { for (let section of Object.keys(sections)) {
if (section === 0) { if (section === 0) {
fill(80); fill(80);
rect(sections[section].x - 15, sections[section].y, 30, 30); rect(sections[section].x - 15, sections[section].y, 30, 30);
} else { } else {
fill(30); fill(30);
ellipse(sections[section].x, sections[section].y, 30); ellipse(sections[section].x, sections[section].y, 30);
fill(255); fill(255);
text(section, sections[section].x, sections[section].y); text(section, sections[section].x, sections[section].y);
}
} }
}
} }
function deliver() { function getIrisType() {
let data = { let sw = select('#sepalWidth').value();
graph: magazineToGraph(), let sl = select('#sepalLength').value();
start_node: forklift.currentSection, let pw = select('#petalWidth').value();
dest_node: int(target.value()), let pl = select('#petalLength').value();
}; let data = {
httpPost( sepalWidth: sw,
'http://localhost:8000/shortestPath', sepalLength: sl,
data, petalWidth: pw,
response => { petalLength: pl,
path = response.split('').map(Number); };
forklift.currentTarget = path[0]; httpPost(serverUrl + '/classify', data, response => {
forklift.setPath(path); deliver(response);
going = true; });
}, }
error => {
console.log(error); function deliver(targetSection) {
}, let data = {
); graph: magazineToGraph(),
start_node: forklift.currentSection,
dest_node: int(targetSection),
};
console.log(data);
httpPost(
serverUrl + '/shortestPath',
data,
response => {
path = response.split('').map(Number);
forklift.currentTarget = path[0];
forklift.setPath(path);
going = true;
},
error => {
console.log(error);
},
);
} }
function createMagazineLayout() { function createMagazineLayout() {
unit = width / 6; unit = width / 6;
sections = { sections = {
0: { x: 2 * unit, y: unit }, 0: { x: 2 * unit, y: unit },
1: createVector(unit, 2 * unit), 1: createVector(unit, 2 * unit),
2: createVector(unit, 3 * unit), 2: createVector(unit, 3 * unit),
3: createVector(unit, 4 * unit), 3: createVector(unit, 4 * unit),
4: createVector(3 * unit, 2 * unit), 4: createVector(3 * unit, 2 * unit),
5: createVector(3 * unit, 3 * unit), 5: createVector(3 * unit, 3 * unit),
6: createVector(3 * unit, 4 * unit), 6: createVector(3 * unit, 4 * unit),
}; };
roads = [ roads = [
[1, 5], [1, 5],
[2, 3], [2, 3],
[0, 1], [0, 1],
[0, 4], [0, 4],
[4, 5], [4, 5],
[5, 6], [5, 6],
[3, 6], [3, 6],
[4, 2], [4, 2],
[5, 3], [5, 3],
]; ];
} }
function magazineToGraph() { function magazineToGraph() {
graph = {}; graph = {};
for (let key of Object.keys(sections)) { for (let key of Object.keys(sections)) {
graph[key] = {}; graph[key] = {};
} }
for (let road of roads) { for (let road of roads) {
start = road[0]; start = road[0];
end = road[1]; end = road[1];
graph[start][end] = Math.sqrt( graph[start][end] = Math.sqrt(
Math.pow(sections[start].x - sections[end].x, 2) + Math.pow(sections[start].x - sections[end].x, 2) +
Math.pow(sections[start].y - sections[end].y, 2), Math.pow(sections[start].y - sections[end].y, 2),
); );
graph[end][start] = Math.sqrt( graph[end][start] = Math.sqrt(
Math.pow(sections[start].x - sections[end].x, 2) + Math.pow(sections[start].x - sections[end].x, 2) +
Math.pow(sections[start].y - sections[end].y, 2), Math.pow(sections[start].y - sections[end].y, 2),
); );
} }
return graph; return graph;
} }

View File

@ -42,27 +42,21 @@
<div class="package"> <div class="package">
<h3 style="margin-top: 0;">Package description</h1> <h3 style="margin-top: 0;">Package description</h1>
<label for="width">Sepal Width</label> <label for="width">Sepal Width</label>
<input type="number" name="width"> <input type="number" id="sepalWidth">
<label for="topWidth">Sepal Length</label> <label for="topWidth">Sepal Length</label>
<input type="number" name="topWidth"> <input type="number" id="sepalLength">
<label for="botWidth">Petal Width</label> <label for="botWidth">Petal Width</label>
<input type="number" name="botWidth"> <input type="number" id="petalWidth">
<label for="height">Petal Length</label> <label for="height">Petal Length</label>
<input type="number" name="height"> <input type="number" id="petalLength">
<label for="target">Target</label>
<input type="number" id="target" name="target">
<button id="button">Send Package</button> <button id="button">Send Package</button>
</div> </div>
<div id="canvas" style="margin: 10px;"></div> <div id="canvas" style="margin: 10px;"></div>
<div class="legend"> <div class="legend">
<h3 style="margin-top: 0">Sections</h3> <h3 style="margin-top: 0">Sections</h3>
<p>A - Cartons</p> <p>1 - Setosa</p>
<p>B - Barrels</p> <p>2 - Versicolor</p>
<p>C - Plastic boxes</p> <p>3 - Viginica</p>
<p>D - ______</p>
<p>E - ______</p>
<p>F - ______</p>
<p>G - ______</p>
</div> </div>
</div> </div>
</body> </body>

View File

@ -1,10 +1,11 @@
import math
import json
from django.shortcuts import render from django.shortcuts import render
from django.http import HttpResponse from django.http import HttpResponse
from django.views.decorators.csrf import csrf_exempt from django.views.decorators.csrf import csrf_exempt
import json import tensorflow as tf
import math import numpy as np
# Create your views here. # Create your views here.
@ -14,7 +15,21 @@ def index(request):
@csrf_exempt @csrf_exempt
def classify(request): def classify(request):
return HttpResponse(json.load(request)) loaded_request = json.load(request)
sw = loaded_request['sepalWidth']
sl = loaded_request['sepalLength']
pw = loaded_request['petalWidth']
pl = loaded_request['petalLength']
model = tf.keras.models.load_model('iris_model.h5')
output = model.predict(np.array([[sw, sl, pw, pl]]))
if output[0][0] > output[0][1] and output[0][0] > output[0][1]:
guess = 1
elif output[0][1] > output[0][0] and output[0][1] > output[0][2]:
guess = 2
else:
guess = 3
return HttpResponse(guess)
@csrf_exempt @csrf_exempt
@ -61,4 +76,6 @@ def shortestPath(request):
current = predecessor[current] current = predecessor[current]
path[node] = p[::-1] path[node] = p[::-1]
print(path)
return HttpResponse(path[dest_node][1:]) return HttpResponse(path[dest_node][1:])

28
train_model.py Normal file
View File

@ -0,0 +1,28 @@
import numpy as np
import tensorflow as tf
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from tensorflow.keras import layers
from tensorflow.keras.utils import to_categorical
# Getting data
data_set = load_iris()
x = data_set['data']
y = to_categorical(data_set['target'])
train_x, test_x, train_y, test_y = train_test_split(x, y, test_size=0.2)
# Building the model
model = tf.keras.Sequential()
model.add(layers.Dense(8, activation='relu', input_dim=4))
model.add(layers.Dense(3, activation='sigmoid'))
model.compile(optimizer='adam', loss='categorical_crossentropy',
metrics=['accuracy'])
# Training the model
model.fit(train_x, train_y, validation_data=(test_x, test_y), epochs=2000)
model.save('iris_model.h5')