aitech-eks-pub/cw/13_transformery2_ODPOWIEDZI...

2435 lines
19 MiB
Plaintext
Raw Normal View History

2021-06-16 15:14:42 +02:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Wizualizacja atencji\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"https://github.com/jessevig/bertviz"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Requirement already satisfied: bertviz in /home/kuba/anaconda3/lib/python3.8/site-packages (1.1.0)\n",
"Requirement already satisfied: boto3 in /home/kuba/anaconda3/lib/python3.8/site-packages (from bertviz) (1.17.93)\n",
"Requirement already satisfied: requests in /home/kuba/anaconda3/lib/python3.8/site-packages (from bertviz) (2.24.0)\n",
"Requirement already satisfied: torch>=1.0 in /home/kuba/anaconda3/lib/python3.8/site-packages (from bertviz) (1.8.1)\n",
"Requirement already satisfied: sentencepiece in /home/kuba/anaconda3/lib/python3.8/site-packages (from bertviz) (0.1.95)\n",
"Requirement already satisfied: tqdm in /home/kuba/anaconda3/lib/python3.8/site-packages (from bertviz) (4.47.0)\n",
"Requirement already satisfied: transformers>=2.0 in /home/kuba/anaconda3/lib/python3.8/site-packages (from bertviz) (4.2.2)\n",
"Requirement already satisfied: regex in /home/kuba/anaconda3/lib/python3.8/site-packages (from bertviz) (2020.6.8)\n",
"Requirement already satisfied: botocore<1.21.0,>=1.20.93 in /home/kuba/anaconda3/lib/python3.8/site-packages (from boto3->bertviz) (1.20.93)\n",
"Requirement already satisfied: jmespath<1.0.0,>=0.7.1 in /home/kuba/anaconda3/lib/python3.8/site-packages (from boto3->bertviz) (0.10.0)\n",
"Requirement already satisfied: s3transfer<0.5.0,>=0.4.0 in /home/kuba/anaconda3/lib/python3.8/site-packages (from boto3->bertviz) (0.4.2)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /home/kuba/anaconda3/lib/python3.8/site-packages (from requests->bertviz) (2020.6.20)\n",
"Requirement already satisfied: idna<3,>=2.5 in /home/kuba/anaconda3/lib/python3.8/site-packages (from requests->bertviz) (2.10)\n",
"Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /home/kuba/anaconda3/lib/python3.8/site-packages (from requests->bertviz) (1.25.9)\n",
"Requirement already satisfied: chardet<4,>=3.0.2 in /home/kuba/anaconda3/lib/python3.8/site-packages (from requests->bertviz) (3.0.4)\n",
"Requirement already satisfied: typing-extensions in /home/kuba/anaconda3/lib/python3.8/site-packages (from torch>=1.0->bertviz) (3.7.4.2)\n",
"Requirement already satisfied: numpy in /home/kuba/anaconda3/lib/python3.8/site-packages (from torch>=1.0->bertviz) (1.18.5)\n",
"Requirement already satisfied: sacremoses in /home/kuba/anaconda3/lib/python3.8/site-packages (from transformers>=2.0->bertviz) (0.0.43)\n",
"Requirement already satisfied: tokenizers==0.9.4 in /home/kuba/anaconda3/lib/python3.8/site-packages (from transformers>=2.0->bertviz) (0.9.4)\n",
"Requirement already satisfied: packaging in /home/kuba/anaconda3/lib/python3.8/site-packages (from transformers>=2.0->bertviz) (20.4)\n",
"Requirement already satisfied: filelock in /home/kuba/anaconda3/lib/python3.8/site-packages (from transformers>=2.0->bertviz) (3.0.12)\n",
"Requirement already satisfied: python-dateutil<3.0.0,>=2.1 in /home/kuba/anaconda3/lib/python3.8/site-packages (from botocore<1.21.0,>=1.20.93->boto3->bertviz) (2.8.1)\n",
"Requirement already satisfied: six in /home/kuba/anaconda3/lib/python3.8/site-packages (from sacremoses->transformers>=2.0->bertviz) (1.15.0)\n",
"Requirement already satisfied: joblib in /home/kuba/anaconda3/lib/python3.8/site-packages (from sacremoses->transformers>=2.0->bertviz) (0.16.0)\n",
"Requirement already satisfied: click in /home/kuba/anaconda3/lib/python3.8/site-packages (from sacremoses->transformers>=2.0->bertviz) (7.1.2)\n",
"Requirement already satisfied: pyparsing>=2.0.2 in /home/kuba/anaconda3/lib/python3.8/site-packages (from packaging->transformers>=2.0->bertviz) (2.4.7)\n"
]
}
],
"source": [
"!pip install bertviz"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"from transformers import AutoTokenizer, AutoModel\n",
"from bertviz import model_view, head_view"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"TEXT = \"This is a sample input sentence for a transformer model\""
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"MODEL = \"distilbert-base-uncased\""
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"tokenizer = AutoTokenizer.from_pretrained(MODEL)\n",
"model = AutoModel.from_pretrained(MODEL, output_attentions=True)\n",
"inputs = tokenizer.encode(TEXT, return_tensors='pt')\n",
"outputs = model(inputs)\n",
"attention = outputs[-1]\n",
"tokens = tokenizer.convert_ids_to_tokens(inputs[0]) \n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### SELF ATTENTION MODELS"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<script src=\"https://cdnjs.cloudflare.com/ajax/libs/require.js/2.3.6/require.min.js\"></script>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
" \n",
" <div id='bertviz-588b8377843e46ceb5981df7aba63167'>\n",
" <span style=\"user-select:none\">\n",
" Layer: <select id=\"layer\"></select>\n",
" \n",
" </span>\n",
" <div id='vis'></div>\n",
" </div>\n",
" "
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/javascript": [
"/**\n",
" * @fileoverview Transformer Visualization D3 javascript code.\n",
" *\n",
" *\n",
" * Based on: https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/visualization/attention.js\n",
" *\n",
" * Change log:\n",
" *\n",
" * 12/19/18 Jesse Vig Assorted cleanup. Changed orientation of attention matrices.\n",
" * 12/29/20 Jesse Vig Significant refactor.\n",
" * 12/31/20 Jesse Vig Support multiple visualizations in single notebook.\n",
" * 02/06/21 Jesse Vig Move require config from separate jupyter notebook step\n",
" * 05/03/21 Jesse Vig Adjust height of visualization dynamically\n",
" **/\n",
"\n",
"require.config({\n",
" paths: {\n",
" d3: '//cdnjs.cloudflare.com/ajax/libs/d3/5.7.0/d3.min',\n",
" jquery: '//ajax.googleapis.com/ajax/libs/jquery/2.0.0/jquery.min',\n",
" }\n",
"});\n",
"\n",
"requirejs(['jquery', 'd3'], function ($, d3) {\n",
"\n",
" const params = {\"attention\": [{\"name\": null, \"attn\": [[[[0.04504745081067085, 0.06121315807104111, 0.03234885632991791, 0.07350576668977737, 0.08869655430316925, 0.06299575418233871, 0.13025709986686707, 0.0724874883890152, 0.0789855495095253, 0.04788164049386978, 0.0528176911175251, 0.06413961201906204, 0.1896234154701233], [0.1642618179321289, 0.045601703226566315, 0.08717474341392517, 0.020270587876439095, 0.06333516538143158, 0.13347627222537994, 0.15195965766906738, 0.07450311630964279, 0.020874062553048134, 0.0886143371462822, 0.03885190933942795, 0.054350875318050385, 0.05672568082809448], [0.19410867989063263, 0.08788827061653137, 0.02643328532576561, 0.027052758261561394, 0.0483211874961853, 0.10077042877674103, 0.14769119024276733, 0.0505562499165535, 0.0312521830201149, 0.11102262884378433, 0.05736817419528961, 0.045798979699611664, 0.0717359259724617], [0.1995920091867447, 0.09999717772006989, 0.0661521703004837, 0.05901019275188446, 0.04424671828746796, 0.0677744597196579, 0.12087972462177277, 0.050675857812166214, 0.06362631916999817, 0.07006122171878815, 0.03917881101369858, 0.04700592905282974, 0.07179943472146988], [0.04939603433012962, 0.024679943919181824, 0.04761141911149025, 0.01254973653703928, 0.07388506829738617, 0.23993298411369324, 0.15196150541305542, 0.00832447037100792, 0.012036890722811222, 0.20173829793930054, 0.021537913009524345, 0.09962346404790878, 0.05672222748398781], [0.05641084164381027, 0.05914205312728882, 0.09467293322086334, 0.022168248891830444, 0.09393802285194397, 0.084870345890522, 0.21269568800926208, 0.03484135493636131, 0.022969357669353485, 0.10200338810682297, 0.048721037805080414, 0.0935591459274292, 0.07400762289762497], [0.14331203699111938, 0.05361737683415413, 0.10078467428684235, 0.024218060076236725, 0.10775977373123169, 0.1332426518201828, 0.08465974032878876, 0.05863681063055992, 0.0241240207105875, 0.08344626426696777, 0.04309769719839096, 0.06576775014400482, 0.0773332417011261], [0.211670383810997, 0.11610246449708939, 0.11225587129592896, 0.045633524656295776, 0.059925444424152374, 0.06493557244539261, 0.08031120151281357, 0.03209289163351059, 0.04349662363529205, 0.07048093527555466, 0.05278267711400986, 0.05153229832649231, 0.05878004804253578], [0.19909143447875977, 0.1108265370130539, 0.06320541352033615, 0.07345421612262726, 0.044627320021390915, 0.06109814718365669, 0.10009115934371948, 0.04955925792455673, 0.07297047972679138, 0.06868425011634827, 0.03863503038883209, 0.044418469071388245, 0.07333831489086151], [0.09310530126094818, 0.06271722912788391, 0.05319841206073761, 0.018992554396390915, 0.12321215867996216, 0.13590605556964874, 0.19085046648979187, 0.011909686028957367, 0.01892036944627762, 0.1059010922908783, 0.035090457648038864, 0.11820467561483383, 0.03199157118797302], [0.12340845912694931, 0.01910078339278698, 0.021354133263230324, 0.00847071222960949, 0.09208827465772629, 0.12047737091779709, 0.19756117463111877, 0.01248408854007721, 0.008209108375012875, 0.18166890740394592, 0.04370952770113945, 0.12460346519947052, 0.046863969415426254], [0.1024712547659874, 0.02981560118496418, 0.042016737163066864, 0.00732750678434968, 0.19956143200397491, 0.16194269061088562, 0.13108232617378235, 0.03385123983025551, 0.006953119300305843, 0.10607277601957321, 0.09258976578712463, 0.04644336178898811, 0.03987228125333786], [0.14260147511959076, 0.0856451690196991, 0.052119556814432144, 0.1065417006611824, 0.03710906580090523, 0.038733918219804764, 0.07154049724340439, 0.04271101951599121, 0.10687444359064102, 0.02995859645307064, 0.034151963889598846, 0.03033091314136982, 0.22168178856372833]], [[0.9453150033950806, 0.0042615714482963085, 0.0033084359019994736, 0.008667639456689358, 0.002442109864205122, 0.0032682078890502453, 0.0018215697491541505, 0.005812120623886585, 0.008080368861556053, 0.0023478888906538486, 0.005752548109740019, 0.0029995867516845465, 0.005922860931605101], [0.0054322523064911366, 0.036913447082042694, 0.10574596375226974, 0.050172608345746994, 0.21406637132167816, 0.1669832468032837, 0.06814419478178024
" const TEXT_SIZE = 15;\n",
" const BOXWIDTH = 110;\n",
" const BOXHEIGHT = 22.5;\n",
" const MATRIX_WIDTH = 115;\n",
" const CHECKBOX_SIZE = 20;\n",
" const TEXT_TOP = 30;\n",
"\n",
" console.log(\"d3 version\", d3.version)\n",
" let headColors;\n",
" try {\n",
" headColors = d3.scaleOrdinal(d3.schemeCategory10);\n",
" } catch (err) {\n",
" console.log('Older d3 version')\n",
" headColors = d3.scale.category10();\n",
" }\n",
" let config = {};\n",
" initialize();\n",
" renderVis();\n",
"\n",
" function initialize() {\n",
" config.attention = params['attention'];\n",
" config.filter = params['default_filter'];\n",
" config.rootDivId = params['root_div_id'];\n",
" config.nLayers = config.attention[config.filter]['attn'].length;\n",
" config.nHeads = config.attention[config.filter]['attn'][0].length;\n",
" if (params['heads']) {\n",
" config.headVis = new Array(config.nHeads).fill(false);\n",
" params['heads'].forEach(x => config.headVis[x] = true);\n",
" } else {\n",
" config.headVis = new Array(config.nHeads).fill(true);\n",
" }\n",
" config.initialTextLength = config.attention[config.filter].right_text.length;\n",
" config.layer = (params['layer'] == null ? 0 : params['layer'])\n",
"\n",
"\n",
" let layerEl = $(`#${config.rootDivId} #layer`);\n",
" for (var i = 0; i < config.nLayers; i++) {\n",
" layerEl.append($(\"<option />\").val(i).text(i));\n",
" }\n",
" layerEl.val(config.layer).change();\n",
" layerEl.on('change', function (e) {\n",
" config.layer = +e.currentTarget.value;\n",
" renderVis();\n",
" });\n",
"\n",
" $(`#${config.rootDivId} #filter`).on('change', function (e) {\n",
" config.filter = e.currentTarget.value;\n",
" renderVis();\n",
" });\n",
" }\n",
"\n",
" function renderVis() {\n",
"\n",
" // Load parameters\n",
" const attnData = config.attention[config.filter];\n",
" const leftText = attnData.left_text;\n",
" const rightText = attnData.right_text;\n",
"\n",
" // Select attention for given layer\n",
" const layerAttention = attnData.attn[config.layer];\n",
"\n",
" // Clear vis\n",
" $(`#${config.rootDivId} #vis`).empty();\n",
"\n",
" // Determine size of visualization\n",
" const height = Math.max(leftText.length, rightText.length) * BOXHEIGHT + TEXT_TOP;\n",
" const svg = d3.select(`#${config.rootDivId} #vis`)\n",
" .append('svg')\n",
" .attr(\"width\", \"100%\")\n",
" .attr(\"height\", height + \"px\");\n",
"\n",
" // Display tokens on left and right side of visualization\n",
" renderText(svg, leftText, true, layerAttention, 0);\n",
" renderText(svg, rightText, false, layerAttention, MATRIX_WIDTH + BOXWIDTH);\n",
"\n",
" // Render attention arcs\n",
" renderAttention(svg, layerAttention);\n",
"\n",
" // Draw squares at top of visualization, one for each head\n",
" drawCheckboxes(0, svg, layerAttention);\n",
" }\n",
"\n",
" function renderText(svg, text, isLeft, attention, leftPos) {\n",
"\n",
" const textContainer = svg.append(\"svg:g\")\n",
" .attr(\"id\", isLeft ? \"left\" : \"right\");\n",
"\n",
" // Add attention highlights superimposed over words\n",
" textContainer.append(\"g\")\n",
" .classed(\"attentionBoxes\", true)\n",
" .selectAll(\"g\")\n",
" .data(attention)\n",
" .enter()\n",
" .append(\"g\")\n",
" .attr(\"head-index\", (d, i) => i)\n",
" .selectAll(\"rect\")\n",
" .data(d => isLeft ? d : transpose(d)) // if right text, transpose attention to get right-to-left weights\n",
" .enter()\n",
" .append(\"rect\")\n",
" .attr(\"x\", function () {\n",
" var headIndex = +this.parentNode.getAttribute(\"head-index\");\n",
" return leftPos + boxOffsets(headIndex);\n",
" })\n",
" .attr(\"y\", (+1) * BOXHEIGHT)\n",
" .attr(\"width\", BOXWIDTH / activeHeads())\n",
" .attr(\"height\", BOXHEIGHT)\n",
" .attr(\"fill\", function () {\n",
" return headColors(+this.parentNode.getAttribute(\"head-index\"))\n",
" })\n",
" .style(\"opacity\", 0.0);\n",
"\n",
" const tokenContainer = textContainer.append(\"g\").selectAll(\"g\")\n",
" .data(text)\n",
" .enter()\n",
" .append(\"g\");\n",
"\n",
" // Add gray background that appears when hovering over text\n",
" tokenContainer.append(\"rect\")\n",
" .classed(\"background\", true)\n",
" .style(\"opacity\", 0.0)\n",
" .attr(\"fill\", \"lightgray\")\n",
" .attr(\"x\", leftPos)\n",
" .attr(\"y\", (d, i) => TEXT_TOP + i * BOXHEIGHT)\n",
" .attr(\"width\", BOXWIDTH)\n",
" .attr(\"height\", BOXHEIGHT);\n",
"\n",
" // Add token text\n",
" const textEl = tokenContainer.append(\"text\")\n",
" .text(d => d)\n",
" .attr(\"font-size\", TEXT_SIZE + \"px\")\n",
" .style(\"cursor\", \"default\")\n",
" .style(\"-webkit-user-select\", \"none\")\n",
" .attr(\"x\", leftPos)\n",
" .attr(\"y\", (d, i) => TEXT_TOP + i * BOXHEIGHT);\n",
"\n",
" if (isLeft) {\n",
" textEl.style(\"text-anchor\", \"end\")\n",
" .attr(\"dx\", BOXWIDTH - 0.5 * TEXT_SIZE)\n",
" .attr(\"dy\", TEXT_SIZE);\n",
" } else {\n",
" textEl.style(\"text-anchor\", \"start\")\n",
" .attr(\"dx\", +0.5 * TEXT_SIZE)\n",
" .attr(\"dy\", TEXT_SIZE);\n",
" }\n",
"\n",
" tokenContainer.on(\"mouseover\", function (d, index) {\n",
"\n",
" // Show gray background for moused-over token\n",
" textContainer.selectAll(\".background\")\n",
" .style(\"opacity\", (d, i) => i === index ? 1.0 : 0.0)\n",
"\n",
" // Reset visibility attribute for any previously highlighted attention arcs\n",
" svg.select(\"#attention\")\n",
" .selectAll(\"line[visibility='visible']\")\n",
" .attr(\"visibility\", null)\n",
"\n",
" // Hide group containing attention arcs\n",
" svg.select(\"#attention\").attr(\"visibility\", \"hidden\");\n",
"\n",
" // Set to visible appropriate attention arcs to be highlighted\n",
" if (isLeft) {\n",
" svg.select(\"#attention\").selectAll(\"line[left-token-index='\" + index + \"']\").attr(\"visibility\", \"visible\");\n",
" } else {\n",
" svg.select(\"#attention\").selectAll(\"line[right-token-index='\" + index + \"']\").attr(\"visibility\", \"visible\");\n",
" }\n",
"\n",
" // Update color boxes superimposed over tokens\n",
" const id = isLeft ? \"right\" : \"left\";\n",
" const leftPos = isLeft ? MATRIX_WIDTH + BOXWIDTH : 0;\n",
" svg.select(\"#\" + id)\n",
" .selectAll(\".attentionBoxes\")\n",
" .selectAll(\"g\")\n",
" .attr(\"head-index\", (d, i) => i)\n",
" .selectAll(\"rect\")\n",
" .attr(\"x\", function () {\n",
" const headIndex = +this.parentNode.getAttribute(\"head-index\");\n",
" return leftPos + boxOffsets(headIndex);\n",
" })\n",
" .attr(\"y\", (d, i) => TEXT_TOP + i * BOXHEIGHT)\n",
" .attr(\"width\", BOXWIDTH / activeHeads())\n",
" .attr(\"height\", BOXHEIGHT)\n",
" .style(\"opacity\", function (d) {\n",
" const headIndex = +this.parentNode.getAttribute(\"head-index\");\n",
" if (config.headVis[headIndex])\n",
" if (d) {\n",
" return d[index];\n",
" } else {\n",
" return 0.0;\n",
" }\n",
" else\n",
" return 0.0;\n",
" });\n",
" });\n",
"\n",
" textContainer.on(\"mouseleave\", function () {\n",
"\n",
" // Unhighlight selected token\n",
" d3.select(this).selectAll(\".background\")\n",
" .style(\"opacity\", 0.0);\n",
"\n",
" // Reset visibility attributes for previously selected lines\n",
" svg.select(\"#attention\")\n",
" .selectAll(\"line[visibility='visible']\")\n",
" .attr(\"visibility\", null) ;\n",
" svg.select(\"#attention\").attr(\"visibility\", \"visible\");\n",
"\n",
" // Reset highlights superimposed over tokens\n",
" svg.selectAll(\".attentionBoxes\")\n",
" .selectAll(\"g\")\n",
" .selectAll(\"rect\")\n",
" .style(\"opacity\", 0.0);\n",
" });\n",
" }\n",
"\n",
" function renderAttention(svg, attention) {\n",
"\n",
" // Remove previous dom elements\n",
" svg.select(\"#attention\").remove();\n",
"\n",
" // Add new elements\n",
" svg.append(\"g\")\n",
" .attr(\"id\", \"attention\") // Container for all attention arcs\n",
" .selectAll(\".headAttention\")\n",
" .data(attention)\n",
" .enter()\n",
" .append(\"g\")\n",
" .classed(\"headAttention\", true) // Group attention arcs by head\n",
" .attr(\"head-index\", (d, i) => i)\n",
" .selectAll(\".tokenAttention\")\n",
" .data(d => d)\n",
" .enter()\n",
" .append(\"g\")\n",
" .classed(\"tokenAttention\", true) // Group attention arcs by left token\n",
" .attr(\"left-token-index\", (d, i) => i)\n",
" .selectAll(\"line\")\n",
" .data(d => d)\n",
" .enter()\n",
" .append(\"line\")\n",
" .attr(\"x1\", BOXWIDTH)\n",
" .attr(\"y1\", function () {\n",
" const leftTokenIndex = +this.parentNode.getAttribute(\"left-token-index\")\n",
" return TEXT_TOP + leftTokenIndex * BOXHEIGHT + (BOXHEIGHT / 2)\n",
" })\n",
" .attr(\"x2\", BOXWIDTH + MATRIX_WIDTH)\n",
" .attr(\"y2\", (d, rightTokenIndex) => TEXT_TOP + rightTokenIndex * BOXHEIGHT + (BOXHEIGHT / 2))\n",
" .attr(\"stroke-width\", 2)\n",
" .attr(\"stroke\", function () {\n",
" const headIndex = +this.parentNode.parentNode.getAttribute(\"head-index\");\n",
" return headColors(headIndex)\n",
" })\n",
" .attr(\"left-token-index\", function () {\n",
" return +this.parentNode.getAttribute(\"left-token-index\")\n",
" })\n",
" .attr(\"right-token-index\", (d, i) => i)\n",
" ;\n",
" updateAttention(svg)\n",
" }\n",
"\n",
" function updateAttention(svg) {\n",
" svg.select(\"#attention\")\n",
" .selectAll(\"line\")\n",
" .attr(\"stroke-opacity\", function (d) {\n",
" const headIndex = +this.parentNode.parentNode.getAttribute(\"head-index\");\n",
" // If head is selected\n",
" if (config.headVis[headIndex]) {\n",
" // Set opacity to attention weight divided by number of active heads\n",
" return d / activeHeads()\n",
" } else {\n",
" return 0.0;\n",
" }\n",
" })\n",
" }\n",
"\n",
" function boxOffsets(i) {\n",
" const numHeadsAbove = config.headVis.reduce(\n",
" function (acc, val, cur) {\n",
" return val && cur < i ? acc + 1 : acc;\n",
" }, 0);\n",
" return numHeadsAbove * (BOXWIDTH / activeHeads());\n",
" }\n",
"\n",
" function activeHeads() {\n",
" return config.headVis.reduce(function (acc, val) {\n",
" return val ? acc + 1 : acc;\n",
" }, 0);\n",
" }\n",
"\n",
" function drawCheckboxes(top, svg) {\n",
" const checkboxContainer = svg.append(\"g\");\n",
" const checkbox = checkboxContainer.selectAll(\"rect\")\n",
" .data(config.headVis)\n",
" .enter()\n",
" .append(\"rect\")\n",
" .attr(\"fill\", (d, i) => headColors(i))\n",
" .attr(\"x\", (d, i) => i * CHECKBOX_SIZE)\n",
" .attr(\"y\", top)\n",
" .attr(\"width\", CHECKBOX_SIZE)\n",
" .attr(\"height\", CHECKBOX_SIZE);\n",
"\n",
" function updateCheckboxes() {\n",
" checkboxContainer.selectAll(\"rect\")\n",
" .data(config.headVis)\n",
" .attr(\"fill\", (d, i) => d ? headColors(i): lighten(headColors(i)));\n",
" }\n",
"\n",
" updateCheckboxes();\n",
"\n",
" checkbox.on(\"click\", function (d, i) {\n",
" if (config.headVis[i] && activeHeads() === 1) return;\n",
" config.headVis[i] = !config.headVis[i];\n",
" updateCheckboxes();\n",
" updateAttention(svg);\n",
" });\n",
"\n",
" checkbox.on(\"dblclick\", function (d, i) {\n",
" // If we double click on the only active head then reset\n",
" if (config.headVis[i] && activeHeads() === 1) {\n",
" config.headVis = new Array(config.nHeads).fill(true);\n",
" } else {\n",
" config.headVis = new Array(config.nHeads).fill(false);\n",
" config.headVis[i] = true;\n",
" }\n",
" updateCheckboxes();\n",
" updateAttention(svg);\n",
" });\n",
" }\n",
"\n",
" function lighten(color) {\n",
" const c = d3.hsl(color);\n",
" const increment = (1 - c.l) * 0.6;\n",
" c.l += increment;\n",
" c.s -= increment;\n",
" return c;\n",
" }\n",
"\n",
" function transpose(mat) {\n",
" return mat[0].map(function (col, i) {\n",
" return mat.map(function (row) {\n",
" return row[i];\n",
" });\n",
" });\n",
" }\n",
"\n",
"});"
],
"text/plain": [
"<IPython.core.display.Javascript object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"head_view(attention, tokens)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<script src=\"https://cdnjs.cloudflare.com/ajax/libs/require.js/2.3.6/require.min.js\"></script>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
" \n",
" <div id='bertviz-61533a3712f543eba69f03bb563030ec'>\n",
" <span style=\"user-select:none\">\n",
" \n",
" </span>\n",
" <div id='vis'></div>\n",
" </div>\n",
" "
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/javascript": [
"/**\n",
" * @fileoverview Transformer Visualization D3 javascript code.\n",
" *\n",
" * Based on: https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/visualization/attention.js\n",
" *\n",
" * Change log:\n",
" *\n",
" * 02/01/19 Jesse Vig Initial implementation\n",
" * 12/31/20 Jesse Vig Support multiple visualizations in single notebook.\n",
" * 01/19/21 Jesse Vig Support light/dark modes\n",
" * 02/06/21 Jesse Vig Move require config from separate jupyter notebook step\n",
" * 05/03/21 Jesse Vig Adjust visualization height dynamically\n",
" **/\n",
"\n",
"require.config({\n",
" paths: {\n",
" d3: '//cdnjs.cloudflare.com/ajax/libs/d3/5.7.0/d3.min',\n",
" jquery: '//ajax.googleapis.com/ajax/libs/jquery/2.0.0/jquery.min',\n",
" }\n",
"});\n",
"\n",
"requirejs(['jquery', 'd3'], function($, d3) {\n",
"\n",
" const params = {\"attention\": [{\"name\": null, \"attn\": [[[[0.04504745081067085, 0.06121315807104111, 0.03234885632991791, 0.07350576668977737, 0.08869655430316925, 0.06299575418233871, 0.13025709986686707, 0.0724874883890152, 0.0789855495095253, 0.04788164049386978, 0.0528176911175251, 0.06413961201906204, 0.1896234154701233], [0.1642618179321289, 0.045601703226566315, 0.08717474341392517, 0.020270587876439095, 0.06333516538143158, 0.13347627222537994, 0.15195965766906738, 0.07450311630964279, 0.020874062553048134, 0.0886143371462822, 0.03885190933942795, 0.054350875318050385, 0.05672568082809448], [0.19410867989063263, 0.08788827061653137, 0.02643328532576561, 0.027052758261561394, 0.0483211874961853, 0.10077042877674103, 0.14769119024276733, 0.0505562499165535, 0.0312521830201149, 0.11102262884378433, 0.05736817419528961, 0.045798979699611664, 0.0717359259724617], [0.1995920091867447, 0.09999717772006989, 0.0661521703004837, 0.05901019275188446, 0.04424671828746796, 0.0677744597196579, 0.12087972462177277, 0.050675857812166214, 0.06362631916999817, 0.07006122171878815, 0.03917881101369858, 0.04700592905282974, 0.07179943472146988], [0.04939603433012962, 0.024679943919181824, 0.04761141911149025, 0.01254973653703928, 0.07388506829738617, 0.23993298411369324, 0.15196150541305542, 0.00832447037100792, 0.012036890722811222, 0.20173829793930054, 0.021537913009524345, 0.09962346404790878, 0.05672222748398781], [0.05641084164381027, 0.05914205312728882, 0.09467293322086334, 0.022168248891830444, 0.09393802285194397, 0.084870345890522, 0.21269568800926208, 0.03484135493636131, 0.022969357669353485, 0.10200338810682297, 0.048721037805080414, 0.0935591459274292, 0.07400762289762497], [0.14331203699111938, 0.05361737683415413, 0.10078467428684235, 0.024218060076236725, 0.10775977373123169, 0.1332426518201828, 0.08465974032878876, 0.05863681063055992, 0.0241240207105875, 0.08344626426696777, 0.04309769719839096, 0.06576775014400482, 0.0773332417011261], [0.211670383810997, 0.11610246449708939, 0.11225587129592896, 0.045633524656295776, 0.059925444424152374, 0.06493557244539261, 0.08031120151281357, 0.03209289163351059, 0.04349662363529205, 0.07048093527555466, 0.05278267711400986, 0.05153229832649231, 0.05878004804253578], [0.19909143447875977, 0.1108265370130539, 0.06320541352033615, 0.07345421612262726, 0.044627320021390915, 0.06109814718365669, 0.10009115934371948, 0.04955925792455673, 0.07297047972679138, 0.06868425011634827, 0.03863503038883209, 0.044418469071388245, 0.07333831489086151], [0.09310530126094818, 0.06271722912788391, 0.05319841206073761, 0.018992554396390915, 0.12321215867996216, 0.13590605556964874, 0.19085046648979187, 0.011909686028957367, 0.01892036944627762, 0.1059010922908783, 0.035090457648038864, 0.11820467561483383, 0.03199157118797302], [0.12340845912694931, 0.01910078339278698, 0.021354133263230324, 0.00847071222960949, 0.09208827465772629, 0.12047737091779709, 0.19756117463111877, 0.01248408854007721, 0.008209108375012875, 0.18166890740394592, 0.04370952770113945, 0.12460346519947052, 0.046863969415426254], [0.1024712547659874, 0.02981560118496418, 0.042016737163066864, 0.00732750678434968, 0.19956143200397491, 0.16194269061088562, 0.13108232617378235, 0.03385123983025551, 0.006953119300305843, 0.10607277601957321, 0.09258976578712463, 0.04644336178898811, 0.03987228125333786], [0.14260147511959076, 0.0856451690196991, 0.052119556814432144, 0.1065417006611824, 0.03710906580090523, 0.038733918219804764, 0.07154049724340439, 0.04271101951599121, 0.10687444359064102, 0.02995859645307064, 0.034151963889598846, 0.03033091314136982, 0.22168178856372833]], [[0.9453150033950806, 0.0042615714482963085, 0.0033084359019994736, 0.008667639456689358, 0.002442109864205122, 0.0032682078890502453, 0.0018215697491541505, 0.005812120623886585, 0.008080368861556053, 0.0023478888906538486, 0.005752548109740019, 0.0029995867516845465, 0.005922860931605101], [0.0054322523064911366, 0.036913447082042694, 0.10574596375226974, 0.050172608345746994, 0.21406637132167816, 0.1669832468032837, 0.0681441947817
" const config = {};\n",
"\n",
" const MIN_X = 0;\n",
" const MIN_Y = 0;\n",
" const DIV_WIDTH = 970;\n",
" const THUMBNAIL_PADDING = 5;\n",
" const DETAIL_WIDTH = 300;\n",
" const DETAIL_ATTENTION_WIDTH = 140;\n",
" const DETAIL_BOX_WIDTH = 80;\n",
" const DETAIL_BOX_HEIGHT = 18;\n",
" const DETAIL_PADDING = 28;\n",
" const ATTN_PADDING = 0;\n",
" const DETAIL_HEADING_HEIGHT = 25;\n",
" const DETAIL_HEADING_TEXT_SIZE = 15;\n",
" const TEXT_SIZE = 13;\n",
" const LAYER_COLORS = d3.schemeCategory10;\n",
" const PALETTE = {\n",
" 'light': {\n",
" 'text': 'black',\n",
" 'background': 'white',\n",
" 'highlight': '#F5F5F5'\n",
" },\n",
" 'dark': {\n",
" 'text': '#bbb',\n",
" 'background': 'black',\n",
" 'highlight': '#222'\n",
" }\n",
" }\n",
"\n",
" function render() {\n",
"\n",
" // Set global state variables\n",
"\n",
" var attData = config.attention[config.filter];\n",
" config.leftText = attData.left_text;\n",
" config.rightText = attData.right_text;\n",
" config.attn = attData.attn;\n",
" config.numLayers = config.attn.length;\n",
" config.numHeads = config.attn[0].length;\n",
" config.thumbnailBoxHeight = 7 * (12 / config.numHeads);\n",
" config.thumbnailHeight = Math.max(config.leftText.length, config.rightText.length) * config.thumbnailBoxHeight + 2 * THUMBNAIL_PADDING;\n",
" config.thumbnailWidth = DIV_WIDTH / config.numHeads;\n",
" config.detailHeight = Math.max(config.leftText.length, config.rightText.length) * DETAIL_BOX_HEIGHT + 2 * DETAIL_PADDING + DETAIL_HEADING_HEIGHT;\n",
" config.divHeight = config.numLayers * config.thumbnailHeight;\n",
"\n",
" const vis = $(`#${config.rootDivId} #vis`)\n",
" vis.empty();\n",
" vis.attr(\"height\", config.divHeight);\n",
" config.svg = d3.select(`#${config.rootDivId} #vis`)\n",
" .append('svg')\n",
" .attr(\"width\", DIV_WIDTH)\n",
" .attr(\"height\", config.divHeight)\n",
" .attr(\"fill\", getBackgroundColor());\n",
"\n",
" var i;\n",
" var j;\n",
" for (i = 0; i < config.numLayers; i++) {\n",
" for (j = 0; j < config.numHeads; j++) {\n",
" renderThumbnail(i, j);\n",
" }\n",
" }\n",
" }\n",
"\n",
" function renderThumbnail(layerIndex, headIndex) {\n",
" var x = headIndex * config.thumbnailWidth;\n",
" var y = layerIndex * config.thumbnailHeight;\n",
" renderThumbnailAttn(x, y, config.attn[layerIndex][headIndex], layerIndex, headIndex);\n",
" }\n",
"\n",
" function renderDetail(att, layerIndex, headIndex) {\n",
" var xOffset = .8 * config.thumbnailWidth;\n",
" var maxX = DIV_WIDTH;\n",
" var maxY = config.divHeight;\n",
" var leftPos = (headIndex / config.numHeads) * DIV_WIDTH\n",
" var x = leftPos + THUMBNAIL_PADDING + xOffset;\n",
" if (x < MIN_X) {\n",
" x = MIN_X;\n",
" } else if (x + DETAIL_WIDTH > maxX) {\n",
" x = leftPos + THUMBNAIL_PADDING - DETAIL_WIDTH + 8;\n",
" }\n",
" var posLeftText = x;\n",
" var posAttention = posLeftText + DETAIL_BOX_WIDTH;\n",
" var posRightText = posAttention + DETAIL_ATTENTION_WIDTH;\n",
" var thumbnailHeight = Math.max(config.leftText.length, config.rightText.length) * config.thumbnailBoxHeight + 2 * THUMBNAIL_PADDING;\n",
" var yOffset = 20;\n",
" var y = layerIndex * thumbnailHeight + THUMBNAIL_PADDING + yOffset;\n",
" if (y < MIN_Y) {\n",
" y = MIN_Y;\n",
" } else if (y + config.detailHeight > maxY) {\n",
" y = maxY - config.detailHeight;\n",
" }\n",
" renderDetailFrame(x, y, layerIndex);\n",
" renderDetailHeading(x, y + Math.max(config.leftText.length, config.rightText.length) *\n",
" DETAIL_BOX_HEIGHT, layerIndex, headIndex);\n",
" renderDetailText(config.leftText, \"leftText\", posLeftText, y + DETAIL_PADDING, layerIndex);\n",
" renderDetailAttn(posAttention, y + DETAIL_PADDING, att, layerIndex, headIndex);\n",
" renderDetailText(config.rightText, \"rightText\", posRightText, y + DETAIL_PADDING, layerIndex);\n",
" }\n",
"\n",
" function renderDetailHeading(x, y, layerIndex, headIndex) {\n",
" var fillColor = getTextColor();\n",
" config.svg.append(\"text\")\n",
" .classed(\"detail\", true)\n",
" .text('Layer ' + layerIndex + \", Head \" + headIndex)\n",
" .attr(\"font-size\", DETAIL_HEADING_TEXT_SIZE + \"px\")\n",
" .style(\"cursor\", \"default\")\n",
" .style(\"-webkit-user-select\", \"none\")\n",
" .attr(\"fill\", fillColor)\n",
" .attr(\"x\", x + DETAIL_WIDTH / 2)\n",
" .attr(\"text-anchor\", \"middle\")\n",
" .attr(\"y\", y + 40)\n",
" .attr(\"height\", DETAIL_HEADING_HEIGHT)\n",
" .attr(\"width\", DETAIL_WIDTH)\n",
" .attr(\"dy\", DETAIL_HEADING_TEXT_SIZE);\n",
" }\n",
"\n",
" function renderDetailText(text, id, x, y, layerIndex) {\n",
" var tokenContainer = config.svg.append(\"svg:g\")\n",
" .classed(\"detail\", true)\n",
" .selectAll(\"g\")\n",
" .data(text)\n",
" .enter()\n",
" .append(\"g\");\n",
"\n",
" var fillColor = getTextColor();\n",
"\n",
" tokenContainer.append(\"rect\")\n",
" .classed(\"highlight\", true)\n",
" .attr(\"fill\", fillColor)\n",
" .style(\"opacity\", 0.0)\n",
" .attr(\"height\", DETAIL_BOX_HEIGHT)\n",
" .attr(\"width\", DETAIL_BOX_WIDTH)\n",
" .attr(\"x\", x)\n",
" .attr(\"y\", function (d, i) {\n",
" return y + i * DETAIL_BOX_HEIGHT;\n",
" });\n",
"\n",
" var textContainer = tokenContainer.append(\"text\")\n",
" .classed(\"token\", true)\n",
" .text(function (d) {\n",
" return d;\n",
" })\n",
" .attr(\"font-size\", TEXT_SIZE + \"px\")\n",
" .style(\"cursor\", \"default\")\n",
" .style(\"-webkit-user-select\", \"none\")\n",
" .attr(\"fill\", fillColor)\n",
" .attr(\"x\", x)\n",
" .attr(\"y\", function (d, i) {\n",
" return i * DETAIL_BOX_HEIGHT + y;\n",
" })\n",
" .attr(\"height\", DETAIL_BOX_HEIGHT)\n",
" .attr(\"width\", DETAIL_BOX_WIDTH)\n",
" .attr(\"dy\", TEXT_SIZE);\n",
"\n",
" if (id == \"leftText\") {\n",
" textContainer.style(\"text-anchor\", \"end\")\n",
" .attr(\"dx\", DETAIL_BOX_WIDTH - 2);\n",
" tokenContainer.on(\"mouseover\", function (d, index) {\n",
" highlightSelection(index);\n",
" });\n",
" tokenContainer.on(\"mouseleave\", function () {\n",
" unhighlightSelection();\n",
" });\n",
" }\n",
" }\n",
"\n",
" function highlightSelection(index) {\n",
" config.svg.select(\"#leftText\")\n",
" .selectAll(\".highlight\")\n",
" .style(\"opacity\", function (d, i) {\n",
" return i == index ? 1.0 : 0.0;\n",
" });\n",
" config.svg.selectAll(\".attn-line-group\")\n",
" .style(\"opacity\", function (d, i) {\n",
" return i == index ? 1.0 : 0.0;\n",
" });\n",
" }\n",
"\n",
" function unhighlightSelection() {\n",
" config.svg.select(\"#leftText\")\n",
" .selectAll(\".highlight\")\n",
" .style(\"opacity\", 0.0);\n",
" config.svg.selectAll(\".attn-line-group\")\n",
" .style(\"opacity\", 1);\n",
" }\n",
"\n",
" function renderThumbnailAttn(x, y, att, layerIndex, headIndex) {\n",
"\n",
" var attnContainer = config.svg.append(\"svg:g\");\n",
"\n",
" var attnBackground = attnContainer.append(\"rect\")\n",
" .attr(\"id\", 'attn_background_' + layerIndex + \"_\" + headIndex)\n",
" .classed(\"attn_background\", true)\n",
" .attr(\"x\", x)\n",
" .attr(\"y\", y)\n",
" .attr(\"height\", config.thumbnailHeight)\n",
" .attr(\"width\", config.thumbnailWidth)\n",
" .attr(\"stroke-width\", 2)\n",
" .attr(\"stroke\", getLayerColor(layerIndex))\n",
" .attr(\"stroke-opacity\", 0)\n",
" .attr(\"fill\", getBackgroundColor());\n",
" var x1 = x + THUMBNAIL_PADDING;\n",
" var x2 = x1 + config.thumbnailWidth - 14;\n",
" var y1 = y + THUMBNAIL_PADDING;\n",
"\n",
" attnContainer.selectAll(\"g\")\n",
" .data(att)\n",
" .enter()\n",
" .append(\"g\") // Add group for each source token\n",
" .attr(\"source-index\", function (d, i) { // Save index of source token\n",
" return i;\n",
" })\n",
" .selectAll(\"line\")\n",
" .data(function (d) { // Loop over all target tokens\n",
" return d;\n",
" })\n",
" .enter() // When entering\n",
" .append(\"line\")\n",
" .attr(\"x1\", x1)\n",
" .attr(\"y1\", function (d) {\n",
" var sourceIndex = +this.parentNode.getAttribute(\"source-index\");\n",
" return y1 + (sourceIndex + .5) * config.thumbnailBoxHeight;\n",
" })\n",
" .attr(\"x2\", x2)\n",
" .attr(\"y2\", function (d, targetIndex) {\n",
" return y1 + (targetIndex + .5) * config.thumbnailBoxHeight;\n",
" })\n",
" .attr(\"stroke-width\", 2.2)\n",
" .attr(\"stroke\", getLayerColor(layerIndex))\n",
" .attr(\"stroke-opacity\", function (d) {\n",
" return d;\n",
" });\n",
"\n",
" var clickRegion = attnContainer.append(\"rect\")\n",
" .attr(\"x\", x)\n",
" .attr(\"y\", y)\n",
" .attr(\"height\", config.thumbnailHeight)\n",
" .attr(\"width\", config.thumbnailWidth)\n",
" .style(\"opacity\", 0);\n",
"\n",
" clickRegion.on(\"click\", function (d, index) {\n",
" var attnBackgroundOther = config.svg.selectAll(\".attn_background\");\n",
" attnBackgroundOther.attr(\"fill\", getBackgroundColor());\n",
" attnBackgroundOther.attr(\"stroke-opacity\", 0);\n",
"\n",
" config.svg.selectAll(\".detail\").remove();\n",
" if (config.detail_layer != layerIndex || config.detail_head != headIndex) {\n",
" renderDetail(att, layerIndex, headIndex);\n",
" config.detail_layer = layerIndex;\n",
" config.detail_head = headIndex;\n",
" attnBackground.attr(\"fill\", getHighlightColor());\n",
" attnBackground.attr(\"stroke-opacity\", .8);\n",
" } else {\n",
" config.detail_layer = null;\n",
" config.detail_head = null;\n",
" attnBackground.attr(\"fill\", getBackgroundColor());\n",
" attnBackground.attr(\"stroke-opacity\", 0);\n",
" }\n",
" });\n",
"\n",
" clickRegion.on(\"mouseover\", function (d) {\n",
" d3.select(this).style(\"cursor\", \"pointer\");\n",
" });\n",
" }\n",
"\n",
" function renderDetailFrame(x, y, layerIndex) {\n",
" var detailFrame = config.svg.append(\"rect\")\n",
" .classed(\"detail\", true)\n",
" .attr(\"x\", x)\n",
" .attr(\"y\", y)\n",
" .attr(\"height\", config.detailHeight)\n",
" .attr(\"width\", DETAIL_WIDTH)\n",
" .style(\"opacity\", 1)\n",
" .attr(\"stroke-width\", 1.5)\n",
" .attr(\"stroke-opacity\", 0.7)\n",
" .attr(\"stroke\", getLayerColor(layerIndex));\n",
" }\n",
"\n",
" function renderDetailAttn(x, y, att, layerIndex) {\n",
" var attnContainer = config.svg.append(\"svg:g\")\n",
" .classed(\"detail\", true)\n",
" .attr(\"pointer-events\", \"none\");\n",
" attnContainer.selectAll(\"g\")\n",
" .data(att)\n",
" .enter()\n",
" .append(\"g\") // Add group for each source token\n",
" .classed('attn-line-group', true)\n",
" .attr(\"source-index\", function (d, i) { // Save index of source token\n",
" return i;\n",
" })\n",
" .selectAll(\"line\")\n",
" .data(function (d) { // Loop over all target tokens\n",
" return d;\n",
" })\n",
" .enter()\n",
" .append(\"line\")\n",
" .attr(\"x1\", x + ATTN_PADDING)\n",
" .attr(\"y1\", function (d) {\n",
" var sourceIndex = +this.parentNode.getAttribute(\"source-index\");\n",
" return y + (sourceIndex + .5) * DETAIL_BOX_HEIGHT;\n",
" })\n",
" .attr(\"x2\", x + DETAIL_ATTENTION_WIDTH - ATTN_PADDING)\n",
" .attr(\"y2\", function (d, targetIndex) {\n",
" return y + (targetIndex + .5) * DETAIL_BOX_HEIGHT;\n",
" })\n",
" .attr(\"stroke-width\", 2.2)\n",
" .attr(\"stroke\", getLayerColor(layerIndex))\n",
" .attr(\"stroke-opacity\", function (d) {\n",
" return d;\n",
" });\n",
" }\n",
"\n",
" function getLayerColor(layer) {\n",
" return LAYER_COLORS[layer % 10];\n",
" }\n",
"\n",
" function getTextColor() {\n",
" return PALETTE[config.mode]['text']\n",
" }\n",
"\n",
" function getBackgroundColor() {\n",
" return PALETTE[config.mode]['background']\n",
" }\n",
"\n",
" function getHighlightColor() {\n",
" return PALETTE[config.mode]['highlight']\n",
" }\n",
"\n",
" function initialize() {\n",
" config.attention = params['attention'];\n",
" config.filter = params['default_filter'];\n",
" config.mode = params['display_mode'];\n",
" config.rootDivId = params['root_div_id'];\n",
" $(`#${config.rootDivId} #filter`).on('change', function (e) {\n",
" config.filter = e.currentTarget.value;\n",
" render();\n",
" });\n",
" }\n",
"\n",
" initialize();\n",
" render();\n",
"\n",
" });"
],
"text/plain": [
"<IPython.core.display.Javascript object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"model_view(attention, tokens)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### ENCODER-DECODER MODELS"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"MODEL = \"Helsinki-NLP/opus-mt-en-de\""
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"TEXT_ENCODER = \"She sees the small elephant.\"\n",
"TEXT_DECODER = \"Sie sieht den kleinen Elefanten.\""
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"tokenizer = AutoTokenizer.from_pretrained(MODEL)\n",
"model = AutoModel.from_pretrained(MODEL, output_attentions=True)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"encoder_input_ids = tokenizer(TEXT_ENCODER, return_tensors=\"pt\", add_special_tokens=True).input_ids\n",
"decoder_input_ids = tokenizer(TEXT_DECODER, return_tensors=\"pt\", add_special_tokens=True).input_ids\n",
"\n",
"outputs = model(input_ids=encoder_input_ids, decoder_input_ids=decoder_input_ids)\n",
"\n",
"encoder_text = tokenizer.convert_ids_to_tokens(encoder_input_ids[0])\n",
"decoder_text = tokenizer.convert_ids_to_tokens(decoder_input_ids[0])"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<script src=\"https://cdnjs.cloudflare.com/ajax/libs/require.js/2.3.6/require.min.js\"></script>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
" \n",
" <div id='bertviz-8b17c300802a4ad9b6896ab4d4e8eba0'>\n",
" <span style=\"user-select:none\">\n",
" Layer: <select id=\"layer\"></select>\n",
" Attention: <select id=\"filter\"><option value=\"0\">Encoder</option>\n",
"<option value=\"1\">Decoder</option>\n",
"<option value=\"2\">Cross</option></select>\n",
" </span>\n",
" <div id='vis'></div>\n",
" </div>\n",
" "
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/javascript": [
"/**\n",
" * @fileoverview Transformer Visualization D3 javascript code.\n",
" *\n",
" *\n",
" * Based on: https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/visualization/attention.js\n",
" *\n",
" * Change log:\n",
" *\n",
" * 12/19/18 Jesse Vig Assorted cleanup. Changed orientation of attention matrices.\n",
" * 12/29/20 Jesse Vig Significant refactor.\n",
" * 12/31/20 Jesse Vig Support multiple visualizations in single notebook.\n",
" * 02/06/21 Jesse Vig Move require config from separate jupyter notebook step\n",
" * 05/03/21 Jesse Vig Adjust height of visualization dynamically\n",
" **/\n",
"\n",
"require.config({\n",
" paths: {\n",
" d3: '//cdnjs.cloudflare.com/ajax/libs/d3/5.7.0/d3.min',\n",
" jquery: '//ajax.googleapis.com/ajax/libs/jquery/2.0.0/jquery.min',\n",
" }\n",
"});\n",
"\n",
"requirejs(['jquery', 'd3'], function ($, d3) {\n",
"\n",
" const params = {\"attention\": [{\"name\": \"Encoder\", \"attn\": [[[[0.0015343179693445563, 0.06372829526662827, 0.3402520716190338, 0.014809303916990757, 0.11685427278280258, 0.09690533578395844, 0.3659164011478424], [0.008656513877213001, 0.08992935717105865, 0.25149786472320557, 0.07625745981931686, 0.20215430855751038, 0.10926036536693573, 0.26224416494369507], [0.01743713766336441, 0.19126100838184357, 0.019580217078328133, 0.10873700678348541, 0.5898742079734802, 0.013855690136551857, 0.05925469845533371], [0.006695661228150129, 0.06985638290643692, 0.03561459109187126, 0.03681695833802223, 0.7022322416305542, 0.054333802312612534, 0.09445030987262726], [0.0010159132070839405, 0.004103748593479395, 0.18814627826213837, 0.013669033534824848, 0.03894238546490669, 0.16656135022640228, 0.5875613689422607], [0.0033098733983933926, 0.06926336139440536, 0.006486027967184782, 0.033568885177373886, 0.7985183000564575, 0.014555910602211952, 0.07429762929677963], [0.006561059970408678, 0.028042230755090714, 0.06652601808309555, 0.04403909668326378, 0.1904589682817459, 0.17164954543113708, 0.49272310733795166]], [[0.0010659039253368974, 0.027911514043807983, 0.5827310681343079, 0.00583711639046669, 0.02539275959134102, 0.09940840303897858, 0.25765326619148254], [0.014391965232789516, 0.07798851281404495, 0.4047695994377136, 0.01874568685889244, 0.03196389973163605, 0.1091737151145935, 0.34296661615371704], [0.05386969447135925, 0.1397337168455124, 0.1042787954211235, 0.41253480315208435, 0.13313433527946472, 0.06411976367235184, 0.09232895076274872], [0.011524752713739872, 0.024617008864879608, 0.11605977267026901, 0.15868355333805084, 0.07823850214481354, 0.1350926011800766, 0.4757837951183319], [0.000765774748288095, 0.0007213741773739457, 0.15132009983062744, 0.007885550148785114, 0.01554951909929514, 0.1716994047164917, 0.6520583033561707], [0.00261464505456388, 0.01399530004709959, 0.0927773043513298, 0.05853908136487007, 0.12484410405158997, 0.22863952815532684, 0.4785899519920349], [0.005509554408490658, 0.018335115164518356, 0.07280543446540833, 0.02366117760539055, 0.06281809508800507, 0.21027612686157227, 0.6065945625305176]], [[0.001894954708404839, 0.04986847937107086, 0.3366052806377411, 0.005386138800531626, 0.11281800270080566, 0.07878921926021576, 0.4146379232406616], [0.0033349881414324045, 0.03520013391971588, 0.5712409615516663, 0.031452734023332596, 0.0384756363928318, 0.08380278199911118, 0.2364927977323532], [0.032189127057790756, 0.08156478404998779, 0.22131586074829102, 0.08562739938497543, 0.08101073652505875, 0.1313415765762329, 0.36695045232772827], [0.0022780823055654764, 0.08724657446146011, 0.20255671441555023, 0.09033971279859543, 0.11631207168102264, 0.25404372811317444, 0.24722309410572052], [0.00028211530297994614, 0.0050274962559342384, 0.12840455770492554, 0.008146004751324654, 0.05242825299501419, 0.23907428979873657, 0.5666372776031494], [8.598285057814792e-05, 0.0020894501358270645, 0.02531367912888527, 0.001749282469972968, 0.01047285832464695, 0.10172442346811295, 0.8585643172264099], [0.0019090479472652078, 0.009371059946715832, 0.03571559488773346, 0.008757798001170158, 0.01917792297899723, 0.12495970726013184, 0.8001089096069336]], [[0.0002124253660440445, 0.07846645265817642, 0.6617725491523743, 0.0009571347036398947, 0.0028506950475275517, 0.0355488583445549, 0.2201918661594391], [0.0013146958081051707, 0.03666481375694275, 0.5017722249031067, 0.032994918525218964, 0.08658347278833389, 0.04107796028256416, 0.29959189891815186], [0.007319945842027664, 0.25871768593788147, 0.07154543697834015, 0.20867864787578583, 0.18836019933223724, 0.18986377120018005, 0.07551436871290207], [0.002156124683097005, 0.056250181049108505, 0.20182666182518005, 0.011035334318876266, 0.0871698409318924, 0.05659180134534836, 0.5849699974060059], [0.0009280861704610288, 0.004628014750778675, 0.20538683235645294, 0.005883386358618736, 0.012845118530094624, 0.10419470071792603, 0.6661338806152344], [0.0012555076973512769, 0.041694026440382004, 0.05578531697392464, 0.12980355322360992, 0.2015871
" const TEXT_SIZE = 15;\n",
" const BOXWIDTH = 110;\n",
" const BOXHEIGHT = 22.5;\n",
" const MATRIX_WIDTH = 115;\n",
" const CHECKBOX_SIZE = 20;\n",
" const TEXT_TOP = 30;\n",
"\n",
" console.log(\"d3 version\", d3.version)\n",
" let headColors;\n",
" try {\n",
" headColors = d3.scaleOrdinal(d3.schemeCategory10);\n",
" } catch (err) {\n",
" console.log('Older d3 version')\n",
" headColors = d3.scale.category10();\n",
" }\n",
" let config = {};\n",
" initialize();\n",
" renderVis();\n",
"\n",
" function initialize() {\n",
" config.attention = params['attention'];\n",
" config.filter = params['default_filter'];\n",
" config.rootDivId = params['root_div_id'];\n",
" config.nLayers = config.attention[config.filter]['attn'].length;\n",
" config.nHeads = config.attention[config.filter]['attn'][0].length;\n",
" if (params['heads']) {\n",
" config.headVis = new Array(config.nHeads).fill(false);\n",
" params['heads'].forEach(x => config.headVis[x] = true);\n",
" } else {\n",
" config.headVis = new Array(config.nHeads).fill(true);\n",
" }\n",
" config.initialTextLength = config.attention[config.filter].right_text.length;\n",
" config.layer = (params['layer'] == null ? 0 : params['layer'])\n",
"\n",
"\n",
" let layerEl = $(`#${config.rootDivId} #layer`);\n",
" for (var i = 0; i < config.nLayers; i++) {\n",
" layerEl.append($(\"<option />\").val(i).text(i));\n",
" }\n",
" layerEl.val(config.layer).change();\n",
" layerEl.on('change', function (e) {\n",
" config.layer = +e.currentTarget.value;\n",
" renderVis();\n",
" });\n",
"\n",
" $(`#${config.rootDivId} #filter`).on('change', function (e) {\n",
" config.filter = e.currentTarget.value;\n",
" renderVis();\n",
" });\n",
" }\n",
"\n",
" function renderVis() {\n",
"\n",
" // Load parameters\n",
" const attnData = config.attention[config.filter];\n",
" const leftText = attnData.left_text;\n",
" const rightText = attnData.right_text;\n",
"\n",
" // Select attention for given layer\n",
" const layerAttention = attnData.attn[config.layer];\n",
"\n",
" // Clear vis\n",
" $(`#${config.rootDivId} #vis`).empty();\n",
"\n",
" // Determine size of visualization\n",
" const height = Math.max(leftText.length, rightText.length) * BOXHEIGHT + TEXT_TOP;\n",
" const svg = d3.select(`#${config.rootDivId} #vis`)\n",
" .append('svg')\n",
" .attr(\"width\", \"100%\")\n",
" .attr(\"height\", height + \"px\");\n",
"\n",
" // Display tokens on left and right side of visualization\n",
" renderText(svg, leftText, true, layerAttention, 0);\n",
" renderText(svg, rightText, false, layerAttention, MATRIX_WIDTH + BOXWIDTH);\n",
"\n",
" // Render attention arcs\n",
" renderAttention(svg, layerAttention);\n",
"\n",
" // Draw squares at top of visualization, one for each head\n",
" drawCheckboxes(0, svg, layerAttention);\n",
" }\n",
"\n",
" function renderText(svg, text, isLeft, attention, leftPos) {\n",
"\n",
" const textContainer = svg.append(\"svg:g\")\n",
" .attr(\"id\", isLeft ? \"left\" : \"right\");\n",
"\n",
" // Add attention highlights superimposed over words\n",
" textContainer.append(\"g\")\n",
" .classed(\"attentionBoxes\", true)\n",
" .selectAll(\"g\")\n",
" .data(attention)\n",
" .enter()\n",
" .append(\"g\")\n",
" .attr(\"head-index\", (d, i) => i)\n",
" .selectAll(\"rect\")\n",
" .data(d => isLeft ? d : transpose(d)) // if right text, transpose attention to get right-to-left weights\n",
" .enter()\n",
" .append(\"rect\")\n",
" .attr(\"x\", function () {\n",
" var headIndex = +this.parentNode.getAttribute(\"head-index\");\n",
" return leftPos + boxOffsets(headIndex);\n",
" })\n",
" .attr(\"y\", (+1) * BOXHEIGHT)\n",
" .attr(\"width\", BOXWIDTH / activeHeads())\n",
" .attr(\"height\", BOXHEIGHT)\n",
" .attr(\"fill\", function () {\n",
" return headColors(+this.parentNode.getAttribute(\"head-index\"))\n",
" })\n",
" .style(\"opacity\", 0.0);\n",
"\n",
" const tokenContainer = textContainer.append(\"g\").selectAll(\"g\")\n",
" .data(text)\n",
" .enter()\n",
" .append(\"g\");\n",
"\n",
" // Add gray background that appears when hovering over text\n",
" tokenContainer.append(\"rect\")\n",
" .classed(\"background\", true)\n",
" .style(\"opacity\", 0.0)\n",
" .attr(\"fill\", \"lightgray\")\n",
" .attr(\"x\", leftPos)\n",
" .attr(\"y\", (d, i) => TEXT_TOP + i * BOXHEIGHT)\n",
" .attr(\"width\", BOXWIDTH)\n",
" .attr(\"height\", BOXHEIGHT);\n",
"\n",
" // Add token text\n",
" const textEl = tokenContainer.append(\"text\")\n",
" .text(d => d)\n",
" .attr(\"font-size\", TEXT_SIZE + \"px\")\n",
" .style(\"cursor\", \"default\")\n",
" .style(\"-webkit-user-select\", \"none\")\n",
" .attr(\"x\", leftPos)\n",
" .attr(\"y\", (d, i) => TEXT_TOP + i * BOXHEIGHT);\n",
"\n",
" if (isLeft) {\n",
" textEl.style(\"text-anchor\", \"end\")\n",
" .attr(\"dx\", BOXWIDTH - 0.5 * TEXT_SIZE)\n",
" .attr(\"dy\", TEXT_SIZE);\n",
" } else {\n",
" textEl.style(\"text-anchor\", \"start\")\n",
" .attr(\"dx\", +0.5 * TEXT_SIZE)\n",
" .attr(\"dy\", TEXT_SIZE);\n",
" }\n",
"\n",
" tokenContainer.on(\"mouseover\", function (d, index) {\n",
"\n",
" // Show gray background for moused-over token\n",
" textContainer.selectAll(\".background\")\n",
" .style(\"opacity\", (d, i) => i === index ? 1.0 : 0.0)\n",
"\n",
" // Reset visibility attribute for any previously highlighted attention arcs\n",
" svg.select(\"#attention\")\n",
" .selectAll(\"line[visibility='visible']\")\n",
" .attr(\"visibility\", null)\n",
"\n",
" // Hide group containing attention arcs\n",
" svg.select(\"#attention\").attr(\"visibility\", \"hidden\");\n",
"\n",
" // Set to visible appropriate attention arcs to be highlighted\n",
" if (isLeft) {\n",
" svg.select(\"#attention\").selectAll(\"line[left-token-index='\" + index + \"']\").attr(\"visibility\", \"visible\");\n",
" } else {\n",
" svg.select(\"#attention\").selectAll(\"line[right-token-index='\" + index + \"']\").attr(\"visibility\", \"visible\");\n",
" }\n",
"\n",
" // Update color boxes superimposed over tokens\n",
" const id = isLeft ? \"right\" : \"left\";\n",
" const leftPos = isLeft ? MATRIX_WIDTH + BOXWIDTH : 0;\n",
" svg.select(\"#\" + id)\n",
" .selectAll(\".attentionBoxes\")\n",
" .selectAll(\"g\")\n",
" .attr(\"head-index\", (d, i) => i)\n",
" .selectAll(\"rect\")\n",
" .attr(\"x\", function () {\n",
" const headIndex = +this.parentNode.getAttribute(\"head-index\");\n",
" return leftPos + boxOffsets(headIndex);\n",
" })\n",
" .attr(\"y\", (d, i) => TEXT_TOP + i * BOXHEIGHT)\n",
" .attr(\"width\", BOXWIDTH / activeHeads())\n",
" .attr(\"height\", BOXHEIGHT)\n",
" .style(\"opacity\", function (d) {\n",
" const headIndex = +this.parentNode.getAttribute(\"head-index\");\n",
" if (config.headVis[headIndex])\n",
" if (d) {\n",
" return d[index];\n",
" } else {\n",
" return 0.0;\n",
" }\n",
" else\n",
" return 0.0;\n",
" });\n",
" });\n",
"\n",
" textContainer.on(\"mouseleave\", function () {\n",
"\n",
" // Unhighlight selected token\n",
" d3.select(this).selectAll(\".background\")\n",
" .style(\"opacity\", 0.0);\n",
"\n",
" // Reset visibility attributes for previously selected lines\n",
" svg.select(\"#attention\")\n",
" .selectAll(\"line[visibility='visible']\")\n",
" .attr(\"visibility\", null) ;\n",
" svg.select(\"#attention\").attr(\"visibility\", \"visible\");\n",
"\n",
" // Reset highlights superimposed over tokens\n",
" svg.selectAll(\".attentionBoxes\")\n",
" .selectAll(\"g\")\n",
" .selectAll(\"rect\")\n",
" .style(\"opacity\", 0.0);\n",
" });\n",
" }\n",
"\n",
" function renderAttention(svg, attention) {\n",
"\n",
" // Remove previous dom elements\n",
" svg.select(\"#attention\").remove();\n",
"\n",
" // Add new elements\n",
" svg.append(\"g\")\n",
" .attr(\"id\", \"attention\") // Container for all attention arcs\n",
" .selectAll(\".headAttention\")\n",
" .data(attention)\n",
" .enter()\n",
" .append(\"g\")\n",
" .classed(\"headAttention\", true) // Group attention arcs by head\n",
" .attr(\"head-index\", (d, i) => i)\n",
" .selectAll(\".tokenAttention\")\n",
" .data(d => d)\n",
" .enter()\n",
" .append(\"g\")\n",
" .classed(\"tokenAttention\", true) // Group attention arcs by left token\n",
" .attr(\"left-token-index\", (d, i) => i)\n",
" .selectAll(\"line\")\n",
" .data(d => d)\n",
" .enter()\n",
" .append(\"line\")\n",
" .attr(\"x1\", BOXWIDTH)\n",
" .attr(\"y1\", function () {\n",
" const leftTokenIndex = +this.parentNode.getAttribute(\"left-token-index\")\n",
" return TEXT_TOP + leftTokenIndex * BOXHEIGHT + (BOXHEIGHT / 2)\n",
" })\n",
" .attr(\"x2\", BOXWIDTH + MATRIX_WIDTH)\n",
" .attr(\"y2\", (d, rightTokenIndex) => TEXT_TOP + rightTokenIndex * BOXHEIGHT + (BOXHEIGHT / 2))\n",
" .attr(\"stroke-width\", 2)\n",
" .attr(\"stroke\", function () {\n",
" const headIndex = +this.parentNode.parentNode.getAttribute(\"head-index\");\n",
" return headColors(headIndex)\n",
" })\n",
" .attr(\"left-token-index\", function () {\n",
" return +this.parentNode.getAttribute(\"left-token-index\")\n",
" })\n",
" .attr(\"right-token-index\", (d, i) => i)\n",
" ;\n",
" updateAttention(svg)\n",
" }\n",
"\n",
" function updateAttention(svg) {\n",
" svg.select(\"#attention\")\n",
" .selectAll(\"line\")\n",
" .attr(\"stroke-opacity\", function (d) {\n",
" const headIndex = +this.parentNode.parentNode.getAttribute(\"head-index\");\n",
" // If head is selected\n",
" if (config.headVis[headIndex]) {\n",
" // Set opacity to attention weight divided by number of active heads\n",
" return d / activeHeads()\n",
" } else {\n",
" return 0.0;\n",
" }\n",
" })\n",
" }\n",
"\n",
" function boxOffsets(i) {\n",
" const numHeadsAbove = config.headVis.reduce(\n",
" function (acc, val, cur) {\n",
" return val && cur < i ? acc + 1 : acc;\n",
" }, 0);\n",
" return numHeadsAbove * (BOXWIDTH / activeHeads());\n",
" }\n",
"\n",
" function activeHeads() {\n",
" return config.headVis.reduce(function (acc, val) {\n",
" return val ? acc + 1 : acc;\n",
" }, 0);\n",
" }\n",
"\n",
" function drawCheckboxes(top, svg) {\n",
" const checkboxContainer = svg.append(\"g\");\n",
" const checkbox = checkboxContainer.selectAll(\"rect\")\n",
" .data(config.headVis)\n",
" .enter()\n",
" .append(\"rect\")\n",
" .attr(\"fill\", (d, i) => headColors(i))\n",
" .attr(\"x\", (d, i) => i * CHECKBOX_SIZE)\n",
" .attr(\"y\", top)\n",
" .attr(\"width\", CHECKBOX_SIZE)\n",
" .attr(\"height\", CHECKBOX_SIZE);\n",
"\n",
" function updateCheckboxes() {\n",
" checkboxContainer.selectAll(\"rect\")\n",
" .data(config.headVis)\n",
" .attr(\"fill\", (d, i) => d ? headColors(i): lighten(headColors(i)));\n",
" }\n",
"\n",
" updateCheckboxes();\n",
"\n",
" checkbox.on(\"click\", function (d, i) {\n",
" if (config.headVis[i] && activeHeads() === 1) return;\n",
" config.headVis[i] = !config.headVis[i];\n",
" updateCheckboxes();\n",
" updateAttention(svg);\n",
" });\n",
"\n",
" checkbox.on(\"dblclick\", function (d, i) {\n",
" // If we double click on the only active head then reset\n",
" if (config.headVis[i] && activeHeads() === 1) {\n",
" config.headVis = new Array(config.nHeads).fill(true);\n",
" } else {\n",
" config.headVis = new Array(config.nHeads).fill(false);\n",
" config.headVis[i] = true;\n",
" }\n",
" updateCheckboxes();\n",
" updateAttention(svg);\n",
" });\n",
" }\n",
"\n",
" function lighten(color) {\n",
" const c = d3.hsl(color);\n",
" const increment = (1 - c.l) * 0.6;\n",
" c.l += increment;\n",
" c.s -= increment;\n",
" return c;\n",
" }\n",
"\n",
" function transpose(mat) {\n",
" return mat[0].map(function (col, i) {\n",
" return mat.map(function (row) {\n",
" return row[i];\n",
" });\n",
" });\n",
" }\n",
"\n",
"});"
],
"text/plain": [
"<IPython.core.display.Javascript object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"head_view(\n",
" encoder_attention=outputs.encoder_attentions,\n",
" decoder_attention=outputs.decoder_attentions,\n",
" cross_attention=outputs.cross_attentions,\n",
" encoder_tokens= encoder_text,\n",
" decoder_tokens = decoder_text\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"scrolled": false
},
"outputs": [
{
"data": {
"text/html": [
"<script src=\"https://cdnjs.cloudflare.com/ajax/libs/require.js/2.3.6/require.min.js\"></script>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
" \n",
" <div id='bertviz-5e36904179de4109b01180d66af00b7a'>\n",
" <span style=\"user-select:none\">\n",
" Attention: <select id=\"filter\"><option value=\"0\">Encoder</option>\n",
"<option value=\"1\">Decoder</option>\n",
"<option value=\"2\">Cross</option></select>\n",
" </span>\n",
" <div id='vis'></div>\n",
" </div>\n",
" "
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/javascript": [
"/**\n",
" * @fileoverview Transformer Visualization D3 javascript code.\n",
" *\n",
" * Based on: https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/visualization/attention.js\n",
" *\n",
" * Change log:\n",
" *\n",
" * 02/01/19 Jesse Vig Initial implementation\n",
" * 12/31/20 Jesse Vig Support multiple visualizations in single notebook.\n",
" * 01/19/21 Jesse Vig Support light/dark modes\n",
" * 02/06/21 Jesse Vig Move require config from separate jupyter notebook step\n",
" * 05/03/21 Jesse Vig Adjust visualization height dynamically\n",
" **/\n",
"\n",
"require.config({\n",
" paths: {\n",
" d3: '//cdnjs.cloudflare.com/ajax/libs/d3/5.7.0/d3.min',\n",
" jquery: '//ajax.googleapis.com/ajax/libs/jquery/2.0.0/jquery.min',\n",
" }\n",
"});\n",
"\n",
"requirejs(['jquery', 'd3'], function($, d3) {\n",
"\n",
" const params = {\"attention\": [{\"name\": \"Encoder\", \"attn\": [[[[0.0015343179693445563, 0.06372829526662827, 0.3402520716190338, 0.014809303916990757, 0.11685427278280258, 0.09690533578395844, 0.3659164011478424], [0.008656513877213001, 0.08992935717105865, 0.25149786472320557, 0.07625745981931686, 0.20215430855751038, 0.10926036536693573, 0.26224416494369507], [0.01743713766336441, 0.19126100838184357, 0.019580217078328133, 0.10873700678348541, 0.5898742079734802, 0.013855690136551857, 0.05925469845533371], [0.006695661228150129, 0.06985638290643692, 0.03561459109187126, 0.03681695833802223, 0.7022322416305542, 0.054333802312612534, 0.09445030987262726], [0.0010159132070839405, 0.004103748593479395, 0.18814627826213837, 0.013669033534824848, 0.03894238546490669, 0.16656135022640228, 0.5875613689422607], [0.0033098733983933926, 0.06926336139440536, 0.006486027967184782, 0.033568885177373886, 0.7985183000564575, 0.014555910602211952, 0.07429762929677963], [0.006561059970408678, 0.028042230755090714, 0.06652601808309555, 0.04403909668326378, 0.1904589682817459, 0.17164954543113708, 0.49272310733795166]], [[0.0010659039253368974, 0.027911514043807983, 0.5827310681343079, 0.00583711639046669, 0.02539275959134102, 0.09940840303897858, 0.25765326619148254], [0.014391965232789516, 0.07798851281404495, 0.4047695994377136, 0.01874568685889244, 0.03196389973163605, 0.1091737151145935, 0.34296661615371704], [0.05386969447135925, 0.1397337168455124, 0.1042787954211235, 0.41253480315208435, 0.13313433527946472, 0.06411976367235184, 0.09232895076274872], [0.011524752713739872, 0.024617008864879608, 0.11605977267026901, 0.15868355333805084, 0.07823850214481354, 0.1350926011800766, 0.4757837951183319], [0.000765774748288095, 0.0007213741773739457, 0.15132009983062744, 0.007885550148785114, 0.01554951909929514, 0.1716994047164917, 0.6520583033561707], [0.00261464505456388, 0.01399530004709959, 0.0927773043513298, 0.05853908136487007, 0.12484410405158997, 0.22863952815532684, 0.4785899519920349], [0.005509554408490658, 0.018335115164518356, 0.07280543446540833, 0.02366117760539055, 0.06281809508800507, 0.21027612686157227, 0.6065945625305176]], [[0.001894954708404839, 0.04986847937107086, 0.3366052806377411, 0.005386138800531626, 0.11281800270080566, 0.07878921926021576, 0.4146379232406616], [0.0033349881414324045, 0.03520013391971588, 0.5712409615516663, 0.031452734023332596, 0.0384756363928318, 0.08380278199911118, 0.2364927977323532], [0.032189127057790756, 0.08156478404998779, 0.22131586074829102, 0.08562739938497543, 0.08101073652505875, 0.1313415765762329, 0.36695045232772827], [0.0022780823055654764, 0.08724657446146011, 0.20255671441555023, 0.09033971279859543, 0.11631207168102264, 0.25404372811317444, 0.24722309410572052], [0.00028211530297994614, 0.0050274962559342384, 0.12840455770492554, 0.008146004751324654, 0.05242825299501419, 0.23907428979873657, 0.5666372776031494], [8.598285057814792e-05, 0.0020894501358270645, 0.02531367912888527, 0.001749282469972968, 0.01047285832464695, 0.10172442346811295, 0.8585643172264099], [0.0019090479472652078, 0.009371059946715832, 0.03571559488773346, 0.008757798001170158, 0.01917792297899723, 0.12495970726013184, 0.8001089096069336]], [[0.0002124253660440445, 0.07846645265817642, 0.6617725491523743, 0.0009571347036398947, 0.0028506950475275517, 0.0355488583445549, 0.2201918661594391], [0.0013146958081051707, 0.03666481375694275, 0.5017722249031067, 0.032994918525218964, 0.08658347278833389, 0.04107796028256416, 0.29959189891815186], [0.007319945842027664, 0.25871768593788147, 0.07154543697834015, 0.20867864787578583, 0.18836019933223724, 0.18986377120018005, 0.07551436871290207], [0.002156124683097005, 0.056250181049108505, 0.20182666182518005, 0.011035334318876266, 0.0871698409318924, 0.05659180134534836, 0.5849699974060059], [0.0009280861704610288, 0.004628014750778675, 0.20538683235645294, 0.005883386358618736, 0.012845118530094624, 0.10419470071792603, 0.6661338806152344], [0.0012555076973512769, 0.041694026440382004, 0.05578531697392464, 0.12980355322360992, 0.201
" const config = {};\n",
"\n",
" const MIN_X = 0;\n",
" const MIN_Y = 0;\n",
" const DIV_WIDTH = 970;\n",
" const THUMBNAIL_PADDING = 5;\n",
" const DETAIL_WIDTH = 300;\n",
" const DETAIL_ATTENTION_WIDTH = 140;\n",
" const DETAIL_BOX_WIDTH = 80;\n",
" const DETAIL_BOX_HEIGHT = 18;\n",
" const DETAIL_PADDING = 28;\n",
" const ATTN_PADDING = 0;\n",
" const DETAIL_HEADING_HEIGHT = 25;\n",
" const DETAIL_HEADING_TEXT_SIZE = 15;\n",
" const TEXT_SIZE = 13;\n",
" const LAYER_COLORS = d3.schemeCategory10;\n",
" const PALETTE = {\n",
" 'light': {\n",
" 'text': 'black',\n",
" 'background': 'white',\n",
" 'highlight': '#F5F5F5'\n",
" },\n",
" 'dark': {\n",
" 'text': '#bbb',\n",
" 'background': 'black',\n",
" 'highlight': '#222'\n",
" }\n",
" }\n",
"\n",
" function render() {\n",
"\n",
" // Set global state variables\n",
"\n",
" var attData = config.attention[config.filter];\n",
" config.leftText = attData.left_text;\n",
" config.rightText = attData.right_text;\n",
" config.attn = attData.attn;\n",
" config.numLayers = config.attn.length;\n",
" config.numHeads = config.attn[0].length;\n",
" config.thumbnailBoxHeight = 7 * (12 / config.numHeads);\n",
" config.thumbnailHeight = Math.max(config.leftText.length, config.rightText.length) * config.thumbnailBoxHeight + 2 * THUMBNAIL_PADDING;\n",
" config.thumbnailWidth = DIV_WIDTH / config.numHeads;\n",
" config.detailHeight = Math.max(config.leftText.length, config.rightText.length) * DETAIL_BOX_HEIGHT + 2 * DETAIL_PADDING + DETAIL_HEADING_HEIGHT;\n",
" config.divHeight = config.numLayers * config.thumbnailHeight;\n",
"\n",
" const vis = $(`#${config.rootDivId} #vis`)\n",
" vis.empty();\n",
" vis.attr(\"height\", config.divHeight);\n",
" config.svg = d3.select(`#${config.rootDivId} #vis`)\n",
" .append('svg')\n",
" .attr(\"width\", DIV_WIDTH)\n",
" .attr(\"height\", config.divHeight)\n",
" .attr(\"fill\", getBackgroundColor());\n",
"\n",
" var i;\n",
" var j;\n",
" for (i = 0; i < config.numLayers; i++) {\n",
" for (j = 0; j < config.numHeads; j++) {\n",
" renderThumbnail(i, j);\n",
" }\n",
" }\n",
" }\n",
"\n",
" function renderThumbnail(layerIndex, headIndex) {\n",
" var x = headIndex * config.thumbnailWidth;\n",
" var y = layerIndex * config.thumbnailHeight;\n",
" renderThumbnailAttn(x, y, config.attn[layerIndex][headIndex], layerIndex, headIndex);\n",
" }\n",
"\n",
" function renderDetail(att, layerIndex, headIndex) {\n",
" var xOffset = .8 * config.thumbnailWidth;\n",
" var maxX = DIV_WIDTH;\n",
" var maxY = config.divHeight;\n",
" var leftPos = (headIndex / config.numHeads) * DIV_WIDTH\n",
" var x = leftPos + THUMBNAIL_PADDING + xOffset;\n",
" if (x < MIN_X) {\n",
" x = MIN_X;\n",
" } else if (x + DETAIL_WIDTH > maxX) {\n",
" x = leftPos + THUMBNAIL_PADDING - DETAIL_WIDTH + 8;\n",
" }\n",
" var posLeftText = x;\n",
" var posAttention = posLeftText + DETAIL_BOX_WIDTH;\n",
" var posRightText = posAttention + DETAIL_ATTENTION_WIDTH;\n",
" var thumbnailHeight = Math.max(config.leftText.length, config.rightText.length) * config.thumbnailBoxHeight + 2 * THUMBNAIL_PADDING;\n",
" var yOffset = 20;\n",
" var y = layerIndex * thumbnailHeight + THUMBNAIL_PADDING + yOffset;\n",
" if (y < MIN_Y) {\n",
" y = MIN_Y;\n",
" } else if (y + config.detailHeight > maxY) {\n",
" y = maxY - config.detailHeight;\n",
" }\n",
" renderDetailFrame(x, y, layerIndex);\n",
" renderDetailHeading(x, y + Math.max(config.leftText.length, config.rightText.length) *\n",
" DETAIL_BOX_HEIGHT, layerIndex, headIndex);\n",
" renderDetailText(config.leftText, \"leftText\", posLeftText, y + DETAIL_PADDING, layerIndex);\n",
" renderDetailAttn(posAttention, y + DETAIL_PADDING, att, layerIndex, headIndex);\n",
" renderDetailText(config.rightText, \"rightText\", posRightText, y + DETAIL_PADDING, layerIndex);\n",
" }\n",
"\n",
" function renderDetailHeading(x, y, layerIndex, headIndex) {\n",
" var fillColor = getTextColor();\n",
" config.svg.append(\"text\")\n",
" .classed(\"detail\", true)\n",
" .text('Layer ' + layerIndex + \", Head \" + headIndex)\n",
" .attr(\"font-size\", DETAIL_HEADING_TEXT_SIZE + \"px\")\n",
" .style(\"cursor\", \"default\")\n",
" .style(\"-webkit-user-select\", \"none\")\n",
" .attr(\"fill\", fillColor)\n",
" .attr(\"x\", x + DETAIL_WIDTH / 2)\n",
" .attr(\"text-anchor\", \"middle\")\n",
" .attr(\"y\", y + 40)\n",
" .attr(\"height\", DETAIL_HEADING_HEIGHT)\n",
" .attr(\"width\", DETAIL_WIDTH)\n",
" .attr(\"dy\", DETAIL_HEADING_TEXT_SIZE);\n",
" }\n",
"\n",
" function renderDetailText(text, id, x, y, layerIndex) {\n",
" var tokenContainer = config.svg.append(\"svg:g\")\n",
" .classed(\"detail\", true)\n",
" .selectAll(\"g\")\n",
" .data(text)\n",
" .enter()\n",
" .append(\"g\");\n",
"\n",
" var fillColor = getTextColor();\n",
"\n",
" tokenContainer.append(\"rect\")\n",
" .classed(\"highlight\", true)\n",
" .attr(\"fill\", fillColor)\n",
" .style(\"opacity\", 0.0)\n",
" .attr(\"height\", DETAIL_BOX_HEIGHT)\n",
" .attr(\"width\", DETAIL_BOX_WIDTH)\n",
" .attr(\"x\", x)\n",
" .attr(\"y\", function (d, i) {\n",
" return y + i * DETAIL_BOX_HEIGHT;\n",
" });\n",
"\n",
" var textContainer = tokenContainer.append(\"text\")\n",
" .classed(\"token\", true)\n",
" .text(function (d) {\n",
" return d;\n",
" })\n",
" .attr(\"font-size\", TEXT_SIZE + \"px\")\n",
" .style(\"cursor\", \"default\")\n",
" .style(\"-webkit-user-select\", \"none\")\n",
" .attr(\"fill\", fillColor)\n",
" .attr(\"x\", x)\n",
" .attr(\"y\", function (d, i) {\n",
" return i * DETAIL_BOX_HEIGHT + y;\n",
" })\n",
" .attr(\"height\", DETAIL_BOX_HEIGHT)\n",
" .attr(\"width\", DETAIL_BOX_WIDTH)\n",
" .attr(\"dy\", TEXT_SIZE);\n",
"\n",
" if (id == \"leftText\") {\n",
" textContainer.style(\"text-anchor\", \"end\")\n",
" .attr(\"dx\", DETAIL_BOX_WIDTH - 2);\n",
" tokenContainer.on(\"mouseover\", function (d, index) {\n",
" highlightSelection(index);\n",
" });\n",
" tokenContainer.on(\"mouseleave\", function () {\n",
" unhighlightSelection();\n",
" });\n",
" }\n",
" }\n",
"\n",
" function highlightSelection(index) {\n",
" config.svg.select(\"#leftText\")\n",
" .selectAll(\".highlight\")\n",
" .style(\"opacity\", function (d, i) {\n",
" return i == index ? 1.0 : 0.0;\n",
" });\n",
" config.svg.selectAll(\".attn-line-group\")\n",
" .style(\"opacity\", function (d, i) {\n",
" return i == index ? 1.0 : 0.0;\n",
" });\n",
" }\n",
"\n",
" function unhighlightSelection() {\n",
" config.svg.select(\"#leftText\")\n",
" .selectAll(\".highlight\")\n",
" .style(\"opacity\", 0.0);\n",
" config.svg.selectAll(\".attn-line-group\")\n",
" .style(\"opacity\", 1);\n",
" }\n",
"\n",
" function renderThumbnailAttn(x, y, att, layerIndex, headIndex) {\n",
"\n",
" var attnContainer = config.svg.append(\"svg:g\");\n",
"\n",
" var attnBackground = attnContainer.append(\"rect\")\n",
" .attr(\"id\", 'attn_background_' + layerIndex + \"_\" + headIndex)\n",
" .classed(\"attn_background\", true)\n",
" .attr(\"x\", x)\n",
" .attr(\"y\", y)\n",
" .attr(\"height\", config.thumbnailHeight)\n",
" .attr(\"width\", config.thumbnailWidth)\n",
" .attr(\"stroke-width\", 2)\n",
" .attr(\"stroke\", getLayerColor(layerIndex))\n",
" .attr(\"stroke-opacity\", 0)\n",
" .attr(\"fill\", getBackgroundColor());\n",
" var x1 = x + THUMBNAIL_PADDING;\n",
" var x2 = x1 + config.thumbnailWidth - 14;\n",
" var y1 = y + THUMBNAIL_PADDING;\n",
"\n",
" attnContainer.selectAll(\"g\")\n",
" .data(att)\n",
" .enter()\n",
" .append(\"g\") // Add group for each source token\n",
" .attr(\"source-index\", function (d, i) { // Save index of source token\n",
" return i;\n",
" })\n",
" .selectAll(\"line\")\n",
" .data(function (d) { // Loop over all target tokens\n",
" return d;\n",
" })\n",
" .enter() // When entering\n",
" .append(\"line\")\n",
" .attr(\"x1\", x1)\n",
" .attr(\"y1\", function (d) {\n",
" var sourceIndex = +this.parentNode.getAttribute(\"source-index\");\n",
" return y1 + (sourceIndex + .5) * config.thumbnailBoxHeight;\n",
" })\n",
" .attr(\"x2\", x2)\n",
" .attr(\"y2\", function (d, targetIndex) {\n",
" return y1 + (targetIndex + .5) * config.thumbnailBoxHeight;\n",
" })\n",
" .attr(\"stroke-width\", 2.2)\n",
" .attr(\"stroke\", getLayerColor(layerIndex))\n",
" .attr(\"stroke-opacity\", function (d) {\n",
" return d;\n",
" });\n",
"\n",
" var clickRegion = attnContainer.append(\"rect\")\n",
" .attr(\"x\", x)\n",
" .attr(\"y\", y)\n",
" .attr(\"height\", config.thumbnailHeight)\n",
" .attr(\"width\", config.thumbnailWidth)\n",
" .style(\"opacity\", 0);\n",
"\n",
" clickRegion.on(\"click\", function (d, index) {\n",
" var attnBackgroundOther = config.svg.selectAll(\".attn_background\");\n",
" attnBackgroundOther.attr(\"fill\", getBackgroundColor());\n",
" attnBackgroundOther.attr(\"stroke-opacity\", 0);\n",
"\n",
" config.svg.selectAll(\".detail\").remove();\n",
" if (config.detail_layer != layerIndex || config.detail_head != headIndex) {\n",
" renderDetail(att, layerIndex, headIndex);\n",
" config.detail_layer = layerIndex;\n",
" config.detail_head = headIndex;\n",
" attnBackground.attr(\"fill\", getHighlightColor());\n",
" attnBackground.attr(\"stroke-opacity\", .8);\n",
" } else {\n",
" config.detail_layer = null;\n",
" config.detail_head = null;\n",
" attnBackground.attr(\"fill\", getBackgroundColor());\n",
" attnBackground.attr(\"stroke-opacity\", 0);\n",
" }\n",
" });\n",
"\n",
" clickRegion.on(\"mouseover\", function (d) {\n",
" d3.select(this).style(\"cursor\", \"pointer\");\n",
" });\n",
" }\n",
"\n",
" function renderDetailFrame(x, y, layerIndex) {\n",
" var detailFrame = config.svg.append(\"rect\")\n",
" .classed(\"detail\", true)\n",
" .attr(\"x\", x)\n",
" .attr(\"y\", y)\n",
" .attr(\"height\", config.detailHeight)\n",
" .attr(\"width\", DETAIL_WIDTH)\n",
" .style(\"opacity\", 1)\n",
" .attr(\"stroke-width\", 1.5)\n",
" .attr(\"stroke-opacity\", 0.7)\n",
" .attr(\"stroke\", getLayerColor(layerIndex));\n",
" }\n",
"\n",
" function renderDetailAttn(x, y, att, layerIndex) {\n",
" var attnContainer = config.svg.append(\"svg:g\")\n",
" .classed(\"detail\", true)\n",
" .attr(\"pointer-events\", \"none\");\n",
" attnContainer.selectAll(\"g\")\n",
" .data(att)\n",
" .enter()\n",
" .append(\"g\") // Add group for each source token\n",
" .classed('attn-line-group', true)\n",
" .attr(\"source-index\", function (d, i) { // Save index of source token\n",
" return i;\n",
" })\n",
" .selectAll(\"line\")\n",
" .data(function (d) { // Loop over all target tokens\n",
" return d;\n",
" })\n",
" .enter()\n",
" .append(\"line\")\n",
" .attr(\"x1\", x + ATTN_PADDING)\n",
" .attr(\"y1\", function (d) {\n",
" var sourceIndex = +this.parentNode.getAttribute(\"source-index\");\n",
" return y + (sourceIndex + .5) * DETAIL_BOX_HEIGHT;\n",
" })\n",
" .attr(\"x2\", x + DETAIL_ATTENTION_WIDTH - ATTN_PADDING)\n",
" .attr(\"y2\", function (d, targetIndex) {\n",
" return y + (targetIndex + .5) * DETAIL_BOX_HEIGHT;\n",
" })\n",
" .attr(\"stroke-width\", 2.2)\n",
" .attr(\"stroke\", getLayerColor(layerIndex))\n",
" .attr(\"stroke-opacity\", function (d) {\n",
" return d;\n",
" });\n",
" }\n",
"\n",
" function getLayerColor(layer) {\n",
" return LAYER_COLORS[layer % 10];\n",
" }\n",
"\n",
" function getTextColor() {\n",
" return PALETTE[config.mode]['text']\n",
" }\n",
"\n",
" function getBackgroundColor() {\n",
" return PALETTE[config.mode]['background']\n",
" }\n",
"\n",
" function getHighlightColor() {\n",
" return PALETTE[config.mode]['highlight']\n",
" }\n",
"\n",
" function initialize() {\n",
" config.attention = params['attention'];\n",
" config.filter = params['default_filter'];\n",
" config.mode = params['display_mode'];\n",
" config.rootDivId = params['root_div_id'];\n",
" $(`#${config.rootDivId} #filter`).on('change', function (e) {\n",
" config.filter = e.currentTarget.value;\n",
" render();\n",
" });\n",
" }\n",
"\n",
" initialize();\n",
" render();\n",
"\n",
" });"
],
"text/plain": [
"<IPython.core.display.Javascript object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"model_view(\n",
" encoder_attention=outputs.encoder_attentions,\n",
" decoder_attention=outputs.decoder_attentions,\n",
" cross_attention=outputs.cross_attentions,\n",
" encoder_tokens= encoder_text,\n",
" decoder_tokens = decoder_text\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Zadanie (10 minut)\n",
"\n",
"Za pomocą modelu en-fr przetłumacz dowolne zdanie z angielskiego na język francuski i sprawdź wagi atencji dla tego tłumaczenia"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"MODEL = \"Helsinki-NLP/opus-mt-en-fr\""
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"TEXT_ENCODER = \"Although I still have fresh memories of my brother the elder Hamlets death, and though it was proper to mourn him throughout our kingdom, life still goes on—I think its wise to mourn him while also thinking about my own well being.\""
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/kuba/anaconda3/lib/python3.8/site-packages/transformers/models/auto/modeling_auto.py:921: FutureWarning: The class `AutoModelWithLMHead` is deprecated and will be removed in a future version. Please use `AutoModelForCausalLM` for causal language models, `AutoModelForMaskedLM` for masked language models and `AutoModelForSeq2SeqLM` for encoder-decoder models.\n",
" warnings.warn(\n"
]
}
],
"source": [
"from transformers import AutoModelWithLMHead, AutoTokenizer\n",
"\n",
"model = AutoModelWithLMHead.from_pretrained(MODEL)\n",
"tokenizer = AutoTokenizer.from_pretrained(MODEL)\n",
"\n",
"inputs = tokenizer.encode(TEXT_ENCODER, return_tensors=\"pt\")\n",
"outputs = model.generate(inputs, max_length=40, num_beams=4, early_stopping=True)"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"TEXT_DECODER = tokenizer.decode(outputs[0])"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"\"<pad> Bien que j'aie encore de nouveaux souvenirs de la mort de mon frère Hamlet, l'aîné, et bien qu'il fût approprié de le pleurer dans tout notre royaume,\""
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"TEXT_DECODER"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
"tokenizer = AutoTokenizer.from_pretrained(MODEL)\n",
"model = AutoModel.from_pretrained(MODEL, output_attentions=True)"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
"encoder_input_ids = tokenizer(TEXT_ENCODER, return_tensors=\"pt\", add_special_tokens=True).input_ids\n",
"decoder_input_ids = tokenizer(TEXT_DECODER, return_tensors=\"pt\", add_special_tokens=True).input_ids\n",
"\n",
"outputs = model(input_ids=encoder_input_ids, decoder_input_ids=decoder_input_ids)\n",
"\n",
"encoder_text = tokenizer.convert_ids_to_tokens(encoder_input_ids[0])\n",
"decoder_text = tokenizer.convert_ids_to_tokens(decoder_input_ids[0])"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {
"scrolled": false
},
"outputs": [
{
"data": {
"text/html": [
"<script src=\"https://cdnjs.cloudflare.com/ajax/libs/require.js/2.3.6/require.min.js\"></script>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
" \n",
" <div id='bertviz-1a9450f47eac4faabbbab2309ee58a83'>\n",
" <span style=\"user-select:none\">\n",
" Layer: <select id=\"layer\"></select>\n",
" Attention: <select id=\"filter\"><option value=\"0\">Encoder</option>\n",
"<option value=\"1\">Decoder</option>\n",
"<option value=\"2\">Cross</option></select>\n",
" </span>\n",
" <div id='vis'></div>\n",
" </div>\n",
" "
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/javascript": [
"/**\n",
" * @fileoverview Transformer Visualization D3 javascript code.\n",
" *\n",
" *\n",
" * Based on: https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/visualization/attention.js\n",
" *\n",
" * Change log:\n",
" *\n",
" * 12/19/18 Jesse Vig Assorted cleanup. Changed orientation of attention matrices.\n",
" * 12/29/20 Jesse Vig Significant refactor.\n",
" * 12/31/20 Jesse Vig Support multiple visualizations in single notebook.\n",
" * 02/06/21 Jesse Vig Move require config from separate jupyter notebook step\n",
" * 05/03/21 Jesse Vig Adjust height of visualization dynamically\n",
" **/\n",
"\n",
"require.config({\n",
" paths: {\n",
" d3: '//cdnjs.cloudflare.com/ajax/libs/d3/5.7.0/d3.min',\n",
" jquery: '//ajax.googleapis.com/ajax/libs/jquery/2.0.0/jquery.min',\n",
" }\n",
"});\n",
"\n",
"requirejs(['jquery', 'd3'], function ($, d3) {\n",
"\n",
" const params = {\"attention\": [{\"name\": \"Encoder\", \"attn\": [[[[0.013920878060162067, 0.023272190243005753, 0.00535218371078372, 0.018134916201233864, 0.007428539451211691, 0.011490718461573124, 0.1538427323102951, 0.001154032303020358, 0.007759112864732742, 0.21605004370212555, 0.006941711064428091, 0.0014982341090217233, 0.0015839363913983107, 0.0017187998164445162, 0.004308860283344984, 0.0005532937939278781, 0.03153957054018974, 0.1311151683330536, 0.003969202283769846, 0.0013689373154193163, 0.011080264113843441, 0.011827018111944199, 0.0761142447590828, 0.020440619438886642, 0.0038698972202837467, 0.001724320463836193, 0.001641443814150989, 0.0034779387060552835, 0.024852916598320007, 0.0009479467989876866, 0.0017434320179745555, 0.002454550703987479, 0.004841943271458149, 0.00019309124036226422, 0.002562746172770858, 0.0019224263960495591, 0.0006097023142501712, 0.0005324910744093359, 0.0012550262035802007, 0.0038530216552317142, 0.023665735498070717, 0.007264423184096813, 0.001738711609505117, 0.0004981358652003109, 0.0007907248218543828, 0.00159708340652287, 0.0011348234256729484, 0.0003918300790246576, 0.000999043695628643, 0.001242546597495675, 0.0002570441865827888, 0.017178047448396683, 0.12429371476173401], [0.14694058895111084, 0.00756652420386672, 0.13187594711780548, 0.03972889855504036, 0.011172577738761902, 0.18229877948760986, 0.013844842091202736, 0.004348167218267918, 0.02901429682970047, 0.012167918495833874, 0.005697580985724926, 0.0032914134208112955, 0.0021410395856946707, 0.0031064630020409822, 0.09043074399232864, 0.0023242426104843616, 0.004553386941552162, 0.015362486243247986, 0.006020396947860718, 0.003748489310964942, 0.05645429342985153, 0.0024631230626255274, 0.0038209601771086454, 0.03585977852344513, 0.004277764819562435, 0.0027697812765836716, 0.010888178832828999, 0.006062007509171963, 0.0022597338538616896, 0.002896188525483012, 0.008117432706058025, 0.002261913614347577, 0.005734980572015047, 0.0003934438282158226, 0.0007230555056594312, 0.008459878154098988, 0.0030921015422791243, 0.002154270652681589, 0.05607650429010391, 0.006397769320756197, 0.002112159039825201, 0.019096804782748222, 0.002815938089042902, 0.0044755516573786736, 0.0003924140764866024, 0.009052095003426075, 0.0026376359164714813, 0.0011557287070900202, 0.00085056311218068, 0.00021846132585778832, 0.004174591973423958, 0.0019873864948749542, 0.014232764020562172], [0.0784544125199318, 0.010349954478442669, 0.018165580928325653, 0.056026097387075424, 0.025532223284244537, 0.10705547034740448, 0.07384178042411804, 0.0019790686201304197, 0.05679738149046898, 0.06699232757091522, 0.015509623102843761, 0.002267960924655199, 0.0043000755831599236, 0.015416891314089298, 0.005750738549977541, 0.06612744182348251, 0.015228715725243092, 0.034174833446741104, 0.019795557484030724, 0.0045778630301356316, 0.009220341220498085, 0.014981023035943508, 0.023301783949136734, 0.04659155756235123, 0.007269928231835365, 0.015406382270157337, 0.00252791540697217, 0.026706278324127197, 0.007078808266669512, 0.04796610772609711, 0.0011842040112242103, 0.008740173652768135, 0.0026490569580346346, 0.00133996841032058, 0.00043630602885968983, 0.009866753593087196, 0.002449900144711137, 0.004007149953395128, 0.001405092072673142, 0.0027358056977391243, 0.010372530668973923, 0.021819977089762688, 0.0035449026618152857, 0.003052458632737398, 0.0004477598995435983, 0.005744684021919966, 0.002297906670719385, 0.00021899335843045264, 0.001487305504269898, 0.000788830395322293, 0.002733449451625347, 0.002520727925002575, 0.030761927366256714], [0.04886134713888168, 0.013975716196000576, 0.11111365258693695, 0.019864728674292564, 0.23528097569942474, 0.11365130543708801, 0.037558287382125854, 0.025197455659508705, 0.018998432904481888, 0.013980338349938393, 0.01313688326627016, 0.0029896548949182034, 0.0010132866445928812, 0.0064606573432683945, 0.0050790454261004925, 0.0020759906619787216, 0.005735225975513458, 0.015039972960948944, 0.09153125435113907, 0.002194924047216773, 0.013067496940493584, 0.018380729481577873, 0.008223
" const TEXT_SIZE = 15;\n",
" const BOXWIDTH = 110;\n",
" const BOXHEIGHT = 22.5;\n",
" const MATRIX_WIDTH = 115;\n",
" const CHECKBOX_SIZE = 20;\n",
" const TEXT_TOP = 30;\n",
"\n",
" console.log(\"d3 version\", d3.version)\n",
" let headColors;\n",
" try {\n",
" headColors = d3.scaleOrdinal(d3.schemeCategory10);\n",
" } catch (err) {\n",
" console.log('Older d3 version')\n",
" headColors = d3.scale.category10();\n",
" }\n",
" let config = {};\n",
" initialize();\n",
" renderVis();\n",
"\n",
" function initialize() {\n",
" config.attention = params['attention'];\n",
" config.filter = params['default_filter'];\n",
" config.rootDivId = params['root_div_id'];\n",
" config.nLayers = config.attention[config.filter]['attn'].length;\n",
" config.nHeads = config.attention[config.filter]['attn'][0].length;\n",
" if (params['heads']) {\n",
" config.headVis = new Array(config.nHeads).fill(false);\n",
" params['heads'].forEach(x => config.headVis[x] = true);\n",
" } else {\n",
" config.headVis = new Array(config.nHeads).fill(true);\n",
" }\n",
" config.initialTextLength = config.attention[config.filter].right_text.length;\n",
" config.layer = (params['layer'] == null ? 0 : params['layer'])\n",
"\n",
"\n",
" let layerEl = $(`#${config.rootDivId} #layer`);\n",
" for (var i = 0; i < config.nLayers; i++) {\n",
" layerEl.append($(\"<option />\").val(i).text(i));\n",
" }\n",
" layerEl.val(config.layer).change();\n",
" layerEl.on('change', function (e) {\n",
" config.layer = +e.currentTarget.value;\n",
" renderVis();\n",
" });\n",
"\n",
" $(`#${config.rootDivId} #filter`).on('change', function (e) {\n",
" config.filter = e.currentTarget.value;\n",
" renderVis();\n",
" });\n",
" }\n",
"\n",
" function renderVis() {\n",
"\n",
" // Load parameters\n",
" const attnData = config.attention[config.filter];\n",
" const leftText = attnData.left_text;\n",
" const rightText = attnData.right_text;\n",
"\n",
" // Select attention for given layer\n",
" const layerAttention = attnData.attn[config.layer];\n",
"\n",
" // Clear vis\n",
" $(`#${config.rootDivId} #vis`).empty();\n",
"\n",
" // Determine size of visualization\n",
" const height = Math.max(leftText.length, rightText.length) * BOXHEIGHT + TEXT_TOP;\n",
" const svg = d3.select(`#${config.rootDivId} #vis`)\n",
" .append('svg')\n",
" .attr(\"width\", \"100%\")\n",
" .attr(\"height\", height + \"px\");\n",
"\n",
" // Display tokens on left and right side of visualization\n",
" renderText(svg, leftText, true, layerAttention, 0);\n",
" renderText(svg, rightText, false, layerAttention, MATRIX_WIDTH + BOXWIDTH);\n",
"\n",
" // Render attention arcs\n",
" renderAttention(svg, layerAttention);\n",
"\n",
" // Draw squares at top of visualization, one for each head\n",
" drawCheckboxes(0, svg, layerAttention);\n",
" }\n",
"\n",
" function renderText(svg, text, isLeft, attention, leftPos) {\n",
"\n",
" const textContainer = svg.append(\"svg:g\")\n",
" .attr(\"id\", isLeft ? \"left\" : \"right\");\n",
"\n",
" // Add attention highlights superimposed over words\n",
" textContainer.append(\"g\")\n",
" .classed(\"attentionBoxes\", true)\n",
" .selectAll(\"g\")\n",
" .data(attention)\n",
" .enter()\n",
" .append(\"g\")\n",
" .attr(\"head-index\", (d, i) => i)\n",
" .selectAll(\"rect\")\n",
" .data(d => isLeft ? d : transpose(d)) // if right text, transpose attention to get right-to-left weights\n",
" .enter()\n",
" .append(\"rect\")\n",
" .attr(\"x\", function () {\n",
" var headIndex = +this.parentNode.getAttribute(\"head-index\");\n",
" return leftPos + boxOffsets(headIndex);\n",
" })\n",
" .attr(\"y\", (+1) * BOXHEIGHT)\n",
" .attr(\"width\", BOXWIDTH / activeHeads())\n",
" .attr(\"height\", BOXHEIGHT)\n",
" .attr(\"fill\", function () {\n",
" return headColors(+this.parentNode.getAttribute(\"head-index\"))\n",
" })\n",
" .style(\"opacity\", 0.0);\n",
"\n",
" const tokenContainer = textContainer.append(\"g\").selectAll(\"g\")\n",
" .data(text)\n",
" .enter()\n",
" .append(\"g\");\n",
"\n",
" // Add gray background that appears when hovering over text\n",
" tokenContainer.append(\"rect\")\n",
" .classed(\"background\", true)\n",
" .style(\"opacity\", 0.0)\n",
" .attr(\"fill\", \"lightgray\")\n",
" .attr(\"x\", leftPos)\n",
" .attr(\"y\", (d, i) => TEXT_TOP + i * BOXHEIGHT)\n",
" .attr(\"width\", BOXWIDTH)\n",
" .attr(\"height\", BOXHEIGHT);\n",
"\n",
" // Add token text\n",
" const textEl = tokenContainer.append(\"text\")\n",
" .text(d => d)\n",
" .attr(\"font-size\", TEXT_SIZE + \"px\")\n",
" .style(\"cursor\", \"default\")\n",
" .style(\"-webkit-user-select\", \"none\")\n",
" .attr(\"x\", leftPos)\n",
" .attr(\"y\", (d, i) => TEXT_TOP + i * BOXHEIGHT);\n",
"\n",
" if (isLeft) {\n",
" textEl.style(\"text-anchor\", \"end\")\n",
" .attr(\"dx\", BOXWIDTH - 0.5 * TEXT_SIZE)\n",
" .attr(\"dy\", TEXT_SIZE);\n",
" } else {\n",
" textEl.style(\"text-anchor\", \"start\")\n",
" .attr(\"dx\", +0.5 * TEXT_SIZE)\n",
" .attr(\"dy\", TEXT_SIZE);\n",
" }\n",
"\n",
" tokenContainer.on(\"mouseover\", function (d, index) {\n",
"\n",
" // Show gray background for moused-over token\n",
" textContainer.selectAll(\".background\")\n",
" .style(\"opacity\", (d, i) => i === index ? 1.0 : 0.0)\n",
"\n",
" // Reset visibility attribute for any previously highlighted attention arcs\n",
" svg.select(\"#attention\")\n",
" .selectAll(\"line[visibility='visible']\")\n",
" .attr(\"visibility\", null)\n",
"\n",
" // Hide group containing attention arcs\n",
" svg.select(\"#attention\").attr(\"visibility\", \"hidden\");\n",
"\n",
" // Set to visible appropriate attention arcs to be highlighted\n",
" if (isLeft) {\n",
" svg.select(\"#attention\").selectAll(\"line[left-token-index='\" + index + \"']\").attr(\"visibility\", \"visible\");\n",
" } else {\n",
" svg.select(\"#attention\").selectAll(\"line[right-token-index='\" + index + \"']\").attr(\"visibility\", \"visible\");\n",
" }\n",
"\n",
" // Update color boxes superimposed over tokens\n",
" const id = isLeft ? \"right\" : \"left\";\n",
" const leftPos = isLeft ? MATRIX_WIDTH + BOXWIDTH : 0;\n",
" svg.select(\"#\" + id)\n",
" .selectAll(\".attentionBoxes\")\n",
" .selectAll(\"g\")\n",
" .attr(\"head-index\", (d, i) => i)\n",
" .selectAll(\"rect\")\n",
" .attr(\"x\", function () {\n",
" const headIndex = +this.parentNode.getAttribute(\"head-index\");\n",
" return leftPos + boxOffsets(headIndex);\n",
" })\n",
" .attr(\"y\", (d, i) => TEXT_TOP + i * BOXHEIGHT)\n",
" .attr(\"width\", BOXWIDTH / activeHeads())\n",
" .attr(\"height\", BOXHEIGHT)\n",
" .style(\"opacity\", function (d) {\n",
" const headIndex = +this.parentNode.getAttribute(\"head-index\");\n",
" if (config.headVis[headIndex])\n",
" if (d) {\n",
" return d[index];\n",
" } else {\n",
" return 0.0;\n",
" }\n",
" else\n",
" return 0.0;\n",
" });\n",
" });\n",
"\n",
" textContainer.on(\"mouseleave\", function () {\n",
"\n",
" // Unhighlight selected token\n",
" d3.select(this).selectAll(\".background\")\n",
" .style(\"opacity\", 0.0);\n",
"\n",
" // Reset visibility attributes for previously selected lines\n",
" svg.select(\"#attention\")\n",
" .selectAll(\"line[visibility='visible']\")\n",
" .attr(\"visibility\", null) ;\n",
" svg.select(\"#attention\").attr(\"visibility\", \"visible\");\n",
"\n",
" // Reset highlights superimposed over tokens\n",
" svg.selectAll(\".attentionBoxes\")\n",
" .selectAll(\"g\")\n",
" .selectAll(\"rect\")\n",
" .style(\"opacity\", 0.0);\n",
" });\n",
" }\n",
"\n",
" function renderAttention(svg, attention) {\n",
"\n",
" // Remove previous dom elements\n",
" svg.select(\"#attention\").remove();\n",
"\n",
" // Add new elements\n",
" svg.append(\"g\")\n",
" .attr(\"id\", \"attention\") // Container for all attention arcs\n",
" .selectAll(\".headAttention\")\n",
" .data(attention)\n",
" .enter()\n",
" .append(\"g\")\n",
" .classed(\"headAttention\", true) // Group attention arcs by head\n",
" .attr(\"head-index\", (d, i) => i)\n",
" .selectAll(\".tokenAttention\")\n",
" .data(d => d)\n",
" .enter()\n",
" .append(\"g\")\n",
" .classed(\"tokenAttention\", true) // Group attention arcs by left token\n",
" .attr(\"left-token-index\", (d, i) => i)\n",
" .selectAll(\"line\")\n",
" .data(d => d)\n",
" .enter()\n",
" .append(\"line\")\n",
" .attr(\"x1\", BOXWIDTH)\n",
" .attr(\"y1\", function () {\n",
" const leftTokenIndex = +this.parentNode.getAttribute(\"left-token-index\")\n",
" return TEXT_TOP + leftTokenIndex * BOXHEIGHT + (BOXHEIGHT / 2)\n",
" })\n",
" .attr(\"x2\", BOXWIDTH + MATRIX_WIDTH)\n",
" .attr(\"y2\", (d, rightTokenIndex) => TEXT_TOP + rightTokenIndex * BOXHEIGHT + (BOXHEIGHT / 2))\n",
" .attr(\"stroke-width\", 2)\n",
" .attr(\"stroke\", function () {\n",
" const headIndex = +this.parentNode.parentNode.getAttribute(\"head-index\");\n",
" return headColors(headIndex)\n",
" })\n",
" .attr(\"left-token-index\", function () {\n",
" return +this.parentNode.getAttribute(\"left-token-index\")\n",
" })\n",
" .attr(\"right-token-index\", (d, i) => i)\n",
" ;\n",
" updateAttention(svg)\n",
" }\n",
"\n",
" function updateAttention(svg) {\n",
" svg.select(\"#attention\")\n",
" .selectAll(\"line\")\n",
" .attr(\"stroke-opacity\", function (d) {\n",
" const headIndex = +this.parentNode.parentNode.getAttribute(\"head-index\");\n",
" // If head is selected\n",
" if (config.headVis[headIndex]) {\n",
" // Set opacity to attention weight divided by number of active heads\n",
" return d / activeHeads()\n",
" } else {\n",
" return 0.0;\n",
" }\n",
" })\n",
" }\n",
"\n",
" function boxOffsets(i) {\n",
" const numHeadsAbove = config.headVis.reduce(\n",
" function (acc, val, cur) {\n",
" return val && cur < i ? acc + 1 : acc;\n",
" }, 0);\n",
" return numHeadsAbove * (BOXWIDTH / activeHeads());\n",
" }\n",
"\n",
" function activeHeads() {\n",
" return config.headVis.reduce(function (acc, val) {\n",
" return val ? acc + 1 : acc;\n",
" }, 0);\n",
" }\n",
"\n",
" function drawCheckboxes(top, svg) {\n",
" const checkboxContainer = svg.append(\"g\");\n",
" const checkbox = checkboxContainer.selectAll(\"rect\")\n",
" .data(config.headVis)\n",
" .enter()\n",
" .append(\"rect\")\n",
" .attr(\"fill\", (d, i) => headColors(i))\n",
" .attr(\"x\", (d, i) => i * CHECKBOX_SIZE)\n",
" .attr(\"y\", top)\n",
" .attr(\"width\", CHECKBOX_SIZE)\n",
" .attr(\"height\", CHECKBOX_SIZE);\n",
"\n",
" function updateCheckboxes() {\n",
" checkboxContainer.selectAll(\"rect\")\n",
" .data(config.headVis)\n",
" .attr(\"fill\", (d, i) => d ? headColors(i): lighten(headColors(i)));\n",
" }\n",
"\n",
" updateCheckboxes();\n",
"\n",
" checkbox.on(\"click\", function (d, i) {\n",
" if (config.headVis[i] && activeHeads() === 1) return;\n",
" config.headVis[i] = !config.headVis[i];\n",
" updateCheckboxes();\n",
" updateAttention(svg);\n",
" });\n",
"\n",
" checkbox.on(\"dblclick\", function (d, i) {\n",
" // If we double click on the only active head then reset\n",
" if (config.headVis[i] && activeHeads() === 1) {\n",
" config.headVis = new Array(config.nHeads).fill(true);\n",
" } else {\n",
" config.headVis = new Array(config.nHeads).fill(false);\n",
" config.headVis[i] = true;\n",
" }\n",
" updateCheckboxes();\n",
" updateAttention(svg);\n",
" });\n",
" }\n",
"\n",
" function lighten(color) {\n",
" const c = d3.hsl(color);\n",
" const increment = (1 - c.l) * 0.6;\n",
" c.l += increment;\n",
" c.s -= increment;\n",
" return c;\n",
" }\n",
"\n",
" function transpose(mat) {\n",
" return mat[0].map(function (col, i) {\n",
" return mat.map(function (row) {\n",
" return row[i];\n",
" });\n",
" });\n",
" }\n",
"\n",
"});"
],
"text/plain": [
"<IPython.core.display.Javascript object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"head_view(\n",
" encoder_attention=outputs.encoder_attentions,\n",
" decoder_attention=outputs.decoder_attentions,\n",
" cross_attention=outputs.cross_attentions,\n",
" encoder_tokens= encoder_text,\n",
" decoder_tokens = decoder_text\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### PRZYKŁAD: GPT3"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### ZADANIE DOMOWE - POLEVAL"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.3"
}
},
"nbformat": 4,
"nbformat_minor": 4
}