34 lines
726 B
Python
34 lines
726 B
Python
|
#!/usr/local/bin/python3
|
||
|
|
||
|
from flask import Flask, abort, jsonify, request
|
||
|
from flask_cors import CORS, cross_origin
|
||
|
|
||
|
|
||
|
from run_generation import generate_text
|
||
|
|
||
|
app = Flask(__name__)
|
||
|
cors = CORS(app)
|
||
|
app.config['CORS_HEADERS'] = 'Content-Type'
|
||
|
app.run()
|
||
|
|
||
|
|
||
|
@app.route("/generate", methods=['POS'])
|
||
|
@cross_origin()
|
||
|
def get_gen():
|
||
|
data = request.get_json()
|
||
|
|
||
|
if 'text' not in data or len(data['text']) == 0 or 'model' not in data:
|
||
|
abort(400)
|
||
|
else:
|
||
|
text = data['text']
|
||
|
model = data['model']
|
||
|
|
||
|
result = generate_text(
|
||
|
model_type='gpt2',
|
||
|
length=100,
|
||
|
prompt=text,
|
||
|
model_name_or_path=model
|
||
|
)
|
||
|
|
||
|
return jsonify({'result': result})
|