challenging-america-word-ga.../gpt-2 finetune.ipynb

698 lines
22 KiB
Raw Normal View History

2023-06-29 18:36:47 +02:00
"cells": [
"cell_type": "code",
"execution_count": 35,
"id": "df9bffa6",
"metadata": {},
"outputs": [],
"source": [
"test_file = 'dev-0/in.tsv.xz'\n",
"out_file = 'dev-0/out.tsv'"
"cell_type": "code",
"execution_count": 36,
"id": "eb2fcfa4",
"metadata": {},
"outputs": [],
"source": [
"import torch.nn.functional as F\n",
"import torch\n",
"import lzma"
"cell_type": "code",
"execution_count": 358,
"id": "a03f19ae",
"metadata": {},
"outputs": [],
"source": [
"def preprocess(line):\n",
" line = get_rid_of_header(line)\n",
" line = replace_endline(line)\n",
" return line\n",
"def get_rid_of_header(line):\n",
" line = line.split('\\t')[6:]\n",
" return \" \".join(line)\n",
" \n",
"def replace_endline(line):\n",
" line = line.replace(\"\\\\n\", \" \")\n",
" return line\n",
"def get_first_word(text):\n",
" \"\"\"Return the first word of a string.\"\"\"\n",
" word = \"\"\n",
" for i in range(len(text)-1):\n",
"# if text[i] in [' ', ',', '.']\n",
" if text[i] == ' ':\n",
" return word.rstrip()\n",
" else:\n",
" word += text[i]\n",
" return word.rstrip()\n",
"def summarize_probs_unk(dic, const_wildcard=True, scale_probs=False, wildcard_minweight=0.01):\n",
" ''' \n",
" dic: dictionary of probabilities returned by model \n",
" returns: tab of probabilities, with <unk> specificly as last element\n",
" '''\n",
" if not scale_probs:\n",
" if '' in dic.keys():\n",
" del dic['']\n",
" tab = [(key, val) for key, val in dic.items()]\n",
" tab.append(('', 1-sum([val for val in dic.values()])))\n",
" elif const_wildcard and scale_probs: #\n",
" if '' in dic.keys():\n",
" del dic['']\n",
" probsum = sum(float(val) for key, val in dic.items())\n",
" for key in dic:\n",
" dic[key] = dic[key]/(probsum*(1+wildcard_minweight)) \n",
" tab = [(key, val) for key, val in dic.items()]\n",
" tab.append(('', 1-sum([val for val in dic.values()])))\n",
"# else if '' not in dic.keys(): #no wildcard entry \n",
" else:\n",
" if '' not in dic.keys():\n",
" wildcard_value = wildcard_minweight\n",
" else:\n",
" wildcard_value = dic['']\n",
" del dic['']\n",
" for key in dic:\n",
" dic[key] = dic[key]/(1-wildcard_value) ###leave some space for wildcar\n",
" tab = [(key, val) for key, val in dic.items()]\n",
" tab.append(('', 1-sum([val for val in dic.values()])))\n",
" \n",
" return tab\n",
"def gonito_format(tab):\n",
"# tab = summarize_probs_unk(dic, const_wildcard=const_wildcard, wildcard_minweight=wildcard_minweight)\n",
" result = ''\n",
" for element in tab[:-1]:\n",
" result+=str(element[0])+':'+str(element[1])+'\\t'\n",
" result+=':'+ str(tab[-1][1])+'\\n'\n",
" return result\n"
"cell_type": "code",
"execution_count": 207,
"id": "20cd7089",
"metadata": {
"scrolled": true
"outputs": [
"name": "stdout",
"output_type": "stream",
"text": [
"Defaulting to user installation because normal site-packages is not writeable\n",
"Requirement already satisfied: transformers in /home/gedin/.local/lib/python3.10/site-packages (4.30.2)\n",
"Requirement already satisfied: requests in /usr/lib/python3/dist-packages (from transformers) (2.25.1)\n",
"Requirement already satisfied: tokenizers!=0.11.3,<0.14,>=0.11.1 in /home/gedin/.local/lib/python3.10/site-packages (from transformers) (0.13.3)\n",
"Requirement already satisfied: numpy>=1.17 in /home/gedin/.local/lib/python3.10/site-packages (from transformers) (1.24.3)\n",
"Requirement already satisfied: pyyaml>=5.1 in /usr/lib/python3/dist-packages (from transformers) (5.4.1)\n",
"Requirement already satisfied: tqdm>=4.27 in /home/gedin/.local/lib/python3.10/site-packages (from transformers) (4.65.0)\n",
"Requirement already satisfied: safetensors>=0.3.1 in /home/gedin/.local/lib/python3.10/site-packages (from transformers) (0.3.1)\n",
"Requirement already satisfied: filelock in /home/gedin/.local/lib/python3.10/site-packages (from transformers) (3.12.0)\n",
"Requirement already satisfied: huggingface-hub<1.0,>=0.14.1 in /home/gedin/.local/lib/python3.10/site-packages (from transformers) (0.15.1)\n",
"Requirement already satisfied: regex!=2019.12.17 in /home/gedin/.local/lib/python3.10/site-packages (from transformers) (2023.5.5)\n",
"Requirement already satisfied: packaging>=20.0 in /home/gedin/.local/lib/python3.10/site-packages (from transformers) (23.1)\n",
"Requirement already satisfied: fsspec in /home/gedin/.local/lib/python3.10/site-packages (from huggingface-hub<1.0,>=0.14.1->transformers) (2023.6.0)\n",
"Requirement already satisfied: typing-extensions>= in /home/gedin/.local/lib/python3.10/site-packages (from huggingface-hub<1.0,>=0.14.1->transformers) (4.6.3)\n"
"source": [
"!pip install transformers\n",
"from transformers import pipeline, set_seed, AutoTokenizer, AutoModel, AutoModelForCausalLM, GPT2Tokenizer"
"cell_type": "code",
"execution_count": 3,
"id": "ff4b3b42",
"metadata": {
"collapsed": true
"outputs": [
"name": "stdout",
"output_type": "stream",
"text": [
"Defaulting to user installation because normal site-packages is not writeable\n",
"Requirement already satisfied: transformers in /home/gedin/.local/lib/python3.10/site-packages (4.30.2)\n",
"Requirement already satisfied: tqdm>=4.27 in /home/gedin/.local/lib/python3.10/site-packages (from transformers) (4.65.0)\n",
"Requirement already satisfied: regex!=2019.12.17 in /home/gedin/.local/lib/python3.10/site-packages (from transformers) (2023.5.5)\n",
"Requirement already satisfied: pyyaml>=5.1 in /usr/lib/python3/dist-packages (from transformers) (5.4.1)\n",
"Requirement already satisfied: requests in /usr/lib/python3/dist-packages (from transformers) (2.25.1)\n",
"Requirement already satisfied: packaging>=20.0 in /home/gedin/.local/lib/python3.10/site-packages (from transformers) (23.1)\n",
"Requirement already satisfied: numpy>=1.17 in /home/gedin/.local/lib/python3.10/site-packages (from transformers) (1.24.3)\n",
"Requirement already satisfied: huggingface-hub<1.0,>=0.14.1 in /home/gedin/.local/lib/python3.10/site-packages (from transformers) (0.15.1)\n",
"Requirement already satisfied: tokenizers!=0.11.3,<0.14,>=0.11.1 in /home/gedin/.local/lib/python3.10/site-packages (from transformers) (0.13.3)\n",
"Requirement already satisfied: safetensors>=0.3.1 in /home/gedin/.local/lib/python3.10/site-packages (from transformers) (0.3.1)\n",
"Requirement already satisfied: filelock in /home/gedin/.local/lib/python3.10/site-packages (from transformers) (3.12.0)\n",
"Requirement already satisfied: typing-extensions>= in /home/gedin/.local/lib/python3.10/site-packages (from huggingface-hub<1.0,>=0.14.1->transformers) (4.6.3)\n",
"Requirement already satisfied: fsspec in /home/gedin/.local/lib/python3.10/site-packages (from huggingface-hub<1.0,>=0.14.1->transformers) (2023.6.0)\n"
"source": []
"cell_type": "code",
"execution_count": 39,
"id": "9e28e9f8",
"metadata": {},
"outputs": [],
"source": [
"model_name = \"gpt2\"\n",
"tokenizer = GPT2Tokenizer.from_pretrained('gpt2')\n",
"pt_model = AutoModelForCausalLM.from_pretrained(model_name)"
"cell_type": "code",
"execution_count": 40,
"id": "1efb36b9",
"metadata": {},
"outputs": [],
"source": [
"sentence = 'Northeasterly hv the head of said .^corns and and the'"
"cell_type": "code",
"execution_count": 41,
"id": "2e8b8b60",
"metadata": {},
"outputs": [],
"source": [
"encoding = tokenizer(sentence, return_tensors='pt')\n",
"output = pt_model(**encoding)"
"cell_type": "code",
"execution_count": 135,
"id": "94f321d5",
"metadata": {},
"outputs": [
"name": "stderr",
"output_type": "stream",
"text": [
"/tmp/ipykernel_9135/ UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.\n",
" probs = F.softmax(output.logits[0][-1])\n"
"source": [
"probs = F.softmax(output.logits[0][-1])"
"cell_type": "code",
"execution_count": 43,
"id": "94986365",
"metadata": {},
"outputs": [],
"source": [
"top = torch.topk(probs, 10)\n",
"top_indices = top.indices.tolist()\n",
"top_probs = top.values.tolist()\n",
"top_words = [tokenizer.decode(idx).strip() for idx in top_indices] "
"cell_type": "code",
"execution_count": 47,
"id": "622c26f2",
"metadata": {},
"outputs": [
"data": {
"text/plain": [
"execution_count": 47,
"metadata": {},
"output_type": "execute_result"
"source": [
"cell_type": "code",
"execution_count": 49,
"id": "fd4beaa3",
"metadata": {},
"outputs": [
"name": "stdout",
"output_type": "stream",
"text": [
"['head', 'heads', '.', 'same', 'body', 'neck', 'other', 'whole', 'said', 'king']\n"
"source": [
"# print(top_indices)\n",
"cell_type": "code",
"execution_count": 50,
"id": "ec8bdc77",
"metadata": {},
"outputs": [
"name": "stdout",
"output_type": "stream",
"text": [
"[1182, 6665, 764, 976, 1767, 7393, 584, 2187, 531, 5822]\n"
"source": [
"cell_type": "code",
"execution_count": 136,
"id": "578257eb",
"metadata": {},
"outputs": [
"name": "stdout",
"output_type": "stream",
"text": [
"[0.18150091171264648, 0.011893990449607372, 0.011805753223598003, 0.011544686742126942, 0.007725409232079983]\n"
"source": [
"cell_type": "code",
"execution_count": 63,
"id": "939e12dc",
"metadata": {
"scrolled": true
"outputs": [
"name": "stdout",
"output_type": "stream",
"text": [
"[('head', 0.18150091171264648), ('heads', 0.011893990449607372), ('.', 0.011805753223598003), ('same', 0.011544686742126942), ('body', 0.007725409232079983), ('neck', 0.007723228540271521), ('other', 0.006957209203392267), ('whole', 0.006453146692365408), ('said', 0.004757815971970558), ('king', 0.004543370567262173)]\n"
"source": [
"print(list(zip(top_words, top_probs)))"
"cell_type": "code",
"execution_count": 66,
"id": "c0d32e0d",
"metadata": {},
"outputs": [
"name": "stdout",
"output_type": "stream",
"text": [
"[('head', 0.7049117686793325), ('heads', 0.04619378363102533), ('.', 0.045851088608379394), ('same', 0.044837160725262136), ('body', 0.030003881711508234), ('neck', 0.02999541235835104), ('other', 0.027020352671284463), ('whole', 0.02506267877962141), ('said', 0.018478367079292513), ('king', 0.01764550575594297), ('', 0.01000000000000012)] 1.0\n"
"name": "stderr",
"output_type": "stream",
"text": [
"/tmp/ipykernel_9135/ UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.\n",
" probs = F.softmax(output.logits[0][-1]) #get the model prediction for the entire sentence ([-1]) no batching ([0])\n"
"source": [
"#################Does probabilities from get_values_from_model sum up to 1##################\n",
"asdf = dict(get_values_from_model(sentence, model=pt_model, tokenizer=tokenizer, k=10))\n",
"asdf = summarize_probs_unk(asdf)\n",
"print([x for x in asdf] ,sum([x[1] for x in asdf]))"
"cell_type": "code",
"execution_count": 152,
"id": "11e4e965",
"metadata": {},
"outputs": [],
"source": [
"def get_values_from_model(context: str, model, tokenizer, k=10):\n",
" encoding = tokenizer(context, return_tensors='pt')\n",
" output = model(**encoding)\n",
" probs = F.softmax(output.logits[0][-1], dim=-1) #get the model prediction for the entire sentence ([-1]) no batching ([0])\n",
" top = torch.topk(probs, k)\n",
" top_probs = top.values.tolist()\n",
" top_indices =top.indices.tolist()\n",
" top_words = [tokenizer.decode(idx).strip() for idx in top_indices] \n",
"# print(context, \"probs: \\n\", list(zip(top_words, top_indices, top_probs)))\n",
" return list(zip(top_words, top_probs))"
"cell_type": "code",
"execution_count": 398,
"id": "17884d9b",
"metadata": {
"scrolled": true
"outputs": [],
"source": [
"with, 'rt') as file:\n",
" left_contexts = []\n",
" results = []\n",
" for line in file:\n",
" line = replace_endline(line) #get only relevant\n",
" line = line.split('\\t')[6:]\n",
" context = ' '.join(line[0].rstrip().split(\" \")[-10:])\n",
"# context = context + ' '\n",
"# print(context)\n",
" left_contexts.append(context)\n",
" ###get results from gpt model###\n",
" for left_context in left_contexts:\n",
" results.append(dict(get_values_from_model(left_context, model=pt_model, tokenizer=tokenizer, k=10)))\n",
" with open(out_file, 'w') as outfile:\n",
" for elem in results:\n",
" tab = summarize_probs_unk(elem, const_wildcard=False, scale_probs=True, wildcard_minweight=0.01)\n",
" outfile.write(gonito_format(tab))\n"
"cell_type": "code",
"execution_count": 95,
"id": "b8846e25",
"metadata": {},
"outputs": [],
"source": [
"sentence = 'Northeasterly hv the head of said .^corns and and the'"
"cell_type": "code",
"execution_count": 148,
"id": "37cdd8c5",
"metadata": {},
"outputs": [
"name": "stdout",
"output_type": "stream",
"text": [
"[('head', 0.18150091171264648), ('heads', 0.011893990449607372), ('.', 0.011805753223598003), ('same', 0.011544686742126942), ('body', 0.007725409232079983)]\n"
"source": [
"encoding = tokenizer(sentence, return_tensors='pt')\n",
"output = pt_model(**encoding)\n",
"probs = F.softmax(output.logits[0][-1], dim = -1) #get the model prediction for the entire sentence ([-1]) no batching ([0])\n",
"top = torch.topk(probs, 5)\n",
"top_probs = top.values.tolist()\n",
"top_indices =top.indices.tolist()\n",
"top_words = [tokenizer.decode(idx).strip() for idx in top_indices] \n",
"print(list(zip(top_words, top_probs)))"
"cell_type": "code",
"execution_count": 105,
"id": "53c8c83c",
"metadata": {},
"outputs": [
"data": {
"text/plain": [
"torch.Size([1, 20, 50257])"
"execution_count": 105,
"metadata": {},
"output_type": "execute_result"
"source": [
"cell_type": "code",
"execution_count": 93,
"id": "c6bf8faf",
"metadata": {},
"outputs": [
"name": "stdout",
"output_type": "stream",
"text": [
"[('head', 0.18150091171264648), ('heads', 0.011893990449607372), ('.', 0.011805753223598003), ('same', 0.011544686742126942), ('body', 0.007725409232079983)]\n"
"source": [
"print(list(zip(top_words, top_probs)))"
"cell_type": "code",
"execution_count": 153,
"id": "e03cc143",
"metadata": {},
"outputs": [],
"source": [
"with, 'rt') as file:\n",
" left_contexts = []\n",
" right_contexts = []\n",
" results = []\n",
" i=0\n",
" for line in file:\n",
" if i >20:\n",
" break\n",
" line = replace_endline(line) #get only relevant\n",
" line = line.split('\\t')[6:]\n",
" l_context = \" \".join(line[0].rstrip().split(\" \")[-10:])\n",
" r_context = \" \".join(line[1].rstrip().split(\" \")[:5])\n",
" left_contexts.append(l_context)\n",
" right_contexts.append(r_context)\n",
" i+=1;"
"cell_type": "code",
"execution_count": 154,
"id": "20113d00",
"metadata": {},
"outputs": [
"data": {
"text/plain": [
"['id Seorgc Acorns laud in said Aina and riinninu from',\n",
" 'true that the Republican par- ty is National in its',\n",
" 'but here¬ abouts 1 have not seen or heard ot',\n",
" 'They will not bend or break like the single spiing.',\n",
" '208 miles long. J. 8 . Con way the Com-',\n",
" 'They felt that He was the long-looked- for Messiah with',\n",
" 'in the state that could produce just as good fruit',\n",
" 'after house and managed to find shelter in the lazarette,',\n",
" 'hundred buildings put up in the place this season. Mechanics',\n",
" 'de\\xad testation. I no longer loved him, and I felt',\n",
" 'ol the gentleman, by dis. cussing the constitutionality of these',\n",
" 'the purest pa- triotism, and tin* most indent devotion to',\n",
" 'now to i high, will come down, and those too',\n",
" 'the \\'extension of MdKlnleJ\" street, and running back to the',\n",
" 'sir, it is a subject ot peculiar delight to me',\n",
" 'to affect a miserable scramble for offices in tiiis country.',\n",
" 'Marj, the eldest, said lo the others: \"Lit us pray',\n",
" 'not enquire—But what sum lie asked, could tempt men well',\n",
" \"hete are 500, or 550 acres ol cleared land, '250\",\n",
" 'petition h is lieeii |ui tty elos ly scru;iiii/e<| and',\n",
" 'on the high reputation of his record as Chief Mag-']"
"execution_count": 154,
"metadata": {},
"output_type": "execute_result"
"source": [
"cell_type": "code",
"execution_count": null,
"id": "f1d16811",
"metadata": {
"scrolled": true
"outputs": [],
"source": [
"results = []\n",
"for left_context in left_contexts:\n",
" results.append(dict(get_values_from_model(left_context, model=pt_model, tokenizer=tokenizer, k=10)))\n",
"for idx, elem in enumerate(results):\n",
" tab = summarize_probs_unk(elem, const_wildcard=True, scale_probs=True, wildcard_minweight=0.01)\n",
" print(idx, \" \", gonito_format(tab))"
"cell_type": "code",
"execution_count": 399,
"id": "df38d415",
"metadata": {},
"outputs": [],
"source": [
"a =dict(get_values_from_model(left_contexts[4], model=pt_model, tokenizer=tokenizer, k=10))"
"cell_type": "code",
"execution_count": 400,
"id": "d49824d4",
"metadata": {
"scrolled": true
"outputs": [
"data": {
"text/plain": [
"{'p': 0.2211313545703888,\n",
" 'm': 0.12113624811172485,\n",
" 'mun': 0.07101781666278839,\n",
" 'mon': 0.05452524498105049,\n",
" 'mission': 0.05246562883257866,\n",
" 'mand': 0.03528319671750069,\n",
" 'pan': 0.02023007906973362,\n",
" 'mer': 0.017731018364429474,\n",
" 'pl': 0.015618585981428623,\n",
" 'ple': 0.01504151988774538}"
"execution_count": 400,
"metadata": {},
"output_type": "execute_result"
"source": [
"cell_type": "code",
"execution_count": 401,
"id": "149c090c",
"metadata": {},
"outputs": [
"name": "stdout",
"output_type": "stream",
"text": [
"source": [
"print(sum(float(y) for x, y in a.items()))"
"cell_type": "code",
"execution_count": 402,
"id": "121595c5",
"metadata": {},
"outputs": [],
"source": [
"b = summarize_probs_unk(a, const_wildcard=True, scale_probs=True, wildcard_minweight=0.1)"
"cell_type": "code",
"execution_count": 403,
"id": "3acb48b9",
"metadata": {},
"outputs": [
"name": "stdout",
"output_type": "stream",
"text": [
"source": [
"cell_type": "code",
"execution_count": 259,
"id": "3c44cdbe",
"metadata": {},
"outputs": [
"name": "stdout",
"output_type": "stream",
"text": [
"source": [
"print(sum(float(y) for x, y in b))"
"cell_type": "code",
"execution_count": null,
"id": "f09ce313",
"metadata": {},
"outputs": [],
"source": []
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.6"
"nbformat": 4,
"nbformat_minor": 5