This commit is contained in:
JulianZablonski 2023-06-08 20:06:03 +02:00
parent 3b0cab7eef
commit c015d96b23
3 changed files with 17964 additions and 27275 deletions

File diff suppressed because it is too large Load Diff

View File

@ -12,8 +12,8 @@
}, },
"outputs": [ "outputs": [
{ {
"output_type": "stream",
"name": "stdout", "name": "stdout",
"output_type": "stream",
"text": [ "text": [
"Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n", "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
"Collecting transformers\n", "Collecting transformers\n",
@ -72,8 +72,8 @@
}, },
"outputs": [ "outputs": [
{ {
"output_type": "stream",
"name": "stdout", "name": "stdout",
"output_type": "stream",
"text": [ "text": [
"Cloning into 'challenging-america-word-gap-prediction'...\n", "Cloning into 'challenging-america-word-gap-prediction'...\n",
"remote: Wymienianie obiektów: 27, gotowe.\u001b[K\n", "remote: Wymienianie obiektów: 27, gotowe.\u001b[K\n",
@ -132,8 +132,8 @@
}, },
"outputs": [ "outputs": [
{ {
"output_type": "stream",
"name": "stdout", "name": "stdout",
"output_type": "stream",
"text": [ "text": [
"/content/challenging-america-word-gap-prediction\n" "/content/challenging-america-word-gap-prediction\n"
] ]
@ -145,54 +145,26 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"source": [ "execution_count": 28,
"\n",
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')"
],
"metadata": { "metadata": {
"id": "PCA7Ank2dnwM" "id": "PCA7Ank2dnwM"
}, },
"execution_count": 28, "outputs": [],
"outputs": [] "source": [
"\n",
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')"
]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"source": [ "execution_count": 32,
"tokenizer = GPT2Tokenizer.from_pretrained(\"gpt2\")\n",
"model = GPT2LMHeadModel.from_pretrained(\"gpt2\").to(device)"
],
"metadata": { "metadata": {
"id": "U0kG_W5AY7uE" "id": "U0kG_W5AY7uE"
}, },
"execution_count": 32,
"outputs": []
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {
"id": "F4MXeKLxMQ4N"
},
"outputs": [], "outputs": [],
"source": [ "source": [
"def prediction(word: str) -> str:\n", "tokenizer = GPT2Tokenizer.from_pretrained(\"gpt2-medium\")\n",
" left_context =tokenizer.encode(word, return_tensors=\"pt\").to(device)\n", "model = GPT2LMHeadModel.from_pretrained(\"gpt2-medium\").to(device)"
" out = model(left_context)\n",
" prob_dist=torch.softmax(out[0][-1],dim=1)\n",
" values,index =prob_dist.topk(5)\n",
" token = [] \n",
" for x in index[-1]:\n",
" token.append(tokenizer.decode(x))\n",
" zipped = list(zip(values[-1], token))\n",
" for index, element in enumerate(zipped):\n",
" unk = None\n",
" if '<unk>' in element:\n",
" unk = zipped.pop(index)\n",
" zipped.append(('', unk[1]))\n",
" break\n",
" if unk is None:\n",
" zipped[-1] = ('', zipped[-1][1])\n",
" return ' '.join([f'{x[0]}:{x[1]}' for x in zipped])"
] ]
}, },
{ {
@ -206,11 +178,20 @@
"def create_outputs(folder_name):\n", "def create_outputs(folder_name):\n",
" print(f'Creating outputs in {folder_name}')\n", " print(f'Creating outputs in {folder_name}')\n",
" with lzma.open(f'{folder_name}/in.tsv.xz', mode='rt', encoding='utf-8') as fid:\n", " with lzma.open(f'{folder_name}/in.tsv.xz', mode='rt', encoding='utf-8') as fid:\n",
" with open(f'{folder_name}/out.tsv', 'w', encoding='utf-8', newline='\\n') as f:\n", " with open(f'{folder_name}/out-{folder_name}.tsv', 'w', encoding='utf-8', newline='\\n') as f:\n",
" for line in fid:\n", " for line in fid:\n",
" separated = line.split('\\t')\n", " separated = line.split('\\t')\n",
" prefix = separated[6].replace(r'\\n', ' ').split()[-1]\n", " prefix = separated[6].replace(r'\\n', ' ')\n",
" output_line = prediction(prefix)\n", " left_context =tokenizer.encode(prefix, return_tensors=\"pt\").to(device)\n",
" out = model(left_context)\n",
" prob_dist=torch.softmax(out[0][0][-1],dim=0)\n",
" values,index =prob_dist.topk(5)\n",
" token = [] \n",
" for x in index:\n",
" token.append(tokenizer.decode(x))\n",
"\n",
" zipped = list(zip(values, token))\n",
" output_line = ' '.join([f'{x[1]}:{x[0]}' for x in zipped])\n",
" f.write(output_line + '\\n')" " f.write(output_line + '\\n')"
] ]
}, },
@ -226,8 +207,8 @@
}, },
"outputs": [ "outputs": [
{ {
"output_type": "stream",
"name": "stdout", "name": "stdout",
"output_type": "stream",
"text": [ "text": [
"Creating outputs in dev-0\n", "Creating outputs in dev-0\n",
"Creating outputs in test-A\n" "Creating outputs in test-A\n"
@ -238,191 +219,6 @@
"create_outputs('dev-0')\n", "create_outputs('dev-0')\n",
"create_outputs('test-A')" "create_outputs('test-A')"
] ]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "YCGOd41pzfAC"
},
"outputs": [],
"source": [
"tokenizer = GPT2Tokenizer.from_pretrained(\"gpt2\")\n",
"model = GPT2LMHeadModel.from_pretrained(\"gpt2\")\n",
"def get_words_from_line(line):\n",
" line = line.rstrip()\n",
" yield '<s>'\n",
" for t in line.split():\n",
" yield t\n",
" yield '</s>'\n",
"\n",
"\n",
"def get_word_lines_from_file(file_name):\n",
" with lzma.open(file_name, encoding='utf8', mode=\"rt\") as fh:\n",
" for line in fh:\n",
" pattern = r'\\^\\^|\\n|\\\\|[<>]|[()]'\n",
" line = re.sub(pattern, '', line)\n",
" yield line\n",
"\n",
"for line in get_word_lines_from_file(\"train/in.tsv.xz\"):\n",
" # line = line.strip('\\n')\n",
" # fields = line.split(\"\\t\")\n",
" # print(line)\n",
" left_context = str(line)\n",
" input_ids = tokenizer.encode(left_context, return_tensors=\"pt\")\n",
" # print(input_ids)\n",
" output = model(input_ids)\n",
" # print(output[0].shape())\n",
" prob_dist=torch.softmax(output[0][-1],dim=1)\n",
" values,index =prob_dist.topk(20) \n",
" print(left_context[-100:])\n",
" print(values.size())\n",
" print(index.size())\n",
" break\n",
" for x,indx in zip(values,index):\n",
" for i in range(20):\n",
" token = tokenizer.decode(indx[i])\n",
" print(f'{x[i]} {indx[i]} {token}')\n",
" print('-------------------------')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "01zkM5giNUR3"
},
"outputs": [],
"source": [
"\n",
"# line = line.strip('\\n')\n",
"# fields = line.split(\"\\t\")\n",
"# print(line)\n",
"left_context = \"he\"\n",
"input_ids = tokenizer.encode(left_context, return_tensors=\"pt\")\n",
"# print(input_ids)\n",
"output = model(input_ids)\n",
"# print(output[0].shape())\n",
"prob_dist=torch.softmax(output[0][-1],dim=1)\n",
"values,index =prob_dist.topk(5) \n",
"token = []\n",
"for x in index[-1]:\n",
" token.append(tokenizer.decode(x))\n",
" # print(token)\n",
"for x,token in zip(values[-1],token):\n",
" # token = tokenizer.decode(indx)\n",
" print(f'{x} {token}')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "lDc9Nw40C3dr"
},
"outputs": [],
"source": [
"for line in get_word_lines_from_file(\"dev-0/in.tsv.xz\"):\n",
" # line = line.strip('\\n')\n",
" # fields = line.split(\"\\t\")\n",
" # print(line)\n",
" left_context = str(line)\n",
" input_ids = tokenizer.encode(left_context, return_tensors=\"pt\")\n",
" # print(input_ids)\n",
" output = model(input_ids)\n",
" # print(output[0].shape())\n",
" prob_dist=torch.softmax(output[0][-1],dim=1)\n",
" values,index =prob_dist.topk(20) \n",
" print(left_context[-100:])\n",
" # print(values.size())\n",
" # print(index.size())\n",
" # print(values[])\n",
" # break\n",
" for x,indx in zip(values[-1],index[-1]):\n",
" token = tokenizer.decode(indx)\n",
" print(f'{x} {indx} {token}')\n",
" print('-------------------------')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "si7wLC2Tx-kg"
},
"outputs": [],
"source": [
"token = tokenizer.decode(256 )\n",
"print(token)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "lJoE0Cwz0JCM"
},
"outputs": [],
"source": [
"top_indices[0]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "tgmT1vG20U_1"
},
"outputs": [],
"source": [
"top_probs[0]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "U9GVSAZz4SlW"
},
"outputs": [],
"source": [
"top =prob_dist.topk(20) \n",
"top_indices = top.indices.tolist()\n",
"top_probs = top.values.tolist()\n",
"top_words = tokenizer.decode(top_indices)\n",
"print(top_words,'\\n',top_indices,'\\n',top_probs)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "8_WSZ_v99xSH"
},
"outputs": [],
"source": [
"print(index[1])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "OAiJNMNMwNNg"
},
"outputs": [],
"source": [
"print(prob_dist.topk(2)[0].size())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "PIUjH8-ow1y9"
},
"outputs": [],
"source": []
} }
], ],
"metadata": { "metadata": {

File diff suppressed because it is too large Load Diff