{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/Users/kubakaczmarek/anaconda3/lib/python3.10/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n" ] } ], "source": [ "from transformers import T5Tokenizer, T5Config, T5ForConditionalGeneration\n", "import torch\n", "import lzma\n", "from tqdm import tqdm" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "DEVICE = torch.device(\"mps\") if torch.backends.mps.is_available() else torch.device(\"cpu\")" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "T5_PATH = 't5-base'" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/Users/kubakaczmarek/anaconda3/lib/python3.10/site-packages/transformers/models/t5/tokenization_t5.py:164: FutureWarning: This tokenizer was incorrectly instantiated with a model max length of 512 which will be corrected in Transformers v5.\n", "For now, this behavior is kept to avoid breaking backwards compatibility when padding/encoding with `truncation is True`.\n", "- Be aware that you SHOULD NOT rely on t5-base automatically truncating your input to 512 when padding/encoding.\n", "- If you want to encode/pad to sequences longer than 512 you can either instantiate this tokenizer with `model_max_length` or pass `max_length` when encoding/padding.\n", "- To avoid this warning, please instantiate this tokenizer with `model_max_length` set to your preferred value.\n", " warnings.warn(\n" ] } ], "source": [ "t5_tokenizer = T5Tokenizer.from_pretrained(T5_PATH)\n", "t5_config = T5Config.from_pretrained(T5_PATH)\n", "t5_mlm = T5ForConditionalGeneration.from_pretrained(T5_PATH, config=t5_config).to(DEVICE)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "def preprocess_data(X):\n", " parsed_data = []\n", "\n", " for line in X:\n", " left = line.strip().split('\\t')[6].replace('\\\\n', ' ').split(' ')\n", " right = line.strip().split('\\t')[7].replace('\\\\n', ' ').split(' ')\n", "\n", " if len(left) + len(right) > 330:\n", " text = f\"{' '.join(left[-100:])} {' '.join(right[:100])})\"\n", " else:\n", " text = f\"{' '.join(left)} {' '.join(right)})\"\n", "\n", " parsed_data.append(text)\n", "\n", " return parsed_data" ] }, { "cell_type": "code", "execution_count": 48, "metadata": {}, "outputs": [], "source": [ "def decode(output):\n", " _txt = t5_tokenizer.decode(output[2:], skip_special_tokens=False, clean_up_tokenization_spaces=False)\n", " end = _txt.index('')\n", " \n", " return _txt[:end]\n", "\n", "def parse_output(outputs):\n", " parsed = set([decode(output) for output in outputs])\n", " res = ''\n", " sum = 0\n", " for i, token in enumerate(parsed):\n", " res += f\"{token}:{1 / (i + 4)} \"\n", " sum += 1 / (i + 4)\n", "\n", " res += f\":{1-sum}\"\n", " \n", " return res" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "with lzma.open('test-A/in.tsv.xz', 'rt') as f:\n", " X = f.readlines()" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "X = preprocess_data(X)" ] }, { "cell_type": "code", "execution_count": 55, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Token indices sequence length is longer than the specified maximum sequence length for this model (556 > 512). Running this sequence through the model will result in indexing errors\n" ] } ], "source": [ "with open('test-A/out.tsv', mode='wt', encoding='utf-8') as f:\n", " for line in tqdm(X):\n", " try:\n", " encoded = t5_tokenizer.encode_plus(line, add_special_tokens=True, return_tensors='pt')\n", " input_ids = encoded['input_ids'].to(DEVICE)\n", " outputs = t5_mlm.generate(input_ids=input_ids, \n", " num_beams=5, num_return_sequences=5,\n", " max_length=5)\n", " f.write(parse_output(outputs) + '\\n')\n", " except:\n", " f.write('the:0.9 :0.1\\n')\n" ] }, { "cell_type": "code", "execution_count": 49, "metadata": {}, "outputs": [], "source": [ "input_text = \"anywhere un-\\nless its somewhere. Well, I says,\\nI'm glad to hear that, but, accord-\\ning to your figures, I left myself\\nwhere 1 was, which is five miles near-\\ner to myself than I was when we\\nwere where we are now.\\nWe have now reached Slidell.\\nThat's a fine place. The people\\ndown there remind me of bananas-\\nthey come and go in bunches. 811-\\ndell used to be noted for her tough\\npeople. Now she is noted for be,\\ntough steaks. Well, I certainly got\\none there. When the waiter brought\\nit in it was so small I thought. It\\nwas a crack in the plate. I skid,\\nwaiter what else have you got? +He\\nbrought me in two codfish and one\\nsmelt. I said, waiter have you got\\npigs feet? He said no, rheumatism\\nmakes me walk that way. I sald,\\nhow is the pumpkin pie?\tsaid\\nit's all squash. The best I could get\\nin that hotel was a soup sandwich.\\nAfter the table battle the waiter and\\nI signed an armistice. I then went\\nover to the hotel clerk and asked for\\na room. He said with or without a\\nbed? I said, with a bed. He said,\\nI don't think I 'have' a bed long\\nenough for you. I said, well, I'll\\naddtwo feettoitwhenIgetinit.\\nHe gave me a lovely room on the\\ntop floor. It was one of those rooms\\nthat stands on each side. If you\\nhappen to get up in the middle of\\nthe night you want to be sure and\\nget up in the middle of the room.\\nThat night I dreamt I was eating\\nflannel cakes. When I woke up half\\nof the blanket was gone. I must\\nhave got up on the wrong side of the\\nbed, for next morning I had an awful\\nheadache. I told the manager about\\nit. He said, you have rheumatic\\npains. I said, no, I think it is on,\\nof those attic room pains. I nad to\\ngetupat5a.m.inthemorningso\\nthey could use the sheet to set the\\nbreakfast table.\".replace('\\n', ' ').split('\\t')\n", "input_text = f\"{input_text[0]} {input_text[1]}\"" ] }, { "cell_type": "code", "execution_count": 50, "metadata": {}, "outputs": [], "source": [ "encoded = t5_tokenizer.encode_plus(input_text, add_special_tokens=True, return_tensors='pt')\n", "input_ids = encoded['input_ids'].to(DEVICE)" ] }, { "cell_type": "code", "execution_count": 51, "metadata": {}, "outputs": [], "source": [ "outputs = t5_mlm.generate(input_ids=input_ids, \n", " num_beams=10, num_return_sequences=5,\n", " max_length=5)" ] }, { "cell_type": "code", "execution_count": 52, "metadata": {}, "outputs": [], "source": [ "parsed = parse_output(outputs)" ] }, { "cell_type": "code", "execution_count": 53, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'+He:0.25 He:0.2 I:0.16666666666666666 :0.3833333333333333'" ] }, "execution_count": 53, "metadata": {}, "output_type": "execute_result" } ], "source": [ "parsed" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "base", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.9" }, "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 2 }