{ "cells": [ { "cell_type": "code", "execution_count": 35, "id": "df9bffa6", "metadata": {}, "outputs": [], "source": [ "test_file = 'dev-0/in.tsv.xz'\n", "out_file = 'dev-0/out.tsv'" ] }, { "cell_type": "code", "execution_count": 36, "id": "eb2fcfa4", "metadata": {}, "outputs": [], "source": [ "import torch.nn.functional as F\n", "import torch\n", "import lzma" ] }, { "cell_type": "code", "execution_count": 358, "id": "a03f19ae", "metadata": {}, "outputs": [], "source": [ "###preprocessing\n", "def preprocess(line):\n", " line = get_rid_of_header(line)\n", " line = replace_endline(line)\n", " return line\n", "\n", "def get_rid_of_header(line):\n", " line = line.split('\\t')[6:]\n", " return \" \".join(line)\n", " \n", "def replace_endline(line):\n", " line = line.replace(\"\\\\n\", \" \")\n", " return line\n", "\n", "def get_first_word(text):\n", " \"\"\"Return the first word of a string.\"\"\"\n", " word = \"\"\n", " for i in range(len(text)-1):\n", "# if text[i] in [' ', ',', '.']\n", " if text[i] == ' ':\n", " return word.rstrip()\n", " else:\n", " word += text[i]\n", " return word.rstrip()\n", "\n", "\n", "\n", "def summarize_probs_unk(dic, const_wildcard=True, scale_probs=False, wildcard_minweight=0.01):\n", " ''' \n", " dic: dictionary of probabilities returned by model \n", " returns: tab of probabilities, with specificly as last element\n", " '''\n", " if not scale_probs:\n", " if '' in dic.keys():\n", " del dic['']\n", " tab = [(key, val) for key, val in dic.items()]\n", " tab.append(('', 1-sum([val for val in dic.values()])))\n", " elif const_wildcard and scale_probs: #\n", " if '' in dic.keys():\n", " del dic['']\n", " probsum = sum(float(val) for key, val in dic.items())\n", " for key in dic:\n", " dic[key] = dic[key]/(probsum*(1+wildcard_minweight)) \n", " tab = [(key, val) for key, val in dic.items()]\n", " tab.append(('', 1-sum([val for val in dic.values()])))\n", "# else if '' not in dic.keys(): #no wildcard entry \n", " else:\n", " if '' not in dic.keys():\n", " wildcard_value = wildcard_minweight\n", " else:\n", " wildcard_value = dic['']\n", " del dic['']\n", " for key in dic:\n", " dic[key] = dic[key]/(1-wildcard_value) ###leave some space for wildcar\n", " tab = [(key, val) for key, val in dic.items()]\n", " tab.append(('', 1-sum([val for val in dic.values()])))\n", " \n", " return tab\n", "\n", "\n", "def gonito_format(tab):\n", "# tab = summarize_probs_unk(dic, const_wildcard=const_wildcard, wildcard_minweight=wildcard_minweight)\n", " result = ''\n", " for element in tab[:-1]:\n", " result+=str(element[0])+':'+str(element[1])+'\\t'\n", " result+=':'+ str(tab[-1][1])+'\\n'\n", " return result\n" ] }, { "cell_type": "code", "execution_count": 207, "id": "20cd7089", "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Defaulting to user installation because normal site-packages is not writeable\n", "Requirement already satisfied: transformers in /home/gedin/.local/lib/python3.10/site-packages (4.30.2)\n", "Requirement already satisfied: requests in /usr/lib/python3/dist-packages (from transformers) (2.25.1)\n", "Requirement already satisfied: tokenizers!=0.11.3,<0.14,>=0.11.1 in /home/gedin/.local/lib/python3.10/site-packages (from transformers) (0.13.3)\n", "Requirement already satisfied: numpy>=1.17 in /home/gedin/.local/lib/python3.10/site-packages (from transformers) (1.24.3)\n", "Requirement already satisfied: pyyaml>=5.1 in /usr/lib/python3/dist-packages (from transformers) (5.4.1)\n", "Requirement already satisfied: tqdm>=4.27 in /home/gedin/.local/lib/python3.10/site-packages (from transformers) (4.65.0)\n", "Requirement already satisfied: safetensors>=0.3.1 in /home/gedin/.local/lib/python3.10/site-packages (from transformers) (0.3.1)\n", "Requirement already satisfied: filelock in /home/gedin/.local/lib/python3.10/site-packages (from transformers) (3.12.0)\n", "Requirement already satisfied: huggingface-hub<1.0,>=0.14.1 in /home/gedin/.local/lib/python3.10/site-packages (from transformers) (0.15.1)\n", "Requirement already satisfied: regex!=2019.12.17 in /home/gedin/.local/lib/python3.10/site-packages (from transformers) (2023.5.5)\n", "Requirement already satisfied: packaging>=20.0 in /home/gedin/.local/lib/python3.10/site-packages (from transformers) (23.1)\n", "Requirement already satisfied: fsspec in /home/gedin/.local/lib/python3.10/site-packages (from huggingface-hub<1.0,>=0.14.1->transformers) (2023.6.0)\n", "Requirement already satisfied: typing-extensions>=3.7.4.3 in /home/gedin/.local/lib/python3.10/site-packages (from huggingface-hub<1.0,>=0.14.1->transformers) (4.6.3)\n" ] } ], "source": [ "!pip install transformers\n", "from transformers import pipeline, set_seed, AutoTokenizer, AutoModel, AutoModelForCausalLM, GPT2Tokenizer" ] }, { "cell_type": "code", "execution_count": 3, "id": "ff4b3b42", "metadata": { "collapsed": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Defaulting to user installation because normal site-packages is not writeable\n", "Requirement already satisfied: transformers in /home/gedin/.local/lib/python3.10/site-packages (4.30.2)\n", "Requirement already satisfied: tqdm>=4.27 in /home/gedin/.local/lib/python3.10/site-packages (from transformers) (4.65.0)\n", "Requirement already satisfied: regex!=2019.12.17 in /home/gedin/.local/lib/python3.10/site-packages (from transformers) (2023.5.5)\n", "Requirement already satisfied: pyyaml>=5.1 in /usr/lib/python3/dist-packages (from transformers) (5.4.1)\n", "Requirement already satisfied: requests in /usr/lib/python3/dist-packages (from transformers) (2.25.1)\n", "Requirement already satisfied: packaging>=20.0 in /home/gedin/.local/lib/python3.10/site-packages (from transformers) (23.1)\n", "Requirement already satisfied: numpy>=1.17 in /home/gedin/.local/lib/python3.10/site-packages (from transformers) (1.24.3)\n", "Requirement already satisfied: huggingface-hub<1.0,>=0.14.1 in /home/gedin/.local/lib/python3.10/site-packages (from transformers) (0.15.1)\n", "Requirement already satisfied: tokenizers!=0.11.3,<0.14,>=0.11.1 in /home/gedin/.local/lib/python3.10/site-packages (from transformers) (0.13.3)\n", "Requirement already satisfied: safetensors>=0.3.1 in /home/gedin/.local/lib/python3.10/site-packages (from transformers) (0.3.1)\n", "Requirement already satisfied: filelock in /home/gedin/.local/lib/python3.10/site-packages (from transformers) (3.12.0)\n", "Requirement already satisfied: typing-extensions>=3.7.4.3 in /home/gedin/.local/lib/python3.10/site-packages (from huggingface-hub<1.0,>=0.14.1->transformers) (4.6.3)\n", "Requirement already satisfied: fsspec in /home/gedin/.local/lib/python3.10/site-packages (from huggingface-hub<1.0,>=0.14.1->transformers) (2023.6.0)\n" ] } ], "source": [] }, { "cell_type": "code", "execution_count": 39, "id": "9e28e9f8", "metadata": {}, "outputs": [], "source": [ "model_name = \"gpt2\"\n", "tokenizer = GPT2Tokenizer.from_pretrained('gpt2')\n", "pt_model = AutoModelForCausalLM.from_pretrained(model_name)" ] }, { "cell_type": "code", "execution_count": 40, "id": "1efb36b9", "metadata": {}, "outputs": [], "source": [ "sentence = 'Northeasterly hv the head of said .^corn’s and and the'" ] }, { "cell_type": "code", "execution_count": 41, "id": "2e8b8b60", "metadata": {}, "outputs": [], "source": [ "encoding = tokenizer(sentence, return_tensors='pt')\n", "output = pt_model(**encoding)" ] }, { "cell_type": "code", "execution_count": 135, "id": "94f321d5", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/tmp/ipykernel_9135/861569571.py:1: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.\n", " probs = F.softmax(output.logits[0][-1])\n" ] } ], "source": [ "probs = F.softmax(output.logits[0][-1])" ] }, { "cell_type": "code", "execution_count": 43, "id": "94986365", "metadata": {}, "outputs": [], "source": [ "top = torch.topk(probs, 10)\n", "top_indices = top.indices.tolist()\n", "top_probs = top.values.tolist()\n", "top_words = [tokenizer.decode(idx).strip() for idx in top_indices] " ] }, { "cell_type": "code", "execution_count": 47, "id": "622c26f2", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'\\n'" ] }, "execution_count": 47, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tokenizer.decode(198)" ] }, { "cell_type": "code", "execution_count": 49, "id": "fd4beaa3", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "['head', 'heads', '.', 'same', 'body', 'neck', 'other', 'whole', 'said', 'king']\n" ] } ], "source": [ "# print(top_indices)\n", "print((top_words))" ] }, { "cell_type": "code", "execution_count": 50, "id": "ec8bdc77", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[1182, 6665, 764, 976, 1767, 7393, 584, 2187, 531, 5822]\n" ] } ], "source": [ "print(top_indices)" ] }, { "cell_type": "code", "execution_count": 136, "id": "578257eb", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[0.18150091171264648, 0.011893990449607372, 0.011805753223598003, 0.011544686742126942, 0.007725409232079983]\n" ] } ], "source": [ "print(top_probs)" ] }, { "cell_type": "code", "execution_count": 63, "id": "939e12dc", "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[('head', 0.18150091171264648), ('heads', 0.011893990449607372), ('.', 0.011805753223598003), ('same', 0.011544686742126942), ('body', 0.007725409232079983), ('neck', 0.007723228540271521), ('other', 0.006957209203392267), ('whole', 0.006453146692365408), ('said', 0.004757815971970558), ('king', 0.004543370567262173)]\n" ] } ], "source": [ "print(list(zip(top_words, top_probs)))" ] }, { "cell_type": "code", "execution_count": 66, "id": "c0d32e0d", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[('head', 0.7049117686793325), ('heads', 0.04619378363102533), ('.', 0.045851088608379394), ('same', 0.044837160725262136), ('body', 0.030003881711508234), ('neck', 0.02999541235835104), ('other', 0.027020352671284463), ('whole', 0.02506267877962141), ('said', 0.018478367079292513), ('king', 0.01764550575594297), ('', 0.01000000000000012)] 1.0\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/tmp/ipykernel_9135/698839240.py:4: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.\n", " probs = F.softmax(output.logits[0][-1]) #get the model prediction for the entire sentence ([-1]) no batching ([0])\n" ] } ], "source": [ "#################Does probabilities from get_values_from_model sum up to 1##################\n", "asdf = dict(get_values_from_model(sentence, model=pt_model, tokenizer=tokenizer, k=10))\n", "asdf = summarize_probs_unk(asdf)\n", "print([x for x in asdf] ,sum([x[1] for x in asdf]))" ] }, { "cell_type": "code", "execution_count": 152, "id": "11e4e965", "metadata": {}, "outputs": [], "source": [ "def get_values_from_model(context: str, model, tokenizer, k=10):\n", " encoding = tokenizer(context, return_tensors='pt')\n", " output = model(**encoding)\n", " probs = F.softmax(output.logits[0][-1], dim=-1) #get the model prediction for the entire sentence ([-1]) no batching ([0])\n", " top = torch.topk(probs, k)\n", " top_probs = top.values.tolist()\n", " top_indices =top.indices.tolist()\n", " top_words = [tokenizer.decode(idx).strip() for idx in top_indices] \n", "# print(context, \"probs: \\n\", list(zip(top_words, top_indices, top_probs)))\n", " return list(zip(top_words, top_probs))" ] }, { "cell_type": "code", "execution_count": 398, "id": "17884d9b", "metadata": { "scrolled": true }, "outputs": [], "source": [ "with lzma.open(test_file, 'rt') as file:\n", " left_contexts = []\n", " results = []\n", " for line in file:\n", " line = replace_endline(line) #get only relevant\n", " line = line.split('\\t')[6:]\n", " context = ' '.join(line[0].rstrip().split(\" \")[-10:])\n", "# context = context + ' '\n", "# print(context)\n", " left_contexts.append(context)\n", " ###get results from gpt model###\n", " for left_context in left_contexts:\n", " results.append(dict(get_values_from_model(left_context, model=pt_model, tokenizer=tokenizer, k=10)))\n", " with open(out_file, 'w') as outfile:\n", " for elem in results:\n", " tab = summarize_probs_unk(elem, const_wildcard=False, scale_probs=True, wildcard_minweight=0.01)\n", " outfile.write(gonito_format(tab))\n" ] }, { "cell_type": "code", "execution_count": 95, "id": "b8846e25", "metadata": {}, "outputs": [], "source": [ "sentence = 'Northeasterly hv the head of said .^corn’s and and the'" ] }, { "cell_type": "code", "execution_count": 148, "id": "37cdd8c5", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[('head', 0.18150091171264648), ('heads', 0.011893990449607372), ('.', 0.011805753223598003), ('same', 0.011544686742126942), ('body', 0.007725409232079983)]\n" ] } ], "source": [ "encoding = tokenizer(sentence, return_tensors='pt')\n", "output = pt_model(**encoding)\n", "probs = F.softmax(output.logits[0][-1], dim = -1) #get the model prediction for the entire sentence ([-1]) no batching ([0])\n", "top = torch.topk(probs, 5)\n", "top_probs = top.values.tolist()\n", "top_indices =top.indices.tolist()\n", "top_words = [tokenizer.decode(idx).strip() for idx in top_indices] \n", "print(list(zip(top_words, top_probs)))" ] }, { "cell_type": "code", "execution_count": 105, "id": "53c8c83c", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([1, 20, 50257])" ] }, "execution_count": 105, "metadata": {}, "output_type": "execute_result" } ], "source": [ "output.logits.shape" ] }, { "cell_type": "code", "execution_count": 93, "id": "c6bf8faf", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[('head', 0.18150091171264648), ('heads', 0.011893990449607372), ('.', 0.011805753223598003), ('same', 0.011544686742126942), ('body', 0.007725409232079983)]\n" ] } ], "source": [ "print(list(zip(top_words, top_probs)))" ] }, { "cell_type": "code", "execution_count": 153, "id": "e03cc143", "metadata": {}, "outputs": [], "source": [ "with lzma.open(test_file, 'rt') as file:\n", " left_contexts = []\n", " right_contexts = []\n", " results = []\n", " i=0\n", " for line in file:\n", " if i >20:\n", " break\n", " line = replace_endline(line) #get only relevant\n", " line = line.split('\\t')[6:]\n", " l_context = \" \".join(line[0].rstrip().split(\" \")[-10:])\n", " r_context = \" \".join(line[1].rstrip().split(\" \")[:5])\n", " left_contexts.append(l_context)\n", " right_contexts.append(r_context)\n", " i+=1;" ] }, { "cell_type": "code", "execution_count": 154, "id": "20113d00", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['id Seorgc Acorns laud in said Aina and riinninu from',\n", " 'true that the Republican par- ty is National in its',\n", " 'but here¬ abouts 1 have not seen or heard ot',\n", " 'They will not bend or break like the single spiing.',\n", " '208 miles long. J. 8 . Con way the Com-',\n", " 'They felt that He was the long-looked- for Messiah with',\n", " 'in the state that could produce just as good fruit',\n", " 'after house and managed to find shelter in the lazarette,',\n", " 'hundred buildings put up in the place this season. Mechanics',\n", " 'de\\xad testation. I no longer loved him, and I felt',\n", " 'ol the gentleman, by dis. cussing the constitutionality of these',\n", " 'the purest pa- triotism, and tin* most indent devotion to',\n", " 'now to i high, will come down, and those too',\n", " 'the \\'extension of MdKlnleJ\" street, and running back to the',\n", " 'sir, it is a subject ot peculiar delight to me',\n", " 'to affect a miserable scramble for offices in tiiis country.',\n", " 'Marj, the eldest, said lo the others: \"Lit us pray',\n", " 'not enquire—But what sum lie asked, could tempt men well',\n", " \"hete are 500, or 550 acres ol cleared land, '250\",\n", " 'petition h is lieeii |ui tty elos ly scru;iiii/e<| and',\n", " 'on the high reputation of his record as Chief Mag-']" ] }, "execution_count": 154, "metadata": {}, "output_type": "execute_result" } ], "source": [ "left_contexts" ] }, { "cell_type": "code", "execution_count": null, "id": "f1d16811", "metadata": { "scrolled": true }, "outputs": [], "source": [ "results = []\n", "for left_context in left_contexts:\n", " results.append(dict(get_values_from_model(left_context, model=pt_model, tokenizer=tokenizer, k=10)))\n", "for idx, elem in enumerate(results):\n", " tab = summarize_probs_unk(elem, const_wildcard=True, scale_probs=True, wildcard_minweight=0.01)\n", " print(idx, \" \", gonito_format(tab))" ] }, { "cell_type": "code", "execution_count": 399, "id": "df38d415", "metadata": {}, "outputs": [], "source": [ "a =dict(get_values_from_model(left_contexts[4], model=pt_model, tokenizer=tokenizer, k=10))" ] }, { "cell_type": "code", "execution_count": 400, "id": "d49824d4", "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/plain": [ "{'p': 0.2211313545703888,\n", " 'm': 0.12113624811172485,\n", " 'mun': 0.07101781666278839,\n", " 'mon': 0.05452524498105049,\n", " 'mission': 0.05246562883257866,\n", " 'mand': 0.03528319671750069,\n", " 'pan': 0.02023007906973362,\n", " 'mer': 0.017731018364429474,\n", " 'pl': 0.015618585981428623,\n", " 'ple': 0.01504151988774538}" ] }, "execution_count": 400, "metadata": {}, "output_type": "execute_result" } ], "source": [ "a" ] }, { "cell_type": "code", "execution_count": 401, "id": "149c090c", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.624180693179369\n" ] } ], "source": [ "print(sum(float(y) for x, y in a.items()))" ] }, { "cell_type": "code", "execution_count": 402, "id": "121595c5", "metadata": {}, "outputs": [], "source": [ "b = summarize_probs_unk(a, const_wildcard=True, scale_probs=True, wildcard_minweight=0.1)" ] }, { "cell_type": "code", "execution_count": 403, "id": "3acb48b9", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "p:0.3220678024033818\tm:0.1764294588459882\tmun:0.1034342334152795\tmon:0.07941354974589608\tmission:0.07641381211021994\tmand:0.05138837796498232\tpan:0.02946419390001886\tmer:0.02582442517073288\tpl:0.022747763081629364\tple:0.021907292452780093\t:0.09090909090909094\n", "\n" ] } ], "source": [ "print(gonito_format(b))" ] }, { "cell_type": "code", "execution_count": 259, "id": "3c44cdbe", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1.0\n" ] } ], "source": [ "print(sum(float(y) for x, y in b))" ] }, { "cell_type": "code", "execution_count": null, "id": "f09ce313", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" } }, "nbformat": 4, "nbformat_minor": 5 }