34 lines
726 B
Python
34 lines
726 B
Python
#!/usr/local/bin/python3
|
|
|
|
from flask import Flask, abort, jsonify, request
|
|
from flask_cors import CORS, cross_origin
|
|
|
|
|
|
from run_generation import generate_text
|
|
|
|
app = Flask(__name__)
|
|
cors = CORS(app)
|
|
app.config['CORS_HEADERS'] = 'Content-Type'
|
|
app.run()
|
|
|
|
|
|
@app.route("/generate", methods=['POS'])
|
|
@cross_origin()
|
|
def get_gen():
|
|
data = request.get_json()
|
|
|
|
if 'text' not in data or len(data['text']) == 0 or 'model' not in data:
|
|
abort(400)
|
|
else:
|
|
text = data['text']
|
|
model = data['model']
|
|
|
|
result = generate_text(
|
|
model_type='gpt2',
|
|
length=100,
|
|
prompt=text,
|
|
model_name_or_path=model
|
|
)
|
|
|
|
return jsonify({'result': result})
|