pbr-ayct-core/app.py
2022-01-24 19:22:04 +01:00

38 lines
883 B
Python

from flask import Flask, request
from utils import POC
from transformers import GPT2Tokenizer, GPT2LMHeadModel
import torch
import json
app = Flask(__name__)
poc = POC()
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model = GPT2LMHeadModel.from_pretrained('gpt2')
model_name = "model-4.pt"
is_available = torch.cuda.is_available()
if(is_available):
model.load_state_dict(torch.load(model_name))
else:
model.load_state_dict(torch.load(model_name, map_location=torch.device('cpu')))
"""
Example:
{
"data": "asdasdas",
"length": 300
}
"""
@app.route("/generate", methods=['POST'])
def hello_world():
data = request.data
data_dict = json.loads(data)
x = poc.generate(model, tokenizer, data_dict['data'], entry_count=1, entry_length=data_dict['length'])
print(x)
return " ".join(x)
if __name__ == "__main__":
app.run(host='0.0.0.0', port=4999)