changes
This commit is contained in:
parent
a1bd9211bc
commit
7a12d96497
5
.gitignore
vendored
5
.gitignore
vendored
@ -1,3 +1,6 @@
|
||||
venv/*
|
||||
dist/*
|
||||
*.egg-info
|
||||
*.egg-info
|
||||
.vscode
|
||||
__pycache__
|
||||
*.pt
|
13
Dockerfile
Normal file
13
Dockerfile
Normal file
@ -0,0 +1,13 @@
|
||||
FROM anibali/pytorch:1.8.1-cuda11.1-ubuntu20.04
|
||||
|
||||
WORKDIR /app
|
||||
COPY requirements.txt /app/requirements.txt
|
||||
|
||||
RUN python3 -m pip install -r requirements.txt
|
||||
RUN python3 -m pip install gdown
|
||||
RUN python3 -m spacy download en_core_web_sm
|
||||
RUN python3 -m pip install transformers
|
||||
RUN gdown --fuzzy https://drive.google.com/file/d/1_zMsEVPOUzQe8yFzRj0-Q08BA8LN4H9x/view?usp=sharing
|
||||
COPY *.py /app/
|
||||
EXPOSE 4999
|
||||
ENTRYPOINT ["python3", "app.py"]
|
38
app.py
Normal file
38
app.py
Normal file
@ -0,0 +1,38 @@
|
||||
from flask import Flask, request
|
||||
from utils import POC
|
||||
from transformers import GPT2Tokenizer, GPT2LMHeadModel
|
||||
import torch
|
||||
import json
|
||||
|
||||
app = Flask(__name__)
|
||||
poc = POC()
|
||||
|
||||
|
||||
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
||||
model = GPT2LMHeadModel.from_pretrained('gpt2')
|
||||
model_name = "model-4.pt"
|
||||
|
||||
is_available = torch.cuda.is_available()
|
||||
if(is_available):
|
||||
model.load_state_dict(torch.load(model_name))
|
||||
else:
|
||||
model.load_state_dict(torch.load(model_name, map_location=torch.device('cpu')))
|
||||
|
||||
|
||||
"""
|
||||
Example:
|
||||
{
|
||||
"data": "asdasdas",
|
||||
"length": 300
|
||||
}
|
||||
"""
|
||||
@app.route("/generate", methods=['POST'])
|
||||
def hello_world():
|
||||
data = request.data
|
||||
data_dict = json.loads(data)
|
||||
x = poc.generate(model, tokenizer, data_dict['data'], entry_count=1, entry_length=data_dict['length'])
|
||||
print(x)
|
||||
return " ".join(x)
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run(host='0.0.0.0', port=4999)
|
736
data/BBC News Sample Solution.csv
Normal file
736
data/BBC News Sample Solution.csv
Normal file
@ -0,0 +1,736 @@
|
||||
ArticleId,Category
|
||||
1018,sport
|
||||
1319,tech
|
||||
1138,business
|
||||
459,entertainment
|
||||
1020,politics
|
||||
51,sport
|
||||
2025,tech
|
||||
1479,business
|
||||
27,entertainment
|
||||
397,politics
|
||||
1644,sport
|
||||
263,tech
|
||||
765,business
|
||||
2134,entertainment
|
||||
297,politics
|
||||
1712,sport
|
||||
1631,tech
|
||||
942,business
|
||||
1549,entertainment
|
||||
516,politics
|
||||
2215,sport
|
||||
531,tech
|
||||
1541,business
|
||||
1340,entertainment
|
||||
56,politics
|
||||
2138,sport
|
||||
338,tech
|
||||
133,business
|
||||
215,entertainment
|
||||
521,politics
|
||||
1777,sport
|
||||
185,tech
|
||||
2105,business
|
||||
1556,entertainment
|
||||
659,politics
|
||||
1837,sport
|
||||
59,tech
|
||||
383,business
|
||||
1581,entertainment
|
||||
2167,politics
|
||||
1593,sport
|
||||
1849,tech
|
||||
1691,business
|
||||
2033,entertainment
|
||||
2092,politics
|
||||
849,sport
|
||||
1780,tech
|
||||
1747,business
|
||||
1509,entertainment
|
||||
1305,politics
|
||||
2055,sport
|
||||
2139,tech
|
||||
57,business
|
||||
1496,entertainment
|
||||
671,politics
|
||||
1193,sport
|
||||
2104,tech
|
||||
872,business
|
||||
606,entertainment
|
||||
1502,politics
|
||||
244,sport
|
||||
381,tech
|
||||
1944,business
|
||||
725,entertainment
|
||||
369,politics
|
||||
211,sport
|
||||
1861,tech
|
||||
1044,business
|
||||
976,entertainment
|
||||
1673,politics
|
||||
442,sport
|
||||
144,tech
|
||||
1954,business
|
||||
48,entertainment
|
||||
608,politics
|
||||
2165,sport
|
||||
913,tech
|
||||
1580,business
|
||||
1061,entertainment
|
||||
529,politics
|
||||
2137,sport
|
||||
1284,tech
|
||||
2040,business
|
||||
349,entertainment
|
||||
1385,politics
|
||||
1307,sport
|
||||
807,tech
|
||||
333,business
|
||||
786,entertainment
|
||||
1775,politics
|
||||
558,sport
|
||||
592,tech
|
||||
2172,business
|
||||
1735,entertainment
|
||||
1188,politics
|
||||
511,sport
|
||||
157,tech
|
||||
646,business
|
||||
8,entertainment
|
||||
821,politics
|
||||
1001,sport
|
||||
309,tech
|
||||
1732,business
|
||||
1291,entertainment
|
||||
2136,politics
|
||||
999,sport
|
||||
1258,tech
|
||||
721,business
|
||||
1922,entertainment
|
||||
954,politics
|
||||
847,sport
|
||||
846,tech
|
||||
313,business
|
||||
1997,entertainment
|
||||
322,politics
|
||||
1221,sport
|
||||
129,tech
|
||||
981,business
|
||||
330,entertainment
|
||||
1172,politics
|
||||
2123,sport
|
||||
1165,tech
|
||||
1592,business
|
||||
2094,entertainment
|
||||
841,politics
|
||||
1224,sport
|
||||
72,tech
|
||||
1176,business
|
||||
1285,entertainment
|
||||
1671,politics
|
||||
67,sport
|
||||
952,tech
|
||||
884,business
|
||||
2011,entertainment
|
||||
1979,politics
|
||||
519,sport
|
||||
1132,tech
|
||||
1782,business
|
||||
112,entertainment
|
||||
1996,politics
|
||||
683,sport
|
||||
1447,tech
|
||||
472,business
|
||||
562,entertainment
|
||||
1476,politics
|
||||
1339,sport
|
||||
1972,tech
|
||||
875,business
|
||||
1067,entertainment
|
||||
1442,politics
|
||||
708,sport
|
||||
515,tech
|
||||
118,business
|
||||
1513,entertainment
|
||||
1227,politics
|
||||
1196,sport
|
||||
635,tech
|
||||
1669,business
|
||||
1823,entertainment
|
||||
483,politics
|
||||
1477,sport
|
||||
509,tech
|
||||
667,business
|
||||
127,entertainment
|
||||
1316,politics
|
||||
1983,sport
|
||||
598,tech
|
||||
1654,business
|
||||
1232,entertainment
|
||||
703,politics
|
||||
2051,sport
|
||||
451,tech
|
||||
1924,business
|
||||
852,entertainment
|
||||
1889,politics
|
||||
1695,sport
|
||||
417,tech
|
||||
2125,business
|
||||
2087,entertainment
|
||||
1274,politics
|
||||
579,sport
|
||||
398,tech
|
||||
1898,business
|
||||
87,entertainment
|
||||
2024,politics
|
||||
1927,sport
|
||||
1019,tech
|
||||
1892,business
|
||||
357,entertainment
|
||||
1161,politics
|
||||
52,sport
|
||||
1798,tech
|
||||
2028,business
|
||||
58,entertainment
|
||||
258,politics
|
||||
638,sport
|
||||
541,tech
|
||||
2190,business
|
||||
1936,entertainment
|
||||
1293,politics
|
||||
2181,sport
|
||||
1329,tech
|
||||
146,business
|
||||
89,entertainment
|
||||
1627,politics
|
||||
1123,sport
|
||||
1233,tech
|
||||
580,business
|
||||
591,entertainment
|
||||
1933,politics
|
||||
911,sport
|
||||
971,tech
|
||||
1154,business
|
||||
1622,entertainment
|
||||
539,politics
|
||||
1131,sport
|
||||
1839,tech
|
||||
545,business
|
||||
1910,entertainment
|
||||
799,politics
|
||||
1016,sport
|
||||
1460,tech
|
||||
1286,business
|
||||
296,entertainment
|
||||
1345,politics
|
||||
1568,sport
|
||||
502,tech
|
||||
1530,business
|
||||
81,entertainment
|
||||
2162,politics
|
||||
1519,sport
|
||||
1634,tech
|
||||
1586,business
|
||||
469,entertainment
|
||||
382,politics
|
||||
353,sport
|
||||
1666,tech
|
||||
1457,business
|
||||
2182,entertainment
|
||||
1591,politics
|
||||
1643,sport
|
||||
1822,tech
|
||||
742,business
|
||||
1956,entertainment
|
||||
1829,politics
|
||||
1675,sport
|
||||
376,tech
|
||||
1626,business
|
||||
168,entertainment
|
||||
321,politics
|
||||
715,sport
|
||||
1278,tech
|
||||
653,business
|
||||
1206,entertainment
|
||||
411,politics
|
||||
931,sport
|
||||
1576,tech
|
||||
1207,business
|
||||
1729,entertainment
|
||||
851,politics
|
||||
994,sport
|
||||
223,tech
|
||||
763,business
|
||||
977,entertainment
|
||||
1094,politics
|
||||
582,sport
|
||||
1845,tech
|
||||
965,business
|
||||
1461,entertainment
|
||||
945,politics
|
||||
746,sport
|
||||
1768,tech
|
||||
1242,business
|
||||
259,entertainment
|
||||
866,politics
|
||||
1288,sport
|
||||
915,tech
|
||||
337,business
|
||||
1842,entertainment
|
||||
975,politics
|
||||
1890,sport
|
||||
1694,tech
|
||||
367,business
|
||||
1270,entertainment
|
||||
1679,politics
|
||||
717,sport
|
||||
838,tech
|
||||
1857,business
|
||||
798,entertainment
|
||||
1915,politics
|
||||
1051,sport
|
||||
99,tech
|
||||
810,business
|
||||
1662,entertainment
|
||||
1521,politics
|
||||
1693,sport
|
||||
347,tech
|
||||
60,business
|
||||
1958,entertainment
|
||||
723,politics
|
||||
402,sport
|
||||
662,tech
|
||||
3,business
|
||||
84,entertainment
|
||||
1551,politics
|
||||
1684,sport
|
||||
806,tech
|
||||
1721,business
|
||||
446,entertainment
|
||||
326,politics
|
||||
2211,sport
|
||||
2037,tech
|
||||
1158,business
|
||||
498,entertainment
|
||||
2122,politics
|
||||
1905,sport
|
||||
564,tech
|
||||
939,business
|
||||
1,entertainment
|
||||
567,politics
|
||||
1219,sport
|
||||
1720,tech
|
||||
1015,business
|
||||
1459,entertainment
|
||||
633,politics
|
||||
1384,sport
|
||||
242,tech
|
||||
1334,business
|
||||
1456,entertainment
|
||||
1298,politics
|
||||
1567,sport
|
||||
205,tech
|
||||
385,business
|
||||
795,entertainment
|
||||
1083,politics
|
||||
1953,sport
|
||||
1987,tech
|
||||
1748,business
|
||||
1869,entertainment
|
||||
34,politics
|
||||
604,sport
|
||||
142,tech
|
||||
1832,business
|
||||
1865,entertainment
|
||||
862,politics
|
||||
1191,sport
|
||||
22,tech
|
||||
1883,business
|
||||
737,entertainment
|
||||
1262,politics
|
||||
1601,sport
|
||||
1404,tech
|
||||
1296,business
|
||||
2163,entertainment
|
||||
1168,politics
|
||||
438,sport
|
||||
1888,tech
|
||||
694,business
|
||||
1663,entertainment
|
||||
861,politics
|
||||
704,sport
|
||||
1010,tech
|
||||
2039,business
|
||||
854,entertainment
|
||||
1562,politics
|
||||
1422,sport
|
||||
874,tech
|
||||
1272,business
|
||||
1405,entertainment
|
||||
2049,politics
|
||||
620,sport
|
||||
2022,tech
|
||||
925,business
|
||||
1573,entertainment
|
||||
1187,politics
|
||||
1471,sport
|
||||
504,tech
|
||||
576,business
|
||||
222,entertainment
|
||||
95,politics
|
||||
332,sport
|
||||
1895,tech
|
||||
2106,business
|
||||
1181,entertainment
|
||||
776,politics
|
||||
9,sport
|
||||
808,tech
|
||||
1002,business
|
||||
784,entertainment
|
||||
845,politics
|
||||
1416,sport
|
||||
1762,tech
|
||||
525,business
|
||||
596,entertainment
|
||||
1800,politics
|
||||
1341,sport
|
||||
113,tech
|
||||
134,business
|
||||
270,entertainment
|
||||
1324,politics
|
||||
12,sport
|
||||
782,tech
|
||||
1357,business
|
||||
1608,entertainment
|
||||
1134,politics
|
||||
276,sport
|
||||
41,tech
|
||||
1705,business
|
||||
732,entertainment
|
||||
968,politics
|
||||
1006,sport
|
||||
1140,tech
|
||||
966,business
|
||||
1399,entertainment
|
||||
586,politics
|
||||
868,sport
|
||||
1698,tech
|
||||
452,business
|
||||
600,entertainment
|
||||
354,politics
|
||||
1281,sport
|
||||
914,tech
|
||||
1430,business
|
||||
554,entertainment
|
||||
1659,politics
|
||||
1727,sport
|
||||
640,tech
|
||||
691,business
|
||||
1531,entertainment
|
||||
1389,politics
|
||||
1545,sport
|
||||
462,tech
|
||||
2101,business
|
||||
1907,entertainment
|
||||
1373,politics
|
||||
1787,sport
|
||||
2202,tech
|
||||
961,business
|
||||
680,entertainment
|
||||
1059,politics
|
||||
626,sport
|
||||
1961,tech
|
||||
1048,business
|
||||
336,entertainment
|
||||
719,politics
|
||||
1563,sport
|
||||
963,tech
|
||||
179,business
|
||||
1438,entertainment
|
||||
1189,politics
|
||||
2126,sport
|
||||
895,tech
|
||||
2005,business
|
||||
819,entertainment
|
||||
23,politics
|
||||
457,sport
|
||||
1423,tech
|
||||
100,business
|
||||
1448,entertainment
|
||||
97,politics
|
||||
172,sport
|
||||
1918,tech
|
||||
938,business
|
||||
752,entertainment
|
||||
2222,politics
|
||||
2038,sport
|
||||
24,tech
|
||||
1325,business
|
||||
409,entertainment
|
||||
624,politics
|
||||
1367,sport
|
||||
621,tech
|
||||
1302,business
|
||||
505,entertainment
|
||||
2057,politics
|
||||
265,sport
|
||||
262,tech
|
||||
2176,business
|
||||
1304,entertainment
|
||||
1539,politics
|
||||
700,sport
|
||||
2124,tech
|
||||
2203,business
|
||||
781,entertainment
|
||||
1558,politics
|
||||
1419,sport
|
||||
1311,tech
|
||||
124,business
|
||||
664,entertainment
|
||||
282,politics
|
||||
1700,sport
|
||||
1107,tech
|
||||
922,business
|
||||
652,entertainment
|
||||
713,politics
|
||||
1068,sport
|
||||
1116,tech
|
||||
869,business
|
||||
1703,entertainment
|
||||
1668,politics
|
||||
501,sport
|
||||
2053,tech
|
||||
908,business
|
||||
1306,entertainment
|
||||
348,politics
|
||||
149,sport
|
||||
1173,tech
|
||||
6,business
|
||||
1180,entertainment
|
||||
107,politics
|
||||
1629,sport
|
||||
131,tech
|
||||
1046,business
|
||||
698,entertainment
|
||||
751,politics
|
||||
1501,sport
|
||||
73,tech
|
||||
1463,business
|
||||
1856,entertainment
|
||||
1074,politics
|
||||
2225,sport
|
||||
1313,tech
|
||||
933,business
|
||||
1136,entertainment
|
||||
779,politics
|
||||
1315,sport
|
||||
1628,tech
|
||||
432,business
|
||||
546,entertainment
|
||||
1779,politics
|
||||
1514,sport
|
||||
1365,tech
|
||||
1776,business
|
||||
109,entertainment
|
||||
235,politics
|
||||
162,sport
|
||||
2071,tech
|
||||
412,business
|
||||
1092,entertainment
|
||||
1352,politics
|
||||
294,sport
|
||||
1973,tech
|
||||
493,business
|
||||
45,entertainment
|
||||
2189,politics
|
||||
1417,sport
|
||||
2140,tech
|
||||
1265,business
|
||||
1353,entertainment
|
||||
830,politics
|
||||
448,sport
|
||||
535,tech
|
||||
1656,business
|
||||
996,entertainment
|
||||
1246,politics
|
||||
161,sport
|
||||
1267,tech
|
||||
138,business
|
||||
1575,entertainment
|
||||
1810,politics
|
||||
2065,sport
|
||||
768,tech
|
||||
1507,business
|
||||
1520,entertainment
|
||||
71,politics
|
||||
1484,sport
|
||||
163,tech
|
||||
1171,business
|
||||
769,entertainment
|
||||
77,politics
|
||||
1252,sport
|
||||
532,tech
|
||||
1975,business
|
||||
178,entertainment
|
||||
1237,politics
|
||||
1899,sport
|
||||
676,tech
|
||||
2008,business
|
||||
1807,entertainment
|
||||
481,politics
|
||||
17,sport
|
||||
837,tech
|
||||
523,business
|
||||
2099,entertainment
|
||||
1957,politics
|
||||
1470,sport
|
||||
1151,tech
|
||||
489,business
|
||||
823,entertainment
|
||||
1743,politics
|
||||
1035,sport
|
||||
2103,tech
|
||||
1230,business
|
||||
1636,entertainment
|
||||
103,politics
|
||||
927,sport
|
||||
379,tech
|
||||
1789,business
|
||||
278,entertainment
|
||||
887,politics
|
||||
1300,sport
|
||||
394,tech
|
||||
625,business
|
||||
1770,entertainment
|
||||
1959,politics
|
||||
1371,sport
|
||||
1359,tech
|
||||
921,business
|
||||
404,entertainment
|
||||
422,politics
|
||||
503,sport
|
||||
1773,tech
|
||||
2174,business
|
||||
1086,entertainment
|
||||
1986,politics
|
||||
2007,sport
|
||||
13,tech
|
||||
796,business
|
||||
585,entertainment
|
||||
1200,politics
|
||||
642,sport
|
||||
237,tech
|
||||
1826,business
|
||||
43,entertainment
|
||||
1594,politics
|
||||
594,sport
|
||||
1084,tech
|
||||
1859,business
|
||||
524,entertainment
|
||||
463,politics
|
||||
1085,sport
|
||||
668,tech
|
||||
2132,business
|
||||
1970,entertainment
|
||||
260,politics
|
||||
141,sport
|
||||
468,tech
|
||||
269,business
|
||||
414,entertainment
|
||||
2097,politics
|
||||
1420,sport
|
||||
1472,tech
|
||||
1757,business
|
||||
243,entertainment
|
||||
1142,politics
|
||||
1295,sport
|
||||
1660,tech
|
||||
1635,business
|
||||
2200,entertainment
|
||||
2146,politics
|
||||
5,sport
|
||||
1256,tech
|
||||
2090,business
|
||||
1816,entertainment
|
||||
1991,politics
|
||||
1310,sport
|
||||
1162,tech
|
||||
499,business
|
||||
285,entertainment
|
||||
1216,politics
|
||||
745,sport
|
||||
421,tech
|
||||
2068,business
|
||||
1145,entertainment
|
||||
681,politics
|
||||
556,sport
|
||||
1063,tech
|
||||
55,business
|
||||
910,entertainment
|
||||
2066,politics
|
||||
1025,sport
|
||||
2154,tech
|
||||
64,business
|
||||
2218,entertainment
|
||||
2013,politics
|
||||
1105,sport
|
||||
669,tech
|
||||
827,business
|
||||
1335,entertainment
|
||||
1426,politics
|
||||
2196,sport
|
||||
368,tech
|
||||
1535,business
|
||||
1672,entertainment
|
||||
599,politics
|
||||
2085,sport
|
||||
1577,tech
|
||||
724,business
|
||||
568,entertainment
|
||||
2018,politics
|
||||
1543,sport
|
||||
308,tech
|
||||
1356,business
|
||||
573,entertainment
|
||||
1746,politics
|
||||
1290,sport
|
||||
272,tech
|
||||
139,business
|
||||
2191,entertainment
|
||||
1255,politics
|
||||
1699,sport
|
||||
1616,tech
|
||||
1619,business
|
||||
1032,entertainment
|
||||
287,politics
|
||||
2194,sport
|
||||
928,tech
|
||||
584,business
|
||||
1722,entertainment
|
||||
1685,politics
|
||||
1828,sport
|
||||
692,tech
|
||||
1393,business
|
||||
540,entertainment
|
||||
1317,politics
|
||||
152,sport
|
||||
650,tech
|
||||
191,business
|
||||
1719,entertainment
|
||||
1564,politics
|
||||
350,sport
|
||||
116,tech
|
||||
1000,business
|
||||
36,entertainment
|
||||
1225,politics
|
||||
2155,sport
|
||||
553,tech
|
||||
551,business
|
||||
1724,entertainment
|
||||
1512,politics
|
||||
1923,sport
|
||||
373,tech
|
||||
1704,business
|
||||
206,entertainment
|
||||
471,politics
|
|
736
data/BBC_News_Test.csv
Normal file
736
data/BBC_News_Test.csv
Normal file
File diff suppressed because one or more lines are too long
1491
data/BBC_News_Train.csv
Normal file
1491
data/BBC_News_Train.csv
Normal file
File diff suppressed because one or more lines are too long
10
docker-compose.yaml
Normal file
10
docker-compose.yaml
Normal file
@ -0,0 +1,10 @@
|
||||
version: "3.9"
|
||||
services:
|
||||
ayct_core:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: Dockerfile
|
||||
image: ayct_core:latest
|
||||
ports:
|
||||
- "4999:4999"
|
||||
container_name: ayct_core
|
@ -1,6 +0,0 @@
|
||||
[build-system]
|
||||
requires = [
|
||||
"setuptools>=42",
|
||||
"wheel"
|
||||
]
|
||||
build-backend = "setuptools.build_meta"
|
10
readme.md
10
readme.md
@ -1,10 +0,0 @@
|
||||
To build project type:
|
||||
|
||||
```
|
||||
pip3 install virtualenv
|
||||
virtuvalenv venv
|
||||
source venv/bin/activate
|
||||
pip3 install setuptools
|
||||
python3 -m pip install --upgrade build
|
||||
python3 -m build
|
||||
```
|
6
requirements.txt
Normal file
6
requirements.txt
Normal file
@ -0,0 +1,6 @@
|
||||
Flask==2.0.2
|
||||
nltk==3.6.7
|
||||
numpy==1.22.1
|
||||
pandas==1.3.5
|
||||
spacy==3.2.1
|
||||
tqdm==4.62.3
|
24
setup.cfg
24
setup.cfg
@ -1,24 +0,0 @@
|
||||
[metadata]
|
||||
name = pbrAyctCore
|
||||
version = 0.0.1
|
||||
author = Ramon Dyzman
|
||||
author_email = ramon.dyzman@gmail.com
|
||||
description = A small example package
|
||||
long_description = file: readme.md
|
||||
long_description_content_type = text/markdown
|
||||
url = https://git.wmi.amu.edu.pl/s415366/pbr-ayct-core
|
||||
project_urls =
|
||||
Bug Tracker = https://github.com/pypa/sampleproject/issues
|
||||
classifiers =
|
||||
Programming Language :: Python :: 3
|
||||
License :: OSI Approved :: MIT License
|
||||
Operating System :: OS Independent
|
||||
|
||||
[options]
|
||||
package_dir =
|
||||
= src
|
||||
packages = find:
|
||||
python_requires = >=3.6
|
||||
|
||||
[options.packages.find]
|
||||
where = src
|
@ -1,4 +0,0 @@
|
||||
|
||||
|
||||
def getTestString() -> str:
|
||||
return "All You Can Tweet"
|
47
test.py
Normal file
47
test.py
Normal file
@ -0,0 +1,47 @@
|
||||
import pandas as pd
|
||||
from nltk import word_tokenize
|
||||
from utils import Vocabulary
|
||||
from transformers import GPT2Tokenizer, GPT2LMHeadModel
|
||||
from utils import POC
|
||||
from utils import NewsEntry
|
||||
import torch
|
||||
|
||||
# def encode_dataset(data):
|
||||
# return data.apply(lambda x: vocabulary.sentence_to_numbers(x))
|
||||
|
||||
# mode = "train"
|
||||
# mode = "test"
|
||||
mode = "generate"
|
||||
prompt = """
|
||||
Mr Johnson's bakerys is bad. My bakery is good.
|
||||
"""
|
||||
model_name = "model-4.pt"
|
||||
|
||||
train = pd.read_csv("data/BBC_News_Train.csv")
|
||||
test = pd.read_csv("data/BBC_News_Test.csv")
|
||||
length_limit = 1024
|
||||
|
||||
train = train[train["Text"].apply(lambda x: len(word_tokenize(x)) < length_limit)]
|
||||
test = test[test["Text"].apply(lambda x: len(word_tokenize(x)) < length_limit)]
|
||||
|
||||
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
||||
model = GPT2LMHeadModel.from_pretrained('gpt2')
|
||||
|
||||
|
||||
if mode == "train":
|
||||
dataset = NewsEntry(train["Text"], train["Text"], truncate=False, gpt2_type="gpt2")
|
||||
poc = POC()
|
||||
print("Starting training")
|
||||
poc.train(dataset, model, tokenizer, output_dir="./", save_model_on_epoch=True)
|
||||
|
||||
elif mode == "generate":
|
||||
# dataset = NewsEntry(test["Text"], test["Text"], truncate=False, gpt2_type="gpt2")
|
||||
poc = POC()
|
||||
model.load_state_dict(torch.load(model_name))
|
||||
x = poc.generate(model, tokenizer, prompt, entry_count=1, entry_length=300)
|
||||
print(x)
|
||||
|
||||
elif mode == "test":
|
||||
dataset = NewsEntry(test["Text"], test["Text"], truncate=False, gpt2_type="gpt2")
|
||||
poc = POC()
|
||||
model.load_state_dict(torch.load(model_name))
|
206
utils.py
Normal file
206
utils.py
Normal file
@ -0,0 +1,206 @@
|
||||
import spacy
|
||||
import pandas as pd
|
||||
from transformers import GPT2LMHeadModel, GPT2Tokenizer
|
||||
import numpy as np
|
||||
import random
|
||||
import torch
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from transformers import GPT2Tokenizer, GPT2LMHeadModel, AdamW, get_linear_schedule_with_warmup
|
||||
from tqdm import tqdm, trange
|
||||
import torch.nn.functional as F
|
||||
import csv
|
||||
import os
|
||||
|
||||
nlp = spacy.load("en_core_web_sm")
|
||||
|
||||
class NewsEntry(Dataset):
|
||||
def __init__(self, control_code, data, truncate=False, gpt2_type="gpt2", max_length=1024):
|
||||
|
||||
self.tokenizer = GPT2Tokenizer.from_pretrained(gpt2_type)
|
||||
self.data = []
|
||||
|
||||
for row in data:
|
||||
self.data.append(torch.tensor(
|
||||
self.tokenizer.encode(f"<|{control_code}|>{row[:max_length]}<|endoftext|>")
|
||||
))
|
||||
if truncate:
|
||||
self.data = self.data[:20000]
|
||||
self.data_count = len(self.data)
|
||||
|
||||
def __len__(self):
|
||||
return self.data_count
|
||||
|
||||
def __getitem__(self, item):
|
||||
return self.data[item]
|
||||
|
||||
class Vocabulary():
|
||||
def __init__(self) -> None:
|
||||
self.__UKNOWN__ = -1
|
||||
self.vocab = {}
|
||||
|
||||
def create_vocab(self, data):
|
||||
counter = 0
|
||||
vocab = {}
|
||||
for row in data:
|
||||
for word in nlp(row):
|
||||
ex = word.lemma_
|
||||
if ex in vocab:
|
||||
pass
|
||||
else:
|
||||
vocab[ex] = counter
|
||||
counter += 1
|
||||
|
||||
self.vocab = vocab
|
||||
|
||||
def word_to_number(self, word):
|
||||
word = nlp(word)
|
||||
for token in word:
|
||||
ex = token.lemma_
|
||||
if ex in self.vocab:
|
||||
return self.vocab[ex]
|
||||
else:
|
||||
return self.__UKNOWN__
|
||||
|
||||
def sentence_to_numbers(self, seq):
|
||||
result = []
|
||||
for word in nlp(seq):
|
||||
ex = word.lemma_
|
||||
if ex in self.vocab:
|
||||
result.append(self.vocab[ex])
|
||||
else:
|
||||
result.append(self.__UKNOWN__)
|
||||
|
||||
return result
|
||||
|
||||
def sequence_to_numbers(self, seq):
|
||||
return [self.word_to_number(x) for x in seq]
|
||||
|
||||
def pack_tensor(new_tensor, packed_tensor, max_seq_len):
|
||||
if packed_tensor is None:
|
||||
return new_tensor, True, None
|
||||
if new_tensor.size()[1] + packed_tensor.size()[1] > max_seq_len:
|
||||
return packed_tensor, False, new_tensor
|
||||
else:
|
||||
packed_tensor = torch.cat([new_tensor, packed_tensor[:, 1:]], dim=1)
|
||||
return packed_tensor, True, None
|
||||
|
||||
class POC():
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
def train(
|
||||
self, dataset, model, tokenizer,
|
||||
batch_size=16, epochs=5, lr=2e-5,
|
||||
max_seq_len=400, warmup_steps=200,
|
||||
gpt2_type="gpt2", output_dir=".", output_prefix="model",
|
||||
test_mode=False,save_model_on_epoch=False,
|
||||
):
|
||||
acc_steps = 100
|
||||
device=torch.device("cuda")
|
||||
model = model.cuda()
|
||||
model.train()
|
||||
|
||||
optimizer = AdamW(model.parameters(), lr=lr)
|
||||
scheduler = get_linear_schedule_with_warmup(
|
||||
optimizer, num_warmup_steps=warmup_steps, num_training_steps=-1
|
||||
)
|
||||
|
||||
train_dataloader = DataLoader(dataset, batch_size=1, shuffle=True)
|
||||
loss=0
|
||||
accumulating_batch_count = 0
|
||||
input_tensor = None
|
||||
|
||||
for epoch in range(epochs):
|
||||
|
||||
print(f"Training epoch {epoch}")
|
||||
print(loss)
|
||||
for idx, entry in tqdm(enumerate(train_dataloader)):
|
||||
(input_tensor, carry_on, remainder) = pack_tensor(entry, input_tensor, 768)
|
||||
|
||||
if carry_on and idx != len(train_dataloader) - 1:
|
||||
continue
|
||||
|
||||
input_tensor = input_tensor.to(device)
|
||||
outputs = model(input_tensor, labels=input_tensor)
|
||||
loss = outputs[0]
|
||||
loss.backward()
|
||||
|
||||
if (accumulating_batch_count % batch_size) == 0:
|
||||
optimizer.step()
|
||||
scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
model.zero_grad()
|
||||
|
||||
accumulating_batch_count += 1
|
||||
input_tensor = None
|
||||
if save_model_on_epoch:
|
||||
torch.save(
|
||||
model.state_dict(),
|
||||
os.path.join(output_dir, f"{output_prefix}-{epoch}.pt"),
|
||||
)
|
||||
return model
|
||||
|
||||
def generate(
|
||||
self,
|
||||
model,
|
||||
tokenizer,
|
||||
prompt,
|
||||
entry_count=10,
|
||||
entry_length=30, #maximum number of words
|
||||
top_p=0.8,
|
||||
temperature=1.,
|
||||
):
|
||||
model.eval()
|
||||
generated_num = 0
|
||||
generated_list = []
|
||||
|
||||
filter_value = -float("Inf")
|
||||
|
||||
with torch.no_grad():
|
||||
|
||||
for entry_idx in trange(entry_count):
|
||||
|
||||
entry_finished = False
|
||||
generated = torch.tensor(tokenizer.encode(prompt)).unsqueeze(0)
|
||||
|
||||
if entry_idx % 100 == 0:
|
||||
print(entry_idx)
|
||||
|
||||
for i in range(entry_length):
|
||||
outputs = model(generated, labels=generated)
|
||||
loss, logits = outputs[:2]
|
||||
logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0)
|
||||
|
||||
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
||||
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
||||
|
||||
sorted_indices_to_remove = cumulative_probs > top_p
|
||||
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[
|
||||
..., :-1
|
||||
].clone()
|
||||
sorted_indices_to_remove[..., 0] = 0
|
||||
|
||||
indices_to_remove = sorted_indices[sorted_indices_to_remove]
|
||||
logits[:, indices_to_remove] = filter_value
|
||||
|
||||
next_token = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)
|
||||
generated = torch.cat((generated, next_token), dim=1)
|
||||
|
||||
if next_token in tokenizer.encode("<|endoftext|>"):
|
||||
entry_finished = True
|
||||
|
||||
if entry_finished:
|
||||
|
||||
generated_num = generated_num + 1
|
||||
|
||||
output_list = list(generated.squeeze().numpy())
|
||||
output_text = tokenizer.decode(output_list)
|
||||
generated_list.append(output_text)
|
||||
break
|
||||
|
||||
if not entry_finished:
|
||||
output_list = list(generated.squeeze().numpy())
|
||||
output_text = f"{tokenizer.decode(output_list)}<|endoftext|>"
|
||||
generated_list.append(output_text)
|
||||
|
||||
return generated_list
|
Loading…
Reference in New Issue
Block a user