challenging-america-word-ga.../gpt-2 finetune.ipynb
2023-06-29 18:36:47 +02:00

698 lines
22 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

{
"cells": [
{
"cell_type": "code",
"execution_count": 35,
"id": "df9bffa6",
"metadata": {},
"outputs": [],
"source": [
"test_file = 'dev-0/in.tsv.xz'\n",
"out_file = 'dev-0/out.tsv'"
]
},
{
"cell_type": "code",
"execution_count": 36,
"id": "eb2fcfa4",
"metadata": {},
"outputs": [],
"source": [
"import torch.nn.functional as F\n",
"import torch\n",
"import lzma"
]
},
{
"cell_type": "code",
"execution_count": 358,
"id": "a03f19ae",
"metadata": {},
"outputs": [],
"source": [
"###preprocessing\n",
"def preprocess(line):\n",
" line = get_rid_of_header(line)\n",
" line = replace_endline(line)\n",
" return line\n",
"\n",
"def get_rid_of_header(line):\n",
" line = line.split('\\t')[6:]\n",
" return \" \".join(line)\n",
" \n",
"def replace_endline(line):\n",
" line = line.replace(\"\\\\n\", \" \")\n",
" return line\n",
"\n",
"def get_first_word(text):\n",
" \"\"\"Return the first word of a string.\"\"\"\n",
" word = \"\"\n",
" for i in range(len(text)-1):\n",
"# if text[i] in [' ', ',', '.']\n",
" if text[i] == ' ':\n",
" return word.rstrip()\n",
" else:\n",
" word += text[i]\n",
" return word.rstrip()\n",
"\n",
"\n",
"\n",
"def summarize_probs_unk(dic, const_wildcard=True, scale_probs=False, wildcard_minweight=0.01):\n",
" ''' \n",
" dic: dictionary of probabilities returned by model \n",
" returns: tab of probabilities, with <unk> specificly as last element\n",
" '''\n",
" if not scale_probs:\n",
" if '' in dic.keys():\n",
" del dic['']\n",
" tab = [(key, val) for key, val in dic.items()]\n",
" tab.append(('', 1-sum([val for val in dic.values()])))\n",
" elif const_wildcard and scale_probs: #\n",
" if '' in dic.keys():\n",
" del dic['']\n",
" probsum = sum(float(val) for key, val in dic.items())\n",
" for key in dic:\n",
" dic[key] = dic[key]/(probsum*(1+wildcard_minweight)) \n",
" tab = [(key, val) for key, val in dic.items()]\n",
" tab.append(('', 1-sum([val for val in dic.values()])))\n",
"# else if '' not in dic.keys(): #no wildcard entry \n",
" else:\n",
" if '' not in dic.keys():\n",
" wildcard_value = wildcard_minweight\n",
" else:\n",
" wildcard_value = dic['']\n",
" del dic['']\n",
" for key in dic:\n",
" dic[key] = dic[key]/(1-wildcard_value) ###leave some space for wildcar\n",
" tab = [(key, val) for key, val in dic.items()]\n",
" tab.append(('', 1-sum([val for val in dic.values()])))\n",
" \n",
" return tab\n",
"\n",
"\n",
"def gonito_format(tab):\n",
"# tab = summarize_probs_unk(dic, const_wildcard=const_wildcard, wildcard_minweight=wildcard_minweight)\n",
" result = ''\n",
" for element in tab[:-1]:\n",
" result+=str(element[0])+':'+str(element[1])+'\\t'\n",
" result+=':'+ str(tab[-1][1])+'\\n'\n",
" return result\n"
]
},
{
"cell_type": "code",
"execution_count": 207,
"id": "20cd7089",
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Defaulting to user installation because normal site-packages is not writeable\n",
"Requirement already satisfied: transformers in /home/gedin/.local/lib/python3.10/site-packages (4.30.2)\n",
"Requirement already satisfied: requests in /usr/lib/python3/dist-packages (from transformers) (2.25.1)\n",
"Requirement already satisfied: tokenizers!=0.11.3,<0.14,>=0.11.1 in /home/gedin/.local/lib/python3.10/site-packages (from transformers) (0.13.3)\n",
"Requirement already satisfied: numpy>=1.17 in /home/gedin/.local/lib/python3.10/site-packages (from transformers) (1.24.3)\n",
"Requirement already satisfied: pyyaml>=5.1 in /usr/lib/python3/dist-packages (from transformers) (5.4.1)\n",
"Requirement already satisfied: tqdm>=4.27 in /home/gedin/.local/lib/python3.10/site-packages (from transformers) (4.65.0)\n",
"Requirement already satisfied: safetensors>=0.3.1 in /home/gedin/.local/lib/python3.10/site-packages (from transformers) (0.3.1)\n",
"Requirement already satisfied: filelock in /home/gedin/.local/lib/python3.10/site-packages (from transformers) (3.12.0)\n",
"Requirement already satisfied: huggingface-hub<1.0,>=0.14.1 in /home/gedin/.local/lib/python3.10/site-packages (from transformers) (0.15.1)\n",
"Requirement already satisfied: regex!=2019.12.17 in /home/gedin/.local/lib/python3.10/site-packages (from transformers) (2023.5.5)\n",
"Requirement already satisfied: packaging>=20.0 in /home/gedin/.local/lib/python3.10/site-packages (from transformers) (23.1)\n",
"Requirement already satisfied: fsspec in /home/gedin/.local/lib/python3.10/site-packages (from huggingface-hub<1.0,>=0.14.1->transformers) (2023.6.0)\n",
"Requirement already satisfied: typing-extensions>=3.7.4.3 in /home/gedin/.local/lib/python3.10/site-packages (from huggingface-hub<1.0,>=0.14.1->transformers) (4.6.3)\n"
]
}
],
"source": [
"!pip install transformers\n",
"from transformers import pipeline, set_seed, AutoTokenizer, AutoModel, AutoModelForCausalLM, GPT2Tokenizer"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "ff4b3b42",
"metadata": {
"collapsed": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Defaulting to user installation because normal site-packages is not writeable\n",
"Requirement already satisfied: transformers in /home/gedin/.local/lib/python3.10/site-packages (4.30.2)\n",
"Requirement already satisfied: tqdm>=4.27 in /home/gedin/.local/lib/python3.10/site-packages (from transformers) (4.65.0)\n",
"Requirement already satisfied: regex!=2019.12.17 in /home/gedin/.local/lib/python3.10/site-packages (from transformers) (2023.5.5)\n",
"Requirement already satisfied: pyyaml>=5.1 in /usr/lib/python3/dist-packages (from transformers) (5.4.1)\n",
"Requirement already satisfied: requests in /usr/lib/python3/dist-packages (from transformers) (2.25.1)\n",
"Requirement already satisfied: packaging>=20.0 in /home/gedin/.local/lib/python3.10/site-packages (from transformers) (23.1)\n",
"Requirement already satisfied: numpy>=1.17 in /home/gedin/.local/lib/python3.10/site-packages (from transformers) (1.24.3)\n",
"Requirement already satisfied: huggingface-hub<1.0,>=0.14.1 in /home/gedin/.local/lib/python3.10/site-packages (from transformers) (0.15.1)\n",
"Requirement already satisfied: tokenizers!=0.11.3,<0.14,>=0.11.1 in /home/gedin/.local/lib/python3.10/site-packages (from transformers) (0.13.3)\n",
"Requirement already satisfied: safetensors>=0.3.1 in /home/gedin/.local/lib/python3.10/site-packages (from transformers) (0.3.1)\n",
"Requirement already satisfied: filelock in /home/gedin/.local/lib/python3.10/site-packages (from transformers) (3.12.0)\n",
"Requirement already satisfied: typing-extensions>=3.7.4.3 in /home/gedin/.local/lib/python3.10/site-packages (from huggingface-hub<1.0,>=0.14.1->transformers) (4.6.3)\n",
"Requirement already satisfied: fsspec in /home/gedin/.local/lib/python3.10/site-packages (from huggingface-hub<1.0,>=0.14.1->transformers) (2023.6.0)\n"
]
}
],
"source": []
},
{
"cell_type": "code",
"execution_count": 39,
"id": "9e28e9f8",
"metadata": {},
"outputs": [],
"source": [
"model_name = \"gpt2\"\n",
"tokenizer = GPT2Tokenizer.from_pretrained('gpt2')\n",
"pt_model = AutoModelForCausalLM.from_pretrained(model_name)"
]
},
{
"cell_type": "code",
"execution_count": 40,
"id": "1efb36b9",
"metadata": {},
"outputs": [],
"source": [
"sentence = 'Northeasterly hv the head of said .^corns and and the'"
]
},
{
"cell_type": "code",
"execution_count": 41,
"id": "2e8b8b60",
"metadata": {},
"outputs": [],
"source": [
"encoding = tokenizer(sentence, return_tensors='pt')\n",
"output = pt_model(**encoding)"
]
},
{
"cell_type": "code",
"execution_count": 135,
"id": "94f321d5",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/tmp/ipykernel_9135/861569571.py:1: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.\n",
" probs = F.softmax(output.logits[0][-1])\n"
]
}
],
"source": [
"probs = F.softmax(output.logits[0][-1])"
]
},
{
"cell_type": "code",
"execution_count": 43,
"id": "94986365",
"metadata": {},
"outputs": [],
"source": [
"top = torch.topk(probs, 10)\n",
"top_indices = top.indices.tolist()\n",
"top_probs = top.values.tolist()\n",
"top_words = [tokenizer.decode(idx).strip() for idx in top_indices] "
]
},
{
"cell_type": "code",
"execution_count": 47,
"id": "622c26f2",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'\\n'"
]
},
"execution_count": 47,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tokenizer.decode(198)"
]
},
{
"cell_type": "code",
"execution_count": 49,
"id": "fd4beaa3",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['head', 'heads', '.', 'same', 'body', 'neck', 'other', 'whole', 'said', 'king']\n"
]
}
],
"source": [
"# print(top_indices)\n",
"print((top_words))"
]
},
{
"cell_type": "code",
"execution_count": 50,
"id": "ec8bdc77",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[1182, 6665, 764, 976, 1767, 7393, 584, 2187, 531, 5822]\n"
]
}
],
"source": [
"print(top_indices)"
]
},
{
"cell_type": "code",
"execution_count": 136,
"id": "578257eb",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0.18150091171264648, 0.011893990449607372, 0.011805753223598003, 0.011544686742126942, 0.007725409232079983]\n"
]
}
],
"source": [
"print(top_probs)"
]
},
{
"cell_type": "code",
"execution_count": 63,
"id": "939e12dc",
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[('head', 0.18150091171264648), ('heads', 0.011893990449607372), ('.', 0.011805753223598003), ('same', 0.011544686742126942), ('body', 0.007725409232079983), ('neck', 0.007723228540271521), ('other', 0.006957209203392267), ('whole', 0.006453146692365408), ('said', 0.004757815971970558), ('king', 0.004543370567262173)]\n"
]
}
],
"source": [
"print(list(zip(top_words, top_probs)))"
]
},
{
"cell_type": "code",
"execution_count": 66,
"id": "c0d32e0d",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[('head', 0.7049117686793325), ('heads', 0.04619378363102533), ('.', 0.045851088608379394), ('same', 0.044837160725262136), ('body', 0.030003881711508234), ('neck', 0.02999541235835104), ('other', 0.027020352671284463), ('whole', 0.02506267877962141), ('said', 0.018478367079292513), ('king', 0.01764550575594297), ('', 0.01000000000000012)] 1.0\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/tmp/ipykernel_9135/698839240.py:4: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.\n",
" probs = F.softmax(output.logits[0][-1]) #get the model prediction for the entire sentence ([-1]) no batching ([0])\n"
]
}
],
"source": [
"#################Does probabilities from get_values_from_model sum up to 1##################\n",
"asdf = dict(get_values_from_model(sentence, model=pt_model, tokenizer=tokenizer, k=10))\n",
"asdf = summarize_probs_unk(asdf)\n",
"print([x for x in asdf] ,sum([x[1] for x in asdf]))"
]
},
{
"cell_type": "code",
"execution_count": 152,
"id": "11e4e965",
"metadata": {},
"outputs": [],
"source": [
"def get_values_from_model(context: str, model, tokenizer, k=10):\n",
" encoding = tokenizer(context, return_tensors='pt')\n",
" output = model(**encoding)\n",
" probs = F.softmax(output.logits[0][-1], dim=-1) #get the model prediction for the entire sentence ([-1]) no batching ([0])\n",
" top = torch.topk(probs, k)\n",
" top_probs = top.values.tolist()\n",
" top_indices =top.indices.tolist()\n",
" top_words = [tokenizer.decode(idx).strip() for idx in top_indices] \n",
"# print(context, \"probs: \\n\", list(zip(top_words, top_indices, top_probs)))\n",
" return list(zip(top_words, top_probs))"
]
},
{
"cell_type": "code",
"execution_count": 398,
"id": "17884d9b",
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"with lzma.open(test_file, 'rt') as file:\n",
" left_contexts = []\n",
" results = []\n",
" for line in file:\n",
" line = replace_endline(line) #get only relevant\n",
" line = line.split('\\t')[6:]\n",
" context = ' '.join(line[0].rstrip().split(\" \")[-10:])\n",
"# context = context + ' '\n",
"# print(context)\n",
" left_contexts.append(context)\n",
" ###get results from gpt model###\n",
" for left_context in left_contexts:\n",
" results.append(dict(get_values_from_model(left_context, model=pt_model, tokenizer=tokenizer, k=10)))\n",
" with open(out_file, 'w') as outfile:\n",
" for elem in results:\n",
" tab = summarize_probs_unk(elem, const_wildcard=False, scale_probs=True, wildcard_minweight=0.01)\n",
" outfile.write(gonito_format(tab))\n"
]
},
{
"cell_type": "code",
"execution_count": 95,
"id": "b8846e25",
"metadata": {},
"outputs": [],
"source": [
"sentence = 'Northeasterly hv the head of said .^corns and and the'"
]
},
{
"cell_type": "code",
"execution_count": 148,
"id": "37cdd8c5",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[('head', 0.18150091171264648), ('heads', 0.011893990449607372), ('.', 0.011805753223598003), ('same', 0.011544686742126942), ('body', 0.007725409232079983)]\n"
]
}
],
"source": [
"encoding = tokenizer(sentence, return_tensors='pt')\n",
"output = pt_model(**encoding)\n",
"probs = F.softmax(output.logits[0][-1], dim = -1) #get the model prediction for the entire sentence ([-1]) no batching ([0])\n",
"top = torch.topk(probs, 5)\n",
"top_probs = top.values.tolist()\n",
"top_indices =top.indices.tolist()\n",
"top_words = [tokenizer.decode(idx).strip() for idx in top_indices] \n",
"print(list(zip(top_words, top_probs)))"
]
},
{
"cell_type": "code",
"execution_count": 105,
"id": "53c8c83c",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([1, 20, 50257])"
]
},
"execution_count": 105,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"output.logits.shape"
]
},
{
"cell_type": "code",
"execution_count": 93,
"id": "c6bf8faf",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[('head', 0.18150091171264648), ('heads', 0.011893990449607372), ('.', 0.011805753223598003), ('same', 0.011544686742126942), ('body', 0.007725409232079983)]\n"
]
}
],
"source": [
"print(list(zip(top_words, top_probs)))"
]
},
{
"cell_type": "code",
"execution_count": 153,
"id": "e03cc143",
"metadata": {},
"outputs": [],
"source": [
"with lzma.open(test_file, 'rt') as file:\n",
" left_contexts = []\n",
" right_contexts = []\n",
" results = []\n",
" i=0\n",
" for line in file:\n",
" if i >20:\n",
" break\n",
" line = replace_endline(line) #get only relevant\n",
" line = line.split('\\t')[6:]\n",
" l_context = \" \".join(line[0].rstrip().split(\" \")[-10:])\n",
" r_context = \" \".join(line[1].rstrip().split(\" \")[:5])\n",
" left_contexts.append(l_context)\n",
" right_contexts.append(r_context)\n",
" i+=1;"
]
},
{
"cell_type": "code",
"execution_count": 154,
"id": "20113d00",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['id Seorgc Acorns laud in said Aina and riinninu from',\n",
" 'true that the Republican par- ty is National in its',\n",
" 'but here¬ abouts 1 have not seen or heard ot',\n",
" 'They will not bend or break like the single spiing.',\n",
" '208 miles long. J. 8 . Con way the Com-',\n",
" 'They felt that He was the long-looked- for Messiah with',\n",
" 'in the state that could produce just as good fruit',\n",
" 'after house and managed to find shelter in the lazarette,',\n",
" 'hundred buildings put up in the place this season. Mechanics',\n",
" 'de\\xad testation. I no longer loved him, and I felt',\n",
" 'ol the gentleman, by dis. cussing the constitutionality of these',\n",
" 'the purest pa- triotism, and tin* most indent devotion to',\n",
" 'now to i high, will come down, and those too',\n",
" 'the \\'extension of MdKlnleJ\" street, and running back to the',\n",
" 'sir, it is a subject ot peculiar delight to me',\n",
" 'to affect a miserable scramble for offices in tiiis country.',\n",
" 'Marj, the eldest, said lo the others: \"Lit us pray',\n",
" 'not enquire—But what sum lie asked, could tempt men well',\n",
" \"hete are 500, or 550 acres ol cleared land, '250\",\n",
" 'petition h is lieeii |ui tty elos ly scru;iiii/e<| and',\n",
" 'on the high reputation of his record as Chief Mag-']"
]
},
"execution_count": 154,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"left_contexts"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f1d16811",
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"results = []\n",
"for left_context in left_contexts:\n",
" results.append(dict(get_values_from_model(left_context, model=pt_model, tokenizer=tokenizer, k=10)))\n",
"for idx, elem in enumerate(results):\n",
" tab = summarize_probs_unk(elem, const_wildcard=True, scale_probs=True, wildcard_minweight=0.01)\n",
" print(idx, \" \", gonito_format(tab))"
]
},
{
"cell_type": "code",
"execution_count": 399,
"id": "df38d415",
"metadata": {},
"outputs": [],
"source": [
"a =dict(get_values_from_model(left_contexts[4], model=pt_model, tokenizer=tokenizer, k=10))"
]
},
{
"cell_type": "code",
"execution_count": 400,
"id": "d49824d4",
"metadata": {
"scrolled": true
},
"outputs": [
{
"data": {
"text/plain": [
"{'p': 0.2211313545703888,\n",
" 'm': 0.12113624811172485,\n",
" 'mun': 0.07101781666278839,\n",
" 'mon': 0.05452524498105049,\n",
" 'mission': 0.05246562883257866,\n",
" 'mand': 0.03528319671750069,\n",
" 'pan': 0.02023007906973362,\n",
" 'mer': 0.017731018364429474,\n",
" 'pl': 0.015618585981428623,\n",
" 'ple': 0.01504151988774538}"
]
},
"execution_count": 400,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"a"
]
},
{
"cell_type": "code",
"execution_count": 401,
"id": "149c090c",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.624180693179369\n"
]
}
],
"source": [
"print(sum(float(y) for x, y in a.items()))"
]
},
{
"cell_type": "code",
"execution_count": 402,
"id": "121595c5",
"metadata": {},
"outputs": [],
"source": [
"b = summarize_probs_unk(a, const_wildcard=True, scale_probs=True, wildcard_minweight=0.1)"
]
},
{
"cell_type": "code",
"execution_count": 403,
"id": "3acb48b9",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"p:0.3220678024033818\tm:0.1764294588459882\tmun:0.1034342334152795\tmon:0.07941354974589608\tmission:0.07641381211021994\tmand:0.05138837796498232\tpan:0.02946419390001886\tmer:0.02582442517073288\tpl:0.022747763081629364\tple:0.021907292452780093\t:0.09090909090909094\n",
"\n"
]
}
],
"source": [
"print(gonito_format(b))"
]
},
{
"cell_type": "code",
"execution_count": 259,
"id": "3c44cdbe",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1.0\n"
]
}
],
"source": [
"print(sum(float(y) for x, y in b))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f09ce313",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}