fix
This commit is contained in:
parent
3b0cab7eef
commit
c015d96b23
21695
dev-0/out.tsv
21695
dev-0/out.tsv
File diff suppressed because it is too large
Load Diff
254
lab12.ipynb
254
lab12.ipynb
@ -12,8 +12,8 @@
|
|||||||
},
|
},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
"output_type": "stream",
|
|
||||||
"name": "stdout",
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
|
"Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
|
||||||
"Collecting transformers\n",
|
"Collecting transformers\n",
|
||||||
@ -72,8 +72,8 @@
|
|||||||
},
|
},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
"output_type": "stream",
|
|
||||||
"name": "stdout",
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"Cloning into 'challenging-america-word-gap-prediction'...\n",
|
"Cloning into 'challenging-america-word-gap-prediction'...\n",
|
||||||
"remote: Wymienianie obiektów: 27, gotowe.\u001b[K\n",
|
"remote: Wymienianie obiektów: 27, gotowe.\u001b[K\n",
|
||||||
@ -132,8 +132,8 @@
|
|||||||
},
|
},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
"output_type": "stream",
|
|
||||||
"name": "stdout",
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"/content/challenging-america-word-gap-prediction\n"
|
"/content/challenging-america-word-gap-prediction\n"
|
||||||
]
|
]
|
||||||
@ -145,54 +145,26 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"source": [
|
"execution_count": 28,
|
||||||
"\n",
|
|
||||||
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')"
|
|
||||||
],
|
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"id": "PCA7Ank2dnwM"
|
"id": "PCA7Ank2dnwM"
|
||||||
},
|
},
|
||||||
"execution_count": 28,
|
"outputs": [],
|
||||||
"outputs": []
|
"source": [
|
||||||
|
"\n",
|
||||||
|
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')"
|
||||||
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"source": [
|
"execution_count": 32,
|
||||||
"tokenizer = GPT2Tokenizer.from_pretrained(\"gpt2\")\n",
|
|
||||||
"model = GPT2LMHeadModel.from_pretrained(\"gpt2\").to(device)"
|
|
||||||
],
|
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"id": "U0kG_W5AY7uE"
|
"id": "U0kG_W5AY7uE"
|
||||||
},
|
},
|
||||||
"execution_count": 32,
|
|
||||||
"outputs": []
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 33,
|
|
||||||
"metadata": {
|
|
||||||
"id": "F4MXeKLxMQ4N"
|
|
||||||
},
|
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"def prediction(word: str) -> str:\n",
|
"tokenizer = GPT2Tokenizer.from_pretrained(\"gpt2-medium\")\n",
|
||||||
" left_context =tokenizer.encode(word, return_tensors=\"pt\").to(device)\n",
|
"model = GPT2LMHeadModel.from_pretrained(\"gpt2-medium\").to(device)"
|
||||||
" out = model(left_context)\n",
|
|
||||||
" prob_dist=torch.softmax(out[0][-1],dim=1)\n",
|
|
||||||
" values,index =prob_dist.topk(5)\n",
|
|
||||||
" token = [] \n",
|
|
||||||
" for x in index[-1]:\n",
|
|
||||||
" token.append(tokenizer.decode(x))\n",
|
|
||||||
" zipped = list(zip(values[-1], token))\n",
|
|
||||||
" for index, element in enumerate(zipped):\n",
|
|
||||||
" unk = None\n",
|
|
||||||
" if '<unk>' in element:\n",
|
|
||||||
" unk = zipped.pop(index)\n",
|
|
||||||
" zipped.append(('', unk[1]))\n",
|
|
||||||
" break\n",
|
|
||||||
" if unk is None:\n",
|
|
||||||
" zipped[-1] = ('', zipped[-1][1])\n",
|
|
||||||
" return ' '.join([f'{x[0]}:{x[1]}' for x in zipped])"
|
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -206,11 +178,20 @@
|
|||||||
"def create_outputs(folder_name):\n",
|
"def create_outputs(folder_name):\n",
|
||||||
" print(f'Creating outputs in {folder_name}')\n",
|
" print(f'Creating outputs in {folder_name}')\n",
|
||||||
" with lzma.open(f'{folder_name}/in.tsv.xz', mode='rt', encoding='utf-8') as fid:\n",
|
" with lzma.open(f'{folder_name}/in.tsv.xz', mode='rt', encoding='utf-8') as fid:\n",
|
||||||
" with open(f'{folder_name}/out.tsv', 'w', encoding='utf-8', newline='\\n') as f:\n",
|
" with open(f'{folder_name}/out-{folder_name}.tsv', 'w', encoding='utf-8', newline='\\n') as f:\n",
|
||||||
" for line in fid:\n",
|
" for line in fid:\n",
|
||||||
" separated = line.split('\\t')\n",
|
" separated = line.split('\\t')\n",
|
||||||
" prefix = separated[6].replace(r'\\n', ' ').split()[-1]\n",
|
" prefix = separated[6].replace(r'\\n', ' ')\n",
|
||||||
" output_line = prediction(prefix)\n",
|
" left_context =tokenizer.encode(prefix, return_tensors=\"pt\").to(device)\n",
|
||||||
|
" out = model(left_context)\n",
|
||||||
|
" prob_dist=torch.softmax(out[0][0][-1],dim=0)\n",
|
||||||
|
" values,index =prob_dist.topk(5)\n",
|
||||||
|
" token = [] \n",
|
||||||
|
" for x in index:\n",
|
||||||
|
" token.append(tokenizer.decode(x))\n",
|
||||||
|
"\n",
|
||||||
|
" zipped = list(zip(values, token))\n",
|
||||||
|
" output_line = ' '.join([f'{x[1]}:{x[0]}' for x in zipped])\n",
|
||||||
" f.write(output_line + '\\n')"
|
" f.write(output_line + '\\n')"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
@ -226,8 +207,8 @@
|
|||||||
},
|
},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
"output_type": "stream",
|
|
||||||
"name": "stdout",
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"Creating outputs in dev-0\n",
|
"Creating outputs in dev-0\n",
|
||||||
"Creating outputs in test-A\n"
|
"Creating outputs in test-A\n"
|
||||||
@ -238,191 +219,6 @@
|
|||||||
"create_outputs('dev-0')\n",
|
"create_outputs('dev-0')\n",
|
||||||
"create_outputs('test-A')"
|
"create_outputs('test-A')"
|
||||||
]
|
]
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {
|
|
||||||
"id": "YCGOd41pzfAC"
|
|
||||||
},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"tokenizer = GPT2Tokenizer.from_pretrained(\"gpt2\")\n",
|
|
||||||
"model = GPT2LMHeadModel.from_pretrained(\"gpt2\")\n",
|
|
||||||
"def get_words_from_line(line):\n",
|
|
||||||
" line = line.rstrip()\n",
|
|
||||||
" yield '<s>'\n",
|
|
||||||
" for t in line.split():\n",
|
|
||||||
" yield t\n",
|
|
||||||
" yield '</s>'\n",
|
|
||||||
"\n",
|
|
||||||
"\n",
|
|
||||||
"def get_word_lines_from_file(file_name):\n",
|
|
||||||
" with lzma.open(file_name, encoding='utf8', mode=\"rt\") as fh:\n",
|
|
||||||
" for line in fh:\n",
|
|
||||||
" pattern = r'\\^\\^|\\n|\\\\|[<>]|[()]'\n",
|
|
||||||
" line = re.sub(pattern, '', line)\n",
|
|
||||||
" yield line\n",
|
|
||||||
"\n",
|
|
||||||
"for line in get_word_lines_from_file(\"train/in.tsv.xz\"):\n",
|
|
||||||
" # line = line.strip('\\n')\n",
|
|
||||||
" # fields = line.split(\"\\t\")\n",
|
|
||||||
" # print(line)\n",
|
|
||||||
" left_context = str(line)\n",
|
|
||||||
" input_ids = tokenizer.encode(left_context, return_tensors=\"pt\")\n",
|
|
||||||
" # print(input_ids)\n",
|
|
||||||
" output = model(input_ids)\n",
|
|
||||||
" # print(output[0].shape())\n",
|
|
||||||
" prob_dist=torch.softmax(output[0][-1],dim=1)\n",
|
|
||||||
" values,index =prob_dist.topk(20) \n",
|
|
||||||
" print(left_context[-100:])\n",
|
|
||||||
" print(values.size())\n",
|
|
||||||
" print(index.size())\n",
|
|
||||||
" break\n",
|
|
||||||
" for x,indx in zip(values,index):\n",
|
|
||||||
" for i in range(20):\n",
|
|
||||||
" token = tokenizer.decode(indx[i])\n",
|
|
||||||
" print(f'{x[i]} {indx[i]} {token}')\n",
|
|
||||||
" print('-------------------------')"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {
|
|
||||||
"id": "01zkM5giNUR3"
|
|
||||||
},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"\n",
|
|
||||||
"# line = line.strip('\\n')\n",
|
|
||||||
"# fields = line.split(\"\\t\")\n",
|
|
||||||
"# print(line)\n",
|
|
||||||
"left_context = \"he\"\n",
|
|
||||||
"input_ids = tokenizer.encode(left_context, return_tensors=\"pt\")\n",
|
|
||||||
"# print(input_ids)\n",
|
|
||||||
"output = model(input_ids)\n",
|
|
||||||
"# print(output[0].shape())\n",
|
|
||||||
"prob_dist=torch.softmax(output[0][-1],dim=1)\n",
|
|
||||||
"values,index =prob_dist.topk(5) \n",
|
|
||||||
"token = []\n",
|
|
||||||
"for x in index[-1]:\n",
|
|
||||||
" token.append(tokenizer.decode(x))\n",
|
|
||||||
" # print(token)\n",
|
|
||||||
"for x,token in zip(values[-1],token):\n",
|
|
||||||
" # token = tokenizer.decode(indx)\n",
|
|
||||||
" print(f'{x} {token}')"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {
|
|
||||||
"id": "lDc9Nw40C3dr"
|
|
||||||
},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"for line in get_word_lines_from_file(\"dev-0/in.tsv.xz\"):\n",
|
|
||||||
" # line = line.strip('\\n')\n",
|
|
||||||
" # fields = line.split(\"\\t\")\n",
|
|
||||||
" # print(line)\n",
|
|
||||||
" left_context = str(line)\n",
|
|
||||||
" input_ids = tokenizer.encode(left_context, return_tensors=\"pt\")\n",
|
|
||||||
" # print(input_ids)\n",
|
|
||||||
" output = model(input_ids)\n",
|
|
||||||
" # print(output[0].shape())\n",
|
|
||||||
" prob_dist=torch.softmax(output[0][-1],dim=1)\n",
|
|
||||||
" values,index =prob_dist.topk(20) \n",
|
|
||||||
" print(left_context[-100:])\n",
|
|
||||||
" # print(values.size())\n",
|
|
||||||
" # print(index.size())\n",
|
|
||||||
" # print(values[])\n",
|
|
||||||
" # break\n",
|
|
||||||
" for x,indx in zip(values[-1],index[-1]):\n",
|
|
||||||
" token = tokenizer.decode(indx)\n",
|
|
||||||
" print(f'{x} {indx} {token}')\n",
|
|
||||||
" print('-------------------------')"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {
|
|
||||||
"id": "si7wLC2Tx-kg"
|
|
||||||
},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"token = tokenizer.decode(256 )\n",
|
|
||||||
"print(token)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {
|
|
||||||
"id": "lJoE0Cwz0JCM"
|
|
||||||
},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"top_indices[0]"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {
|
|
||||||
"id": "tgmT1vG20U_1"
|
|
||||||
},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"top_probs[0]"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {
|
|
||||||
"id": "U9GVSAZz4SlW"
|
|
||||||
},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"top =prob_dist.topk(20) \n",
|
|
||||||
"top_indices = top.indices.tolist()\n",
|
|
||||||
"top_probs = top.values.tolist()\n",
|
|
||||||
"top_words = tokenizer.decode(top_indices)\n",
|
|
||||||
"print(top_words,'\\n',top_indices,'\\n',top_probs)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {
|
|
||||||
"id": "8_WSZ_v99xSH"
|
|
||||||
},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"print(index[1])"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {
|
|
||||||
"id": "OAiJNMNMwNNg"
|
|
||||||
},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"print(prob_dist.topk(2)[0].size())"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {
|
|
||||||
"id": "PIUjH8-ow1y9"
|
|
||||||
},
|
|
||||||
"outputs": [],
|
|
||||||
"source": []
|
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
|
23288
test-A/out.tsv
23288
test-A/out.tsv
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user