init gpt commit
This commit is contained in:
commit
2b78d73fee
9
README.md
Normal file
9
README.md
Normal file
@ -0,0 +1,9 @@
|
||||
Challenging America word-gap prediction
|
||||
===================================
|
||||
|
||||
Guess a word in a gap.
|
||||
|
||||
Evaluation metric
|
||||
-----------------
|
||||
|
||||
LikelihoodHashed is the metric
|
1
config.txt
Normal file
1
config.txt
Normal file
@ -0,0 +1 @@
|
||||
--metric PerplexityHashed --precision 2 --in-header in-header.tsv --out-header out-header.tsv
|
10519
dev-0/expected.tsv
Normal file
10519
dev-0/expected.tsv
Normal file
File diff suppressed because it is too large
Load Diff
10519
dev-0/hate-speech-info.tsv
Normal file
10519
dev-0/hate-speech-info.tsv
Normal file
File diff suppressed because it is too large
Load Diff
BIN
dev-0/in.tsv.xz
Normal file
BIN
dev-0/in.tsv.xz
Normal file
Binary file not shown.
10519
dev-0/out.tsv
Normal file
10519
dev-0/out.tsv
Normal file
File diff suppressed because it is too large
Load Diff
BIN
filename.pickle
Normal file
BIN
filename.pickle
Normal file
Binary file not shown.
14
gonito.yaml
Normal file
14
gonito.yaml
Normal file
@ -0,0 +1,14 @@
|
||||
description: zad8, trigram with left/right context embeddings
|
||||
tags:
|
||||
- neural-network
|
||||
- left-to-right
|
||||
params:
|
||||
epochs: 3
|
||||
learning-rate: 0.0003
|
||||
vocab-size: 20000
|
||||
batch_s: 3200
|
||||
top_k_words: 20
|
||||
param-files:
|
||||
- config/*.yaml
|
||||
links:
|
||||
- repo: "https://git.wmi.amu.edu.pl/s470618/challenging-america-word-gap-prediction"
|
697
gpt-2 finetune.ipynb
Normal file
697
gpt-2 finetune.ipynb
Normal file
@ -0,0 +1,697 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 35,
|
||||
"id": "df9bffa6",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"test_file = 'dev-0/in.tsv.xz'\n",
|
||||
"out_file = 'dev-0/out.tsv'"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 36,
|
||||
"id": "eb2fcfa4",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import torch.nn.functional as F\n",
|
||||
"import torch\n",
|
||||
"import lzma"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 358,
|
||||
"id": "a03f19ae",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"###preprocessing\n",
|
||||
"def preprocess(line):\n",
|
||||
" line = get_rid_of_header(line)\n",
|
||||
" line = replace_endline(line)\n",
|
||||
" return line\n",
|
||||
"\n",
|
||||
"def get_rid_of_header(line):\n",
|
||||
" line = line.split('\\t')[6:]\n",
|
||||
" return \" \".join(line)\n",
|
||||
" \n",
|
||||
"def replace_endline(line):\n",
|
||||
" line = line.replace(\"\\\\n\", \" \")\n",
|
||||
" return line\n",
|
||||
"\n",
|
||||
"def get_first_word(text):\n",
|
||||
" \"\"\"Return the first word of a string.\"\"\"\n",
|
||||
" word = \"\"\n",
|
||||
" for i in range(len(text)-1):\n",
|
||||
"# if text[i] in [' ', ',', '.']\n",
|
||||
" if text[i] == ' ':\n",
|
||||
" return word.rstrip()\n",
|
||||
" else:\n",
|
||||
" word += text[i]\n",
|
||||
" return word.rstrip()\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def summarize_probs_unk(dic, const_wildcard=True, scale_probs=False, wildcard_minweight=0.01):\n",
|
||||
" ''' \n",
|
||||
" dic: dictionary of probabilities returned by model \n",
|
||||
" returns: tab of probabilities, with <unk> specificly as last element\n",
|
||||
" '''\n",
|
||||
" if not scale_probs:\n",
|
||||
" if '' in dic.keys():\n",
|
||||
" del dic['']\n",
|
||||
" tab = [(key, val) for key, val in dic.items()]\n",
|
||||
" tab.append(('', 1-sum([val for val in dic.values()])))\n",
|
||||
" elif const_wildcard and scale_probs: #\n",
|
||||
" if '' in dic.keys():\n",
|
||||
" del dic['']\n",
|
||||
" probsum = sum(float(val) for key, val in dic.items())\n",
|
||||
" for key in dic:\n",
|
||||
" dic[key] = dic[key]/(probsum*(1+wildcard_minweight)) \n",
|
||||
" tab = [(key, val) for key, val in dic.items()]\n",
|
||||
" tab.append(('', 1-sum([val for val in dic.values()])))\n",
|
||||
"# else if '' not in dic.keys(): #no wildcard entry \n",
|
||||
" else:\n",
|
||||
" if '' not in dic.keys():\n",
|
||||
" wildcard_value = wildcard_minweight\n",
|
||||
" else:\n",
|
||||
" wildcard_value = dic['']\n",
|
||||
" del dic['']\n",
|
||||
" for key in dic:\n",
|
||||
" dic[key] = dic[key]/(1-wildcard_value) ###leave some space for wildcar\n",
|
||||
" tab = [(key, val) for key, val in dic.items()]\n",
|
||||
" tab.append(('', 1-sum([val for val in dic.values()])))\n",
|
||||
" \n",
|
||||
" return tab\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def gonito_format(tab):\n",
|
||||
"# tab = summarize_probs_unk(dic, const_wildcard=const_wildcard, wildcard_minweight=wildcard_minweight)\n",
|
||||
" result = ''\n",
|
||||
" for element in tab[:-1]:\n",
|
||||
" result+=str(element[0])+':'+str(element[1])+'\\t'\n",
|
||||
" result+=':'+ str(tab[-1][1])+'\\n'\n",
|
||||
" return result\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 207,
|
||||
"id": "20cd7089",
|
||||
"metadata": {
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Defaulting to user installation because normal site-packages is not writeable\n",
|
||||
"Requirement already satisfied: transformers in /home/gedin/.local/lib/python3.10/site-packages (4.30.2)\n",
|
||||
"Requirement already satisfied: requests in /usr/lib/python3/dist-packages (from transformers) (2.25.1)\n",
|
||||
"Requirement already satisfied: tokenizers!=0.11.3,<0.14,>=0.11.1 in /home/gedin/.local/lib/python3.10/site-packages (from transformers) (0.13.3)\n",
|
||||
"Requirement already satisfied: numpy>=1.17 in /home/gedin/.local/lib/python3.10/site-packages (from transformers) (1.24.3)\n",
|
||||
"Requirement already satisfied: pyyaml>=5.1 in /usr/lib/python3/dist-packages (from transformers) (5.4.1)\n",
|
||||
"Requirement already satisfied: tqdm>=4.27 in /home/gedin/.local/lib/python3.10/site-packages (from transformers) (4.65.0)\n",
|
||||
"Requirement already satisfied: safetensors>=0.3.1 in /home/gedin/.local/lib/python3.10/site-packages (from transformers) (0.3.1)\n",
|
||||
"Requirement already satisfied: filelock in /home/gedin/.local/lib/python3.10/site-packages (from transformers) (3.12.0)\n",
|
||||
"Requirement already satisfied: huggingface-hub<1.0,>=0.14.1 in /home/gedin/.local/lib/python3.10/site-packages (from transformers) (0.15.1)\n",
|
||||
"Requirement already satisfied: regex!=2019.12.17 in /home/gedin/.local/lib/python3.10/site-packages (from transformers) (2023.5.5)\n",
|
||||
"Requirement already satisfied: packaging>=20.0 in /home/gedin/.local/lib/python3.10/site-packages (from transformers) (23.1)\n",
|
||||
"Requirement already satisfied: fsspec in /home/gedin/.local/lib/python3.10/site-packages (from huggingface-hub<1.0,>=0.14.1->transformers) (2023.6.0)\n",
|
||||
"Requirement already satisfied: typing-extensions>=3.7.4.3 in /home/gedin/.local/lib/python3.10/site-packages (from huggingface-hub<1.0,>=0.14.1->transformers) (4.6.3)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"!pip install transformers\n",
|
||||
"from transformers import pipeline, set_seed, AutoTokenizer, AutoModel, AutoModelForCausalLM, GPT2Tokenizer"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "ff4b3b42",
|
||||
"metadata": {
|
||||
"collapsed": true
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Defaulting to user installation because normal site-packages is not writeable\n",
|
||||
"Requirement already satisfied: transformers in /home/gedin/.local/lib/python3.10/site-packages (4.30.2)\n",
|
||||
"Requirement already satisfied: tqdm>=4.27 in /home/gedin/.local/lib/python3.10/site-packages (from transformers) (4.65.0)\n",
|
||||
"Requirement already satisfied: regex!=2019.12.17 in /home/gedin/.local/lib/python3.10/site-packages (from transformers) (2023.5.5)\n",
|
||||
"Requirement already satisfied: pyyaml>=5.1 in /usr/lib/python3/dist-packages (from transformers) (5.4.1)\n",
|
||||
"Requirement already satisfied: requests in /usr/lib/python3/dist-packages (from transformers) (2.25.1)\n",
|
||||
"Requirement already satisfied: packaging>=20.0 in /home/gedin/.local/lib/python3.10/site-packages (from transformers) (23.1)\n",
|
||||
"Requirement already satisfied: numpy>=1.17 in /home/gedin/.local/lib/python3.10/site-packages (from transformers) (1.24.3)\n",
|
||||
"Requirement already satisfied: huggingface-hub<1.0,>=0.14.1 in /home/gedin/.local/lib/python3.10/site-packages (from transformers) (0.15.1)\n",
|
||||
"Requirement already satisfied: tokenizers!=0.11.3,<0.14,>=0.11.1 in /home/gedin/.local/lib/python3.10/site-packages (from transformers) (0.13.3)\n",
|
||||
"Requirement already satisfied: safetensors>=0.3.1 in /home/gedin/.local/lib/python3.10/site-packages (from transformers) (0.3.1)\n",
|
||||
"Requirement already satisfied: filelock in /home/gedin/.local/lib/python3.10/site-packages (from transformers) (3.12.0)\n",
|
||||
"Requirement already satisfied: typing-extensions>=3.7.4.3 in /home/gedin/.local/lib/python3.10/site-packages (from huggingface-hub<1.0,>=0.14.1->transformers) (4.6.3)\n",
|
||||
"Requirement already satisfied: fsspec in /home/gedin/.local/lib/python3.10/site-packages (from huggingface-hub<1.0,>=0.14.1->transformers) (2023.6.0)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 39,
|
||||
"id": "9e28e9f8",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model_name = \"gpt2\"\n",
|
||||
"tokenizer = GPT2Tokenizer.from_pretrained('gpt2')\n",
|
||||
"pt_model = AutoModelForCausalLM.from_pretrained(model_name)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 40,
|
||||
"id": "1efb36b9",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"sentence = 'Northeasterly hv the head of said .^corn’s and and the'"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 41,
|
||||
"id": "2e8b8b60",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"encoding = tokenizer(sentence, return_tensors='pt')\n",
|
||||
"output = pt_model(**encoding)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 135,
|
||||
"id": "94f321d5",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/tmp/ipykernel_9135/861569571.py:1: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.\n",
|
||||
" probs = F.softmax(output.logits[0][-1])\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"probs = F.softmax(output.logits[0][-1])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 43,
|
||||
"id": "94986365",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"top = torch.topk(probs, 10)\n",
|
||||
"top_indices = top.indices.tolist()\n",
|
||||
"top_probs = top.values.tolist()\n",
|
||||
"top_words = [tokenizer.decode(idx).strip() for idx in top_indices] "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 47,
|
||||
"id": "622c26f2",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"'\\n'"
|
||||
]
|
||||
},
|
||||
"execution_count": 47,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"tokenizer.decode(198)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 49,
|
||||
"id": "fd4beaa3",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"['head', 'heads', '.', 'same', 'body', 'neck', 'other', 'whole', 'said', 'king']\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# print(top_indices)\n",
|
||||
"print((top_words))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 50,
|
||||
"id": "ec8bdc77",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[1182, 6665, 764, 976, 1767, 7393, 584, 2187, 531, 5822]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(top_indices)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 136,
|
||||
"id": "578257eb",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[0.18150091171264648, 0.011893990449607372, 0.011805753223598003, 0.011544686742126942, 0.007725409232079983]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(top_probs)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 63,
|
||||
"id": "939e12dc",
|
||||
"metadata": {
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[('head', 0.18150091171264648), ('heads', 0.011893990449607372), ('.', 0.011805753223598003), ('same', 0.011544686742126942), ('body', 0.007725409232079983), ('neck', 0.007723228540271521), ('other', 0.006957209203392267), ('whole', 0.006453146692365408), ('said', 0.004757815971970558), ('king', 0.004543370567262173)]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(list(zip(top_words, top_probs)))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 66,
|
||||
"id": "c0d32e0d",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[('head', 0.7049117686793325), ('heads', 0.04619378363102533), ('.', 0.045851088608379394), ('same', 0.044837160725262136), ('body', 0.030003881711508234), ('neck', 0.02999541235835104), ('other', 0.027020352671284463), ('whole', 0.02506267877962141), ('said', 0.018478367079292513), ('king', 0.01764550575594297), ('', 0.01000000000000012)] 1.0\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/tmp/ipykernel_9135/698839240.py:4: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.\n",
|
||||
" probs = F.softmax(output.logits[0][-1]) #get the model prediction for the entire sentence ([-1]) no batching ([0])\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"#################Does probabilities from get_values_from_model sum up to 1##################\n",
|
||||
"asdf = dict(get_values_from_model(sentence, model=pt_model, tokenizer=tokenizer, k=10))\n",
|
||||
"asdf = summarize_probs_unk(asdf)\n",
|
||||
"print([x for x in asdf] ,sum([x[1] for x in asdf]))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 152,
|
||||
"id": "11e4e965",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def get_values_from_model(context: str, model, tokenizer, k=10):\n",
|
||||
" encoding = tokenizer(context, return_tensors='pt')\n",
|
||||
" output = model(**encoding)\n",
|
||||
" probs = F.softmax(output.logits[0][-1], dim=-1) #get the model prediction for the entire sentence ([-1]) no batching ([0])\n",
|
||||
" top = torch.topk(probs, k)\n",
|
||||
" top_probs = top.values.tolist()\n",
|
||||
" top_indices =top.indices.tolist()\n",
|
||||
" top_words = [tokenizer.decode(idx).strip() for idx in top_indices] \n",
|
||||
"# print(context, \"probs: \\n\", list(zip(top_words, top_indices, top_probs)))\n",
|
||||
" return list(zip(top_words, top_probs))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 398,
|
||||
"id": "17884d9b",
|
||||
"metadata": {
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"with lzma.open(test_file, 'rt') as file:\n",
|
||||
" left_contexts = []\n",
|
||||
" results = []\n",
|
||||
" for line in file:\n",
|
||||
" line = replace_endline(line) #get only relevant\n",
|
||||
" line = line.split('\\t')[6:]\n",
|
||||
" context = ' '.join(line[0].rstrip().split(\" \")[-10:])\n",
|
||||
"# context = context + ' '\n",
|
||||
"# print(context)\n",
|
||||
" left_contexts.append(context)\n",
|
||||
" ###get results from gpt model###\n",
|
||||
" for left_context in left_contexts:\n",
|
||||
" results.append(dict(get_values_from_model(left_context, model=pt_model, tokenizer=tokenizer, k=10)))\n",
|
||||
" with open(out_file, 'w') as outfile:\n",
|
||||
" for elem in results:\n",
|
||||
" tab = summarize_probs_unk(elem, const_wildcard=False, scale_probs=True, wildcard_minweight=0.01)\n",
|
||||
" outfile.write(gonito_format(tab))\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 95,
|
||||
"id": "b8846e25",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"sentence = 'Northeasterly hv the head of said .^corn’s and and the'"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 148,
|
||||
"id": "37cdd8c5",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[('head', 0.18150091171264648), ('heads', 0.011893990449607372), ('.', 0.011805753223598003), ('same', 0.011544686742126942), ('body', 0.007725409232079983)]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"encoding = tokenizer(sentence, return_tensors='pt')\n",
|
||||
"output = pt_model(**encoding)\n",
|
||||
"probs = F.softmax(output.logits[0][-1], dim = -1) #get the model prediction for the entire sentence ([-1]) no batching ([0])\n",
|
||||
"top = torch.topk(probs, 5)\n",
|
||||
"top_probs = top.values.tolist()\n",
|
||||
"top_indices =top.indices.tolist()\n",
|
||||
"top_words = [tokenizer.decode(idx).strip() for idx in top_indices] \n",
|
||||
"print(list(zip(top_words, top_probs)))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 105,
|
||||
"id": "53c8c83c",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"torch.Size([1, 20, 50257])"
|
||||
]
|
||||
},
|
||||
"execution_count": 105,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"output.logits.shape"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 93,
|
||||
"id": "c6bf8faf",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[('head', 0.18150091171264648), ('heads', 0.011893990449607372), ('.', 0.011805753223598003), ('same', 0.011544686742126942), ('body', 0.007725409232079983)]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(list(zip(top_words, top_probs)))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 153,
|
||||
"id": "e03cc143",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"with lzma.open(test_file, 'rt') as file:\n",
|
||||
" left_contexts = []\n",
|
||||
" right_contexts = []\n",
|
||||
" results = []\n",
|
||||
" i=0\n",
|
||||
" for line in file:\n",
|
||||
" if i >20:\n",
|
||||
" break\n",
|
||||
" line = replace_endline(line) #get only relevant\n",
|
||||
" line = line.split('\\t')[6:]\n",
|
||||
" l_context = \" \".join(line[0].rstrip().split(\" \")[-10:])\n",
|
||||
" r_context = \" \".join(line[1].rstrip().split(\" \")[:5])\n",
|
||||
" left_contexts.append(l_context)\n",
|
||||
" right_contexts.append(r_context)\n",
|
||||
" i+=1;"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 154,
|
||||
"id": "20113d00",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"['id Seorgc Acorns laud in said Aina and riinninu from',\n",
|
||||
" 'true that the Republican par- ty is National in its',\n",
|
||||
" 'but here¬ abouts 1 have not seen or heard ot',\n",
|
||||
" 'They will not bend or break like the single spiing.',\n",
|
||||
" '208 miles long. J. 8 . Con way the Com-',\n",
|
||||
" 'They felt that He was the long-looked- for Messiah with',\n",
|
||||
" 'in the state that could produce just as good fruit',\n",
|
||||
" 'after house and managed to find shelter in the lazarette,',\n",
|
||||
" 'hundred buildings put up in the place this season. Mechanics',\n",
|
||||
" 'de\\xad testation. I no longer loved him, and I felt',\n",
|
||||
" 'ol the gentleman, by dis. cussing the constitutionality of these',\n",
|
||||
" 'the purest pa- triotism, and tin* most indent devotion to',\n",
|
||||
" 'now to i high, will come down, and those too',\n",
|
||||
" 'the \\'extension of MdKlnleJ\" street, and running back to the',\n",
|
||||
" 'sir, it is a subject ot peculiar delight to me',\n",
|
||||
" 'to affect a miserable scramble for offices in tiiis country.',\n",
|
||||
" 'Marj, the eldest, said lo the others: \"Lit us pray',\n",
|
||||
" 'not enquire—But what sum lie asked, could tempt men well',\n",
|
||||
" \"hete are 500, or 550 acres ol cleared land, '250\",\n",
|
||||
" 'petition h is lieeii |ui tty elos ly scru;iiii/e<| and',\n",
|
||||
" 'on the high reputation of his record as Chief Mag-']"
|
||||
]
|
||||
},
|
||||
"execution_count": 154,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"left_contexts"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "f1d16811",
|
||||
"metadata": {
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"results = []\n",
|
||||
"for left_context in left_contexts:\n",
|
||||
" results.append(dict(get_values_from_model(left_context, model=pt_model, tokenizer=tokenizer, k=10)))\n",
|
||||
"for idx, elem in enumerate(results):\n",
|
||||
" tab = summarize_probs_unk(elem, const_wildcard=True, scale_probs=True, wildcard_minweight=0.01)\n",
|
||||
" print(idx, \" \", gonito_format(tab))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 399,
|
||||
"id": "df38d415",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"a =dict(get_values_from_model(left_contexts[4], model=pt_model, tokenizer=tokenizer, k=10))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 400,
|
||||
"id": "d49824d4",
|
||||
"metadata": {
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"{'p': 0.2211313545703888,\n",
|
||||
" 'm': 0.12113624811172485,\n",
|
||||
" 'mun': 0.07101781666278839,\n",
|
||||
" 'mon': 0.05452524498105049,\n",
|
||||
" 'mission': 0.05246562883257866,\n",
|
||||
" 'mand': 0.03528319671750069,\n",
|
||||
" 'pan': 0.02023007906973362,\n",
|
||||
" 'mer': 0.017731018364429474,\n",
|
||||
" 'pl': 0.015618585981428623,\n",
|
||||
" 'ple': 0.01504151988774538}"
|
||||
]
|
||||
},
|
||||
"execution_count": 400,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"a"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 401,
|
||||
"id": "149c090c",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"0.624180693179369\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(sum(float(y) for x, y in a.items()))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 402,
|
||||
"id": "121595c5",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"b = summarize_probs_unk(a, const_wildcard=True, scale_probs=True, wildcard_minweight=0.1)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 403,
|
||||
"id": "3acb48b9",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"p:0.3220678024033818\tm:0.1764294588459882\tmun:0.1034342334152795\tmon:0.07941354974589608\tmission:0.07641381211021994\tmand:0.05138837796498232\tpan:0.02946419390001886\tmer:0.02582442517073288\tpl:0.022747763081629364\tple:0.021907292452780093\t:0.09090909090909094\n",
|
||||
"\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(gonito_format(b))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 259,
|
||||
"id": "3c44cdbe",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"1.0\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(sum(float(y) for x, y in b))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "f09ce313",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.6"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
1
in-header.tsv
Normal file
1
in-header.tsv
Normal file
@ -0,0 +1 @@
|
||||
FileId Year LeftContext RightContext
|
|
1
out-header.tsv
Normal file
1
out-header.tsv
Normal file
@ -0,0 +1 @@
|
||||
Word
|
|
7414
test-A/hate-speech-info.tsv
Normal file
7414
test-A/hate-speech-info.tsv
Normal file
File diff suppressed because it is too large
Load Diff
BIN
test-A/in.tsv.xz
Normal file
BIN
test-A/in.tsv.xz
Normal file
Binary file not shown.
432022
train/expected.tsv
Normal file
432022
train/expected.tsv
Normal file
File diff suppressed because it is too large
Load Diff
432022
train/hate-speech-info.tsv
Normal file
432022
train/hate-speech-info.tsv
Normal file
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user