changes
This commit is contained in:
parent
a1bd9211bc
commit
7a12d96497
3
.gitignore
vendored
3
.gitignore
vendored
@ -1,3 +1,6 @@
|
|||||||
venv/*
|
venv/*
|
||||||
dist/*
|
dist/*
|
||||||
*.egg-info
|
*.egg-info
|
||||||
|
.vscode
|
||||||
|
__pycache__
|
||||||
|
*.pt
|
13
Dockerfile
Normal file
13
Dockerfile
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
FROM anibali/pytorch:1.8.1-cuda11.1-ubuntu20.04
|
||||||
|
|
||||||
|
WORKDIR /app
|
||||||
|
COPY requirements.txt /app/requirements.txt
|
||||||
|
|
||||||
|
RUN python3 -m pip install -r requirements.txt
|
||||||
|
RUN python3 -m pip install gdown
|
||||||
|
RUN python3 -m spacy download en_core_web_sm
|
||||||
|
RUN python3 -m pip install transformers
|
||||||
|
RUN gdown --fuzzy https://drive.google.com/file/d/1_zMsEVPOUzQe8yFzRj0-Q08BA8LN4H9x/view?usp=sharing
|
||||||
|
COPY *.py /app/
|
||||||
|
EXPOSE 4999
|
||||||
|
ENTRYPOINT ["python3", "app.py"]
|
38
app.py
Normal file
38
app.py
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
from flask import Flask, request
|
||||||
|
from utils import POC
|
||||||
|
from transformers import GPT2Tokenizer, GPT2LMHeadModel
|
||||||
|
import torch
|
||||||
|
import json
|
||||||
|
|
||||||
|
app = Flask(__name__)
|
||||||
|
poc = POC()
|
||||||
|
|
||||||
|
|
||||||
|
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
||||||
|
model = GPT2LMHeadModel.from_pretrained('gpt2')
|
||||||
|
model_name = "model-4.pt"
|
||||||
|
|
||||||
|
is_available = torch.cuda.is_available()
|
||||||
|
if(is_available):
|
||||||
|
model.load_state_dict(torch.load(model_name))
|
||||||
|
else:
|
||||||
|
model.load_state_dict(torch.load(model_name, map_location=torch.device('cpu')))
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
Example:
|
||||||
|
{
|
||||||
|
"data": "asdasdas",
|
||||||
|
"length": 300
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
@app.route("/generate", methods=['POST'])
|
||||||
|
def hello_world():
|
||||||
|
data = request.data
|
||||||
|
data_dict = json.loads(data)
|
||||||
|
x = poc.generate(model, tokenizer, data_dict['data'], entry_count=1, entry_length=data_dict['length'])
|
||||||
|
print(x)
|
||||||
|
return " ".join(x)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
app.run(host='0.0.0.0', port=4999)
|
736
data/BBC News Sample Solution.csv
Normal file
736
data/BBC News Sample Solution.csv
Normal file
@ -0,0 +1,736 @@
|
|||||||
|
ArticleId,Category
|
||||||
|
1018,sport
|
||||||
|
1319,tech
|
||||||
|
1138,business
|
||||||
|
459,entertainment
|
||||||
|
1020,politics
|
||||||
|
51,sport
|
||||||
|
2025,tech
|
||||||
|
1479,business
|
||||||
|
27,entertainment
|
||||||
|
397,politics
|
||||||
|
1644,sport
|
||||||
|
263,tech
|
||||||
|
765,business
|
||||||
|
2134,entertainment
|
||||||
|
297,politics
|
||||||
|
1712,sport
|
||||||
|
1631,tech
|
||||||
|
942,business
|
||||||
|
1549,entertainment
|
||||||
|
516,politics
|
||||||
|
2215,sport
|
||||||
|
531,tech
|
||||||
|
1541,business
|
||||||
|
1340,entertainment
|
||||||
|
56,politics
|
||||||
|
2138,sport
|
||||||
|
338,tech
|
||||||
|
133,business
|
||||||
|
215,entertainment
|
||||||
|
521,politics
|
||||||
|
1777,sport
|
||||||
|
185,tech
|
||||||
|
2105,business
|
||||||
|
1556,entertainment
|
||||||
|
659,politics
|
||||||
|
1837,sport
|
||||||
|
59,tech
|
||||||
|
383,business
|
||||||
|
1581,entertainment
|
||||||
|
2167,politics
|
||||||
|
1593,sport
|
||||||
|
1849,tech
|
||||||
|
1691,business
|
||||||
|
2033,entertainment
|
||||||
|
2092,politics
|
||||||
|
849,sport
|
||||||
|
1780,tech
|
||||||
|
1747,business
|
||||||
|
1509,entertainment
|
||||||
|
1305,politics
|
||||||
|
2055,sport
|
||||||
|
2139,tech
|
||||||
|
57,business
|
||||||
|
1496,entertainment
|
||||||
|
671,politics
|
||||||
|
1193,sport
|
||||||
|
2104,tech
|
||||||
|
872,business
|
||||||
|
606,entertainment
|
||||||
|
1502,politics
|
||||||
|
244,sport
|
||||||
|
381,tech
|
||||||
|
1944,business
|
||||||
|
725,entertainment
|
||||||
|
369,politics
|
||||||
|
211,sport
|
||||||
|
1861,tech
|
||||||
|
1044,business
|
||||||
|
976,entertainment
|
||||||
|
1673,politics
|
||||||
|
442,sport
|
||||||
|
144,tech
|
||||||
|
1954,business
|
||||||
|
48,entertainment
|
||||||
|
608,politics
|
||||||
|
2165,sport
|
||||||
|
913,tech
|
||||||
|
1580,business
|
||||||
|
1061,entertainment
|
||||||
|
529,politics
|
||||||
|
2137,sport
|
||||||
|
1284,tech
|
||||||
|
2040,business
|
||||||
|
349,entertainment
|
||||||
|
1385,politics
|
||||||
|
1307,sport
|
||||||
|
807,tech
|
||||||
|
333,business
|
||||||
|
786,entertainment
|
||||||
|
1775,politics
|
||||||
|
558,sport
|
||||||
|
592,tech
|
||||||
|
2172,business
|
||||||
|
1735,entertainment
|
||||||
|
1188,politics
|
||||||
|
511,sport
|
||||||
|
157,tech
|
||||||
|
646,business
|
||||||
|
8,entertainment
|
||||||
|
821,politics
|
||||||
|
1001,sport
|
||||||
|
309,tech
|
||||||
|
1732,business
|
||||||
|
1291,entertainment
|
||||||
|
2136,politics
|
||||||
|
999,sport
|
||||||
|
1258,tech
|
||||||
|
721,business
|
||||||
|
1922,entertainment
|
||||||
|
954,politics
|
||||||
|
847,sport
|
||||||
|
846,tech
|
||||||
|
313,business
|
||||||
|
1997,entertainment
|
||||||
|
322,politics
|
||||||
|
1221,sport
|
||||||
|
129,tech
|
||||||
|
981,business
|
||||||
|
330,entertainment
|
||||||
|
1172,politics
|
||||||
|
2123,sport
|
||||||
|
1165,tech
|
||||||
|
1592,business
|
||||||
|
2094,entertainment
|
||||||
|
841,politics
|
||||||
|
1224,sport
|
||||||
|
72,tech
|
||||||
|
1176,business
|
||||||
|
1285,entertainment
|
||||||
|
1671,politics
|
||||||
|
67,sport
|
||||||
|
952,tech
|
||||||
|
884,business
|
||||||
|
2011,entertainment
|
||||||
|
1979,politics
|
||||||
|
519,sport
|
||||||
|
1132,tech
|
||||||
|
1782,business
|
||||||
|
112,entertainment
|
||||||
|
1996,politics
|
||||||
|
683,sport
|
||||||
|
1447,tech
|
||||||
|
472,business
|
||||||
|
562,entertainment
|
||||||
|
1476,politics
|
||||||
|
1339,sport
|
||||||
|
1972,tech
|
||||||
|
875,business
|
||||||
|
1067,entertainment
|
||||||
|
1442,politics
|
||||||
|
708,sport
|
||||||
|
515,tech
|
||||||
|
118,business
|
||||||
|
1513,entertainment
|
||||||
|
1227,politics
|
||||||
|
1196,sport
|
||||||
|
635,tech
|
||||||
|
1669,business
|
||||||
|
1823,entertainment
|
||||||
|
483,politics
|
||||||
|
1477,sport
|
||||||
|
509,tech
|
||||||
|
667,business
|
||||||
|
127,entertainment
|
||||||
|
1316,politics
|
||||||
|
1983,sport
|
||||||
|
598,tech
|
||||||
|
1654,business
|
||||||
|
1232,entertainment
|
||||||
|
703,politics
|
||||||
|
2051,sport
|
||||||
|
451,tech
|
||||||
|
1924,business
|
||||||
|
852,entertainment
|
||||||
|
1889,politics
|
||||||
|
1695,sport
|
||||||
|
417,tech
|
||||||
|
2125,business
|
||||||
|
2087,entertainment
|
||||||
|
1274,politics
|
||||||
|
579,sport
|
||||||
|
398,tech
|
||||||
|
1898,business
|
||||||
|
87,entertainment
|
||||||
|
2024,politics
|
||||||
|
1927,sport
|
||||||
|
1019,tech
|
||||||
|
1892,business
|
||||||
|
357,entertainment
|
||||||
|
1161,politics
|
||||||
|
52,sport
|
||||||
|
1798,tech
|
||||||
|
2028,business
|
||||||
|
58,entertainment
|
||||||
|
258,politics
|
||||||
|
638,sport
|
||||||
|
541,tech
|
||||||
|
2190,business
|
||||||
|
1936,entertainment
|
||||||
|
1293,politics
|
||||||
|
2181,sport
|
||||||
|
1329,tech
|
||||||
|
146,business
|
||||||
|
89,entertainment
|
||||||
|
1627,politics
|
||||||
|
1123,sport
|
||||||
|
1233,tech
|
||||||
|
580,business
|
||||||
|
591,entertainment
|
||||||
|
1933,politics
|
||||||
|
911,sport
|
||||||
|
971,tech
|
||||||
|
1154,business
|
||||||
|
1622,entertainment
|
||||||
|
539,politics
|
||||||
|
1131,sport
|
||||||
|
1839,tech
|
||||||
|
545,business
|
||||||
|
1910,entertainment
|
||||||
|
799,politics
|
||||||
|
1016,sport
|
||||||
|
1460,tech
|
||||||
|
1286,business
|
||||||
|
296,entertainment
|
||||||
|
1345,politics
|
||||||
|
1568,sport
|
||||||
|
502,tech
|
||||||
|
1530,business
|
||||||
|
81,entertainment
|
||||||
|
2162,politics
|
||||||
|
1519,sport
|
||||||
|
1634,tech
|
||||||
|
1586,business
|
||||||
|
469,entertainment
|
||||||
|
382,politics
|
||||||
|
353,sport
|
||||||
|
1666,tech
|
||||||
|
1457,business
|
||||||
|
2182,entertainment
|
||||||
|
1591,politics
|
||||||
|
1643,sport
|
||||||
|
1822,tech
|
||||||
|
742,business
|
||||||
|
1956,entertainment
|
||||||
|
1829,politics
|
||||||
|
1675,sport
|
||||||
|
376,tech
|
||||||
|
1626,business
|
||||||
|
168,entertainment
|
||||||
|
321,politics
|
||||||
|
715,sport
|
||||||
|
1278,tech
|
||||||
|
653,business
|
||||||
|
1206,entertainment
|
||||||
|
411,politics
|
||||||
|
931,sport
|
||||||
|
1576,tech
|
||||||
|
1207,business
|
||||||
|
1729,entertainment
|
||||||
|
851,politics
|
||||||
|
994,sport
|
||||||
|
223,tech
|
||||||
|
763,business
|
||||||
|
977,entertainment
|
||||||
|
1094,politics
|
||||||
|
582,sport
|
||||||
|
1845,tech
|
||||||
|
965,business
|
||||||
|
1461,entertainment
|
||||||
|
945,politics
|
||||||
|
746,sport
|
||||||
|
1768,tech
|
||||||
|
1242,business
|
||||||
|
259,entertainment
|
||||||
|
866,politics
|
||||||
|
1288,sport
|
||||||
|
915,tech
|
||||||
|
337,business
|
||||||
|
1842,entertainment
|
||||||
|
975,politics
|
||||||
|
1890,sport
|
||||||
|
1694,tech
|
||||||
|
367,business
|
||||||
|
1270,entertainment
|
||||||
|
1679,politics
|
||||||
|
717,sport
|
||||||
|
838,tech
|
||||||
|
1857,business
|
||||||
|
798,entertainment
|
||||||
|
1915,politics
|
||||||
|
1051,sport
|
||||||
|
99,tech
|
||||||
|
810,business
|
||||||
|
1662,entertainment
|
||||||
|
1521,politics
|
||||||
|
1693,sport
|
||||||
|
347,tech
|
||||||
|
60,business
|
||||||
|
1958,entertainment
|
||||||
|
723,politics
|
||||||
|
402,sport
|
||||||
|
662,tech
|
||||||
|
3,business
|
||||||
|
84,entertainment
|
||||||
|
1551,politics
|
||||||
|
1684,sport
|
||||||
|
806,tech
|
||||||
|
1721,business
|
||||||
|
446,entertainment
|
||||||
|
326,politics
|
||||||
|
2211,sport
|
||||||
|
2037,tech
|
||||||
|
1158,business
|
||||||
|
498,entertainment
|
||||||
|
2122,politics
|
||||||
|
1905,sport
|
||||||
|
564,tech
|
||||||
|
939,business
|
||||||
|
1,entertainment
|
||||||
|
567,politics
|
||||||
|
1219,sport
|
||||||
|
1720,tech
|
||||||
|
1015,business
|
||||||
|
1459,entertainment
|
||||||
|
633,politics
|
||||||
|
1384,sport
|
||||||
|
242,tech
|
||||||
|
1334,business
|
||||||
|
1456,entertainment
|
||||||
|
1298,politics
|
||||||
|
1567,sport
|
||||||
|
205,tech
|
||||||
|
385,business
|
||||||
|
795,entertainment
|
||||||
|
1083,politics
|
||||||
|
1953,sport
|
||||||
|
1987,tech
|
||||||
|
1748,business
|
||||||
|
1869,entertainment
|
||||||
|
34,politics
|
||||||
|
604,sport
|
||||||
|
142,tech
|
||||||
|
1832,business
|
||||||
|
1865,entertainment
|
||||||
|
862,politics
|
||||||
|
1191,sport
|
||||||
|
22,tech
|
||||||
|
1883,business
|
||||||
|
737,entertainment
|
||||||
|
1262,politics
|
||||||
|
1601,sport
|
||||||
|
1404,tech
|
||||||
|
1296,business
|
||||||
|
2163,entertainment
|
||||||
|
1168,politics
|
||||||
|
438,sport
|
||||||
|
1888,tech
|
||||||
|
694,business
|
||||||
|
1663,entertainment
|
||||||
|
861,politics
|
||||||
|
704,sport
|
||||||
|
1010,tech
|
||||||
|
2039,business
|
||||||
|
854,entertainment
|
||||||
|
1562,politics
|
||||||
|
1422,sport
|
||||||
|
874,tech
|
||||||
|
1272,business
|
||||||
|
1405,entertainment
|
||||||
|
2049,politics
|
||||||
|
620,sport
|
||||||
|
2022,tech
|
||||||
|
925,business
|
||||||
|
1573,entertainment
|
||||||
|
1187,politics
|
||||||
|
1471,sport
|
||||||
|
504,tech
|
||||||
|
576,business
|
||||||
|
222,entertainment
|
||||||
|
95,politics
|
||||||
|
332,sport
|
||||||
|
1895,tech
|
||||||
|
2106,business
|
||||||
|
1181,entertainment
|
||||||
|
776,politics
|
||||||
|
9,sport
|
||||||
|
808,tech
|
||||||
|
1002,business
|
||||||
|
784,entertainment
|
||||||
|
845,politics
|
||||||
|
1416,sport
|
||||||
|
1762,tech
|
||||||
|
525,business
|
||||||
|
596,entertainment
|
||||||
|
1800,politics
|
||||||
|
1341,sport
|
||||||
|
113,tech
|
||||||
|
134,business
|
||||||
|
270,entertainment
|
||||||
|
1324,politics
|
||||||
|
12,sport
|
||||||
|
782,tech
|
||||||
|
1357,business
|
||||||
|
1608,entertainment
|
||||||
|
1134,politics
|
||||||
|
276,sport
|
||||||
|
41,tech
|
||||||
|
1705,business
|
||||||
|
732,entertainment
|
||||||
|
968,politics
|
||||||
|
1006,sport
|
||||||
|
1140,tech
|
||||||
|
966,business
|
||||||
|
1399,entertainment
|
||||||
|
586,politics
|
||||||
|
868,sport
|
||||||
|
1698,tech
|
||||||
|
452,business
|
||||||
|
600,entertainment
|
||||||
|
354,politics
|
||||||
|
1281,sport
|
||||||
|
914,tech
|
||||||
|
1430,business
|
||||||
|
554,entertainment
|
||||||
|
1659,politics
|
||||||
|
1727,sport
|
||||||
|
640,tech
|
||||||
|
691,business
|
||||||
|
1531,entertainment
|
||||||
|
1389,politics
|
||||||
|
1545,sport
|
||||||
|
462,tech
|
||||||
|
2101,business
|
||||||
|
1907,entertainment
|
||||||
|
1373,politics
|
||||||
|
1787,sport
|
||||||
|
2202,tech
|
||||||
|
961,business
|
||||||
|
680,entertainment
|
||||||
|
1059,politics
|
||||||
|
626,sport
|
||||||
|
1961,tech
|
||||||
|
1048,business
|
||||||
|
336,entertainment
|
||||||
|
719,politics
|
||||||
|
1563,sport
|
||||||
|
963,tech
|
||||||
|
179,business
|
||||||
|
1438,entertainment
|
||||||
|
1189,politics
|
||||||
|
2126,sport
|
||||||
|
895,tech
|
||||||
|
2005,business
|
||||||
|
819,entertainment
|
||||||
|
23,politics
|
||||||
|
457,sport
|
||||||
|
1423,tech
|
||||||
|
100,business
|
||||||
|
1448,entertainment
|
||||||
|
97,politics
|
||||||
|
172,sport
|
||||||
|
1918,tech
|
||||||
|
938,business
|
||||||
|
752,entertainment
|
||||||
|
2222,politics
|
||||||
|
2038,sport
|
||||||
|
24,tech
|
||||||
|
1325,business
|
||||||
|
409,entertainment
|
||||||
|
624,politics
|
||||||
|
1367,sport
|
||||||
|
621,tech
|
||||||
|
1302,business
|
||||||
|
505,entertainment
|
||||||
|
2057,politics
|
||||||
|
265,sport
|
||||||
|
262,tech
|
||||||
|
2176,business
|
||||||
|
1304,entertainment
|
||||||
|
1539,politics
|
||||||
|
700,sport
|
||||||
|
2124,tech
|
||||||
|
2203,business
|
||||||
|
781,entertainment
|
||||||
|
1558,politics
|
||||||
|
1419,sport
|
||||||
|
1311,tech
|
||||||
|
124,business
|
||||||
|
664,entertainment
|
||||||
|
282,politics
|
||||||
|
1700,sport
|
||||||
|
1107,tech
|
||||||
|
922,business
|
||||||
|
652,entertainment
|
||||||
|
713,politics
|
||||||
|
1068,sport
|
||||||
|
1116,tech
|
||||||
|
869,business
|
||||||
|
1703,entertainment
|
||||||
|
1668,politics
|
||||||
|
501,sport
|
||||||
|
2053,tech
|
||||||
|
908,business
|
||||||
|
1306,entertainment
|
||||||
|
348,politics
|
||||||
|
149,sport
|
||||||
|
1173,tech
|
||||||
|
6,business
|
||||||
|
1180,entertainment
|
||||||
|
107,politics
|
||||||
|
1629,sport
|
||||||
|
131,tech
|
||||||
|
1046,business
|
||||||
|
698,entertainment
|
||||||
|
751,politics
|
||||||
|
1501,sport
|
||||||
|
73,tech
|
||||||
|
1463,business
|
||||||
|
1856,entertainment
|
||||||
|
1074,politics
|
||||||
|
2225,sport
|
||||||
|
1313,tech
|
||||||
|
933,business
|
||||||
|
1136,entertainment
|
||||||
|
779,politics
|
||||||
|
1315,sport
|
||||||
|
1628,tech
|
||||||
|
432,business
|
||||||
|
546,entertainment
|
||||||
|
1779,politics
|
||||||
|
1514,sport
|
||||||
|
1365,tech
|
||||||
|
1776,business
|
||||||
|
109,entertainment
|
||||||
|
235,politics
|
||||||
|
162,sport
|
||||||
|
2071,tech
|
||||||
|
412,business
|
||||||
|
1092,entertainment
|
||||||
|
1352,politics
|
||||||
|
294,sport
|
||||||
|
1973,tech
|
||||||
|
493,business
|
||||||
|
45,entertainment
|
||||||
|
2189,politics
|
||||||
|
1417,sport
|
||||||
|
2140,tech
|
||||||
|
1265,business
|
||||||
|
1353,entertainment
|
||||||
|
830,politics
|
||||||
|
448,sport
|
||||||
|
535,tech
|
||||||
|
1656,business
|
||||||
|
996,entertainment
|
||||||
|
1246,politics
|
||||||
|
161,sport
|
||||||
|
1267,tech
|
||||||
|
138,business
|
||||||
|
1575,entertainment
|
||||||
|
1810,politics
|
||||||
|
2065,sport
|
||||||
|
768,tech
|
||||||
|
1507,business
|
||||||
|
1520,entertainment
|
||||||
|
71,politics
|
||||||
|
1484,sport
|
||||||
|
163,tech
|
||||||
|
1171,business
|
||||||
|
769,entertainment
|
||||||
|
77,politics
|
||||||
|
1252,sport
|
||||||
|
532,tech
|
||||||
|
1975,business
|
||||||
|
178,entertainment
|
||||||
|
1237,politics
|
||||||
|
1899,sport
|
||||||
|
676,tech
|
||||||
|
2008,business
|
||||||
|
1807,entertainment
|
||||||
|
481,politics
|
||||||
|
17,sport
|
||||||
|
837,tech
|
||||||
|
523,business
|
||||||
|
2099,entertainment
|
||||||
|
1957,politics
|
||||||
|
1470,sport
|
||||||
|
1151,tech
|
||||||
|
489,business
|
||||||
|
823,entertainment
|
||||||
|
1743,politics
|
||||||
|
1035,sport
|
||||||
|
2103,tech
|
||||||
|
1230,business
|
||||||
|
1636,entertainment
|
||||||
|
103,politics
|
||||||
|
927,sport
|
||||||
|
379,tech
|
||||||
|
1789,business
|
||||||
|
278,entertainment
|
||||||
|
887,politics
|
||||||
|
1300,sport
|
||||||
|
394,tech
|
||||||
|
625,business
|
||||||
|
1770,entertainment
|
||||||
|
1959,politics
|
||||||
|
1371,sport
|
||||||
|
1359,tech
|
||||||
|
921,business
|
||||||
|
404,entertainment
|
||||||
|
422,politics
|
||||||
|
503,sport
|
||||||
|
1773,tech
|
||||||
|
2174,business
|
||||||
|
1086,entertainment
|
||||||
|
1986,politics
|
||||||
|
2007,sport
|
||||||
|
13,tech
|
||||||
|
796,business
|
||||||
|
585,entertainment
|
||||||
|
1200,politics
|
||||||
|
642,sport
|
||||||
|
237,tech
|
||||||
|
1826,business
|
||||||
|
43,entertainment
|
||||||
|
1594,politics
|
||||||
|
594,sport
|
||||||
|
1084,tech
|
||||||
|
1859,business
|
||||||
|
524,entertainment
|
||||||
|
463,politics
|
||||||
|
1085,sport
|
||||||
|
668,tech
|
||||||
|
2132,business
|
||||||
|
1970,entertainment
|
||||||
|
260,politics
|
||||||
|
141,sport
|
||||||
|
468,tech
|
||||||
|
269,business
|
||||||
|
414,entertainment
|
||||||
|
2097,politics
|
||||||
|
1420,sport
|
||||||
|
1472,tech
|
||||||
|
1757,business
|
||||||
|
243,entertainment
|
||||||
|
1142,politics
|
||||||
|
1295,sport
|
||||||
|
1660,tech
|
||||||
|
1635,business
|
||||||
|
2200,entertainment
|
||||||
|
2146,politics
|
||||||
|
5,sport
|
||||||
|
1256,tech
|
||||||
|
2090,business
|
||||||
|
1816,entertainment
|
||||||
|
1991,politics
|
||||||
|
1310,sport
|
||||||
|
1162,tech
|
||||||
|
499,business
|
||||||
|
285,entertainment
|
||||||
|
1216,politics
|
||||||
|
745,sport
|
||||||
|
421,tech
|
||||||
|
2068,business
|
||||||
|
1145,entertainment
|
||||||
|
681,politics
|
||||||
|
556,sport
|
||||||
|
1063,tech
|
||||||
|
55,business
|
||||||
|
910,entertainment
|
||||||
|
2066,politics
|
||||||
|
1025,sport
|
||||||
|
2154,tech
|
||||||
|
64,business
|
||||||
|
2218,entertainment
|
||||||
|
2013,politics
|
||||||
|
1105,sport
|
||||||
|
669,tech
|
||||||
|
827,business
|
||||||
|
1335,entertainment
|
||||||
|
1426,politics
|
||||||
|
2196,sport
|
||||||
|
368,tech
|
||||||
|
1535,business
|
||||||
|
1672,entertainment
|
||||||
|
599,politics
|
||||||
|
2085,sport
|
||||||
|
1577,tech
|
||||||
|
724,business
|
||||||
|
568,entertainment
|
||||||
|
2018,politics
|
||||||
|
1543,sport
|
||||||
|
308,tech
|
||||||
|
1356,business
|
||||||
|
573,entertainment
|
||||||
|
1746,politics
|
||||||
|
1290,sport
|
||||||
|
272,tech
|
||||||
|
139,business
|
||||||
|
2191,entertainment
|
||||||
|
1255,politics
|
||||||
|
1699,sport
|
||||||
|
1616,tech
|
||||||
|
1619,business
|
||||||
|
1032,entertainment
|
||||||
|
287,politics
|
||||||
|
2194,sport
|
||||||
|
928,tech
|
||||||
|
584,business
|
||||||
|
1722,entertainment
|
||||||
|
1685,politics
|
||||||
|
1828,sport
|
||||||
|
692,tech
|
||||||
|
1393,business
|
||||||
|
540,entertainment
|
||||||
|
1317,politics
|
||||||
|
152,sport
|
||||||
|
650,tech
|
||||||
|
191,business
|
||||||
|
1719,entertainment
|
||||||
|
1564,politics
|
||||||
|
350,sport
|
||||||
|
116,tech
|
||||||
|
1000,business
|
||||||
|
36,entertainment
|
||||||
|
1225,politics
|
||||||
|
2155,sport
|
||||||
|
553,tech
|
||||||
|
551,business
|
||||||
|
1724,entertainment
|
||||||
|
1512,politics
|
||||||
|
1923,sport
|
||||||
|
373,tech
|
||||||
|
1704,business
|
||||||
|
206,entertainment
|
||||||
|
471,politics
|
|
736
data/BBC_News_Test.csv
Normal file
736
data/BBC_News_Test.csv
Normal file
File diff suppressed because one or more lines are too long
1491
data/BBC_News_Train.csv
Normal file
1491
data/BBC_News_Train.csv
Normal file
File diff suppressed because one or more lines are too long
10
docker-compose.yaml
Normal file
10
docker-compose.yaml
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
version: "3.9"
|
||||||
|
services:
|
||||||
|
ayct_core:
|
||||||
|
build:
|
||||||
|
context: .
|
||||||
|
dockerfile: Dockerfile
|
||||||
|
image: ayct_core:latest
|
||||||
|
ports:
|
||||||
|
- "4999:4999"
|
||||||
|
container_name: ayct_core
|
@ -1,6 +0,0 @@
|
|||||||
[build-system]
|
|
||||||
requires = [
|
|
||||||
"setuptools>=42",
|
|
||||||
"wheel"
|
|
||||||
]
|
|
||||||
build-backend = "setuptools.build_meta"
|
|
10
readme.md
10
readme.md
@ -1,10 +0,0 @@
|
|||||||
To build project type:
|
|
||||||
|
|
||||||
```
|
|
||||||
pip3 install virtualenv
|
|
||||||
virtuvalenv venv
|
|
||||||
source venv/bin/activate
|
|
||||||
pip3 install setuptools
|
|
||||||
python3 -m pip install --upgrade build
|
|
||||||
python3 -m build
|
|
||||||
```
|
|
6
requirements.txt
Normal file
6
requirements.txt
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
Flask==2.0.2
|
||||||
|
nltk==3.6.7
|
||||||
|
numpy==1.22.1
|
||||||
|
pandas==1.3.5
|
||||||
|
spacy==3.2.1
|
||||||
|
tqdm==4.62.3
|
24
setup.cfg
24
setup.cfg
@ -1,24 +0,0 @@
|
|||||||
[metadata]
|
|
||||||
name = pbrAyctCore
|
|
||||||
version = 0.0.1
|
|
||||||
author = Ramon Dyzman
|
|
||||||
author_email = ramon.dyzman@gmail.com
|
|
||||||
description = A small example package
|
|
||||||
long_description = file: readme.md
|
|
||||||
long_description_content_type = text/markdown
|
|
||||||
url = https://git.wmi.amu.edu.pl/s415366/pbr-ayct-core
|
|
||||||
project_urls =
|
|
||||||
Bug Tracker = https://github.com/pypa/sampleproject/issues
|
|
||||||
classifiers =
|
|
||||||
Programming Language :: Python :: 3
|
|
||||||
License :: OSI Approved :: MIT License
|
|
||||||
Operating System :: OS Independent
|
|
||||||
|
|
||||||
[options]
|
|
||||||
package_dir =
|
|
||||||
= src
|
|
||||||
packages = find:
|
|
||||||
python_requires = >=3.6
|
|
||||||
|
|
||||||
[options.packages.find]
|
|
||||||
where = src
|
|
@ -1,4 +0,0 @@
|
|||||||
|
|
||||||
|
|
||||||
def getTestString() -> str:
|
|
||||||
return "All You Can Tweet"
|
|
47
test.py
Normal file
47
test.py
Normal file
@ -0,0 +1,47 @@
|
|||||||
|
import pandas as pd
|
||||||
|
from nltk import word_tokenize
|
||||||
|
from utils import Vocabulary
|
||||||
|
from transformers import GPT2Tokenizer, GPT2LMHeadModel
|
||||||
|
from utils import POC
|
||||||
|
from utils import NewsEntry
|
||||||
|
import torch
|
||||||
|
|
||||||
|
# def encode_dataset(data):
|
||||||
|
# return data.apply(lambda x: vocabulary.sentence_to_numbers(x))
|
||||||
|
|
||||||
|
# mode = "train"
|
||||||
|
# mode = "test"
|
||||||
|
mode = "generate"
|
||||||
|
prompt = """
|
||||||
|
Mr Johnson's bakerys is bad. My bakery is good.
|
||||||
|
"""
|
||||||
|
model_name = "model-4.pt"
|
||||||
|
|
||||||
|
train = pd.read_csv("data/BBC_News_Train.csv")
|
||||||
|
test = pd.read_csv("data/BBC_News_Test.csv")
|
||||||
|
length_limit = 1024
|
||||||
|
|
||||||
|
train = train[train["Text"].apply(lambda x: len(word_tokenize(x)) < length_limit)]
|
||||||
|
test = test[test["Text"].apply(lambda x: len(word_tokenize(x)) < length_limit)]
|
||||||
|
|
||||||
|
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
||||||
|
model = GPT2LMHeadModel.from_pretrained('gpt2')
|
||||||
|
|
||||||
|
|
||||||
|
if mode == "train":
|
||||||
|
dataset = NewsEntry(train["Text"], train["Text"], truncate=False, gpt2_type="gpt2")
|
||||||
|
poc = POC()
|
||||||
|
print("Starting training")
|
||||||
|
poc.train(dataset, model, tokenizer, output_dir="./", save_model_on_epoch=True)
|
||||||
|
|
||||||
|
elif mode == "generate":
|
||||||
|
# dataset = NewsEntry(test["Text"], test["Text"], truncate=False, gpt2_type="gpt2")
|
||||||
|
poc = POC()
|
||||||
|
model.load_state_dict(torch.load(model_name))
|
||||||
|
x = poc.generate(model, tokenizer, prompt, entry_count=1, entry_length=300)
|
||||||
|
print(x)
|
||||||
|
|
||||||
|
elif mode == "test":
|
||||||
|
dataset = NewsEntry(test["Text"], test["Text"], truncate=False, gpt2_type="gpt2")
|
||||||
|
poc = POC()
|
||||||
|
model.load_state_dict(torch.load(model_name))
|
206
utils.py
Normal file
206
utils.py
Normal file
@ -0,0 +1,206 @@
|
|||||||
|
import spacy
|
||||||
|
import pandas as pd
|
||||||
|
from transformers import GPT2LMHeadModel, GPT2Tokenizer
|
||||||
|
import numpy as np
|
||||||
|
import random
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import Dataset, DataLoader
|
||||||
|
from transformers import GPT2Tokenizer, GPT2LMHeadModel, AdamW, get_linear_schedule_with_warmup
|
||||||
|
from tqdm import tqdm, trange
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import csv
|
||||||
|
import os
|
||||||
|
|
||||||
|
nlp = spacy.load("en_core_web_sm")
|
||||||
|
|
||||||
|
class NewsEntry(Dataset):
|
||||||
|
def __init__(self, control_code, data, truncate=False, gpt2_type="gpt2", max_length=1024):
|
||||||
|
|
||||||
|
self.tokenizer = GPT2Tokenizer.from_pretrained(gpt2_type)
|
||||||
|
self.data = []
|
||||||
|
|
||||||
|
for row in data:
|
||||||
|
self.data.append(torch.tensor(
|
||||||
|
self.tokenizer.encode(f"<|{control_code}|>{row[:max_length]}<|endoftext|>")
|
||||||
|
))
|
||||||
|
if truncate:
|
||||||
|
self.data = self.data[:20000]
|
||||||
|
self.data_count = len(self.data)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.data_count
|
||||||
|
|
||||||
|
def __getitem__(self, item):
|
||||||
|
return self.data[item]
|
||||||
|
|
||||||
|
class Vocabulary():
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.__UKNOWN__ = -1
|
||||||
|
self.vocab = {}
|
||||||
|
|
||||||
|
def create_vocab(self, data):
|
||||||
|
counter = 0
|
||||||
|
vocab = {}
|
||||||
|
for row in data:
|
||||||
|
for word in nlp(row):
|
||||||
|
ex = word.lemma_
|
||||||
|
if ex in vocab:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
vocab[ex] = counter
|
||||||
|
counter += 1
|
||||||
|
|
||||||
|
self.vocab = vocab
|
||||||
|
|
||||||
|
def word_to_number(self, word):
|
||||||
|
word = nlp(word)
|
||||||
|
for token in word:
|
||||||
|
ex = token.lemma_
|
||||||
|
if ex in self.vocab:
|
||||||
|
return self.vocab[ex]
|
||||||
|
else:
|
||||||
|
return self.__UKNOWN__
|
||||||
|
|
||||||
|
def sentence_to_numbers(self, seq):
|
||||||
|
result = []
|
||||||
|
for word in nlp(seq):
|
||||||
|
ex = word.lemma_
|
||||||
|
if ex in self.vocab:
|
||||||
|
result.append(self.vocab[ex])
|
||||||
|
else:
|
||||||
|
result.append(self.__UKNOWN__)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def sequence_to_numbers(self, seq):
|
||||||
|
return [self.word_to_number(x) for x in seq]
|
||||||
|
|
||||||
|
def pack_tensor(new_tensor, packed_tensor, max_seq_len):
|
||||||
|
if packed_tensor is None:
|
||||||
|
return new_tensor, True, None
|
||||||
|
if new_tensor.size()[1] + packed_tensor.size()[1] > max_seq_len:
|
||||||
|
return packed_tensor, False, new_tensor
|
||||||
|
else:
|
||||||
|
packed_tensor = torch.cat([new_tensor, packed_tensor[:, 1:]], dim=1)
|
||||||
|
return packed_tensor, True, None
|
||||||
|
|
||||||
|
class POC():
|
||||||
|
def __init__(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def train(
|
||||||
|
self, dataset, model, tokenizer,
|
||||||
|
batch_size=16, epochs=5, lr=2e-5,
|
||||||
|
max_seq_len=400, warmup_steps=200,
|
||||||
|
gpt2_type="gpt2", output_dir=".", output_prefix="model",
|
||||||
|
test_mode=False,save_model_on_epoch=False,
|
||||||
|
):
|
||||||
|
acc_steps = 100
|
||||||
|
device=torch.device("cuda")
|
||||||
|
model = model.cuda()
|
||||||
|
model.train()
|
||||||
|
|
||||||
|
optimizer = AdamW(model.parameters(), lr=lr)
|
||||||
|
scheduler = get_linear_schedule_with_warmup(
|
||||||
|
optimizer, num_warmup_steps=warmup_steps, num_training_steps=-1
|
||||||
|
)
|
||||||
|
|
||||||
|
train_dataloader = DataLoader(dataset, batch_size=1, shuffle=True)
|
||||||
|
loss=0
|
||||||
|
accumulating_batch_count = 0
|
||||||
|
input_tensor = None
|
||||||
|
|
||||||
|
for epoch in range(epochs):
|
||||||
|
|
||||||
|
print(f"Training epoch {epoch}")
|
||||||
|
print(loss)
|
||||||
|
for idx, entry in tqdm(enumerate(train_dataloader)):
|
||||||
|
(input_tensor, carry_on, remainder) = pack_tensor(entry, input_tensor, 768)
|
||||||
|
|
||||||
|
if carry_on and idx != len(train_dataloader) - 1:
|
||||||
|
continue
|
||||||
|
|
||||||
|
input_tensor = input_tensor.to(device)
|
||||||
|
outputs = model(input_tensor, labels=input_tensor)
|
||||||
|
loss = outputs[0]
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
|
if (accumulating_batch_count % batch_size) == 0:
|
||||||
|
optimizer.step()
|
||||||
|
scheduler.step()
|
||||||
|
optimizer.zero_grad()
|
||||||
|
model.zero_grad()
|
||||||
|
|
||||||
|
accumulating_batch_count += 1
|
||||||
|
input_tensor = None
|
||||||
|
if save_model_on_epoch:
|
||||||
|
torch.save(
|
||||||
|
model.state_dict(),
|
||||||
|
os.path.join(output_dir, f"{output_prefix}-{epoch}.pt"),
|
||||||
|
)
|
||||||
|
return model
|
||||||
|
|
||||||
|
def generate(
|
||||||
|
self,
|
||||||
|
model,
|
||||||
|
tokenizer,
|
||||||
|
prompt,
|
||||||
|
entry_count=10,
|
||||||
|
entry_length=30, #maximum number of words
|
||||||
|
top_p=0.8,
|
||||||
|
temperature=1.,
|
||||||
|
):
|
||||||
|
model.eval()
|
||||||
|
generated_num = 0
|
||||||
|
generated_list = []
|
||||||
|
|
||||||
|
filter_value = -float("Inf")
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
|
||||||
|
for entry_idx in trange(entry_count):
|
||||||
|
|
||||||
|
entry_finished = False
|
||||||
|
generated = torch.tensor(tokenizer.encode(prompt)).unsqueeze(0)
|
||||||
|
|
||||||
|
if entry_idx % 100 == 0:
|
||||||
|
print(entry_idx)
|
||||||
|
|
||||||
|
for i in range(entry_length):
|
||||||
|
outputs = model(generated, labels=generated)
|
||||||
|
loss, logits = outputs[:2]
|
||||||
|
logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0)
|
||||||
|
|
||||||
|
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
||||||
|
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
||||||
|
|
||||||
|
sorted_indices_to_remove = cumulative_probs > top_p
|
||||||
|
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[
|
||||||
|
..., :-1
|
||||||
|
].clone()
|
||||||
|
sorted_indices_to_remove[..., 0] = 0
|
||||||
|
|
||||||
|
indices_to_remove = sorted_indices[sorted_indices_to_remove]
|
||||||
|
logits[:, indices_to_remove] = filter_value
|
||||||
|
|
||||||
|
next_token = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)
|
||||||
|
generated = torch.cat((generated, next_token), dim=1)
|
||||||
|
|
||||||
|
if next_token in tokenizer.encode("<|endoftext|>"):
|
||||||
|
entry_finished = True
|
||||||
|
|
||||||
|
if entry_finished:
|
||||||
|
|
||||||
|
generated_num = generated_num + 1
|
||||||
|
|
||||||
|
output_list = list(generated.squeeze().numpy())
|
||||||
|
output_text = tokenizer.decode(output_list)
|
||||||
|
generated_list.append(output_text)
|
||||||
|
break
|
||||||
|
|
||||||
|
if not entry_finished:
|
||||||
|
output_list = list(generated.squeeze().numpy())
|
||||||
|
output_text = f"{tokenizer.decode(output_list)}<|endoftext|>"
|
||||||
|
generated_list.append(output_text)
|
||||||
|
|
||||||
|
return generated_list
|
Loading…
Reference in New Issue
Block a user