diff --git a/README.md b/README.md index daca471..285df1b 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,16 @@ # projekt-sztuczna-inteligencja +## Uruchomienie projektu + +Wymaga Pythona oraz pip. + +Należy zainstalować zależności: +```sh +pip install -r requirements.txt +``` + +A następnie uruchomić serwer +```sh +cd src +python3 main.py +``` diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..7ba0612 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,5 @@ +Flask==2.0.1 +matplotlib==3.4.2 +numpy==1.20.3 +scikit-learn==0.24.2 +scipy==1.6.3 diff --git a/src/decision_tree/main.py b/src/decision_tree.py similarity index 62% rename from src/decision_tree/main.py rename to src/decision_tree.py index 6c3e7de..9b1031b 100755 --- a/src/decision_tree/main.py +++ b/src/decision_tree.py @@ -1,31 +1,31 @@ #!/usr/bin/python3 from sklearn import tree from pprint import PrettyPrinter +import pickle +import os pp = PrettyPrinter(indent=2, compact=True) def p(*args, **kwargs): pp.pprint(*args, **kwargs) -def invoke_consume_exceptions(function, *args, **kwargs): - try: - return function(*args, **kwargs) - except: - return None - def read_tsv_from(filename): from csv import reader with open(filename, 'r') as f: header, *rows = list(reader(f, delimiter='\t')) return [dict(zip(header, (el.strip() for el in row))) for row in rows] -def main(): - from sys import argv - import os - import pathlib - source_file = argv[1] - invoke_consume_exceptions(os.mkdir, os.path.dirname(source_file)) - data = read_tsv_from(source_file) +def serialize(clf): + with open('./data/tree', 'wb') as f: + pickle.dump(clf, f) +def deserialize(): + with open('./data/tree', 'rb') as f: + return pickle.load(f) + +def init(source_file, cache=False): + from random import randint + defaults, *data = read_tsv_from(source_file) + defaults.pop("nazwa") types = dict() for row in data: for (key, value) in row.items(): @@ -42,18 +42,29 @@ def main(): t2n = dict((key, dict(v)) for (key, v) in base.items()) n2t = dict((key, dict((b, a) for (a, b) in v)) for (key, v) in base.items()) + X = [[ - t2n[name][feature] + t2n[name][feature] for (name, feature) in sample.items() if name not in ['nazwa', 'polka']] for sample in data if 'polka' in sample] Y = [t2n['polka'][sample['polka']] for sample in data if 'polka' in sample] - clf = tree.DecisionTreeClassifier() - clf.fit(X, Y) - l = clf.get_n_leaves() - d = clf.get_depth() - print(f'Leaves: {l}\nDepth: {d}') + if cache and os.path.isfile('./data/tree'): + clf = deserialize() + else: + clf = tree.DecisionTreeClassifier() + clf.fit(X, Y) + if cache: + serialize(clf) + + return [ + types, + t2n, + n2t, + defaults, + clf + ] if __name__ == '__main__': - main() \ No newline at end of file + init('./data/productsTree.tsv') diff --git a/src/main.py b/src/main.py new file mode 100644 index 0000000..ebb4a4f --- /dev/null +++ b/src/main.py @@ -0,0 +1,40 @@ +from flask import Flask, redirect, jsonify, request +import decision_tree + +app = Flask( + __name__, + static_url_path='', + static_folder='.') + +@app.route('/api/types/', defaults={ 'searched_type': None }) +@app.route('/api/types/') +def api_get_features_types(searched_type): + return t2n[searched_type] if searched_type else t2n + +@app.route('/api/types/defaults') +def api_get_default_features(): + return jsonify(defaults) + +@app.route('/api/decide', methods=['POST']) +def api_predict_shelf(): + json = request.get_json(force=True) + defs = dict(defaults) + defs.pop('polka') + keys = list(json.keys()) + X = [] + + for (key, default) in defs.items(): + if key in keys: + X.append(t2n[key][json[key]]) + else: + X.append(t2n[key][default]) + + return jsonify([ n2t['polka'][clf.predict([X]).tolist()[0]] ]) + +@app.route('/') +def index(): + return redirect('/index.html') + +categories, t2n, n2t, defaults, clf = decision_tree.init('data/productsTree.tsv', cache=True) +categories = dict((key, list(vals)) for (key, vals) in categories.items()) +app.run() diff --git a/test.sh b/test.sh new file mode 100755 index 0000000..f0d4ba7 --- /dev/null +++ b/test.sh @@ -0,0 +1,10 @@ +#!/bin/sh + +decide () { + echo Input: "$1" + echo Output: + curl "http://localhost:5000/api/decide" -d"$1" +} + +decide '{ "nazwa": "chleb", "kategoria": "kuchnia", "typ_zywnosci": "gotowe" }' +decide '{ "nazwa": "parowki", "kategoria": "kuchnia", "typ_zywnosci": "mieso", "termin_przydatnosci": "krotki" }'