38 lines
883 B
Python
38 lines
883 B
Python
from flask import Flask, request
|
|
from utils import POC
|
|
from transformers import GPT2Tokenizer, GPT2LMHeadModel
|
|
import torch
|
|
import json
|
|
|
|
app = Flask(__name__)
|
|
poc = POC()
|
|
|
|
|
|
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
|
model = GPT2LMHeadModel.from_pretrained('gpt2')
|
|
model_name = "model-4.pt"
|
|
|
|
is_available = torch.cuda.is_available()
|
|
if(is_available):
|
|
model.load_state_dict(torch.load(model_name))
|
|
else:
|
|
model.load_state_dict(torch.load(model_name, map_location=torch.device('cpu')))
|
|
|
|
|
|
"""
|
|
Example:
|
|
{
|
|
"data": "asdasdas",
|
|
"length": 300
|
|
}
|
|
"""
|
|
@app.route("/generate", methods=['POST'])
|
|
def hello_world():
|
|
data = request.data
|
|
data_dict = json.loads(data)
|
|
x = poc.generate(model, tokenizer, data_dict['data'], entry_count=1, entry_length=data_dict['length'])
|
|
print(x)
|
|
return " ".join(x)
|
|
|
|
if __name__ == "__main__":
|
|
app.run(host='0.0.0.0', port=4999) |