aitech-eks-pub/cw/15_similarity_search.ipynb

2067 lines
130 KiB
Plaintext
Raw Normal View History

2021-06-23 10:01:55 +02:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"https://arxiv.org/pdf/1910.10683.pdf"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"https://github.com/applicaai/kleister-nda"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"from transformers import T5Tokenizer, T5ForConditionalGeneration"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"text = \"translate English to French: My name is Azeem and I live in India\""
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"text = \"summarize: Machine learning involves computers discovering how they can perform tasks without being explicitly programmed to do so. It involves computers learning from data provided so that they carry out certain tasks. For simple tasks assigned to computers, it is possible to program algorithms telling the machine how to execute all steps required to solve the problem at hand; on the computer's part, no learning is needed. For more advanced tasks, it can be challenging for a human to manually create the needed algorithms. In practice, it can turn out to be more effective to help the machine develop its own algorithm, rather than having human programmers specify every needed step.\""
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"machine learning involves computers learning from data provided so that they carry out certain tasks without being explicitly programme\n"
]
}
],
"source": [
"from transformers import T5Tokenizer, T5ForConditionalGeneration\n",
"\n",
"tokenizer = T5Tokenizer.from_pretrained('t5-small')\n",
"\n",
"model = T5ForConditionalGeneration.from_pretrained('t5-small', return_dict=True,).to('cuda')\n",
"\n",
"\n",
"# You can also use \"translate English to French\" and \"translate English to Romanian\"\n",
"input_ids = tokenizer(text, return_tensors=\"pt\").input_ids.to('cuda') # Batch size 1\n",
"\n",
"outputs = model.generate(input_ids)\n",
"\n",
"decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)\n",
"\n",
"print(decoded)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"T5ForConditionalGeneration(\n",
" (shared): Embedding(32128, 512)\n",
" (encoder): T5Stack(\n",
" (embed_tokens): Embedding(32128, 512)\n",
" (block): ModuleList(\n",
" (0): T5Block(\n",
" (layer): ModuleList(\n",
" (0): T5LayerSelfAttention(\n",
" (SelfAttention): T5Attention(\n",
" (q): Linear(in_features=512, out_features=512, bias=False)\n",
" (k): Linear(in_features=512, out_features=512, bias=False)\n",
" (v): Linear(in_features=512, out_features=512, bias=False)\n",
" (o): Linear(in_features=512, out_features=512, bias=False)\n",
" (relative_attention_bias): Embedding(32, 8)\n",
" )\n",
" (layer_norm): T5LayerNorm()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (1): T5LayerFF(\n",
" (DenseReluDense): T5DenseReluDense(\n",
" (wi): Linear(in_features=512, out_features=2048, bias=False)\n",
" (wo): Linear(in_features=2048, out_features=512, bias=False)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (layer_norm): T5LayerNorm()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" )\n",
" (1): T5Block(\n",
" (layer): ModuleList(\n",
" (0): T5LayerSelfAttention(\n",
" (SelfAttention): T5Attention(\n",
" (q): Linear(in_features=512, out_features=512, bias=False)\n",
" (k): Linear(in_features=512, out_features=512, bias=False)\n",
" (v): Linear(in_features=512, out_features=512, bias=False)\n",
" (o): Linear(in_features=512, out_features=512, bias=False)\n",
" )\n",
" (layer_norm): T5LayerNorm()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (1): T5LayerFF(\n",
" (DenseReluDense): T5DenseReluDense(\n",
" (wi): Linear(in_features=512, out_features=2048, bias=False)\n",
" (wo): Linear(in_features=2048, out_features=512, bias=False)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (layer_norm): T5LayerNorm()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" )\n",
" (2): T5Block(\n",
" (layer): ModuleList(\n",
" (0): T5LayerSelfAttention(\n",
" (SelfAttention): T5Attention(\n",
" (q): Linear(in_features=512, out_features=512, bias=False)\n",
" (k): Linear(in_features=512, out_features=512, bias=False)\n",
" (v): Linear(in_features=512, out_features=512, bias=False)\n",
" (o): Linear(in_features=512, out_features=512, bias=False)\n",
" )\n",
" (layer_norm): T5LayerNorm()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (1): T5LayerFF(\n",
" (DenseReluDense): T5DenseReluDense(\n",
" (wi): Linear(in_features=512, out_features=2048, bias=False)\n",
" (wo): Linear(in_features=2048, out_features=512, bias=False)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (layer_norm): T5LayerNorm()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" )\n",
" (3): T5Block(\n",
" (layer): ModuleList(\n",
" (0): T5LayerSelfAttention(\n",
" (SelfAttention): T5Attention(\n",
" (q): Linear(in_features=512, out_features=512, bias=False)\n",
" (k): Linear(in_features=512, out_features=512, bias=False)\n",
" (v): Linear(in_features=512, out_features=512, bias=False)\n",
" (o): Linear(in_features=512, out_features=512, bias=False)\n",
" )\n",
" (layer_norm): T5LayerNorm()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (1): T5LayerFF(\n",
" (DenseReluDense): T5DenseReluDense(\n",
" (wi): Linear(in_features=512, out_features=2048, bias=False)\n",
" (wo): Linear(in_features=2048, out_features=512, bias=False)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (layer_norm): T5LayerNorm()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" )\n",
" (4): T5Block(\n",
" (layer): ModuleList(\n",
" (0): T5LayerSelfAttention(\n",
" (SelfAttention): T5Attention(\n",
" (q): Linear(in_features=512, out_features=512, bias=False)\n",
" (k): Linear(in_features=512, out_features=512, bias=False)\n",
" (v): Linear(in_features=512, out_features=512, bias=False)\n",
" (o): Linear(in_features=512, out_features=512, bias=False)\n",
" )\n",
" (layer_norm): T5LayerNorm()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (1): T5LayerFF(\n",
" (DenseReluDense): T5DenseReluDense(\n",
" (wi): Linear(in_features=512, out_features=2048, bias=False)\n",
" (wo): Linear(in_features=2048, out_features=512, bias=False)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (layer_norm): T5LayerNorm()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" )\n",
" (5): T5Block(\n",
" (layer): ModuleList(\n",
" (0): T5LayerSelfAttention(\n",
" (SelfAttention): T5Attention(\n",
" (q): Linear(in_features=512, out_features=512, bias=False)\n",
" (k): Linear(in_features=512, out_features=512, bias=False)\n",
" (v): Linear(in_features=512, out_features=512, bias=False)\n",
" (o): Linear(in_features=512, out_features=512, bias=False)\n",
" )\n",
" (layer_norm): T5LayerNorm()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (1): T5LayerFF(\n",
" (DenseReluDense): T5DenseReluDense(\n",
" (wi): Linear(in_features=512, out_features=2048, bias=False)\n",
" (wo): Linear(in_features=2048, out_features=512, bias=False)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (layer_norm): T5LayerNorm()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" )\n",
" )\n",
" (final_layer_norm): T5LayerNorm()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (decoder): T5Stack(\n",
" (embed_tokens): Embedding(32128, 512)\n",
" (block): ModuleList(\n",
" (0): T5Block(\n",
" (layer): ModuleList(\n",
" (0): T5LayerSelfAttention(\n",
" (SelfAttention): T5Attention(\n",
" (q): Linear(in_features=512, out_features=512, bias=False)\n",
" (k): Linear(in_features=512, out_features=512, bias=False)\n",
" (v): Linear(in_features=512, out_features=512, bias=False)\n",
" (o): Linear(in_features=512, out_features=512, bias=False)\n",
" (relative_attention_bias): Embedding(32, 8)\n",
" )\n",
" (layer_norm): T5LayerNorm()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (1): T5LayerCrossAttention(\n",
" (EncDecAttention): T5Attention(\n",
" (q): Linear(in_features=512, out_features=512, bias=False)\n",
" (k): Linear(in_features=512, out_features=512, bias=False)\n",
" (v): Linear(in_features=512, out_features=512, bias=False)\n",
" (o): Linear(in_features=512, out_features=512, bias=False)\n",
" )\n",
" (layer_norm): T5LayerNorm()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (2): T5LayerFF(\n",
" (DenseReluDense): T5DenseReluDense(\n",
" (wi): Linear(in_features=512, out_features=2048, bias=False)\n",
" (wo): Linear(in_features=2048, out_features=512, bias=False)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (layer_norm): T5LayerNorm()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" )\n",
" (1): T5Block(\n",
" (layer): ModuleList(\n",
" (0): T5LayerSelfAttention(\n",
" (SelfAttention): T5Attention(\n",
" (q): Linear(in_features=512, out_features=512, bias=False)\n",
" (k): Linear(in_features=512, out_features=512, bias=False)\n",
" (v): Linear(in_features=512, out_features=512, bias=False)\n",
" (o): Linear(in_features=512, out_features=512, bias=False)\n",
" )\n",
" (layer_norm): T5LayerNorm()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (1): T5LayerCrossAttention(\n",
" (EncDecAttention): T5Attention(\n",
" (q): Linear(in_features=512, out_features=512, bias=False)\n",
" (k): Linear(in_features=512, out_features=512, bias=False)\n",
" (v): Linear(in_features=512, out_features=512, bias=False)\n",
" (o): Linear(in_features=512, out_features=512, bias=False)\n",
" )\n",
" (layer_norm): T5LayerNorm()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (2): T5LayerFF(\n",
" (DenseReluDense): T5DenseReluDense(\n",
" (wi): Linear(in_features=512, out_features=2048, bias=False)\n",
" (wo): Linear(in_features=2048, out_features=512, bias=False)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (layer_norm): T5LayerNorm()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" )\n",
" (2): T5Block(\n",
" (layer): ModuleList(\n",
" (0): T5LayerSelfAttention(\n",
" (SelfAttention): T5Attention(\n",
" (q): Linear(in_features=512, out_features=512, bias=False)\n",
" (k): Linear(in_features=512, out_features=512, bias=False)\n",
" (v): Linear(in_features=512, out_features=512, bias=False)\n",
" (o): Linear(in_features=512, out_features=512, bias=False)\n",
" )\n",
" (layer_norm): T5LayerNorm()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (1): T5LayerCrossAttention(\n",
" (EncDecAttention): T5Attention(\n",
" (q): Linear(in_features=512, out_features=512, bias=False)\n",
" (k): Linear(in_features=512, out_features=512, bias=False)\n",
" (v): Linear(in_features=512, out_features=512, bias=False)\n",
" (o): Linear(in_features=512, out_features=512, bias=False)\n",
" )\n",
" (layer_norm): T5LayerNorm()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (2): T5LayerFF(\n",
" (DenseReluDense): T5DenseReluDense(\n",
" (wi): Linear(in_features=512, out_features=2048, bias=False)\n",
" (wo): Linear(in_features=2048, out_features=512, bias=False)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (layer_norm): T5LayerNorm()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" )\n",
" (3): T5Block(\n",
" (layer): ModuleList(\n",
" (0): T5LayerSelfAttention(\n",
" (SelfAttention): T5Attention(\n",
" (q): Linear(in_features=512, out_features=512, bias=False)\n",
" (k): Linear(in_features=512, out_features=512, bias=False)\n",
" (v): Linear(in_features=512, out_features=512, bias=False)\n",
" (o): Linear(in_features=512, out_features=512, bias=False)\n",
" )\n",
" (layer_norm): T5LayerNorm()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (1): T5LayerCrossAttention(\n",
" (EncDecAttention): T5Attention(\n",
" (q): Linear(in_features=512, out_features=512, bias=False)\n",
" (k): Linear(in_features=512, out_features=512, bias=False)\n",
" (v): Linear(in_features=512, out_features=512, bias=False)\n",
" (o): Linear(in_features=512, out_features=512, bias=False)\n",
" )\n",
" (layer_norm): T5LayerNorm()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (2): T5LayerFF(\n",
" (DenseReluDense): T5DenseReluDense(\n",
" (wi): Linear(in_features=512, out_features=2048, bias=False)\n",
" (wo): Linear(in_features=2048, out_features=512, bias=False)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (layer_norm): T5LayerNorm()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" )\n",
" (4): T5Block(\n",
" (layer): ModuleList(\n",
" (0): T5LayerSelfAttention(\n",
" (SelfAttention): T5Attention(\n",
" (q): Linear(in_features=512, out_features=512, bias=False)\n",
" (k): Linear(in_features=512, out_features=512, bias=False)\n",
" (v): Linear(in_features=512, out_features=512, bias=False)\n",
" (o): Linear(in_features=512, out_features=512, bias=False)\n",
" )\n",
" (layer_norm): T5LayerNorm()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (1): T5LayerCrossAttention(\n",
" (EncDecAttention): T5Attention(\n",
" (q): Linear(in_features=512, out_features=512, bias=False)\n",
" (k): Linear(in_features=512, out_features=512, bias=False)\n",
" (v): Linear(in_features=512, out_features=512, bias=False)\n",
" (o): Linear(in_features=512, out_features=512, bias=False)\n",
" )\n",
" (layer_norm): T5LayerNorm()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (2): T5LayerFF(\n",
" (DenseReluDense): T5DenseReluDense(\n",
" (wi): Linear(in_features=512, out_features=2048, bias=False)\n",
" (wo): Linear(in_features=2048, out_features=512, bias=False)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (layer_norm): T5LayerNorm()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" )\n",
" (5): T5Block(\n",
" (layer): ModuleList(\n",
" (0): T5LayerSelfAttention(\n",
" (SelfAttention): T5Attention(\n",
" (q): Linear(in_features=512, out_features=512, bias=False)\n",
" (k): Linear(in_features=512, out_features=512, bias=False)\n",
" (v): Linear(in_features=512, out_features=512, bias=False)\n",
" (o): Linear(in_features=512, out_features=512, bias=False)\n",
" )\n",
" (layer_norm): T5LayerNorm()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (1): T5LayerCrossAttention(\n",
" (EncDecAttention): T5Attention(\n",
" (q): Linear(in_features=512, out_features=512, bias=False)\n",
" (k): Linear(in_features=512, out_features=512, bias=False)\n",
" (v): Linear(in_features=512, out_features=512, bias=False)\n",
" (o): Linear(in_features=512, out_features=512, bias=False)\n",
" )\n",
" (layer_norm): T5LayerNorm()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (2): T5LayerFF(\n",
" (DenseReluDense): T5DenseReluDense(\n",
" (wi): Linear(in_features=512, out_features=2048, bias=False)\n",
" (wo): Linear(in_features=2048, out_features=512, bias=False)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (layer_norm): T5LayerNorm()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" )\n",
" )\n",
" (final_layer_norm): T5LayerNorm()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (lm_head): Linear(in_features=512, out_features=32128, bias=False)\n",
")"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"KLEISTER_PATH = '/media/kuba/ssdsam/Syncthing/Syncthing/przedmioty/2020-02/IE/applica/kleister-nda/'"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"train_exp_f = open(KLEISTER_PATH + 'train/expected.tsv')\n",
"train_exp = []\n",
"for line in train_exp_f:\n",
" line_splitted = line.strip('\\n').split(' ')\n",
" found = False\n",
" for elem in line_splitted:\n",
" if 'jurisdiction=' in elem:\n",
" train_exp.append('jurisdiction: ' + elem.split('=')[1])\n",
" found = True\n",
" break\n",
" if not found:\n",
" train_exp.append('jurisdiction: NONE')"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"dev_exp_f = open(KLEISTER_PATH + 'dev-0/expected.tsv')\n",
"dev_exp = []\n",
"for line in dev_exp_f:\n",
" line_splitted = line.strip('\\n').split(' ')\n",
" found = False\n",
" for elem in line_splitted:\n",
" if 'jurisdiction=' in elem:\n",
" dev_exp.append('jurisdiction: ' + elem.split('=')[1])\n",
" found = True\n",
" break\n",
" if not found:\n",
" dev_exp.append('jurisdiction: NONE')"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['jurisdiction: Oregon',\n",
" 'jurisdiction: California',\n",
" 'jurisdiction: Florida',\n",
" 'jurisdiction: Pennsylvania',\n",
" 'jurisdiction: California',\n",
" 'jurisdiction: California',\n",
" 'jurisdiction: New_York',\n",
" 'jurisdiction: Delaware',\n",
" 'jurisdiction: Illinois',\n",
" 'jurisdiction: New_York',\n",
" 'jurisdiction: Delaware',\n",
" 'jurisdiction: Iowa',\n",
" 'jurisdiction: Delaware',\n",
" 'jurisdiction: Indiana',\n",
" 'jurisdiction: New_York',\n",
" 'jurisdiction: Massachusetts',\n",
" 'jurisdiction: New_York',\n",
" 'jurisdiction: Michigan',\n",
" 'jurisdiction: Indiana',\n",
" 'jurisdiction: Colorado',\n",
" 'jurisdiction: Georgia',\n",
" 'jurisdiction: New_York',\n",
" 'jurisdiction: Oregon',\n",
" 'jurisdiction: Pennsylvania',\n",
" 'jurisdiction: Delaware',\n",
" 'jurisdiction: Florida',\n",
" 'jurisdiction: Delaware',\n",
" 'jurisdiction: Illinois',\n",
" 'jurisdiction: Illinois',\n",
" 'jurisdiction: Delaware',\n",
" 'jurisdiction: California',\n",
" 'jurisdiction: Delaware',\n",
" 'jurisdiction: Missouri',\n",
" 'jurisdiction: Oregon',\n",
" 'jurisdiction: Delaware',\n",
" 'jurisdiction: Delaware',\n",
" 'jurisdiction: Connecticut',\n",
" 'jurisdiction: Nevada',\n",
" 'jurisdiction: New_York',\n",
" 'jurisdiction: Illinois',\n",
" 'jurisdiction: Idaho',\n",
" 'jurisdiction: Florida',\n",
" 'jurisdiction: Delaware',\n",
" 'jurisdiction: Delaware',\n",
" 'jurisdiction: Minnesota',\n",
" 'jurisdiction: Virginia',\n",
" 'jurisdiction: California',\n",
" 'jurisdiction: California',\n",
" 'jurisdiction: Nevada',\n",
" 'jurisdiction: New_York',\n",
" 'jurisdiction: Washington',\n",
" 'jurisdiction: New_York',\n",
" 'jurisdiction: Ohio',\n",
" 'jurisdiction: Nevada',\n",
" 'jurisdiction: Georgia',\n",
" 'jurisdiction: Massachusetts',\n",
" 'jurisdiction: Texas',\n",
" 'jurisdiction: New_York',\n",
" 'jurisdiction: New_York',\n",
" 'jurisdiction: Virginia',\n",
" 'jurisdiction: Wisconsin',\n",
" 'jurisdiction: Colorado',\n",
" 'jurisdiction: Oregon',\n",
" 'jurisdiction: Delaware',\n",
" 'jurisdiction: Ohio',\n",
" 'jurisdiction: Missouri',\n",
" 'jurisdiction: South_Dakota',\n",
" 'jurisdiction: New_York',\n",
" 'jurisdiction: Indiana',\n",
" 'jurisdiction: Minnesota',\n",
" 'jurisdiction: Maine',\n",
" 'jurisdiction: Missouri',\n",
" 'jurisdiction: Delaware',\n",
" 'jurisdiction: Illinois',\n",
" 'jurisdiction: Indiana',\n",
" 'jurisdiction: Massachusetts',\n",
" 'jurisdiction: Illinois',\n",
" 'jurisdiction: New_Jersey',\n",
" 'jurisdiction: California',\n",
" 'jurisdiction: California',\n",
" 'jurisdiction: Maine',\n",
" 'jurisdiction: North_Carolina',\n",
" 'jurisdiction: Missouri',\n",
" 'jurisdiction: Georgia',\n",
" 'jurisdiction: Missouri',\n",
" 'jurisdiction: New_York',\n",
" 'jurisdiction: Georgia',\n",
" 'jurisdiction: New_York',\n",
" 'jurisdiction: Kansas',\n",
" 'jurisdiction: California',\n",
" 'jurisdiction: Oregon',\n",
" 'jurisdiction: Delaware',\n",
" 'jurisdiction: Delaware',\n",
" 'jurisdiction: Connecticut',\n",
" 'jurisdiction: Utah',\n",
" 'jurisdiction: Texas',\n",
" 'jurisdiction: Delaware',\n",
" 'jurisdiction: Ohio',\n",
" 'jurisdiction: California',\n",
" 'jurisdiction: California',\n",
" 'jurisdiction: New_York',\n",
" 'jurisdiction: South_Carolina',\n",
" 'jurisdiction: Texas',\n",
" 'jurisdiction: New_York',\n",
" 'jurisdiction: New_Jersey',\n",
" 'jurisdiction: Georgia',\n",
" 'jurisdiction: Massachusetts',\n",
" 'jurisdiction: Texas',\n",
" 'jurisdiction: Delaware',\n",
" 'jurisdiction: New_York',\n",
" 'jurisdiction: Pennsylvania',\n",
" 'jurisdiction: Pennsylvania',\n",
" 'jurisdiction: Massachusetts',\n",
" 'jurisdiction: Delaware',\n",
" 'jurisdiction: Florida',\n",
" 'jurisdiction: California',\n",
" 'jurisdiction: New_York',\n",
" 'jurisdiction: Delaware',\n",
" 'jurisdiction: Oregon',\n",
" 'jurisdiction: North_Carolina',\n",
" 'jurisdiction: Delaware',\n",
" 'jurisdiction: New_York',\n",
" 'jurisdiction: Delaware',\n",
" 'jurisdiction: Delaware',\n",
" 'jurisdiction: New_York',\n",
" 'jurisdiction: Massachusetts',\n",
" 'jurisdiction: Massachusetts',\n",
" 'jurisdiction: New_York',\n",
" 'jurisdiction: Missouri',\n",
" 'jurisdiction: Virginia',\n",
" 'jurisdiction: California',\n",
" 'jurisdiction: New_York',\n",
" 'jurisdiction: California',\n",
" 'jurisdiction: Massachusetts',\n",
" 'jurisdiction: Wisconsin',\n",
" 'jurisdiction: Washington',\n",
" 'jurisdiction: New_York',\n",
" 'jurisdiction: California',\n",
" 'jurisdiction: Illinois',\n",
" 'jurisdiction: Delaware',\n",
" 'jurisdiction: Massachusetts',\n",
" 'jurisdiction: California',\n",
" 'jurisdiction: Ohio',\n",
" 'jurisdiction: Illinois',\n",
" 'jurisdiction: New_York',\n",
" 'jurisdiction: New_Jersey',\n",
" 'jurisdiction: Delaware',\n",
" 'jurisdiction: Massachusetts',\n",
" 'jurisdiction: Massachusetts',\n",
" 'jurisdiction: Utah',\n",
" 'jurisdiction: Washington',\n",
" 'jurisdiction: Texas',\n",
" 'jurisdiction: California',\n",
" 'jurisdiction: California',\n",
" 'jurisdiction: Colorado',\n",
" 'jurisdiction: Delaware',\n",
" 'jurisdiction: Ohio',\n",
" 'jurisdiction: Pennsylvania',\n",
" 'jurisdiction: New_Jersey',\n",
" 'jurisdiction: Virginia',\n",
" 'jurisdiction: New_York',\n",
" 'jurisdiction: Delaware',\n",
" 'jurisdiction: Nevada',\n",
" 'jurisdiction: New_York',\n",
" 'jurisdiction: Texas',\n",
" 'jurisdiction: California',\n",
" 'jurisdiction: New_York',\n",
" 'jurisdiction: New_York',\n",
" 'jurisdiction: California',\n",
" 'jurisdiction: New_Jersey',\n",
" 'jurisdiction: Missouri',\n",
" 'jurisdiction: Illinois',\n",
" 'jurisdiction: Texas',\n",
" 'jurisdiction: New_Jersey',\n",
" 'jurisdiction: New_York',\n",
" 'jurisdiction: New_York',\n",
" 'jurisdiction: Missouri',\n",
" 'jurisdiction: Delaware',\n",
" 'jurisdiction: Nevada',\n",
" 'jurisdiction: Florida',\n",
" 'jurisdiction: Kansas',\n",
" 'jurisdiction: Oregon',\n",
" 'jurisdiction: Delaware',\n",
" 'jurisdiction: New_York',\n",
" 'jurisdiction: New_York',\n",
" 'jurisdiction: Texas',\n",
" 'jurisdiction: New_Jersey',\n",
" 'jurisdiction: Florida',\n",
" 'jurisdiction: New_York',\n",
" 'jurisdiction: New_York',\n",
" 'jurisdiction: Delaware',\n",
" 'jurisdiction: Delaware',\n",
" 'jurisdiction: Oregon',\n",
" 'jurisdiction: Minnesota',\n",
" 'jurisdiction: Texas',\n",
" 'jurisdiction: California',\n",
" 'jurisdiction: Delaware',\n",
" 'jurisdiction: California',\n",
" 'jurisdiction: New_York',\n",
" 'jurisdiction: Delaware',\n",
" 'jurisdiction: Colorado',\n",
" 'jurisdiction: Pennsylvania',\n",
" 'jurisdiction: New_York',\n",
" 'jurisdiction: Indiana',\n",
" 'jurisdiction: Delaware',\n",
" 'jurisdiction: Pennsylvania',\n",
" 'jurisdiction: Massachusetts',\n",
" 'jurisdiction: Massachusetts',\n",
" 'jurisdiction: New_York',\n",
" 'jurisdiction: Ohio',\n",
" 'jurisdiction: Illinois',\n",
" 'jurisdiction: California',\n",
" 'jurisdiction: California',\n",
" 'jurisdiction: California',\n",
" 'jurisdiction: Oregon',\n",
" 'jurisdiction: Texas',\n",
" 'jurisdiction: Texas',\n",
" 'jurisdiction: Michigan',\n",
" 'jurisdiction: Delaware',\n",
" 'jurisdiction: California',\n",
" 'jurisdiction: Florida',\n",
" 'jurisdiction: California',\n",
" 'jurisdiction: Ohio',\n",
" 'jurisdiction: New_York',\n",
" 'jurisdiction: Massachusetts',\n",
" 'jurisdiction: Delaware',\n",
" 'jurisdiction: Georgia',\n",
" 'jurisdiction: Delaware',\n",
" 'jurisdiction: Massachusetts',\n",
" 'jurisdiction: Texas',\n",
" 'jurisdiction: New_York',\n",
" 'jurisdiction: Pennsylvania',\n",
" 'jurisdiction: Michigan',\n",
" 'jurisdiction: Washington',\n",
" 'jurisdiction: New_York',\n",
" 'jurisdiction: Missouri',\n",
" 'jurisdiction: California',\n",
" 'jurisdiction: California',\n",
" 'jurisdiction: California',\n",
" 'jurisdiction: Texas',\n",
" 'jurisdiction: Florida',\n",
" 'jurisdiction: Ohio',\n",
" 'jurisdiction: Delaware',\n",
" 'jurisdiction: New_York',\n",
" 'jurisdiction: New_York',\n",
" 'jurisdiction: Pennsylvania',\n",
" 'jurisdiction: New_York',\n",
" 'jurisdiction: Rhode_Island',\n",
" 'jurisdiction: California',\n",
" 'jurisdiction: Florida',\n",
" 'jurisdiction: New_York',\n",
" 'jurisdiction: Delaware',\n",
" 'jurisdiction: California',\n",
" 'jurisdiction: Delaware']"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train_exp"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"train_in_f = open(KLEISTER_PATH + 'train/in.tsv')\n",
"train_in = []\n",
"for line in train_in_f:\n",
" line = line.rstrip('\\n')\n",
" train_in.append(line)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"dev_in_f = open(KLEISTER_PATH + 'dev-0/in.tsv')\n",
"dev_in = []\n",
"for line in dev_in_f:\n",
" line = line.rstrip('\\n')\n",
" dev_in.append(line)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'00a1d238e37ac225b8045a97953e845d.pdf\\teffective_date jurisdiction party term\\tEX-10.23 5 dex1023.htm COVENANT NOT TO COMPETE AND NON-DISCLOSURE AGREEMENT\\\\nExhibit 10.23\\\\nCOVENANT NOT TO COMPETE\\\\nAND NON-DISCLOSURE AGREEMENT\\\\nPARTIES:\\\\nEric Dean Sprunk (“EMPLOYEE”)\\\\nand\\\\nNIKE, Inc., divisions, subsidiaries\\\\nand affiliates. (“NIKE”):\\\\nRECITALS:\\\\nA. This Covenant Not to Compete and Non-Disclosure Agreement is executed upon initial employment or upon the EMPLOYEEs\\\\nadvancement with NIKE and is a condition of such employment or advancement.\\\\nB. Over the course of EMPLOYEEs employment with NIKE, EMPLOYEE will be or has been exposed to and/or is in a position to\\\\ndevelop confidential information peculiar to NIKEs business and not generally known to the public as defined below (“Protected Information”). It is\\\\nanticipated that EMPLOYEE will continue to be exposed to Protected Information of greater sensitivity as EMPLOYEE advances in the company.\\\\nC. The nature of NIKEs business is highly competitive and disclosure of any Protected Information would result in severe damage to NIKE\\\\nand be difficult to measure.\\\\nD. NIKE makes use of its Protected Information throughout the world. Protected Information of NIKE can be used to NIKEs detriment\\\\nanywhere in the world.\\\\nAGREEMENT:\\\\nIn consideration of the foregoing, and the terms and conditions set forth below, the parties agree as follows:\\\\n1. Covenant Not to Compete.\\\\n(a) Competition Restriction. During EMPLOYEEs employment by NIKE, under the terms of any employment contract or\\\\notherwise, and for one year thereafter, (the “Restriction Period”), EMPLOYEE will not directly or indirectly, own, manage, control, or participate in\\\\nthe ownership,\\\\nmanagement or control of, or be employed by, consult for, or be connected in any manner with, any business engaged anywhere in the world in the\\\\nathletic footwear, athletic apparel or sports equipment and accessories business, or any other business which directly competes with NIKE or any of\\\\nits parent, subsidiaries or affiliated corporations ( “Competitor”). By way of illustration only, examples of NIKE competitors include, but are not\\\\nlimited to: Adidas, FILA, Reebok, Puma, Champion, Oakley, DKNY, Converse, Asics, Saucony, New Balance, Ralph Lauren/Polo Sport, B.U.M,\\\\nFUBU, The Gap, Tommy Hilfiger, Umbro, Northface, Venator (Foot lockers), Sports Authority, Columbia Sportswear, Wilson, Mizuno, Callaway\\\\nGolf and Titleist. This provision is subject to NIKEs option to waive all or any portion of the Restriction Period as more specifically provided\\\\nbelow.\\\\n(b) Extension of Time. In the event EMPLOYEE breaches this covenant not to compete, the Restriction Period shall automatically\\\\ntoll from the date of the first breach, and all subsequent breaches, until the resolution of the breach through private settlement, judicial or other\\\\naction, including all appeals. The Restriction Period shall continue upon the effective date of any such settlement judicial or other resolution. NIKE\\\\nshall not be obligated to pay EMPLOYEE the additional compensation described in paragraph 1(d) below during any period of time in which this\\\\nAgreement is tolled due to EMPLOYEEs breach. In the event EMPLOYEE receives such additional compensation after any such breach,\\\\nEMPLOYEE must immediately reimburse NIKE in the amount of all such compensation upon the receipt of a written request by NIKE.\\\\n(c) Waiver of Non-Compete. NIKE has the option, in its sole discretion, to elect to waive all or a portion of the Restriction Period or\\\\nto limit the definition of Competitor, by giving EMPLOYEE seven (7) days prior notice of such election. In the event all or a portion of the\\\\nRestriction Period is waived, NIKE shall not be obligated to pay EMPLOYEE for any period of time as to which the covenant not to compete has\\\\nbeen waived.\\\\n(d) Additional Consideration. As additional consideration for the cov
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train_in[0]"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"device(type='cuda', index=0)"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.device"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Token indices sequence length is longer than the specified maximum sequence length for this model (11717 > 512). Running this sequence through the model will result in indexing errors\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"and non-disclosure Agreement.n(a) Competition Restriction.\n"
]
}
],
"source": [
"input = train_in[0]\n",
"\n",
"# You can also use \"translate English to French\" and \"translate English to Romanian\"\n",
"input_ids = tokenizer(input, return_tensors=\"pt\").input_ids[:,:512].to('cuda') # Batch size 1\n",
"\n",
"outputs = model.generate(input_ids)\n",
"\n",
"decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)\n",
"\n",
"print(decoded)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"input_ids = tokenizer('translate English to German: The house is wonderful.', return_tensors='pt').input_ids.to('cuda')\n",
"labels = tokenizer('Das Haus ist wunderbar.', return_tensors='pt').input_ids.to('cuda')\n",
"# the forward function automatically creates the correct decoder_input_ids\n",
"loss = model(input_ids=input_ids, labels=labels).loss"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor(0.2543, device='cuda:0', grad_fn=<NllLossBackward>)"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"loss"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"from transformers import AdamW\n",
"\n",
"optimizer = AdamW(model.parameters(), lr=5e-5)"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"T5ForConditionalGeneration(\n",
" (shared): Embedding(32128, 512)\n",
" (encoder): T5Stack(\n",
" (embed_tokens): Embedding(32128, 512)\n",
" (block): ModuleList(\n",
" (0): T5Block(\n",
" (layer): ModuleList(\n",
" (0): T5LayerSelfAttention(\n",
" (SelfAttention): T5Attention(\n",
" (q): Linear(in_features=512, out_features=512, bias=False)\n",
" (k): Linear(in_features=512, out_features=512, bias=False)\n",
" (v): Linear(in_features=512, out_features=512, bias=False)\n",
" (o): Linear(in_features=512, out_features=512, bias=False)\n",
" (relative_attention_bias): Embedding(32, 8)\n",
" )\n",
" (layer_norm): T5LayerNorm()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (1): T5LayerFF(\n",
" (DenseReluDense): T5DenseReluDense(\n",
" (wi): Linear(in_features=512, out_features=2048, bias=False)\n",
" (wo): Linear(in_features=2048, out_features=512, bias=False)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (layer_norm): T5LayerNorm()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" )\n",
" (1): T5Block(\n",
" (layer): ModuleList(\n",
" (0): T5LayerSelfAttention(\n",
" (SelfAttention): T5Attention(\n",
" (q): Linear(in_features=512, out_features=512, bias=False)\n",
" (k): Linear(in_features=512, out_features=512, bias=False)\n",
" (v): Linear(in_features=512, out_features=512, bias=False)\n",
" (o): Linear(in_features=512, out_features=512, bias=False)\n",
" )\n",
" (layer_norm): T5LayerNorm()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (1): T5LayerFF(\n",
" (DenseReluDense): T5DenseReluDense(\n",
" (wi): Linear(in_features=512, out_features=2048, bias=False)\n",
" (wo): Linear(in_features=2048, out_features=512, bias=False)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (layer_norm): T5LayerNorm()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" )\n",
" (2): T5Block(\n",
" (layer): ModuleList(\n",
" (0): T5LayerSelfAttention(\n",
" (SelfAttention): T5Attention(\n",
" (q): Linear(in_features=512, out_features=512, bias=False)\n",
" (k): Linear(in_features=512, out_features=512, bias=False)\n",
" (v): Linear(in_features=512, out_features=512, bias=False)\n",
" (o): Linear(in_features=512, out_features=512, bias=False)\n",
" )\n",
" (layer_norm): T5LayerNorm()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (1): T5LayerFF(\n",
" (DenseReluDense): T5DenseReluDense(\n",
" (wi): Linear(in_features=512, out_features=2048, bias=False)\n",
" (wo): Linear(in_features=2048, out_features=512, bias=False)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (layer_norm): T5LayerNorm()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" )\n",
" (3): T5Block(\n",
" (layer): ModuleList(\n",
" (0): T5LayerSelfAttention(\n",
" (SelfAttention): T5Attention(\n",
" (q): Linear(in_features=512, out_features=512, bias=False)\n",
" (k): Linear(in_features=512, out_features=512, bias=False)\n",
" (v): Linear(in_features=512, out_features=512, bias=False)\n",
" (o): Linear(in_features=512, out_features=512, bias=False)\n",
" )\n",
" (layer_norm): T5LayerNorm()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (1): T5LayerFF(\n",
" (DenseReluDense): T5DenseReluDense(\n",
" (wi): Linear(in_features=512, out_features=2048, bias=False)\n",
" (wo): Linear(in_features=2048, out_features=512, bias=False)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (layer_norm): T5LayerNorm()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" )\n",
" (4): T5Block(\n",
" (layer): ModuleList(\n",
" (0): T5LayerSelfAttention(\n",
" (SelfAttention): T5Attention(\n",
" (q): Linear(in_features=512, out_features=512, bias=False)\n",
" (k): Linear(in_features=512, out_features=512, bias=False)\n",
" (v): Linear(in_features=512, out_features=512, bias=False)\n",
" (o): Linear(in_features=512, out_features=512, bias=False)\n",
" )\n",
" (layer_norm): T5LayerNorm()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (1): T5LayerFF(\n",
" (DenseReluDense): T5DenseReluDense(\n",
" (wi): Linear(in_features=512, out_features=2048, bias=False)\n",
" (wo): Linear(in_features=2048, out_features=512, bias=False)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (layer_norm): T5LayerNorm()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" )\n",
" (5): T5Block(\n",
" (layer): ModuleList(\n",
" (0): T5LayerSelfAttention(\n",
" (SelfAttention): T5Attention(\n",
" (q): Linear(in_features=512, out_features=512, bias=False)\n",
" (k): Linear(in_features=512, out_features=512, bias=False)\n",
" (v): Linear(in_features=512, out_features=512, bias=False)\n",
" (o): Linear(in_features=512, out_features=512, bias=False)\n",
" )\n",
" (layer_norm): T5LayerNorm()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (1): T5LayerFF(\n",
" (DenseReluDense): T5DenseReluDense(\n",
" (wi): Linear(in_features=512, out_features=2048, bias=False)\n",
" (wo): Linear(in_features=2048, out_features=512, bias=False)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (layer_norm): T5LayerNorm()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" )\n",
" )\n",
" (final_layer_norm): T5LayerNorm()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (decoder): T5Stack(\n",
" (embed_tokens): Embedding(32128, 512)\n",
" (block): ModuleList(\n",
" (0): T5Block(\n",
" (layer): ModuleList(\n",
" (0): T5LayerSelfAttention(\n",
" (SelfAttention): T5Attention(\n",
" (q): Linear(in_features=512, out_features=512, bias=False)\n",
" (k): Linear(in_features=512, out_features=512, bias=False)\n",
" (v): Linear(in_features=512, out_features=512, bias=False)\n",
" (o): Linear(in_features=512, out_features=512, bias=False)\n",
" (relative_attention_bias): Embedding(32, 8)\n",
" )\n",
" (layer_norm): T5LayerNorm()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (1): T5LayerCrossAttention(\n",
" (EncDecAttention): T5Attention(\n",
" (q): Linear(in_features=512, out_features=512, bias=False)\n",
" (k): Linear(in_features=512, out_features=512, bias=False)\n",
" (v): Linear(in_features=512, out_features=512, bias=False)\n",
" (o): Linear(in_features=512, out_features=512, bias=False)\n",
" )\n",
" (layer_norm): T5LayerNorm()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (2): T5LayerFF(\n",
" (DenseReluDense): T5DenseReluDense(\n",
" (wi): Linear(in_features=512, out_features=2048, bias=False)\n",
" (wo): Linear(in_features=2048, out_features=512, bias=False)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (layer_norm): T5LayerNorm()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" )\n",
" (1): T5Block(\n",
" (layer): ModuleList(\n",
" (0): T5LayerSelfAttention(\n",
" (SelfAttention): T5Attention(\n",
" (q): Linear(in_features=512, out_features=512, bias=False)\n",
" (k): Linear(in_features=512, out_features=512, bias=False)\n",
" (v): Linear(in_features=512, out_features=512, bias=False)\n",
" (o): Linear(in_features=512, out_features=512, bias=False)\n",
" )\n",
" (layer_norm): T5LayerNorm()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (1): T5LayerCrossAttention(\n",
" (EncDecAttention): T5Attention(\n",
" (q): Linear(in_features=512, out_features=512, bias=False)\n",
" (k): Linear(in_features=512, out_features=512, bias=False)\n",
" (v): Linear(in_features=512, out_features=512, bias=False)\n",
" (o): Linear(in_features=512, out_features=512, bias=False)\n",
" )\n",
" (layer_norm): T5LayerNorm()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (2): T5LayerFF(\n",
" (DenseReluDense): T5DenseReluDense(\n",
" (wi): Linear(in_features=512, out_features=2048, bias=False)\n",
" (wo): Linear(in_features=2048, out_features=512, bias=False)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (layer_norm): T5LayerNorm()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" )\n",
" (2): T5Block(\n",
" (layer): ModuleList(\n",
" (0): T5LayerSelfAttention(\n",
" (SelfAttention): T5Attention(\n",
" (q): Linear(in_features=512, out_features=512, bias=False)\n",
" (k): Linear(in_features=512, out_features=512, bias=False)\n",
" (v): Linear(in_features=512, out_features=512, bias=False)\n",
" (o): Linear(in_features=512, out_features=512, bias=False)\n",
" )\n",
" (layer_norm): T5LayerNorm()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (1): T5LayerCrossAttention(\n",
" (EncDecAttention): T5Attention(\n",
" (q): Linear(in_features=512, out_features=512, bias=False)\n",
" (k): Linear(in_features=512, out_features=512, bias=False)\n",
" (v): Linear(in_features=512, out_features=512, bias=False)\n",
" (o): Linear(in_features=512, out_features=512, bias=False)\n",
" )\n",
" (layer_norm): T5LayerNorm()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (2): T5LayerFF(\n",
" (DenseReluDense): T5DenseReluDense(\n",
" (wi): Linear(in_features=512, out_features=2048, bias=False)\n",
" (wo): Linear(in_features=2048, out_features=512, bias=False)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (layer_norm): T5LayerNorm()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" )\n",
" (3): T5Block(\n",
" (layer): ModuleList(\n",
" (0): T5LayerSelfAttention(\n",
" (SelfAttention): T5Attention(\n",
" (q): Linear(in_features=512, out_features=512, bias=False)\n",
" (k): Linear(in_features=512, out_features=512, bias=False)\n",
" (v): Linear(in_features=512, out_features=512, bias=False)\n",
" (o): Linear(in_features=512, out_features=512, bias=False)\n",
" )\n",
" (layer_norm): T5LayerNorm()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (1): T5LayerCrossAttention(\n",
" (EncDecAttention): T5Attention(\n",
" (q): Linear(in_features=512, out_features=512, bias=False)\n",
" (k): Linear(in_features=512, out_features=512, bias=False)\n",
" (v): Linear(in_features=512, out_features=512, bias=False)\n",
" (o): Linear(in_features=512, out_features=512, bias=False)\n",
" )\n",
" (layer_norm): T5LayerNorm()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (2): T5LayerFF(\n",
" (DenseReluDense): T5DenseReluDense(\n",
" (wi): Linear(in_features=512, out_features=2048, bias=False)\n",
" (wo): Linear(in_features=2048, out_features=512, bias=False)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (layer_norm): T5LayerNorm()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" )\n",
" (4): T5Block(\n",
" (layer): ModuleList(\n",
" (0): T5LayerSelfAttention(\n",
" (SelfAttention): T5Attention(\n",
" (q): Linear(in_features=512, out_features=512, bias=False)\n",
" (k): Linear(in_features=512, out_features=512, bias=False)\n",
" (v): Linear(in_features=512, out_features=512, bias=False)\n",
" (o): Linear(in_features=512, out_features=512, bias=False)\n",
" )\n",
" (layer_norm): T5LayerNorm()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (1): T5LayerCrossAttention(\n",
" (EncDecAttention): T5Attention(\n",
" (q): Linear(in_features=512, out_features=512, bias=False)\n",
" (k): Linear(in_features=512, out_features=512, bias=False)\n",
" (v): Linear(in_features=512, out_features=512, bias=False)\n",
" (o): Linear(in_features=512, out_features=512, bias=False)\n",
" )\n",
" (layer_norm): T5LayerNorm()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (2): T5LayerFF(\n",
" (DenseReluDense): T5DenseReluDense(\n",
" (wi): Linear(in_features=512, out_features=2048, bias=False)\n",
" (wo): Linear(in_features=2048, out_features=512, bias=False)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (layer_norm): T5LayerNorm()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" )\n",
" (5): T5Block(\n",
" (layer): ModuleList(\n",
" (0): T5LayerSelfAttention(\n",
" (SelfAttention): T5Attention(\n",
" (q): Linear(in_features=512, out_features=512, bias=False)\n",
" (k): Linear(in_features=512, out_features=512, bias=False)\n",
" (v): Linear(in_features=512, out_features=512, bias=False)\n",
" (o): Linear(in_features=512, out_features=512, bias=False)\n",
" )\n",
" (layer_norm): T5LayerNorm()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (1): T5LayerCrossAttention(\n",
" (EncDecAttention): T5Attention(\n",
" (q): Linear(in_features=512, out_features=512, bias=False)\n",
" (k): Linear(in_features=512, out_features=512, bias=False)\n",
" (v): Linear(in_features=512, out_features=512, bias=False)\n",
" (o): Linear(in_features=512, out_features=512, bias=False)\n",
" )\n",
" (layer_norm): T5LayerNorm()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (2): T5LayerFF(\n",
" (DenseReluDense): T5DenseReluDense(\n",
" (wi): Linear(in_features=512, out_features=2048, bias=False)\n",
" (wo): Linear(in_features=2048, out_features=512, bias=False)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (layer_norm): T5LayerNorm()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" )\n",
" )\n",
" (final_layer_norm): T5LayerNorm()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (lm_head): Linear(in_features=512, out_features=32128, bias=False)\n",
")"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.train()"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"13.828309059143066\n",
"11.455500602722168\n",
"12.591864585876465\n",
"11.697681427001953\n",
"9.457676887512207\n",
"10.367218017578125\n",
"7.407022953033447\n",
"8.830719947814941\n",
"10.031709671020508\n",
"6.843804359436035\n",
"9.030264854431152\n",
"8.841073989868164\n",
"9.884418487548828\n",
"8.1090087890625\n",
"5.866975784301758\n",
"8.52608585357666\n",
"5.992447853088379\n",
"7.147337436676025\n",
"6.601171970367432\n",
"8.028266906738281\n",
"6.183577060699463\n",
"5.559406280517578\n",
"6.755654335021973\n",
"5.919793128967285\n",
"5.167813301086426\n",
"5.351068496704102\n",
"5.7952165603637695\n",
"6.730508804321289\n",
"5.469816207885742\n",
"4.3772478103637695\n",
"4.868475914001465\n",
"5.726585865020752\n",
"3.966099739074707\n",
"5.961289405822754\n",
"5.155783653259277\n",
"4.634646892547607\n",
"4.736303806304932\n",
"4.152906894683838\n",
"4.373996257781982\n",
"4.358081340789795\n",
"4.958395957946777\n",
"3.8232321739196777\n",
"4.142550945281982\n",
"2.666247606277466\n",
"4.235062122344971\n",
"4.233397483825684\n",
"3.8168039321899414\n",
"3.1151959896087646\n",
"1.9562475681304932\n",
"3.445767641067505\n",
"4.4933247566223145\n",
"3.4922804832458496\n",
"2.250882625579834\n",
"2.4218058586120605\n",
"2.260007858276367\n",
"2.5280778408050537\n",
"2.7701780796051025\n",
"3.8142340183258057\n",
"3.0554733276367188\n",
"1.8644142150878906\n",
"3.2941484451293945\n",
"2.286688804626465\n",
"3.366548538208008\n",
"1.0562607049942017\n",
"1.8493285179138184\n",
"2.8790605068206787\n",
"4.513855934143066\n",
"2.9482157230377197\n",
"2.0251893997192383\n",
"1.5018310546875\n",
"1.8084921836853027\n",
"1.7678613662719727\n",
"1.0362716913223267\n",
"1.6407744884490967\n",
"1.2443599700927734\n",
"2.2683565616607666\n",
"1.4040197134017944\n",
"3.9230520725250244\n",
"0.8626512289047241\n",
"0.7241716384887695\n",
"0.8391153812408447\n",
"3.9508471488952637\n",
"1.4111053943634033\n",
"1.333533525466919\n",
"0.38448166847229004\n",
"2.132805109024048\n",
"1.7784374952316284\n",
"2.150501251220703\n",
"2.3192851543426514\n",
"1.4407600164413452\n",
"1.4160407781600952\n",
"0.5990514159202576\n",
"1.2548216581344604\n",
"1.1115673780441284\n",
"1.957241177558899\n",
"1.2597360610961914\n",
"1.0772262811660767\n",
"1.1419639587402344\n",
"0.30694711208343506\n",
"2.0387325286865234\n",
"2.2052383422851562\n",
"4.552682399749756\n",
"1.1284838914871216\n",
"1.628050446510315\n",
"2.827632188796997\n",
"1.256350040435791\n",
"1.5137629508972168\n",
"0.17800401151180267\n",
"1.1130807399749756\n",
"1.4471491575241089\n",
"1.4046872854232788\n",
"1.5159196853637695\n",
"1.5683913230895996\n",
"0.9050359725952148\n",
"0.2453073114156723\n",
"0.829986572265625\n",
"1.342026948928833\n",
"0.697879433631897\n",
"0.8360342383384705\n",
"3.773777723312378\n",
"1.0000628232955933\n",
"1.163111925125122\n",
"0.636287271976471\n",
"0.6960057616233826\n",
"1.2984236478805542\n",
"1.4369347095489502\n",
"1.2260591983795166\n",
"1.1619309186935425\n",
"1.2387232780456543\n",
"0.4039798974990845\n",
"1.261201024055481\n",
"2.0990383625030518\n",
"0.6930045485496521\n",
"1.9684548377990723\n",
"0.41637909412384033\n",
"1.5580865144729614\n",
"0.935876727104187\n",
"0.5318026542663574\n",
"1.207798719406128\n",
"0.5434905290603638\n",
"0.10893465578556061\n",
"0.8033742904663086\n",
"0.25061750411987305\n",
"0.9297510981559753\n",
"1.1515181064605713\n",
"2.179370641708374\n",
"0.912304699420929\n",
"0.9962441325187683\n",
"1.3243765830993652\n",
"1.5690778493881226\n",
"1.0356395244598389\n",
"1.3098541498184204\n",
"0.2543454170227051\n",
"0.7984715104103088\n",
"0.10885466635227203\n",
"1.5388046503067017\n",
"1.3934229612350464\n",
"1.0405352115631104\n",
"1.744563341140747\n",
"0.9149143695831299\n",
"0.4559175670146942\n",
"0.7720739841461182\n",
"1.6526525020599365\n",
"0.5373530387878418\n",
"0.5430313348770142\n",
"0.5173842310905457\n",
"0.7213934659957886\n",
"0.6729367971420288\n",
"0.8275019526481628\n",
"1.3139863014221191\n",
"1.1809828281402588\n",
"1.423504114151001\n",
"0.4956137537956238\n",
"1.2472567558288574\n",
"0.3318641185760498\n",
"0.3209134638309479\n",
"0.09695105999708176\n",
"0.6424573063850403\n",
"1.224516749382019\n",
"0.13458161056041718\n",
"1.1670427322387695\n",
"1.1272934675216675\n",
"1.0477215051651\n",
"0.7291663289070129\n",
"0.6467929482460022\n",
"0.924201488494873\n",
"1.455331563949585\n",
"0.6269064545631409\n",
"0.7512378692626953\n",
"0.5907666087150574\n",
"0.8808064460754395\n",
"0.5326775312423706\n",
"0.4754364490509033\n",
"0.5422216653823853\n",
"0.9144468307495117\n",
"0.6809101700782776\n",
"0.1790292114019394\n",
"0.7104746103286743\n",
"0.41490861773490906\n",
"1.4695433378219604\n",
"1.381641149520874\n",
"0.34390121698379517\n",
"0.5615295171737671\n",
"0.4991306960582733\n",
"1.755591630935669\n",
"0.02876635640859604\n",
"0.06847237050533295\n",
"1.4051387310028076\n",
"0.3321903944015503\n",
"0.5550190210342407\n",
"0.8398134708404541\n",
"0.6281668543815613\n",
"0.7955247759819031\n",
"0.4672299921512604\n",
"1.0951168537139893\n",
"0.6541656255722046\n",
"0.8140543699264526\n",
"0.043958500027656555\n",
"0.04899679496884346\n",
"0.8996919989585876\n",
"0.275490403175354\n",
"0.2666592597961426\n",
"0.09318255633115768\n",
"0.3718479871749878\n",
"1.495982050895691\n",
"0.0595063678920269\n",
"1.7708230018615723\n",
"0.7092909216880798\n",
"0.9086990356445312\n",
"0.010129873640835285\n",
"0.7636302709579468\n",
"1.0733331441879272\n",
"0.060608845204114914\n",
"1.3388985395431519\n",
"0.4673462510108948\n",
"0.21733486652374268\n",
"0.5459968447685242\n",
"0.050972938537597656\n",
"0.4641537666320801\n",
"0.7601963877677917\n",
"0.44411876797676086\n",
"0.09443528205156326\n",
"1.623687982559204\n",
"0.5162641406059265\n",
"0.6031121611595154\n",
"0.8987085223197937\n",
"0.3393983840942383\n",
"2.8573479652404785\n",
"0.8427947759628296\n",
"1.0764878988265991\n",
"0.4185052812099457\n",
"0.6308793425559998\n",
"0.01906685344874859\n",
"0.141354501247406\n"
]
}
],
"source": [
"for line_in, line_exp in zip(train_in, train_exp):\n",
" input_ids = tokenizer(line_in, return_tensors='pt').input_ids[:,:512].to('cuda')\n",
" labels = tokenizer(line_exp, return_tensors='pt').input_ids.to('cuda')\n",
" # the forward function automatically creates the correct decoder_input_ids\n",
" loss = model(input_ids=input_ids, labels=labels).loss\n",
" loss.backward()\n",
" optimizer.step()\n",
" optimizer.zero_grad()\n",
" print(loss.item())"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"T5ForConditionalGeneration(\n",
" (shared): Embedding(32128, 512)\n",
" (encoder): T5Stack(\n",
" (embed_tokens): Embedding(32128, 512)\n",
" (block): ModuleList(\n",
" (0): T5Block(\n",
" (layer): ModuleList(\n",
" (0): T5LayerSelfAttention(\n",
" (SelfAttention): T5Attention(\n",
" (q): Linear(in_features=512, out_features=512, bias=False)\n",
" (k): Linear(in_features=512, out_features=512, bias=False)\n",
" (v): Linear(in_features=512, out_features=512, bias=False)\n",
" (o): Linear(in_features=512, out_features=512, bias=False)\n",
" (relative_attention_bias): Embedding(32, 8)\n",
" )\n",
" (layer_norm): T5LayerNorm()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (1): T5LayerFF(\n",
" (DenseReluDense): T5DenseReluDense(\n",
" (wi): Linear(in_features=512, out_features=2048, bias=False)\n",
" (wo): Linear(in_features=2048, out_features=512, bias=False)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (layer_norm): T5LayerNorm()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" )\n",
" (1): T5Block(\n",
" (layer): ModuleList(\n",
" (0): T5LayerSelfAttention(\n",
" (SelfAttention): T5Attention(\n",
" (q): Linear(in_features=512, out_features=512, bias=False)\n",
" (k): Linear(in_features=512, out_features=512, bias=False)\n",
" (v): Linear(in_features=512, out_features=512, bias=False)\n",
" (o): Linear(in_features=512, out_features=512, bias=False)\n",
" )\n",
" (layer_norm): T5LayerNorm()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (1): T5LayerFF(\n",
" (DenseReluDense): T5DenseReluDense(\n",
" (wi): Linear(in_features=512, out_features=2048, bias=False)\n",
" (wo): Linear(in_features=2048, out_features=512, bias=False)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (layer_norm): T5LayerNorm()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" )\n",
" (2): T5Block(\n",
" (layer): ModuleList(\n",
" (0): T5LayerSelfAttention(\n",
" (SelfAttention): T5Attention(\n",
" (q): Linear(in_features=512, out_features=512, bias=False)\n",
" (k): Linear(in_features=512, out_features=512, bias=False)\n",
" (v): Linear(in_features=512, out_features=512, bias=False)\n",
" (o): Linear(in_features=512, out_features=512, bias=False)\n",
" )\n",
" (layer_norm): T5LayerNorm()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (1): T5LayerFF(\n",
" (DenseReluDense): T5DenseReluDense(\n",
" (wi): Linear(in_features=512, out_features=2048, bias=False)\n",
" (wo): Linear(in_features=2048, out_features=512, bias=False)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (layer_norm): T5LayerNorm()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" )\n",
" (3): T5Block(\n",
" (layer): ModuleList(\n",
" (0): T5LayerSelfAttention(\n",
" (SelfAttention): T5Attention(\n",
" (q): Linear(in_features=512, out_features=512, bias=False)\n",
" (k): Linear(in_features=512, out_features=512, bias=False)\n",
" (v): Linear(in_features=512, out_features=512, bias=False)\n",
" (o): Linear(in_features=512, out_features=512, bias=False)\n",
" )\n",
" (layer_norm): T5LayerNorm()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (1): T5LayerFF(\n",
" (DenseReluDense): T5DenseReluDense(\n",
" (wi): Linear(in_features=512, out_features=2048, bias=False)\n",
" (wo): Linear(in_features=2048, out_features=512, bias=False)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (layer_norm): T5LayerNorm()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" )\n",
" (4): T5Block(\n",
" (layer): ModuleList(\n",
" (0): T5LayerSelfAttention(\n",
" (SelfAttention): T5Attention(\n",
" (q): Linear(in_features=512, out_features=512, bias=False)\n",
" (k): Linear(in_features=512, out_features=512, bias=False)\n",
" (v): Linear(in_features=512, out_features=512, bias=False)\n",
" (o): Linear(in_features=512, out_features=512, bias=False)\n",
" )\n",
" (layer_norm): T5LayerNorm()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (1): T5LayerFF(\n",
" (DenseReluDense): T5DenseReluDense(\n",
" (wi): Linear(in_features=512, out_features=2048, bias=False)\n",
" (wo): Linear(in_features=2048, out_features=512, bias=False)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (layer_norm): T5LayerNorm()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" )\n",
" (5): T5Block(\n",
" (layer): ModuleList(\n",
" (0): T5LayerSelfAttention(\n",
" (SelfAttention): T5Attention(\n",
" (q): Linear(in_features=512, out_features=512, bias=False)\n",
" (k): Linear(in_features=512, out_features=512, bias=False)\n",
" (v): Linear(in_features=512, out_features=512, bias=False)\n",
" (o): Linear(in_features=512, out_features=512, bias=False)\n",
" )\n",
" (layer_norm): T5LayerNorm()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (1): T5LayerFF(\n",
" (DenseReluDense): T5DenseReluDense(\n",
" (wi): Linear(in_features=512, out_features=2048, bias=False)\n",
" (wo): Linear(in_features=2048, out_features=512, bias=False)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (layer_norm): T5LayerNorm()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" )\n",
" )\n",
" (final_layer_norm): T5LayerNorm()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (decoder): T5Stack(\n",
" (embed_tokens): Embedding(32128, 512)\n",
" (block): ModuleList(\n",
" (0): T5Block(\n",
" (layer): ModuleList(\n",
" (0): T5LayerSelfAttention(\n",
" (SelfAttention): T5Attention(\n",
" (q): Linear(in_features=512, out_features=512, bias=False)\n",
" (k): Linear(in_features=512, out_features=512, bias=False)\n",
" (v): Linear(in_features=512, out_features=512, bias=False)\n",
" (o): Linear(in_features=512, out_features=512, bias=False)\n",
" (relative_attention_bias): Embedding(32, 8)\n",
" )\n",
" (layer_norm): T5LayerNorm()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (1): T5LayerCrossAttention(\n",
" (EncDecAttention): T5Attention(\n",
" (q): Linear(in_features=512, out_features=512, bias=False)\n",
" (k): Linear(in_features=512, out_features=512, bias=False)\n",
" (v): Linear(in_features=512, out_features=512, bias=False)\n",
" (o): Linear(in_features=512, out_features=512, bias=False)\n",
" )\n",
" (layer_norm): T5LayerNorm()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (2): T5LayerFF(\n",
" (DenseReluDense): T5DenseReluDense(\n",
" (wi): Linear(in_features=512, out_features=2048, bias=False)\n",
" (wo): Linear(in_features=2048, out_features=512, bias=False)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (layer_norm): T5LayerNorm()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" )\n",
" (1): T5Block(\n",
" (layer): ModuleList(\n",
" (0): T5LayerSelfAttention(\n",
" (SelfAttention): T5Attention(\n",
" (q): Linear(in_features=512, out_features=512, bias=False)\n",
" (k): Linear(in_features=512, out_features=512, bias=False)\n",
" (v): Linear(in_features=512, out_features=512, bias=False)\n",
" (o): Linear(in_features=512, out_features=512, bias=False)\n",
" )\n",
" (layer_norm): T5LayerNorm()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (1): T5LayerCrossAttention(\n",
" (EncDecAttention): T5Attention(\n",
" (q): Linear(in_features=512, out_features=512, bias=False)\n",
" (k): Linear(in_features=512, out_features=512, bias=False)\n",
" (v): Linear(in_features=512, out_features=512, bias=False)\n",
" (o): Linear(in_features=512, out_features=512, bias=False)\n",
" )\n",
" (layer_norm): T5LayerNorm()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (2): T5LayerFF(\n",
" (DenseReluDense): T5DenseReluDense(\n",
" (wi): Linear(in_features=512, out_features=2048, bias=False)\n",
" (wo): Linear(in_features=2048, out_features=512, bias=False)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (layer_norm): T5LayerNorm()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" )\n",
" (2): T5Block(\n",
" (layer): ModuleList(\n",
" (0): T5LayerSelfAttention(\n",
" (SelfAttention): T5Attention(\n",
" (q): Linear(in_features=512, out_features=512, bias=False)\n",
" (k): Linear(in_features=512, out_features=512, bias=False)\n",
" (v): Linear(in_features=512, out_features=512, bias=False)\n",
" (o): Linear(in_features=512, out_features=512, bias=False)\n",
" )\n",
" (layer_norm): T5LayerNorm()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (1): T5LayerCrossAttention(\n",
" (EncDecAttention): T5Attention(\n",
" (q): Linear(in_features=512, out_features=512, bias=False)\n",
" (k): Linear(in_features=512, out_features=512, bias=False)\n",
" (v): Linear(in_features=512, out_features=512, bias=False)\n",
" (o): Linear(in_features=512, out_features=512, bias=False)\n",
" )\n",
" (layer_norm): T5LayerNorm()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (2): T5LayerFF(\n",
" (DenseReluDense): T5DenseReluDense(\n",
" (wi): Linear(in_features=512, out_features=2048, bias=False)\n",
" (wo): Linear(in_features=2048, out_features=512, bias=False)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (layer_norm): T5LayerNorm()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" )\n",
" (3): T5Block(\n",
" (layer): ModuleList(\n",
" (0): T5LayerSelfAttention(\n",
" (SelfAttention): T5Attention(\n",
" (q): Linear(in_features=512, out_features=512, bias=False)\n",
" (k): Linear(in_features=512, out_features=512, bias=False)\n",
" (v): Linear(in_features=512, out_features=512, bias=False)\n",
" (o): Linear(in_features=512, out_features=512, bias=False)\n",
" )\n",
" (layer_norm): T5LayerNorm()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (1): T5LayerCrossAttention(\n",
" (EncDecAttention): T5Attention(\n",
" (q): Linear(in_features=512, out_features=512, bias=False)\n",
" (k): Linear(in_features=512, out_features=512, bias=False)\n",
" (v): Linear(in_features=512, out_features=512, bias=False)\n",
" (o): Linear(in_features=512, out_features=512, bias=False)\n",
" )\n",
" (layer_norm): T5LayerNorm()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (2): T5LayerFF(\n",
" (DenseReluDense): T5DenseReluDense(\n",
" (wi): Linear(in_features=512, out_features=2048, bias=False)\n",
" (wo): Linear(in_features=2048, out_features=512, bias=False)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (layer_norm): T5LayerNorm()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" )\n",
" (4): T5Block(\n",
" (layer): ModuleList(\n",
" (0): T5LayerSelfAttention(\n",
" (SelfAttention): T5Attention(\n",
" (q): Linear(in_features=512, out_features=512, bias=False)\n",
" (k): Linear(in_features=512, out_features=512, bias=False)\n",
" (v): Linear(in_features=512, out_features=512, bias=False)\n",
" (o): Linear(in_features=512, out_features=512, bias=False)\n",
" )\n",
" (layer_norm): T5LayerNorm()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (1): T5LayerCrossAttention(\n",
" (EncDecAttention): T5Attention(\n",
" (q): Linear(in_features=512, out_features=512, bias=False)\n",
" (k): Linear(in_features=512, out_features=512, bias=False)\n",
" (v): Linear(in_features=512, out_features=512, bias=False)\n",
" (o): Linear(in_features=512, out_features=512, bias=False)\n",
" )\n",
" (layer_norm): T5LayerNorm()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (2): T5LayerFF(\n",
" (DenseReluDense): T5DenseReluDense(\n",
" (wi): Linear(in_features=512, out_features=2048, bias=False)\n",
" (wo): Linear(in_features=2048, out_features=512, bias=False)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (layer_norm): T5LayerNorm()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" )\n",
" (5): T5Block(\n",
" (layer): ModuleList(\n",
" (0): T5LayerSelfAttention(\n",
" (SelfAttention): T5Attention(\n",
" (q): Linear(in_features=512, out_features=512, bias=False)\n",
" (k): Linear(in_features=512, out_features=512, bias=False)\n",
" (v): Linear(in_features=512, out_features=512, bias=False)\n",
" (o): Linear(in_features=512, out_features=512, bias=False)\n",
" )\n",
" (layer_norm): T5LayerNorm()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (1): T5LayerCrossAttention(\n",
" (EncDecAttention): T5Attention(\n",
" (q): Linear(in_features=512, out_features=512, bias=False)\n",
" (k): Linear(in_features=512, out_features=512, bias=False)\n",
" (v): Linear(in_features=512, out_features=512, bias=False)\n",
" (o): Linear(in_features=512, out_features=512, bias=False)\n",
" )\n",
" (layer_norm): T5LayerNorm()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (2): T5LayerFF(\n",
" (DenseReluDense): T5DenseReluDense(\n",
" (wi): Linear(in_features=512, out_features=2048, bias=False)\n",
" (wo): Linear(in_features=2048, out_features=512, bias=False)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (layer_norm): T5LayerNorm()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" )\n",
" )\n",
" (final_layer_norm): T5LayerNorm()\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (lm_head): Linear(in_features=512, out_features=32128, bias=False)\n",
")"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.eval()\n"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"jurisdiction: Colorado\n"
]
}
],
"source": [
"input = dev_in[0]\n",
"\n",
"input_ids = tokenizer(input, return_tensors=\"pt\").input_ids[:,:512].to('cuda') # Batch size 1\n",
"\n",
"outputs = model.generate(input_ids)\n",
"\n",
"decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)\n",
"\n",
"print(decoded)"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {
"scrolled": true
},
"outputs": [
{
"data": {
"text/plain": [
"'jurisdiction: New_York'"
]
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dev_exp[0]"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"jurisdiction: Delaware\n"
]
}
],
"source": [
"input = dev_in[2]\n",
"\n",
"input_ids = tokenizer(input, return_tensors=\"pt\").input_ids[:,:512].to('cuda') # Batch size 1\n",
"\n",
"outputs = model.generate(input_ids)\n",
"\n",
"decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)\n",
"\n",
"print(decoded)"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'jurisdiction: Delaware'"
]
},
"execution_count": 24,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dev_exp[2]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## pytanie:\n",
"- co można poprawić w istniejącym rozwiązaniu?"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
2021-07-12 12:44:24 +02:00
"version": "3.8.3"
2021-06-23 10:01:55 +02:00
}
},
"nbformat": 4,
"nbformat_minor": 4
}