added server and finalized tree creation

This commit is contained in:
Robert Bendun 2021-05-24 14:58:18 +02:00
parent 1b71852d8b
commit 334d15ead5
5 changed files with 100 additions and 20 deletions

View File

@ -1,2 +1,16 @@
# projekt-sztuczna-inteligencja # projekt-sztuczna-inteligencja
## Uruchomienie projektu
Wymaga Pythona oraz pip.
Należy zainstalować zależności:
```sh
pip install -r requirements.txt
```
A następnie uruchomić serwer
```sh
cd src
python3 main.py
```

5
requirements.txt Normal file
View File

@ -0,0 +1,5 @@
Flask==2.0.1
matplotlib==3.4.2
numpy==1.20.3
scikit-learn==0.24.2
scipy==1.6.3

View File

@ -1,31 +1,31 @@
#!/usr/bin/python3 #!/usr/bin/python3
from sklearn import tree from sklearn import tree
from pprint import PrettyPrinter from pprint import PrettyPrinter
import pickle
import os
pp = PrettyPrinter(indent=2, compact=True) pp = PrettyPrinter(indent=2, compact=True)
def p(*args, **kwargs): def p(*args, **kwargs):
pp.pprint(*args, **kwargs) pp.pprint(*args, **kwargs)
def invoke_consume_exceptions(function, *args, **kwargs):
try:
return function(*args, **kwargs)
except:
return None
def read_tsv_from(filename): def read_tsv_from(filename):
from csv import reader from csv import reader
with open(filename, 'r') as f: with open(filename, 'r') as f:
header, *rows = list(reader(f, delimiter='\t')) header, *rows = list(reader(f, delimiter='\t'))
return [dict(zip(header, (el.strip() for el in row))) for row in rows] return [dict(zip(header, (el.strip() for el in row))) for row in rows]
def main(): def serialize(clf):
from sys import argv with open('./data/tree', 'wb') as f:
import os pickle.dump(clf, f)
import pathlib
source_file = argv[1]
invoke_consume_exceptions(os.mkdir, os.path.dirname(source_file))
data = read_tsv_from(source_file)
def deserialize():
with open('./data/tree', 'rb') as f:
return pickle.load(f)
def init(source_file, cache=False):
from random import randint
defaults, *data = read_tsv_from(source_file)
defaults.pop("nazwa")
types = dict() types = dict()
for row in data: for row in data:
for (key, value) in row.items(): for (key, value) in row.items():
@ -42,6 +42,7 @@ def main():
t2n = dict((key, dict(v)) for (key, v) in base.items()) t2n = dict((key, dict(v)) for (key, v) in base.items())
n2t = dict((key, dict((b, a) for (a, b) in v)) for (key, v) in base.items()) n2t = dict((key, dict((b, a) for (a, b) in v)) for (key, v) in base.items())
X = [[ X = [[
t2n[name][feature] t2n[name][feature]
for (name, feature) in sample.items() for (name, feature) in sample.items()
@ -49,11 +50,21 @@ def main():
Y = [t2n['polka'][sample['polka']] for sample in data if 'polka' in sample] Y = [t2n['polka'][sample['polka']] for sample in data if 'polka' in sample]
if cache and os.path.isfile('./data/tree'):
clf = deserialize()
else:
clf = tree.DecisionTreeClassifier() clf = tree.DecisionTreeClassifier()
clf.fit(X, Y) clf.fit(X, Y)
l = clf.get_n_leaves() if cache:
d = clf.get_depth() serialize(clf)
print(f'Leaves: {l}\nDepth: {d}')
return [
types,
t2n,
n2t,
defaults,
clf
]
if __name__ == '__main__': if __name__ == '__main__':
main() init('./data/productsTree.tsv')

40
src/main.py Normal file
View File

@ -0,0 +1,40 @@
from flask import Flask, redirect, jsonify, request
import decision_tree
app = Flask(
__name__,
static_url_path='',
static_folder='.')
@app.route('/api/types/', defaults={ 'searched_type': None })
@app.route('/api/types/<searched_type>')
def api_get_features_types(searched_type):
return t2n[searched_type] if searched_type else t2n
@app.route('/api/types/defaults')
def api_get_default_features():
return jsonify(defaults)
@app.route('/api/decide', methods=['POST'])
def api_predict_shelf():
json = request.get_json(force=True)
defs = dict(defaults)
defs.pop('polka')
keys = list(json.keys())
X = []
for (key, default) in defs.items():
if key in keys:
X.append(t2n[key][json[key]])
else:
X.append(t2n[key][default])
return jsonify([ n2t['polka'][clf.predict([X]).tolist()[0]] ])
@app.route('/')
def index():
return redirect('/index.html')
categories, t2n, n2t, defaults, clf = decision_tree.init('data/productsTree.tsv', cache=True)
categories = dict((key, list(vals)) for (key, vals) in categories.items())
app.run()

10
test.sh Executable file
View File

@ -0,0 +1,10 @@
#!/bin/sh
decide () {
echo Input: "$1"
echo Output:
curl "http://localhost:5000/api/decide" -d"$1"
}
decide '{ "nazwa": "chleb", "kategoria": "kuchnia", "typ_zywnosci": "gotowe" }'
decide '{ "nazwa": "parowki", "kategoria": "kuchnia", "typ_zywnosci": "mieso", "termin_przydatnosci": "krotki" }'