{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "W8-j-5oV0o46", "outputId": "5cf81efc-7e9b-46a6-d3bd-792a4b4b39b9" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n", "Collecting transformers\n", " Downloading transformers-4.29.2-py3-none-any.whl (7.1 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.1/7.1 MB\u001b[0m \u001b[31m105.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hRequirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from transformers) (3.12.0)\n", "Collecting huggingface-hub<1.0,>=0.14.1 (from transformers)\n", " Downloading huggingface_hub-0.15.1-py3-none-any.whl (236 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m236.8/236.8 kB\u001b[0m \u001b[31m33.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hRequirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (1.22.4)\n", "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from transformers) (23.1)\n", "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (6.0)\n", "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (2022.10.31)\n", "Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from transformers) (2.27.1)\n", "Collecting tokenizers!=0.11.3,<0.14,>=0.11.1 (from transformers)\n", " Downloading tokenizers-0.13.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.8 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.8/7.8 MB\u001b[0m \u001b[31m117.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hRequirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.10/dist-packages (from transformers) (4.65.0)\n", "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.14.1->transformers) (2023.4.0)\n", "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.14.1->transformers) (4.5.0)\n", "Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (1.26.15)\n", "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2022.12.7)\n", "Requirement already satisfied: charset-normalizer~=2.0.0 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2.0.12)\n", "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (3.4)\n", "Installing collected packages: tokenizers, huggingface-hub, transformers\n", "Successfully installed huggingface-hub-0.15.1 tokenizers-0.13.3 transformers-4.29.2\n", "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n", "Requirement already satisfied: torch in /usr/local/lib/python3.10/dist-packages (2.0.1+cu118)\n", "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch) (3.12.0)\n", "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.10/dist-packages (from torch) (4.5.0)\n", "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch) (1.11.1)\n", "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch) (3.1)\n", "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch) (3.1.2)\n", "Requirement already satisfied: triton==2.0.0 in /usr/local/lib/python3.10/dist-packages (from torch) (2.0.0)\n", "Requirement already satisfied: cmake in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch) (3.25.2)\n", "Requirement already satisfied: lit in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch) (16.0.5)\n", "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch) (2.1.2)\n", "Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch) (1.3.0)\n" ] } ], "source": [ "!pip install transformers\n", "!pip install torch" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "O6aa5mpE0s6H", "outputId": "18112d31-6a14-4b91-b9db-44ea197c8d0c" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Cloning into 'challenging-america-word-gap-prediction'...\n", "remote: Wymienianie obiektów: 27, gotowe.\u001b[K\n", "remote: Zliczanie obiektów: 100% (27/27), gotowe.\u001b[K\n", "remote: Kompresowanie obiektów: 100% (23/23), gotowe.\u001b[K\n", "remote: Razem 27 (delty 2), użyte ponownie 17 (delty 0), paczki użyte ponownie 0\u001b[K\n", "Receiving objects: 100% (27/27), 278.33 MiB | 8.52 MiB/s, done.\n", "Resolving deltas: 100% (2/2), done.\n" ] } ], "source": [ "!git clone --single-branch git://gonito.net/challenging-america-word-gap-prediction -b master" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "id": "uHkXCRs-0iSr" }, "outputs": [], "source": [ "import torch\n", "import sys\n", "from transformers import GPT2Tokenizer, GPT2LMHeadModel\n" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "id": "HyKM4zn41YvQ" }, "outputs": [], "source": [ "import lzma\n", "from itertools import islice\n", "import regex as re\n", "import sys\n", "from torchtext.vocab import build_vocab_from_iterator\n", "from torch import nn\n", "from torch.utils.data import IterableDataset\n", "import itertools" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "-k8RhlmI06mQ", "outputId": "e2ef4117-5d5b-40e9-f774-9faba825042c" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "/content/challenging-america-word-gap-prediction\n" ] } ], "source": [ "%cd /content/challenging-america-word-gap-prediction" ] }, { "cell_type": "code", "execution_count": 28, "metadata": { "id": "PCA7Ank2dnwM" }, "outputs": [], "source": [ "\n", "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')" ] }, { "cell_type": "code", "execution_count": 32, "metadata": { "id": "U0kG_W5AY7uE" }, "outputs": [], "source": [ "tokenizer = GPT2Tokenizer.from_pretrained(\"gpt2-medium\")\n", "model = GPT2LMHeadModel.from_pretrained(\"gpt2-medium\").to(device)" ] }, { "cell_type": "code", "execution_count": 34, "metadata": { "id": "My_1_4L5MMc3" }, "outputs": [], "source": [ "def create_outputs(folder_name):\n", " print(f'Creating outputs in {folder_name}')\n", " with lzma.open(f'{folder_name}/in.tsv.xz', mode='rt', encoding='utf-8') as fid:\n", " with open(f'{folder_name}/out-{folder_name}.tsv', 'w', encoding='utf-8', newline='\\n') as f:\n", " for line in fid:\n", " separated = line.split('\\t')\n", " prefix = separated[6].replace(r'\\n', ' ')\n", " left_context =tokenizer.encode(prefix, return_tensors=\"pt\").to(device)\n", " out = model(left_context)\n", " prob_dist=torch.softmax(out[0][0][-1],dim=0)\n", " values,index =prob_dist.topk(5)\n", " token = [] \n", " for x in index:\n", " token.append(tokenizer.decode(x))\n", "\n", " zipped = list(zip(values, token))\n", " output_line = ' '.join([f'{x[1]}:{x[0]}' for x in zipped])\n", " f.write(output_line + '\\n')" ] }, { "cell_type": "code", "execution_count": 35, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "4VQPcLF-OChJ", "outputId": "7a7408b2-ad26-4041-887e-99a8e9d36d0f" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Creating outputs in dev-0\n", "Creating outputs in test-A\n" ] } ], "source": [ "create_outputs('dev-0')\n", "create_outputs('test-A')" ] } ], "metadata": { "accelerator": "GPU", "colab": { "provenance": [] }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "name": "python" } }, "nbformat": 4, "nbformat_minor": 0 }