added server and finalized tree creation
This commit is contained in:
parent
1b71852d8b
commit
334d15ead5
14
README.md
14
README.md
@ -1,2 +1,16 @@
|
||||
# projekt-sztuczna-inteligencja
|
||||
|
||||
## Uruchomienie projektu
|
||||
|
||||
Wymaga Pythona oraz pip.
|
||||
|
||||
Należy zainstalować zależności:
|
||||
```sh
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
A następnie uruchomić serwer
|
||||
```sh
|
||||
cd src
|
||||
python3 main.py
|
||||
```
|
||||
|
5
requirements.txt
Normal file
5
requirements.txt
Normal file
@ -0,0 +1,5 @@
|
||||
Flask==2.0.1
|
||||
matplotlib==3.4.2
|
||||
numpy==1.20.3
|
||||
scikit-learn==0.24.2
|
||||
scipy==1.6.3
|
@ -1,31 +1,31 @@
|
||||
#!/usr/bin/python3
|
||||
from sklearn import tree
|
||||
from pprint import PrettyPrinter
|
||||
import pickle
|
||||
import os
|
||||
|
||||
pp = PrettyPrinter(indent=2, compact=True)
|
||||
def p(*args, **kwargs):
|
||||
pp.pprint(*args, **kwargs)
|
||||
|
||||
def invoke_consume_exceptions(function, *args, **kwargs):
|
||||
try:
|
||||
return function(*args, **kwargs)
|
||||
except:
|
||||
return None
|
||||
|
||||
def read_tsv_from(filename):
|
||||
from csv import reader
|
||||
with open(filename, 'r') as f:
|
||||
header, *rows = list(reader(f, delimiter='\t'))
|
||||
return [dict(zip(header, (el.strip() for el in row))) for row in rows]
|
||||
|
||||
def main():
|
||||
from sys import argv
|
||||
import os
|
||||
import pathlib
|
||||
source_file = argv[1]
|
||||
invoke_consume_exceptions(os.mkdir, os.path.dirname(source_file))
|
||||
data = read_tsv_from(source_file)
|
||||
def serialize(clf):
|
||||
with open('./data/tree', 'wb') as f:
|
||||
pickle.dump(clf, f)
|
||||
|
||||
def deserialize():
|
||||
with open('./data/tree', 'rb') as f:
|
||||
return pickle.load(f)
|
||||
|
||||
def init(source_file, cache=False):
|
||||
from random import randint
|
||||
defaults, *data = read_tsv_from(source_file)
|
||||
defaults.pop("nazwa")
|
||||
types = dict()
|
||||
for row in data:
|
||||
for (key, value) in row.items():
|
||||
@ -42,18 +42,29 @@ def main():
|
||||
t2n = dict((key, dict(v)) for (key, v) in base.items())
|
||||
n2t = dict((key, dict((b, a) for (a, b) in v)) for (key, v) in base.items())
|
||||
|
||||
|
||||
X = [[
|
||||
t2n[name][feature]
|
||||
t2n[name][feature]
|
||||
for (name, feature) in sample.items()
|
||||
if name not in ['nazwa', 'polka']] for sample in data if 'polka' in sample]
|
||||
|
||||
Y = [t2n['polka'][sample['polka']] for sample in data if 'polka' in sample]
|
||||
|
||||
clf = tree.DecisionTreeClassifier()
|
||||
clf.fit(X, Y)
|
||||
l = clf.get_n_leaves()
|
||||
d = clf.get_depth()
|
||||
print(f'Leaves: {l}\nDepth: {d}')
|
||||
if cache and os.path.isfile('./data/tree'):
|
||||
clf = deserialize()
|
||||
else:
|
||||
clf = tree.DecisionTreeClassifier()
|
||||
clf.fit(X, Y)
|
||||
if cache:
|
||||
serialize(clf)
|
||||
|
||||
return [
|
||||
types,
|
||||
t2n,
|
||||
n2t,
|
||||
defaults,
|
||||
clf
|
||||
]
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
init('./data/productsTree.tsv')
|
40
src/main.py
Normal file
40
src/main.py
Normal file
@ -0,0 +1,40 @@
|
||||
from flask import Flask, redirect, jsonify, request
|
||||
import decision_tree
|
||||
|
||||
app = Flask(
|
||||
__name__,
|
||||
static_url_path='',
|
||||
static_folder='.')
|
||||
|
||||
@app.route('/api/types/', defaults={ 'searched_type': None })
|
||||
@app.route('/api/types/<searched_type>')
|
||||
def api_get_features_types(searched_type):
|
||||
return t2n[searched_type] if searched_type else t2n
|
||||
|
||||
@app.route('/api/types/defaults')
|
||||
def api_get_default_features():
|
||||
return jsonify(defaults)
|
||||
|
||||
@app.route('/api/decide', methods=['POST'])
|
||||
def api_predict_shelf():
|
||||
json = request.get_json(force=True)
|
||||
defs = dict(defaults)
|
||||
defs.pop('polka')
|
||||
keys = list(json.keys())
|
||||
X = []
|
||||
|
||||
for (key, default) in defs.items():
|
||||
if key in keys:
|
||||
X.append(t2n[key][json[key]])
|
||||
else:
|
||||
X.append(t2n[key][default])
|
||||
|
||||
return jsonify([ n2t['polka'][clf.predict([X]).tolist()[0]] ])
|
||||
|
||||
@app.route('/')
|
||||
def index():
|
||||
return redirect('/index.html')
|
||||
|
||||
categories, t2n, n2t, defaults, clf = decision_tree.init('data/productsTree.tsv', cache=True)
|
||||
categories = dict((key, list(vals)) for (key, vals) in categories.items())
|
||||
app.run()
|
10
test.sh
Executable file
10
test.sh
Executable file
@ -0,0 +1,10 @@
|
||||
#!/bin/sh
|
||||
|
||||
decide () {
|
||||
echo Input: "$1"
|
||||
echo Output:
|
||||
curl "http://localhost:5000/api/decide" -d"$1"
|
||||
}
|
||||
|
||||
decide '{ "nazwa": "chleb", "kategoria": "kuchnia", "typ_zywnosci": "gotowe" }'
|
||||
decide '{ "nazwa": "parowki", "kategoria": "kuchnia", "typ_zywnosci": "mieso", "termin_przydatnosci": "krotki" }'
|
Loading…
Reference in New Issue
Block a user