ISI-transformers/transformers2.ipynb

1907 lines
2.2 MiB
Plaintext
Raw Normal View History

2021-06-13 21:53:53 +02:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Wizualizacja atencji\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"https://github.com/jessevig/bertviz"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Requirement already satisfied: bertviz in /home/kuba/anaconda3/lib/python3.8/site-packages (1.1.0)\n",
"Requirement already satisfied: boto3 in /home/kuba/anaconda3/lib/python3.8/site-packages (from bertviz) (1.17.93)\n",
"Requirement already satisfied: requests in /home/kuba/anaconda3/lib/python3.8/site-packages (from bertviz) (2.24.0)\n",
"Requirement already satisfied: torch>=1.0 in /home/kuba/anaconda3/lib/python3.8/site-packages (from bertviz) (1.8.1)\n",
"Requirement already satisfied: sentencepiece in /home/kuba/anaconda3/lib/python3.8/site-packages (from bertviz) (0.1.95)\n",
"Requirement already satisfied: tqdm in /home/kuba/anaconda3/lib/python3.8/site-packages (from bertviz) (4.47.0)\n",
"Requirement already satisfied: transformers>=2.0 in /home/kuba/anaconda3/lib/python3.8/site-packages (from bertviz) (4.2.2)\n",
"Requirement already satisfied: regex in /home/kuba/anaconda3/lib/python3.8/site-packages (from bertviz) (2020.6.8)\n",
"Requirement already satisfied: botocore<1.21.0,>=1.20.93 in /home/kuba/anaconda3/lib/python3.8/site-packages (from boto3->bertviz) (1.20.93)\n",
"Requirement already satisfied: jmespath<1.0.0,>=0.7.1 in /home/kuba/anaconda3/lib/python3.8/site-packages (from boto3->bertviz) (0.10.0)\n",
"Requirement already satisfied: s3transfer<0.5.0,>=0.4.0 in /home/kuba/anaconda3/lib/python3.8/site-packages (from boto3->bertviz) (0.4.2)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /home/kuba/anaconda3/lib/python3.8/site-packages (from requests->bertviz) (2020.6.20)\n",
"Requirement already satisfied: idna<3,>=2.5 in /home/kuba/anaconda3/lib/python3.8/site-packages (from requests->bertviz) (2.10)\n",
"Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /home/kuba/anaconda3/lib/python3.8/site-packages (from requests->bertviz) (1.25.9)\n",
"Requirement already satisfied: chardet<4,>=3.0.2 in /home/kuba/anaconda3/lib/python3.8/site-packages (from requests->bertviz) (3.0.4)\n",
"Requirement already satisfied: typing-extensions in /home/kuba/anaconda3/lib/python3.8/site-packages (from torch>=1.0->bertviz) (3.7.4.2)\n",
"Requirement already satisfied: numpy in /home/kuba/anaconda3/lib/python3.8/site-packages (from torch>=1.0->bertviz) (1.18.5)\n",
"Requirement already satisfied: sacremoses in /home/kuba/anaconda3/lib/python3.8/site-packages (from transformers>=2.0->bertviz) (0.0.43)\n",
"Requirement already satisfied: tokenizers==0.9.4 in /home/kuba/anaconda3/lib/python3.8/site-packages (from transformers>=2.0->bertviz) (0.9.4)\n",
"Requirement already satisfied: packaging in /home/kuba/anaconda3/lib/python3.8/site-packages (from transformers>=2.0->bertviz) (20.4)\n",
"Requirement already satisfied: filelock in /home/kuba/anaconda3/lib/python3.8/site-packages (from transformers>=2.0->bertviz) (3.0.12)\n",
"Requirement already satisfied: python-dateutil<3.0.0,>=2.1 in /home/kuba/anaconda3/lib/python3.8/site-packages (from botocore<1.21.0,>=1.20.93->boto3->bertviz) (2.8.1)\n",
"Requirement already satisfied: six in /home/kuba/anaconda3/lib/python3.8/site-packages (from sacremoses->transformers>=2.0->bertviz) (1.15.0)\n",
"Requirement already satisfied: joblib in /home/kuba/anaconda3/lib/python3.8/site-packages (from sacremoses->transformers>=2.0->bertviz) (0.16.0)\n",
"Requirement already satisfied: click in /home/kuba/anaconda3/lib/python3.8/site-packages (from sacremoses->transformers>=2.0->bertviz) (7.1.2)\n",
"Requirement already satisfied: pyparsing>=2.0.2 in /home/kuba/anaconda3/lib/python3.8/site-packages (from packaging->transformers>=2.0->bertviz) (2.4.7)\n"
]
}
],
"source": [
"!pip install bertviz"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"from transformers import AutoTokenizer, AutoModel\n",
"from bertviz import model_view, head_view"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"TEXT = \"This is a sample input sentence for a transformer model\""
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"MODEL = \"distilbert-base-uncased\""
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"tokenizer = AutoTokenizer.from_pretrained(MODEL)\n",
"model = AutoModel.from_pretrained(MODEL, output_attentions=True)\n",
"inputs = tokenizer.encode(TEXT, return_tensors='pt')\n",
"outputs = model(inputs)\n",
"attention = outputs[-1]\n",
"tokens = tokenizer.convert_ids_to_tokens(inputs[0]) \n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### SELF ATTENTION MODELS"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<script src=\"https://cdnjs.cloudflare.com/ajax/libs/require.js/2.3.6/require.min.js\"></script>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
" \n",
" <div id='bertviz-588b8377843e46ceb5981df7aba63167'>\n",
" <span style=\"user-select:none\">\n",
" Layer: <select id=\"layer\"></select>\n",
" \n",
" </span>\n",
" <div id='vis'></div>\n",
" </div>\n",
" "
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/javascript": [
"/**\n",
" * @fileoverview Transformer Visualization D3 javascript code.\n",
" *\n",
" *\n",
" * Based on: https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/visualization/attention.js\n",
" *\n",
" * Change log:\n",
" *\n",
" * 12/19/18 Jesse Vig Assorted cleanup. Changed orientation of attention matrices.\n",
" * 12/29/20 Jesse Vig Significant refactor.\n",
" * 12/31/20 Jesse Vig Support multiple visualizations in single notebook.\n",
" * 02/06/21 Jesse Vig Move require config from separate jupyter notebook step\n",
" * 05/03/21 Jesse Vig Adjust height of visualization dynamically\n",
" **/\n",
"\n",
"require.config({\n",
" paths: {\n",
" d3: '//cdnjs.cloudflare.com/ajax/libs/d3/5.7.0/d3.min',\n",
" jquery: '//ajax.googleapis.com/ajax/libs/jquery/2.0.0/jquery.min',\n",
" }\n",
"});\n",
"\n",
"requirejs(['jquery', 'd3'], function ($, d3) {\n",
"\n",
" const params = {\"attention\": [{\"name\": null, \"attn\": [[[[0.04504745081067085, 0.06121315807104111, 0.03234885632991791, 0.07350576668977737, 0.08869655430316925, 0.06299575418233871, 0.13025709986686707, 0.0724874883890152, 0.0789855495095253, 0.04788164049386978, 0.0528176911175251, 0.06413961201906204, 0.1896234154701233], [0.1642618179321289, 0.045601703226566315, 0.08717474341392517, 0.020270587876439095, 0.06333516538143158, 0.13347627222537994, 0.15195965766906738, 0.07450311630964279, 0.020874062553048134, 0.0886143371462822, 0.03885190933942795, 0.054350875318050385, 0.05672568082809448], [0.19410867989063263, 0.08788827061653137, 0.02643328532576561, 0.027052758261561394, 0.0483211874961853, 0.10077042877674103, 0.14769119024276733, 0.0505562499165535, 0.0312521830201149, 0.11102262884378433, 0.05736817419528961, 0.045798979699611664, 0.0717359259724617], [0.1995920091867447, 0.09999717772006989, 0.0661521703004837, 0.05901019275188446, 0.04424671828746796, 0.0677744597196579, 0.12087972462177277, 0.050675857812166214, 0.06362631916999817, 0.07006122171878815, 0.03917881101369858, 0.04700592905282974, 0.07179943472146988], [0.04939603433012962, 0.024679943919181824, 0.04761141911149025, 0.01254973653703928, 0.07388506829738617, 0.23993298411369324, 0.15196150541305542, 0.00832447037100792, 0.012036890722811222, 0.20173829793930054, 0.021537913009524345, 0.09962346404790878, 0.05672222748398781], [0.05641084164381027, 0.05914205312728882, 0.09467293322086334, 0.022168248891830444, 0.09393802285194397, 0.084870345890522, 0.21269568800926208, 0.03484135493636131, 0.022969357669353485, 0.10200338810682297, 0.048721037805080414, 0.0935591459274292, 0.07400762289762497], [0.14331203699111938, 0.05361737683415413, 0.10078467428684235, 0.024218060076236725, 0.10775977373123169, 0.1332426518201828, 0.08465974032878876, 0.05863681063055992, 0.0241240207105875, 0.08344626426696777, 0.04309769719839096, 0.06576775014400482, 0.0773332417011261], [0.211670383810997, 0.11610246449708939, 0.11225587129592896, 0.045633524656295776, 0.059925444424152374, 0.06493557244539261, 0.08031120151281357, 0.03209289163351059, 0.04349662363529205, 0.07048093527555466, 0.05278267711400986, 0.05153229832649231, 0.05878004804253578], [0.19909143447875977, 0.1108265370130539, 0.06320541352033615, 0.07345421612262726, 0.044627320021390915, 0.06109814718365669, 0.10009115934371948, 0.04955925792455673, 0.07297047972679138, 0.06868425011634827, 0.03863503038883209, 0.044418469071388245, 0.07333831489086151], [0.09310530126094818, 0.06271722912788391, 0.05319841206073761, 0.018992554396390915, 0.12321215867996216, 0.13590605556964874, 0.19085046648979187, 0.011909686028957367, 0.01892036944627762, 0.1059010922908783, 0.035090457648038864, 0.11820467561483383, 0.03199157118797302], [0.12340845912694931, 0.01910078339278698, 0.021354133263230324, 0.00847071222960949, 0.09208827465772629, 0.12047737091779709, 0.19756117463111877, 0.01248408854007721, 0.008209108375012875, 0.18166890740394592, 0.04370952770113945, 0.12460346519947052, 0.046863969415426254], [0.1024712547659874, 0.02981560118496418, 0.042016737163066864, 0.00732750678434968, 0.19956143200397491, 0.16194269061088562, 0.13108232617378235, 0.03385123983025551, 0.006953119300305843, 0.10607277601957321, 0.09258976578712463, 0.04644336178898811, 0.03987228125333786], [0.14260147511959076, 0.0856451690196991, 0.052119556814432144, 0.1065417006611824, 0.03710906580090523, 0.038733918219804764, 0.07154049724340439, 0.04271101951599121, 0.10687444359064102, 0.02995859645307064, 0.034151963889598846, 0.03033091314136982, 0.22168178856372833]], [[0.9453150033950806, 0.0042615714482963085, 0.0033084359019994736, 0.008667639456689358, 0.002442109864205122, 0.0032682078890502453, 0.0018215697491541505, 0.005812120623886585, 0.008080368861556053, 0.0023478888906538486, 0.005752548109740019, 0.0029995867516845465, 0.005922860931605101], [0.0054322523064911366, 0.036913447082042694, 0.10574596375226974, 0.050172608345746994, 0.21406637132167816, 0.1669832468032837, 0.06814419478178024
" const TEXT_SIZE = 15;\n",
" const BOXWIDTH = 110;\n",
" const BOXHEIGHT = 22.5;\n",
" const MATRIX_WIDTH = 115;\n",
" const CHECKBOX_SIZE = 20;\n",
" const TEXT_TOP = 30;\n",
"\n",
" console.log(\"d3 version\", d3.version)\n",
" let headColors;\n",
" try {\n",
" headColors = d3.scaleOrdinal(d3.schemeCategory10);\n",
" } catch (err) {\n",
" console.log('Older d3 version')\n",
" headColors = d3.scale.category10();\n",
" }\n",
" let config = {};\n",
" initialize();\n",
" renderVis();\n",
"\n",
" function initialize() {\n",
" config.attention = params['attention'];\n",
" config.filter = params['default_filter'];\n",
" config.rootDivId = params['root_div_id'];\n",
" config.nLayers = config.attention[config.filter]['attn'].length;\n",
" config.nHeads = config.attention[config.filter]['attn'][0].length;\n",
" if (params['heads']) {\n",
" config.headVis = new Array(config.nHeads).fill(false);\n",
" params['heads'].forEach(x => config.headVis[x] = true);\n",
" } else {\n",
" config.headVis = new Array(config.nHeads).fill(true);\n",
" }\n",
" config.initialTextLength = config.attention[config.filter].right_text.length;\n",
" config.layer = (params['layer'] == null ? 0 : params['layer'])\n",
"\n",
"\n",
" let layerEl = $(`#${config.rootDivId} #layer`);\n",
" for (var i = 0; i < config.nLayers; i++) {\n",
" layerEl.append($(\"<option />\").val(i).text(i));\n",
" }\n",
" layerEl.val(config.layer).change();\n",
" layerEl.on('change', function (e) {\n",
" config.layer = +e.currentTarget.value;\n",
" renderVis();\n",
" });\n",
"\n",
" $(`#${config.rootDivId} #filter`).on('change', function (e) {\n",
" config.filter = e.currentTarget.value;\n",
" renderVis();\n",
" });\n",
" }\n",
"\n",
" function renderVis() {\n",
"\n",
" // Load parameters\n",
" const attnData = config.attention[config.filter];\n",
" const leftText = attnData.left_text;\n",
" const rightText = attnData.right_text;\n",
"\n",
" // Select attention for given layer\n",
" const layerAttention = attnData.attn[config.layer];\n",
"\n",
" // Clear vis\n",
" $(`#${config.rootDivId} #vis`).empty();\n",
"\n",
" // Determine size of visualization\n",
" const height = Math.max(leftText.length, rightText.length) * BOXHEIGHT + TEXT_TOP;\n",
" const svg = d3.select(`#${config.rootDivId} #vis`)\n",
" .append('svg')\n",
" .attr(\"width\", \"100%\")\n",
" .attr(\"height\", height + \"px\");\n",
"\n",
" // Display tokens on left and right side of visualization\n",
" renderText(svg, leftText, true, layerAttention, 0);\n",
" renderText(svg, rightText, false, layerAttention, MATRIX_WIDTH + BOXWIDTH);\n",
"\n",
" // Render attention arcs\n",
" renderAttention(svg, layerAttention);\n",
"\n",
" // Draw squares at top of visualization, one for each head\n",
" drawCheckboxes(0, svg, layerAttention);\n",
" }\n",
"\n",
" function renderText(svg, text, isLeft, attention, leftPos) {\n",
"\n",
" const textContainer = svg.append(\"svg:g\")\n",
" .attr(\"id\", isLeft ? \"left\" : \"right\");\n",
"\n",
" // Add attention highlights superimposed over words\n",
" textContainer.append(\"g\")\n",
" .classed(\"attentionBoxes\", true)\n",
" .selectAll(\"g\")\n",
" .data(attention)\n",
" .enter()\n",
" .append(\"g\")\n",
" .attr(\"head-index\", (d, i) => i)\n",
" .selectAll(\"rect\")\n",
" .data(d => isLeft ? d : transpose(d)) // if right text, transpose attention to get right-to-left weights\n",
" .enter()\n",
" .append(\"rect\")\n",
" .attr(\"x\", function () {\n",
" var headIndex = +this.parentNode.getAttribute(\"head-index\");\n",
" return leftPos + boxOffsets(headIndex);\n",
" })\n",
" .attr(\"y\", (+1) * BOXHEIGHT)\n",
" .attr(\"width\", BOXWIDTH / activeHeads())\n",
" .attr(\"height\", BOXHEIGHT)\n",
" .attr(\"fill\", function () {\n",
" return headColors(+this.parentNode.getAttribute(\"head-index\"))\n",
" })\n",
" .style(\"opacity\", 0.0);\n",
"\n",
" const tokenContainer = textContainer.append(\"g\").selectAll(\"g\")\n",
" .data(text)\n",
" .enter()\n",
" .append(\"g\");\n",
"\n",
" // Add gray background that appears when hovering over text\n",
" tokenContainer.append(\"rect\")\n",
" .classed(\"background\", true)\n",
" .style(\"opacity\", 0.0)\n",
" .attr(\"fill\", \"lightgray\")\n",
" .attr(\"x\", leftPos)\n",
" .attr(\"y\", (d, i) => TEXT_TOP + i * BOXHEIGHT)\n",
" .attr(\"width\", BOXWIDTH)\n",
" .attr(\"height\", BOXHEIGHT);\n",
"\n",
" // Add token text\n",
" const textEl = tokenContainer.append(\"text\")\n",
" .text(d => d)\n",
" .attr(\"font-size\", TEXT_SIZE + \"px\")\n",
" .style(\"cursor\", \"default\")\n",
" .style(\"-webkit-user-select\", \"none\")\n",
" .attr(\"x\", leftPos)\n",
" .attr(\"y\", (d, i) => TEXT_TOP + i * BOXHEIGHT);\n",
"\n",
" if (isLeft) {\n",
" textEl.style(\"text-anchor\", \"end\")\n",
" .attr(\"dx\", BOXWIDTH - 0.5 * TEXT_SIZE)\n",
" .attr(\"dy\", TEXT_SIZE);\n",
" } else {\n",
" textEl.style(\"text-anchor\", \"start\")\n",
" .attr(\"dx\", +0.5 * TEXT_SIZE)\n",
" .attr(\"dy\", TEXT_SIZE);\n",
" }\n",
"\n",
" tokenContainer.on(\"mouseover\", function (d, index) {\n",
"\n",
" // Show gray background for moused-over token\n",
" textContainer.selectAll(\".background\")\n",
" .style(\"opacity\", (d, i) => i === index ? 1.0 : 0.0)\n",
"\n",
" // Reset visibility attribute for any previously highlighted attention arcs\n",
" svg.select(\"#attention\")\n",
" .selectAll(\"line[visibility='visible']\")\n",
" .attr(\"visibility\", null)\n",
"\n",
" // Hide group containing attention arcs\n",
" svg.select(\"#attention\").attr(\"visibility\", \"hidden\");\n",
"\n",
" // Set to visible appropriate attention arcs to be highlighted\n",
" if (isLeft) {\n",
" svg.select(\"#attention\").selectAll(\"line[left-token-index='\" + index + \"']\").attr(\"visibility\", \"visible\");\n",
" } else {\n",
" svg.select(\"#attention\").selectAll(\"line[right-token-index='\" + index + \"']\").attr(\"visibility\", \"visible\");\n",
" }\n",
"\n",
" // Update color boxes superimposed over tokens\n",
" const id = isLeft ? \"right\" : \"left\";\n",
" const leftPos = isLeft ? MATRIX_WIDTH + BOXWIDTH : 0;\n",
" svg.select(\"#\" + id)\n",
" .selectAll(\".attentionBoxes\")\n",
" .selectAll(\"g\")\n",
" .attr(\"head-index\", (d, i) => i)\n",
" .selectAll(\"rect\")\n",
" .attr(\"x\", function () {\n",
" const headIndex = +this.parentNode.getAttribute(\"head-index\");\n",
" return leftPos + boxOffsets(headIndex);\n",
" })\n",
" .attr(\"y\", (d, i) => TEXT_TOP + i * BOXHEIGHT)\n",
" .attr(\"width\", BOXWIDTH / activeHeads())\n",
" .attr(\"height\", BOXHEIGHT)\n",
" .style(\"opacity\", function (d) {\n",
" const headIndex = +this.parentNode.getAttribute(\"head-index\");\n",
" if (config.headVis[headIndex])\n",
" if (d) {\n",
" return d[index];\n",
" } else {\n",
" return 0.0;\n",
" }\n",
" else\n",
" return 0.0;\n",
" });\n",
" });\n",
"\n",
" textContainer.on(\"mouseleave\", function () {\n",
"\n",
" // Unhighlight selected token\n",
" d3.select(this).selectAll(\".background\")\n",
" .style(\"opacity\", 0.0);\n",
"\n",
" // Reset visibility attributes for previously selected lines\n",
" svg.select(\"#attention\")\n",
" .selectAll(\"line[visibility='visible']\")\n",
" .attr(\"visibility\", null) ;\n",
" svg.select(\"#attention\").attr(\"visibility\", \"visible\");\n",
"\n",
" // Reset highlights superimposed over tokens\n",
" svg.selectAll(\".attentionBoxes\")\n",
" .selectAll(\"g\")\n",
" .selectAll(\"rect\")\n",
" .style(\"opacity\", 0.0);\n",
" });\n",
" }\n",
"\n",
" function renderAttention(svg, attention) {\n",
"\n",
" // Remove previous dom elements\n",
" svg.select(\"#attention\").remove();\n",
"\n",
" // Add new elements\n",
" svg.append(\"g\")\n",
" .attr(\"id\", \"attention\") // Container for all attention arcs\n",
" .selectAll(\".headAttention\")\n",
" .data(attention)\n",
" .enter()\n",
" .append(\"g\")\n",
" .classed(\"headAttention\", true) // Group attention arcs by head\n",
" .attr(\"head-index\", (d, i) => i)\n",
" .selectAll(\".tokenAttention\")\n",
" .data(d => d)\n",
" .enter()\n",
" .append(\"g\")\n",
" .classed(\"tokenAttention\", true) // Group attention arcs by left token\n",
" .attr(\"left-token-index\", (d, i) => i)\n",
" .selectAll(\"line\")\n",
" .data(d => d)\n",
" .enter()\n",
" .append(\"line\")\n",
" .attr(\"x1\", BOXWIDTH)\n",
" .attr(\"y1\", function () {\n",
" const leftTokenIndex = +this.parentNode.getAttribute(\"left-token-index\")\n",
" return TEXT_TOP + leftTokenIndex * BOXHEIGHT + (BOXHEIGHT / 2)\n",
" })\n",
" .attr(\"x2\", BOXWIDTH + MATRIX_WIDTH)\n",
" .attr(\"y2\", (d, rightTokenIndex) => TEXT_TOP + rightTokenIndex * BOXHEIGHT + (BOXHEIGHT / 2))\n",
" .attr(\"stroke-width\", 2)\n",
" .attr(\"stroke\", function () {\n",
" const headIndex = +this.parentNode.parentNode.getAttribute(\"head-index\");\n",
" return headColors(headIndex)\n",
" })\n",
" .attr(\"left-token-index\", function () {\n",
" return +this.parentNode.getAttribute(\"left-token-index\")\n",
" })\n",
" .attr(\"right-token-index\", (d, i) => i)\n",
" ;\n",
" updateAttention(svg)\n",
" }\n",
"\n",
" function updateAttention(svg) {\n",
" svg.select(\"#attention\")\n",
" .selectAll(\"line\")\n",
" .attr(\"stroke-opacity\", function (d) {\n",
" const headIndex = +this.parentNode.parentNode.getAttribute(\"head-index\");\n",
" // If head is selected\n",
" if (config.headVis[headIndex]) {\n",
" // Set opacity to attention weight divided by number of active heads\n",
" return d / activeHeads()\n",
" } else {\n",
" return 0.0;\n",
" }\n",
" })\n",
" }\n",
"\n",
" function boxOffsets(i) {\n",
" const numHeadsAbove = config.headVis.reduce(\n",
" function (acc, val, cur) {\n",
" return val && cur < i ? acc + 1 : acc;\n",
" }, 0);\n",
" return numHeadsAbove * (BOXWIDTH / activeHeads());\n",
" }\n",
"\n",
" function activeHeads() {\n",
" return config.headVis.reduce(function (acc, val) {\n",
" return val ? acc + 1 : acc;\n",
" }, 0);\n",
" }\n",
"\n",
" function drawCheckboxes(top, svg) {\n",
" const checkboxContainer = svg.append(\"g\");\n",
" const checkbox = checkboxContainer.selectAll(\"rect\")\n",
" .data(config.headVis)\n",
" .enter()\n",
" .append(\"rect\")\n",
" .attr(\"fill\", (d, i) => headColors(i))\n",
" .attr(\"x\", (d, i) => i * CHECKBOX_SIZE)\n",
" .attr(\"y\", top)\n",
" .attr(\"width\", CHECKBOX_SIZE)\n",
" .attr(\"height\", CHECKBOX_SIZE);\n",
"\n",
" function updateCheckboxes() {\n",
" checkboxContainer.selectAll(\"rect\")\n",
" .data(config.headVis)\n",
" .attr(\"fill\", (d, i) => d ? headColors(i): lighten(headColors(i)));\n",
" }\n",
"\n",
" updateCheckboxes();\n",
"\n",
" checkbox.on(\"click\", function (d, i) {\n",
" if (config.headVis[i] && activeHeads() === 1) return;\n",
" config.headVis[i] = !config.headVis[i];\n",
" updateCheckboxes();\n",
" updateAttention(svg);\n",
" });\n",
"\n",
" checkbox.on(\"dblclick\", function (d, i) {\n",
" // If we double click on the only active head then reset\n",
" if (config.headVis[i] && activeHeads() === 1) {\n",
" config.headVis = new Array(config.nHeads).fill(true);\n",
" } else {\n",
" config.headVis = new Array(config.nHeads).fill(false);\n",
" config.headVis[i] = true;\n",
" }\n",
" updateCheckboxes();\n",
" updateAttention(svg);\n",
" });\n",
" }\n",
"\n",
" function lighten(color) {\n",
" const c = d3.hsl(color);\n",
" const increment = (1 - c.l) * 0.6;\n",
" c.l += increment;\n",
" c.s -= increment;\n",
" return c;\n",
" }\n",
"\n",
" function transpose(mat) {\n",
" return mat[0].map(function (col, i) {\n",
" return mat.map(function (row) {\n",
" return row[i];\n",
" });\n",
" });\n",
" }\n",
"\n",
"});"
],
"text/plain": [
"<IPython.core.display.Javascript object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"head_view(attention, tokens)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<script src=\"https://cdnjs.cloudflare.com/ajax/libs/require.js/2.3.6/require.min.js\"></script>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
" \n",
" <div id='bertviz-61533a3712f543eba69f03bb563030ec'>\n",
" <span style=\"user-select:none\">\n",
" \n",
" </span>\n",
" <div id='vis'></div>\n",
" </div>\n",
" "
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/javascript": [
"/**\n",
" * @fileoverview Transformer Visualization D3 javascript code.\n",
" *\n",
" * Based on: https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/visualization/attention.js\n",
" *\n",
" * Change log:\n",
" *\n",
" * 02/01/19 Jesse Vig Initial implementation\n",
" * 12/31/20 Jesse Vig Support multiple visualizations in single notebook.\n",
" * 01/19/21 Jesse Vig Support light/dark modes\n",
" * 02/06/21 Jesse Vig Move require config from separate jupyter notebook step\n",
" * 05/03/21 Jesse Vig Adjust visualization height dynamically\n",
" **/\n",
"\n",
"require.config({\n",
" paths: {\n",
" d3: '//cdnjs.cloudflare.com/ajax/libs/d3/5.7.0/d3.min',\n",
" jquery: '//ajax.googleapis.com/ajax/libs/jquery/2.0.0/jquery.min',\n",
" }\n",
"});\n",
"\n",
"requirejs(['jquery', 'd3'], function($, d3) {\n",
"\n",
" const params = {\"attention\": [{\"name\": null, \"attn\": [[[[0.04504745081067085, 0.06121315807104111, 0.03234885632991791, 0.07350576668977737, 0.08869655430316925, 0.06299575418233871, 0.13025709986686707, 0.0724874883890152, 0.0789855495095253, 0.04788164049386978, 0.0528176911175251, 0.06413961201906204, 0.1896234154701233], [0.1642618179321289, 0.045601703226566315, 0.08717474341392517, 0.020270587876439095, 0.06333516538143158, 0.13347627222537994, 0.15195965766906738, 0.07450311630964279, 0.020874062553048134, 0.0886143371462822, 0.03885190933942795, 0.054350875318050385, 0.05672568082809448], [0.19410867989063263, 0.08788827061653137, 0.02643328532576561, 0.027052758261561394, 0.0483211874961853, 0.10077042877674103, 0.14769119024276733, 0.0505562499165535, 0.0312521830201149, 0.11102262884378433, 0.05736817419528961, 0.045798979699611664, 0.0717359259724617], [0.1995920091867447, 0.09999717772006989, 0.0661521703004837, 0.05901019275188446, 0.04424671828746796, 0.0677744597196579, 0.12087972462177277, 0.050675857812166214, 0.06362631916999817, 0.07006122171878815, 0.03917881101369858, 0.04700592905282974, 0.07179943472146988], [0.04939603433012962, 0.024679943919181824, 0.04761141911149025, 0.01254973653703928, 0.07388506829738617, 0.23993298411369324, 0.15196150541305542, 0.00832447037100792, 0.012036890722811222, 0.20173829793930054, 0.021537913009524345, 0.09962346404790878, 0.05672222748398781], [0.05641084164381027, 0.05914205312728882, 0.09467293322086334, 0.022168248891830444, 0.09393802285194397, 0.084870345890522, 0.21269568800926208, 0.03484135493636131, 0.022969357669353485, 0.10200338810682297, 0.048721037805080414, 0.0935591459274292, 0.07400762289762497], [0.14331203699111938, 0.05361737683415413, 0.10078467428684235, 0.024218060076236725, 0.10775977373123169, 0.1332426518201828, 0.08465974032878876, 0.05863681063055992, 0.0241240207105875, 0.08344626426696777, 0.04309769719839096, 0.06576775014400482, 0.0773332417011261], [0.211670383810997, 0.11610246449708939, 0.11225587129592896, 0.045633524656295776, 0.059925444424152374, 0.06493557244539261, 0.08031120151281357, 0.03209289163351059, 0.04349662363529205, 0.07048093527555466, 0.05278267711400986, 0.05153229832649231, 0.05878004804253578], [0.19909143447875977, 0.1108265370130539, 0.06320541352033615, 0.07345421612262726, 0.044627320021390915, 0.06109814718365669, 0.10009115934371948, 0.04955925792455673, 0.07297047972679138, 0.06868425011634827, 0.03863503038883209, 0.044418469071388245, 0.07333831489086151], [0.09310530126094818, 0.06271722912788391, 0.05319841206073761, 0.018992554396390915, 0.12321215867996216, 0.13590605556964874, 0.19085046648979187, 0.011909686028957367, 0.01892036944627762, 0.1059010922908783, 0.035090457648038864, 0.11820467561483383, 0.03199157118797302], [0.12340845912694931, 0.01910078339278698, 0.021354133263230324, 0.00847071222960949, 0.09208827465772629, 0.12047737091779709, 0.19756117463111877, 0.01248408854007721, 0.008209108375012875, 0.18166890740394592, 0.04370952770113945, 0.12460346519947052, 0.046863969415426254], [0.1024712547659874, 0.02981560118496418, 0.042016737163066864, 0.00732750678434968, 0.19956143200397491, 0.16194269061088562, 0.13108232617378235, 0.03385123983025551, 0.006953119300305843, 0.10607277601957321, 0.09258976578712463, 0.04644336178898811, 0.03987228125333786], [0.14260147511959076, 0.0856451690196991, 0.052119556814432144, 0.1065417006611824, 0.03710906580090523, 0.038733918219804764, 0.07154049724340439, 0.04271101951599121, 0.10687444359064102, 0.02995859645307064, 0.034151963889598846, 0.03033091314136982, 0.22168178856372833]], [[0.9453150033950806, 0.0042615714482963085, 0.0033084359019994736, 0.008667639456689358, 0.002442109864205122, 0.0032682078890502453, 0.0018215697491541505, 0.005812120623886585, 0.008080368861556053, 0.0023478888906538486, 0.005752548109740019, 0.0029995867516845465, 0.005922860931605101], [0.0054322523064911366, 0.036913447082042694, 0.10574596375226974, 0.050172608345746994, 0.21406637132167816, 0.1669832468032837, 0.0681441947817
" const config = {};\n",
"\n",
" const MIN_X = 0;\n",
" const MIN_Y = 0;\n",
" const DIV_WIDTH = 970;\n",
" const THUMBNAIL_PADDING = 5;\n",
" const DETAIL_WIDTH = 300;\n",
" const DETAIL_ATTENTION_WIDTH = 140;\n",
" const DETAIL_BOX_WIDTH = 80;\n",
" const DETAIL_BOX_HEIGHT = 18;\n",
" const DETAIL_PADDING = 28;\n",
" const ATTN_PADDING = 0;\n",
" const DETAIL_HEADING_HEIGHT = 25;\n",
" const DETAIL_HEADING_TEXT_SIZE = 15;\n",
" const TEXT_SIZE = 13;\n",
" const LAYER_COLORS = d3.schemeCategory10;\n",
" const PALETTE = {\n",
" 'light': {\n",
" 'text': 'black',\n",
" 'background': 'white',\n",
" 'highlight': '#F5F5F5'\n",
" },\n",
" 'dark': {\n",
" 'text': '#bbb',\n",
" 'background': 'black',\n",
" 'highlight': '#222'\n",
" }\n",
" }\n",
"\n",
" function render() {\n",
"\n",
" // Set global state variables\n",
"\n",
" var attData = config.attention[config.filter];\n",
" config.leftText = attData.left_text;\n",
" config.rightText = attData.right_text;\n",
" config.attn = attData.attn;\n",
" config.numLayers = config.attn.length;\n",
" config.numHeads = config.attn[0].length;\n",
" config.thumbnailBoxHeight = 7 * (12 / config.numHeads);\n",
" config.thumbnailHeight = Math.max(config.leftText.length, config.rightText.length) * config.thumbnailBoxHeight + 2 * THUMBNAIL_PADDING;\n",
" config.thumbnailWidth = DIV_WIDTH / config.numHeads;\n",
" config.detailHeight = Math.max(config.leftText.length, config.rightText.length) * DETAIL_BOX_HEIGHT + 2 * DETAIL_PADDING + DETAIL_HEADING_HEIGHT;\n",
" config.divHeight = config.numLayers * config.thumbnailHeight;\n",
"\n",
" const vis = $(`#${config.rootDivId} #vis`)\n",
" vis.empty();\n",
" vis.attr(\"height\", config.divHeight);\n",
" config.svg = d3.select(`#${config.rootDivId} #vis`)\n",
" .append('svg')\n",
" .attr(\"width\", DIV_WIDTH)\n",
" .attr(\"height\", config.divHeight)\n",
" .attr(\"fill\", getBackgroundColor());\n",
"\n",
" var i;\n",
" var j;\n",
" for (i = 0; i < config.numLayers; i++) {\n",
" for (j = 0; j < config.numHeads; j++) {\n",
" renderThumbnail(i, j);\n",
" }\n",
" }\n",
" }\n",
"\n",
" function renderThumbnail(layerIndex, headIndex) {\n",
" var x = headIndex * config.thumbnailWidth;\n",
" var y = layerIndex * config.thumbnailHeight;\n",
" renderThumbnailAttn(x, y, config.attn[layerIndex][headIndex], layerIndex, headIndex);\n",
" }\n",
"\n",
" function renderDetail(att, layerIndex, headIndex) {\n",
" var xOffset = .8 * config.thumbnailWidth;\n",
" var maxX = DIV_WIDTH;\n",
" var maxY = config.divHeight;\n",
" var leftPos = (headIndex / config.numHeads) * DIV_WIDTH\n",
" var x = leftPos + THUMBNAIL_PADDING + xOffset;\n",
" if (x < MIN_X) {\n",
" x = MIN_X;\n",
" } else if (x + DETAIL_WIDTH > maxX) {\n",
" x = leftPos + THUMBNAIL_PADDING - DETAIL_WIDTH + 8;\n",
" }\n",
" var posLeftText = x;\n",
" var posAttention = posLeftText + DETAIL_BOX_WIDTH;\n",
" var posRightText = posAttention + DETAIL_ATTENTION_WIDTH;\n",
" var thumbnailHeight = Math.max(config.leftText.length, config.rightText.length) * config.thumbnailBoxHeight + 2 * THUMBNAIL_PADDING;\n",
" var yOffset = 20;\n",
" var y = layerIndex * thumbnailHeight + THUMBNAIL_PADDING + yOffset;\n",
" if (y < MIN_Y) {\n",
" y = MIN_Y;\n",
" } else if (y + config.detailHeight > maxY) {\n",
" y = maxY - config.detailHeight;\n",
" }\n",
" renderDetailFrame(x, y, layerIndex);\n",
" renderDetailHeading(x, y + Math.max(config.leftText.length, config.rightText.length) *\n",
" DETAIL_BOX_HEIGHT, layerIndex, headIndex);\n",
" renderDetailText(config.leftText, \"leftText\", posLeftText, y + DETAIL_PADDING, layerIndex);\n",
" renderDetailAttn(posAttention, y + DETAIL_PADDING, att, layerIndex, headIndex);\n",
" renderDetailText(config.rightText, \"rightText\", posRightText, y + DETAIL_PADDING, layerIndex);\n",
" }\n",
"\n",
" function renderDetailHeading(x, y, layerIndex, headIndex) {\n",
" var fillColor = getTextColor();\n",
" config.svg.append(\"text\")\n",
" .classed(\"detail\", true)\n",
" .text('Layer ' + layerIndex + \", Head \" + headIndex)\n",
" .attr(\"font-size\", DETAIL_HEADING_TEXT_SIZE + \"px\")\n",
" .style(\"cursor\", \"default\")\n",
" .style(\"-webkit-user-select\", \"none\")\n",
" .attr(\"fill\", fillColor)\n",
" .attr(\"x\", x + DETAIL_WIDTH / 2)\n",
" .attr(\"text-anchor\", \"middle\")\n",
" .attr(\"y\", y + 40)\n",
" .attr(\"height\", DETAIL_HEADING_HEIGHT)\n",
" .attr(\"width\", DETAIL_WIDTH)\n",
" .attr(\"dy\", DETAIL_HEADING_TEXT_SIZE);\n",
" }\n",
"\n",
" function renderDetailText(text, id, x, y, layerIndex) {\n",
" var tokenContainer = config.svg.append(\"svg:g\")\n",
" .classed(\"detail\", true)\n",
" .selectAll(\"g\")\n",
" .data(text)\n",
" .enter()\n",
" .append(\"g\");\n",
"\n",
" var fillColor = getTextColor();\n",
"\n",
" tokenContainer.append(\"rect\")\n",
" .classed(\"highlight\", true)\n",
" .attr(\"fill\", fillColor)\n",
" .style(\"opacity\", 0.0)\n",
" .attr(\"height\", DETAIL_BOX_HEIGHT)\n",
" .attr(\"width\", DETAIL_BOX_WIDTH)\n",
" .attr(\"x\", x)\n",
" .attr(\"y\", function (d, i) {\n",
" return y + i * DETAIL_BOX_HEIGHT;\n",
" });\n",
"\n",
" var textContainer = tokenContainer.append(\"text\")\n",
" .classed(\"token\", true)\n",
" .text(function (d) {\n",
" return d;\n",
" })\n",
" .attr(\"font-size\", TEXT_SIZE + \"px\")\n",
" .style(\"cursor\", \"default\")\n",
" .style(\"-webkit-user-select\", \"none\")\n",
" .attr(\"fill\", fillColor)\n",
" .attr(\"x\", x)\n",
" .attr(\"y\", function (d, i) {\n",
" return i * DETAIL_BOX_HEIGHT + y;\n",
" })\n",
" .attr(\"height\", DETAIL_BOX_HEIGHT)\n",
" .attr(\"width\", DETAIL_BOX_WIDTH)\n",
" .attr(\"dy\", TEXT_SIZE);\n",
"\n",
" if (id == \"leftText\") {\n",
" textContainer.style(\"text-anchor\", \"end\")\n",
" .attr(\"dx\", DETAIL_BOX_WIDTH - 2);\n",
" tokenContainer.on(\"mouseover\", function (d, index) {\n",
" highlightSelection(index);\n",
" });\n",
" tokenContainer.on(\"mouseleave\", function () {\n",
" unhighlightSelection();\n",
" });\n",
" }\n",
" }\n",
"\n",
" function highlightSelection(index) {\n",
" config.svg.select(\"#leftText\")\n",
" .selectAll(\".highlight\")\n",
" .style(\"opacity\", function (d, i) {\n",
" return i == index ? 1.0 : 0.0;\n",
" });\n",
" config.svg.selectAll(\".attn-line-group\")\n",
" .style(\"opacity\", function (d, i) {\n",
" return i == index ? 1.0 : 0.0;\n",
" });\n",
" }\n",
"\n",
" function unhighlightSelection() {\n",
" config.svg.select(\"#leftText\")\n",
" .selectAll(\".highlight\")\n",
" .style(\"opacity\", 0.0);\n",
" config.svg.selectAll(\".attn-line-group\")\n",
" .style(\"opacity\", 1);\n",
" }\n",
"\n",
" function renderThumbnailAttn(x, y, att, layerIndex, headIndex) {\n",
"\n",
" var attnContainer = config.svg.append(\"svg:g\");\n",
"\n",
" var attnBackground = attnContainer.append(\"rect\")\n",
" .attr(\"id\", 'attn_background_' + layerIndex + \"_\" + headIndex)\n",
" .classed(\"attn_background\", true)\n",
" .attr(\"x\", x)\n",
" .attr(\"y\", y)\n",
" .attr(\"height\", config.thumbnailHeight)\n",
" .attr(\"width\", config.thumbnailWidth)\n",
" .attr(\"stroke-width\", 2)\n",
" .attr(\"stroke\", getLayerColor(layerIndex))\n",
" .attr(\"stroke-opacity\", 0)\n",
" .attr(\"fill\", getBackgroundColor());\n",
" var x1 = x + THUMBNAIL_PADDING;\n",
" var x2 = x1 + config.thumbnailWidth - 14;\n",
" var y1 = y + THUMBNAIL_PADDING;\n",
"\n",
" attnContainer.selectAll(\"g\")\n",
" .data(att)\n",
" .enter()\n",
" .append(\"g\") // Add group for each source token\n",
" .attr(\"source-index\", function (d, i) { // Save index of source token\n",
" return i;\n",
" })\n",
" .selectAll(\"line\")\n",
" .data(function (d) { // Loop over all target tokens\n",
" return d;\n",
" })\n",
" .enter() // When entering\n",
" .append(\"line\")\n",
" .attr(\"x1\", x1)\n",
" .attr(\"y1\", function (d) {\n",
" var sourceIndex = +this.parentNode.getAttribute(\"source-index\");\n",
" return y1 + (sourceIndex + .5) * config.thumbnailBoxHeight;\n",
" })\n",
" .attr(\"x2\", x2)\n",
" .attr(\"y2\", function (d, targetIndex) {\n",
" return y1 + (targetIndex + .5) * config.thumbnailBoxHeight;\n",
" })\n",
" .attr(\"stroke-width\", 2.2)\n",
" .attr(\"stroke\", getLayerColor(layerIndex))\n",
" .attr(\"stroke-opacity\", function (d) {\n",
" return d;\n",
" });\n",
"\n",
" var clickRegion = attnContainer.append(\"rect\")\n",
" .attr(\"x\", x)\n",
" .attr(\"y\", y)\n",
" .attr(\"height\", config.thumbnailHeight)\n",
" .attr(\"width\", config.thumbnailWidth)\n",
" .style(\"opacity\", 0);\n",
"\n",
" clickRegion.on(\"click\", function (d, index) {\n",
" var attnBackgroundOther = config.svg.selectAll(\".attn_background\");\n",
" attnBackgroundOther.attr(\"fill\", getBackgroundColor());\n",
" attnBackgroundOther.attr(\"stroke-opacity\", 0);\n",
"\n",
" config.svg.selectAll(\".detail\").remove();\n",
" if (config.detail_layer != layerIndex || config.detail_head != headIndex) {\n",
" renderDetail(att, layerIndex, headIndex);\n",
" config.detail_layer = layerIndex;\n",
" config.detail_head = headIndex;\n",
" attnBackground.attr(\"fill\", getHighlightColor());\n",
" attnBackground.attr(\"stroke-opacity\", .8);\n",
" } else {\n",
" config.detail_layer = null;\n",
" config.detail_head = null;\n",
" attnBackground.attr(\"fill\", getBackgroundColor());\n",
" attnBackground.attr(\"stroke-opacity\", 0);\n",
" }\n",
" });\n",
"\n",
" clickRegion.on(\"mouseover\", function (d) {\n",
" d3.select(this).style(\"cursor\", \"pointer\");\n",
" });\n",
" }\n",
"\n",
" function renderDetailFrame(x, y, layerIndex) {\n",
" var detailFrame = config.svg.append(\"rect\")\n",
" .classed(\"detail\", true)\n",
" .attr(\"x\", x)\n",
" .attr(\"y\", y)\n",
" .attr(\"height\", config.detailHeight)\n",
" .attr(\"width\", DETAIL_WIDTH)\n",
" .style(\"opacity\", 1)\n",
" .attr(\"stroke-width\", 1.5)\n",
" .attr(\"stroke-opacity\", 0.7)\n",
" .attr(\"stroke\", getLayerColor(layerIndex));\n",
" }\n",
"\n",
" function renderDetailAttn(x, y, att, layerIndex) {\n",
" var attnContainer = config.svg.append(\"svg:g\")\n",
" .classed(\"detail\", true)\n",
" .attr(\"pointer-events\", \"none\");\n",
" attnContainer.selectAll(\"g\")\n",
" .data(att)\n",
" .enter()\n",
" .append(\"g\") // Add group for each source token\n",
" .classed('attn-line-group', true)\n",
" .attr(\"source-index\", function (d, i) { // Save index of source token\n",
" return i;\n",
" })\n",
" .selectAll(\"line\")\n",
" .data(function (d) { // Loop over all target tokens\n",
" return d;\n",
" })\n",
" .enter()\n",
" .append(\"line\")\n",
" .attr(\"x1\", x + ATTN_PADDING)\n",
" .attr(\"y1\", function (d) {\n",
" var sourceIndex = +this.parentNode.getAttribute(\"source-index\");\n",
" return y + (sourceIndex + .5) * DETAIL_BOX_HEIGHT;\n",
" })\n",
" .attr(\"x2\", x + DETAIL_ATTENTION_WIDTH - ATTN_PADDING)\n",
" .attr(\"y2\", function (d, targetIndex) {\n",
" return y + (targetIndex + .5) * DETAIL_BOX_HEIGHT;\n",
" })\n",
" .attr(\"stroke-width\", 2.2)\n",
" .attr(\"stroke\", getLayerColor(layerIndex))\n",
" .attr(\"stroke-opacity\", function (d) {\n",
" return d;\n",
" });\n",
" }\n",
"\n",
" function getLayerColor(layer) {\n",
" return LAYER_COLORS[layer % 10];\n",
" }\n",
"\n",
" function getTextColor() {\n",
" return PALETTE[config.mode]['text']\n",
" }\n",
"\n",
" function getBackgroundColor() {\n",
" return PALETTE[config.mode]['background']\n",
" }\n",
"\n",
" function getHighlightColor() {\n",
" return PALETTE[config.mode]['highlight']\n",
" }\n",
"\n",
" function initialize() {\n",
" config.attention = params['attention'];\n",
" config.filter = params['default_filter'];\n",
" config.mode = params['display_mode'];\n",
" config.rootDivId = params['root_div_id'];\n",
" $(`#${config.rootDivId} #filter`).on('change', function (e) {\n",
" config.filter = e.currentTarget.value;\n",
" render();\n",
" });\n",
" }\n",
"\n",
" initialize();\n",
" render();\n",
"\n",
" });"
],
"text/plain": [
"<IPython.core.display.Javascript object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"model_view(attention, tokens)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### ENCODER-DECODER MODELS"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"MODEL = \"Helsinki-NLP/opus-mt-en-de\""
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"TEXT_ENCODER = \"She sees the small elephant.\"\n",
"TEXT_DECODER = \"Sie sieht den kleinen Elefanten.\""
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"tokenizer = AutoTokenizer.from_pretrained(MODEL)\n",
"model = AutoModel.from_pretrained(MODEL, output_attentions=True)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"encoder_input_ids = tokenizer(TEXT_ENCODER, return_tensors=\"pt\", add_special_tokens=True).input_ids\n",
"decoder_input_ids = tokenizer(TEXT_DECODER, return_tensors=\"pt\", add_special_tokens=True).input_ids\n",
"\n",
"outputs = model(input_ids=encoder_input_ids, decoder_input_ids=decoder_input_ids)\n",
"\n",
"encoder_text = tokenizer.convert_ids_to_tokens(encoder_input_ids[0])\n",
"decoder_text = tokenizer.convert_ids_to_tokens(decoder_input_ids[0])"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<script src=\"https://cdnjs.cloudflare.com/ajax/libs/require.js/2.3.6/require.min.js\"></script>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
" \n",
" <div id='bertviz-8b17c300802a4ad9b6896ab4d4e8eba0'>\n",
" <span style=\"user-select:none\">\n",
" Layer: <select id=\"layer\"></select>\n",
" Attention: <select id=\"filter\"><option value=\"0\">Encoder</option>\n",
"<option value=\"1\">Decoder</option>\n",
"<option value=\"2\">Cross</option></select>\n",
" </span>\n",
" <div id='vis'></div>\n",
" </div>\n",
" "
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/javascript": [
"/**\n",
" * @fileoverview Transformer Visualization D3 javascript code.\n",
" *\n",
" *\n",
" * Based on: https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/visualization/attention.js\n",
" *\n",
" * Change log:\n",
" *\n",
" * 12/19/18 Jesse Vig Assorted cleanup. Changed orientation of attention matrices.\n",
" * 12/29/20 Jesse Vig Significant refactor.\n",
" * 12/31/20 Jesse Vig Support multiple visualizations in single notebook.\n",
" * 02/06/21 Jesse Vig Move require config from separate jupyter notebook step\n",
" * 05/03/21 Jesse Vig Adjust height of visualization dynamically\n",
" **/\n",
"\n",
"require.config({\n",
" paths: {\n",
" d3: '//cdnjs.cloudflare.com/ajax/libs/d3/5.7.0/d3.min',\n",
" jquery: '//ajax.googleapis.com/ajax/libs/jquery/2.0.0/jquery.min',\n",
" }\n",
"});\n",
"\n",
"requirejs(['jquery', 'd3'], function ($, d3) {\n",
"\n",
" const params = {\"attention\": [{\"name\": \"Encoder\", \"attn\": [[[[0.0015343179693445563, 0.06372829526662827, 0.3402520716190338, 0.014809303916990757, 0.11685427278280258, 0.09690533578395844, 0.3659164011478424], [0.008656513877213001, 0.08992935717105865, 0.25149786472320557, 0.07625745981931686, 0.20215430855751038, 0.10926036536693573, 0.26224416494369507], [0.01743713766336441, 0.19126100838184357, 0.019580217078328133, 0.10873700678348541, 0.5898742079734802, 0.013855690136551857, 0.05925469845533371], [0.006695661228150129, 0.06985638290643692, 0.03561459109187126, 0.03681695833802223, 0.7022322416305542, 0.054333802312612534, 0.09445030987262726], [0.0010159132070839405, 0.004103748593479395, 0.18814627826213837, 0.013669033534824848, 0.03894238546490669, 0.16656135022640228, 0.5875613689422607], [0.0033098733983933926, 0.06926336139440536, 0.006486027967184782, 0.033568885177373886, 0.7985183000564575, 0.014555910602211952, 0.07429762929677963], [0.006561059970408678, 0.028042230755090714, 0.06652601808309555, 0.04403909668326378, 0.1904589682817459, 0.17164954543113708, 0.49272310733795166]], [[0.0010659039253368974, 0.027911514043807983, 0.5827310681343079, 0.00583711639046669, 0.02539275959134102, 0.09940840303897858, 0.25765326619148254], [0.014391965232789516, 0.07798851281404495, 0.4047695994377136, 0.01874568685889244, 0.03196389973163605, 0.1091737151145935, 0.34296661615371704], [0.05386969447135925, 0.1397337168455124, 0.1042787954211235, 0.41253480315208435, 0.13313433527946472, 0.06411976367235184, 0.09232895076274872], [0.011524752713739872, 0.024617008864879608, 0.11605977267026901, 0.15868355333805084, 0.07823850214481354, 0.1350926011800766, 0.4757837951183319], [0.000765774748288095, 0.0007213741773739457, 0.15132009983062744, 0.007885550148785114, 0.01554951909929514, 0.1716994047164917, 0.6520583033561707], [0.00261464505456388, 0.01399530004709959, 0.0927773043513298, 0.05853908136487007, 0.12484410405158997, 0.22863952815532684, 0.4785899519920349], [0.005509554408490658, 0.018335115164518356, 0.07280543446540833, 0.02366117760539055, 0.06281809508800507, 0.21027612686157227, 0.6065945625305176]], [[0.001894954708404839, 0.04986847937107086, 0.3366052806377411, 0.005386138800531626, 0.11281800270080566, 0.07878921926021576, 0.4146379232406616], [0.0033349881414324045, 0.03520013391971588, 0.5712409615516663, 0.031452734023332596, 0.0384756363928318, 0.08380278199911118, 0.2364927977323532], [0.032189127057790756, 0.08156478404998779, 0.22131586074829102, 0.08562739938497543, 0.08101073652505875, 0.1313415765762329, 0.36695045232772827], [0.0022780823055654764, 0.08724657446146011, 0.20255671441555023, 0.09033971279859543, 0.11631207168102264, 0.25404372811317444, 0.24722309410572052], [0.00028211530297994614, 0.0050274962559342384, 0.12840455770492554, 0.008146004751324654, 0.05242825299501419, 0.23907428979873657, 0.5666372776031494], [8.598285057814792e-05, 0.0020894501358270645, 0.02531367912888527, 0.001749282469972968, 0.01047285832464695, 0.10172442346811295, 0.8585643172264099], [0.0019090479472652078, 0.009371059946715832, 0.03571559488773346, 0.008757798001170158, 0.01917792297899723, 0.12495970726013184, 0.8001089096069336]], [[0.0002124253660440445, 0.07846645265817642, 0.6617725491523743, 0.0009571347036398947, 0.0028506950475275517, 0.0355488583445549, 0.2201918661594391], [0.0013146958081051707, 0.03666481375694275, 0.5017722249031067, 0.032994918525218964, 0.08658347278833389, 0.04107796028256416, 0.29959189891815186], [0.007319945842027664, 0.25871768593788147, 0.07154543697834015, 0.20867864787578583, 0.18836019933223724, 0.18986377120018005, 0.07551436871290207], [0.002156124683097005, 0.056250181049108505, 0.20182666182518005, 0.011035334318876266, 0.0871698409318924, 0.05659180134534836, 0.5849699974060059], [0.0009280861704610288, 0.004628014750778675, 0.20538683235645294, 0.005883386358618736, 0.012845118530094624, 0.10419470071792603, 0.6661338806152344], [0.0012555076973512769, 0.041694026440382004, 0.05578531697392464, 0.12980355322360992, 0.2015871
" const TEXT_SIZE = 15;\n",
" const BOXWIDTH = 110;\n",
" const BOXHEIGHT = 22.5;\n",
" const MATRIX_WIDTH = 115;\n",
" const CHECKBOX_SIZE = 20;\n",
" const TEXT_TOP = 30;\n",
"\n",
" console.log(\"d3 version\", d3.version)\n",
" let headColors;\n",
" try {\n",
" headColors = d3.scaleOrdinal(d3.schemeCategory10);\n",
" } catch (err) {\n",
" console.log('Older d3 version')\n",
" headColors = d3.scale.category10();\n",
" }\n",
" let config = {};\n",
" initialize();\n",
" renderVis();\n",
"\n",
" function initialize() {\n",
" config.attention = params['attention'];\n",
" config.filter = params['default_filter'];\n",
" config.rootDivId = params['root_div_id'];\n",
" config.nLayers = config.attention[config.filter]['attn'].length;\n",
" config.nHeads = config.attention[config.filter]['attn'][0].length;\n",
" if (params['heads']) {\n",
" config.headVis = new Array(config.nHeads).fill(false);\n",
" params['heads'].forEach(x => config.headVis[x] = true);\n",
" } else {\n",
" config.headVis = new Array(config.nHeads).fill(true);\n",
" }\n",
" config.initialTextLength = config.attention[config.filter].right_text.length;\n",
" config.layer = (params['layer'] == null ? 0 : params['layer'])\n",
"\n",
"\n",
" let layerEl = $(`#${config.rootDivId} #layer`);\n",
" for (var i = 0; i < config.nLayers; i++) {\n",
" layerEl.append($(\"<option />\").val(i).text(i));\n",
" }\n",
" layerEl.val(config.layer).change();\n",
" layerEl.on('change', function (e) {\n",
" config.layer = +e.currentTarget.value;\n",
" renderVis();\n",
" });\n",
"\n",
" $(`#${config.rootDivId} #filter`).on('change', function (e) {\n",
" config.filter = e.currentTarget.value;\n",
" renderVis();\n",
" });\n",
" }\n",
"\n",
" function renderVis() {\n",
"\n",
" // Load parameters\n",
" const attnData = config.attention[config.filter];\n",
" const leftText = attnData.left_text;\n",
" const rightText = attnData.right_text;\n",
"\n",
" // Select attention for given layer\n",
" const layerAttention = attnData.attn[config.layer];\n",
"\n",
" // Clear vis\n",
" $(`#${config.rootDivId} #vis`).empty();\n",
"\n",
" // Determine size of visualization\n",
" const height = Math.max(leftText.length, rightText.length) * BOXHEIGHT + TEXT_TOP;\n",
" const svg = d3.select(`#${config.rootDivId} #vis`)\n",
" .append('svg')\n",
" .attr(\"width\", \"100%\")\n",
" .attr(\"height\", height + \"px\");\n",
"\n",
" // Display tokens on left and right side of visualization\n",
" renderText(svg, leftText, true, layerAttention, 0);\n",
" renderText(svg, rightText, false, layerAttention, MATRIX_WIDTH + BOXWIDTH);\n",
"\n",
" // Render attention arcs\n",
" renderAttention(svg, layerAttention);\n",
"\n",
" // Draw squares at top of visualization, one for each head\n",
" drawCheckboxes(0, svg, layerAttention);\n",
" }\n",
"\n",
" function renderText(svg, text, isLeft, attention, leftPos) {\n",
"\n",
" const textContainer = svg.append(\"svg:g\")\n",
" .attr(\"id\", isLeft ? \"left\" : \"right\");\n",
"\n",
" // Add attention highlights superimposed over words\n",
" textContainer.append(\"g\")\n",
" .classed(\"attentionBoxes\", true)\n",
" .selectAll(\"g\")\n",
" .data(attention)\n",
" .enter()\n",
" .append(\"g\")\n",
" .attr(\"head-index\", (d, i) => i)\n",
" .selectAll(\"rect\")\n",
" .data(d => isLeft ? d : transpose(d)) // if right text, transpose attention to get right-to-left weights\n",
" .enter()\n",
" .append(\"rect\")\n",
" .attr(\"x\", function () {\n",
" var headIndex = +this.parentNode.getAttribute(\"head-index\");\n",
" return leftPos + boxOffsets(headIndex);\n",
" })\n",
" .attr(\"y\", (+1) * BOXHEIGHT)\n",
" .attr(\"width\", BOXWIDTH / activeHeads())\n",
" .attr(\"height\", BOXHEIGHT)\n",
" .attr(\"fill\", function () {\n",
" return headColors(+this.parentNode.getAttribute(\"head-index\"))\n",
" })\n",
" .style(\"opacity\", 0.0);\n",
"\n",
" const tokenContainer = textContainer.append(\"g\").selectAll(\"g\")\n",
" .data(text)\n",
" .enter()\n",
" .append(\"g\");\n",
"\n",
" // Add gray background that appears when hovering over text\n",
" tokenContainer.append(\"rect\")\n",
" .classed(\"background\", true)\n",
" .style(\"opacity\", 0.0)\n",
" .attr(\"fill\", \"lightgray\")\n",
" .attr(\"x\", leftPos)\n",
" .attr(\"y\", (d, i) => TEXT_TOP + i * BOXHEIGHT)\n",
" .attr(\"width\", BOXWIDTH)\n",
" .attr(\"height\", BOXHEIGHT);\n",
"\n",
" // Add token text\n",
" const textEl = tokenContainer.append(\"text\")\n",
" .text(d => d)\n",
" .attr(\"font-size\", TEXT_SIZE + \"px\")\n",
" .style(\"cursor\", \"default\")\n",
" .style(\"-webkit-user-select\", \"none\")\n",
" .attr(\"x\", leftPos)\n",
" .attr(\"y\", (d, i) => TEXT_TOP + i * BOXHEIGHT);\n",
"\n",
" if (isLeft) {\n",
" textEl.style(\"text-anchor\", \"end\")\n",
" .attr(\"dx\", BOXWIDTH - 0.5 * TEXT_SIZE)\n",
" .attr(\"dy\", TEXT_SIZE);\n",
" } else {\n",
" textEl.style(\"text-anchor\", \"start\")\n",
" .attr(\"dx\", +0.5 * TEXT_SIZE)\n",
" .attr(\"dy\", TEXT_SIZE);\n",
" }\n",
"\n",
" tokenContainer.on(\"mouseover\", function (d, index) {\n",
"\n",
" // Show gray background for moused-over token\n",
" textContainer.selectAll(\".background\")\n",
" .style(\"opacity\", (d, i) => i === index ? 1.0 : 0.0)\n",
"\n",
" // Reset visibility attribute for any previously highlighted attention arcs\n",
" svg.select(\"#attention\")\n",
" .selectAll(\"line[visibility='visible']\")\n",
" .attr(\"visibility\", null)\n",
"\n",
" // Hide group containing attention arcs\n",
" svg.select(\"#attention\").attr(\"visibility\", \"hidden\");\n",
"\n",
" // Set to visible appropriate attention arcs to be highlighted\n",
" if (isLeft) {\n",
" svg.select(\"#attention\").selectAll(\"line[left-token-index='\" + index + \"']\").attr(\"visibility\", \"visible\");\n",
" } else {\n",
" svg.select(\"#attention\").selectAll(\"line[right-token-index='\" + index + \"']\").attr(\"visibility\", \"visible\");\n",
" }\n",
"\n",
" // Update color boxes superimposed over tokens\n",
" const id = isLeft ? \"right\" : \"left\";\n",
" const leftPos = isLeft ? MATRIX_WIDTH + BOXWIDTH : 0;\n",
" svg.select(\"#\" + id)\n",
" .selectAll(\".attentionBoxes\")\n",
" .selectAll(\"g\")\n",
" .attr(\"head-index\", (d, i) => i)\n",
" .selectAll(\"rect\")\n",
" .attr(\"x\", function () {\n",
" const headIndex = +this.parentNode.getAttribute(\"head-index\");\n",
" return leftPos + boxOffsets(headIndex);\n",
" })\n",
" .attr(\"y\", (d, i) => TEXT_TOP + i * BOXHEIGHT)\n",
" .attr(\"width\", BOXWIDTH / activeHeads())\n",
" .attr(\"height\", BOXHEIGHT)\n",
" .style(\"opacity\", function (d) {\n",
" const headIndex = +this.parentNode.getAttribute(\"head-index\");\n",
" if (config.headVis[headIndex])\n",
" if (d) {\n",
" return d[index];\n",
" } else {\n",
" return 0.0;\n",
" }\n",
" else\n",
" return 0.0;\n",
" });\n",
" });\n",
"\n",
" textContainer.on(\"mouseleave\", function () {\n",
"\n",
" // Unhighlight selected token\n",
" d3.select(this).selectAll(\".background\")\n",
" .style(\"opacity\", 0.0);\n",
"\n",
" // Reset visibility attributes for previously selected lines\n",
" svg.select(\"#attention\")\n",
" .selectAll(\"line[visibility='visible']\")\n",
" .attr(\"visibility\", null) ;\n",
" svg.select(\"#attention\").attr(\"visibility\", \"visible\");\n",
"\n",
" // Reset highlights superimposed over tokens\n",
" svg.selectAll(\".attentionBoxes\")\n",
" .selectAll(\"g\")\n",
" .selectAll(\"rect\")\n",
" .style(\"opacity\", 0.0);\n",
" });\n",
" }\n",
"\n",
" function renderAttention(svg, attention) {\n",
"\n",
" // Remove previous dom elements\n",
" svg.select(\"#attention\").remove();\n",
"\n",
" // Add new elements\n",
" svg.append(\"g\")\n",
" .attr(\"id\", \"attention\") // Container for all attention arcs\n",
" .selectAll(\".headAttention\")\n",
" .data(attention)\n",
" .enter()\n",
" .append(\"g\")\n",
" .classed(\"headAttention\", true) // Group attention arcs by head\n",
" .attr(\"head-index\", (d, i) => i)\n",
" .selectAll(\".tokenAttention\")\n",
" .data(d => d)\n",
" .enter()\n",
" .append(\"g\")\n",
" .classed(\"tokenAttention\", true) // Group attention arcs by left token\n",
" .attr(\"left-token-index\", (d, i) => i)\n",
" .selectAll(\"line\")\n",
" .data(d => d)\n",
" .enter()\n",
" .append(\"line\")\n",
" .attr(\"x1\", BOXWIDTH)\n",
" .attr(\"y1\", function () {\n",
" const leftTokenIndex = +this.parentNode.getAttribute(\"left-token-index\")\n",
" return TEXT_TOP + leftTokenIndex * BOXHEIGHT + (BOXHEIGHT / 2)\n",
" })\n",
" .attr(\"x2\", BOXWIDTH + MATRIX_WIDTH)\n",
" .attr(\"y2\", (d, rightTokenIndex) => TEXT_TOP + rightTokenIndex * BOXHEIGHT + (BOXHEIGHT / 2))\n",
" .attr(\"stroke-width\", 2)\n",
" .attr(\"stroke\", function () {\n",
" const headIndex = +this.parentNode.parentNode.getAttribute(\"head-index\");\n",
" return headColors(headIndex)\n",
" })\n",
" .attr(\"left-token-index\", function () {\n",
" return +this.parentNode.getAttribute(\"left-token-index\")\n",
" })\n",
" .attr(\"right-token-index\", (d, i) => i)\n",
" ;\n",
" updateAttention(svg)\n",
" }\n",
"\n",
" function updateAttention(svg) {\n",
" svg.select(\"#attention\")\n",
" .selectAll(\"line\")\n",
" .attr(\"stroke-opacity\", function (d) {\n",
" const headIndex = +this.parentNode.parentNode.getAttribute(\"head-index\");\n",
" // If head is selected\n",
" if (config.headVis[headIndex]) {\n",
" // Set opacity to attention weight divided by number of active heads\n",
" return d / activeHeads()\n",
" } else {\n",
" return 0.0;\n",
" }\n",
" })\n",
" }\n",
"\n",
" function boxOffsets(i) {\n",
" const numHeadsAbove = config.headVis.reduce(\n",
" function (acc, val, cur) {\n",
" return val && cur < i ? acc + 1 : acc;\n",
" }, 0);\n",
" return numHeadsAbove * (BOXWIDTH / activeHeads());\n",
" }\n",
"\n",
" function activeHeads() {\n",
" return config.headVis.reduce(function (acc, val) {\n",
" return val ? acc + 1 : acc;\n",
" }, 0);\n",
" }\n",
"\n",
" function drawCheckboxes(top, svg) {\n",
" const checkboxContainer = svg.append(\"g\");\n",
" const checkbox = checkboxContainer.selectAll(\"rect\")\n",
" .data(config.headVis)\n",
" .enter()\n",
" .append(\"rect\")\n",
" .attr(\"fill\", (d, i) => headColors(i))\n",
" .attr(\"x\", (d, i) => i * CHECKBOX_SIZE)\n",
" .attr(\"y\", top)\n",
" .attr(\"width\", CHECKBOX_SIZE)\n",
" .attr(\"height\", CHECKBOX_SIZE);\n",
"\n",
" function updateCheckboxes() {\n",
" checkboxContainer.selectAll(\"rect\")\n",
" .data(config.headVis)\n",
" .attr(\"fill\", (d, i) => d ? headColors(i): lighten(headColors(i)));\n",
" }\n",
"\n",
" updateCheckboxes();\n",
"\n",
" checkbox.on(\"click\", function (d, i) {\n",
" if (config.headVis[i] && activeHeads() === 1) return;\n",
" config.headVis[i] = !config.headVis[i];\n",
" updateCheckboxes();\n",
" updateAttention(svg);\n",
" });\n",
"\n",
" checkbox.on(\"dblclick\", function (d, i) {\n",
" // If we double click on the only active head then reset\n",
" if (config.headVis[i] && activeHeads() === 1) {\n",
" config.headVis = new Array(config.nHeads).fill(true);\n",
" } else {\n",
" config.headVis = new Array(config.nHeads).fill(false);\n",
" config.headVis[i] = true;\n",
" }\n",
" updateCheckboxes();\n",
" updateAttention(svg);\n",
" });\n",
" }\n",
"\n",
" function lighten(color) {\n",
" const c = d3.hsl(color);\n",
" const increment = (1 - c.l) * 0.6;\n",
" c.l += increment;\n",
" c.s -= increment;\n",
" return c;\n",
" }\n",
"\n",
" function transpose(mat) {\n",
" return mat[0].map(function (col, i) {\n",
" return mat.map(function (row) {\n",
" return row[i];\n",
" });\n",
" });\n",
" }\n",
"\n",
"});"
],
"text/plain": [
"<IPython.core.display.Javascript object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"head_view(\n",
" encoder_attention=outputs.encoder_attentions,\n",
" decoder_attention=outputs.decoder_attentions,\n",
" cross_attention=outputs.cross_attentions,\n",
" encoder_tokens= encoder_text,\n",
" decoder_tokens = decoder_text\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"scrolled": false
},
"outputs": [
{
"data": {
"text/html": [
"<script src=\"https://cdnjs.cloudflare.com/ajax/libs/require.js/2.3.6/require.min.js\"></script>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
" \n",
" <div id='bertviz-5e36904179de4109b01180d66af00b7a'>\n",
" <span style=\"user-select:none\">\n",
" Attention: <select id=\"filter\"><option value=\"0\">Encoder</option>\n",
"<option value=\"1\">Decoder</option>\n",
"<option value=\"2\">Cross</option></select>\n",
" </span>\n",
" <div id='vis'></div>\n",
" </div>\n",
" "
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/javascript": [
"/**\n",
" * @fileoverview Transformer Visualization D3 javascript code.\n",
" *\n",
" * Based on: https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/visualization/attention.js\n",
" *\n",
" * Change log:\n",
" *\n",
" * 02/01/19 Jesse Vig Initial implementation\n",
" * 12/31/20 Jesse Vig Support multiple visualizations in single notebook.\n",
" * 01/19/21 Jesse Vig Support light/dark modes\n",
" * 02/06/21 Jesse Vig Move require config from separate jupyter notebook step\n",
" * 05/03/21 Jesse Vig Adjust visualization height dynamically\n",
" **/\n",
"\n",
"require.config({\n",
" paths: {\n",
" d3: '//cdnjs.cloudflare.com/ajax/libs/d3/5.7.0/d3.min',\n",
" jquery: '//ajax.googleapis.com/ajax/libs/jquery/2.0.0/jquery.min',\n",
" }\n",
"});\n",
"\n",
"requirejs(['jquery', 'd3'], function($, d3) {\n",
"\n",
" const params = {\"attention\": [{\"name\": \"Encoder\", \"attn\": [[[[0.0015343179693445563, 0.06372829526662827, 0.3402520716190338, 0.014809303916990757, 0.11685427278280258, 0.09690533578395844, 0.3659164011478424], [0.008656513877213001, 0.08992935717105865, 0.25149786472320557, 0.07625745981931686, 0.20215430855751038, 0.10926036536693573, 0.26224416494369507], [0.01743713766336441, 0.19126100838184357, 0.019580217078328133, 0.10873700678348541, 0.5898742079734802, 0.013855690136551857, 0.05925469845533371], [0.006695661228150129, 0.06985638290643692, 0.03561459109187126, 0.03681695833802223, 0.7022322416305542, 0.054333802312612534, 0.09445030987262726], [0.0010159132070839405, 0.004103748593479395, 0.18814627826213837, 0.013669033534824848, 0.03894238546490669, 0.16656135022640228, 0.5875613689422607], [0.0033098733983933926, 0.06926336139440536, 0.006486027967184782, 0.033568885177373886, 0.7985183000564575, 0.014555910602211952, 0.07429762929677963], [0.006561059970408678, 0.028042230755090714, 0.06652601808309555, 0.04403909668326378, 0.1904589682817459, 0.17164954543113708, 0.49272310733795166]], [[0.0010659039253368974, 0.027911514043807983, 0.5827310681343079, 0.00583711639046669, 0.02539275959134102, 0.09940840303897858, 0.25765326619148254], [0.014391965232789516, 0.07798851281404495, 0.4047695994377136, 0.01874568685889244, 0.03196389973163605, 0.1091737151145935, 0.34296661615371704], [0.05386969447135925, 0.1397337168455124, 0.1042787954211235, 0.41253480315208435, 0.13313433527946472, 0.06411976367235184, 0.09232895076274872], [0.011524752713739872, 0.024617008864879608, 0.11605977267026901, 0.15868355333805084, 0.07823850214481354, 0.1350926011800766, 0.4757837951183319], [0.000765774748288095, 0.0007213741773739457, 0.15132009983062744, 0.007885550148785114, 0.01554951909929514, 0.1716994047164917, 0.6520583033561707], [0.00261464505456388, 0.01399530004709959, 0.0927773043513298, 0.05853908136487007, 0.12484410405158997, 0.22863952815532684, 0.4785899519920349], [0.005509554408490658, 0.018335115164518356, 0.07280543446540833, 0.02366117760539055, 0.06281809508800507, 0.21027612686157227, 0.6065945625305176]], [[0.001894954708404839, 0.04986847937107086, 0.3366052806377411, 0.005386138800531626, 0.11281800270080566, 0.07878921926021576, 0.4146379232406616], [0.0033349881414324045, 0.03520013391971588, 0.5712409615516663, 0.031452734023332596, 0.0384756363928318, 0.08380278199911118, 0.2364927977323532], [0.032189127057790756, 0.08156478404998779, 0.22131586074829102, 0.08562739938497543, 0.08101073652505875, 0.1313415765762329, 0.36695045232772827], [0.0022780823055654764, 0.08724657446146011, 0.20255671441555023, 0.09033971279859543, 0.11631207168102264, 0.25404372811317444, 0.24722309410572052], [0.00028211530297994614, 0.0050274962559342384, 0.12840455770492554, 0.008146004751324654, 0.05242825299501419, 0.23907428979873657, 0.5666372776031494], [8.598285057814792e-05, 0.0020894501358270645, 0.02531367912888527, 0.001749282469972968, 0.01047285832464695, 0.10172442346811295, 0.8585643172264099], [0.0019090479472652078, 0.009371059946715832, 0.03571559488773346, 0.008757798001170158, 0.01917792297899723, 0.12495970726013184, 0.8001089096069336]], [[0.0002124253660440445, 0.07846645265817642, 0.6617725491523743, 0.0009571347036398947, 0.0028506950475275517, 0.0355488583445549, 0.2201918661594391], [0.0013146958081051707, 0.03666481375694275, 0.5017722249031067, 0.032994918525218964, 0.08658347278833389, 0.04107796028256416, 0.29959189891815186], [0.007319945842027664, 0.25871768593788147, 0.07154543697834015, 0.20867864787578583, 0.18836019933223724, 0.18986377120018005, 0.07551436871290207], [0.002156124683097005, 0.056250181049108505, 0.20182666182518005, 0.011035334318876266, 0.0871698409318924, 0.05659180134534836, 0.5849699974060059], [0.0009280861704610288, 0.004628014750778675, 0.20538683235645294, 0.005883386358618736, 0.012845118530094624, 0.10419470071792603, 0.6661338806152344], [0.0012555076973512769, 0.041694026440382004, 0.05578531697392464, 0.12980355322360992, 0.201
" const config = {};\n",
"\n",
" const MIN_X = 0;\n",
" const MIN_Y = 0;\n",
" const DIV_WIDTH = 970;\n",
" const THUMBNAIL_PADDING = 5;\n",
" const DETAIL_WIDTH = 300;\n",
" const DETAIL_ATTENTION_WIDTH = 140;\n",
" const DETAIL_BOX_WIDTH = 80;\n",
" const DETAIL_BOX_HEIGHT = 18;\n",
" const DETAIL_PADDING = 28;\n",
" const ATTN_PADDING = 0;\n",
" const DETAIL_HEADING_HEIGHT = 25;\n",
" const DETAIL_HEADING_TEXT_SIZE = 15;\n",
" const TEXT_SIZE = 13;\n",
" const LAYER_COLORS = d3.schemeCategory10;\n",
" const PALETTE = {\n",
" 'light': {\n",
" 'text': 'black',\n",
" 'background': 'white',\n",
" 'highlight': '#F5F5F5'\n",
" },\n",
" 'dark': {\n",
" 'text': '#bbb',\n",
" 'background': 'black',\n",
" 'highlight': '#222'\n",
" }\n",
" }\n",
"\n",
" function render() {\n",
"\n",
" // Set global state variables\n",
"\n",
" var attData = config.attention[config.filter];\n",
" config.leftText = attData.left_text;\n",
" config.rightText = attData.right_text;\n",
" config.attn = attData.attn;\n",
" config.numLayers = config.attn.length;\n",
" config.numHeads = config.attn[0].length;\n",
" config.thumbnailBoxHeight = 7 * (12 / config.numHeads);\n",
" config.thumbnailHeight = Math.max(config.leftText.length, config.rightText.length) * config.thumbnailBoxHeight + 2 * THUMBNAIL_PADDING;\n",
" config.thumbnailWidth = DIV_WIDTH / config.numHeads;\n",
" config.detailHeight = Math.max(config.leftText.length, config.rightText.length) * DETAIL_BOX_HEIGHT + 2 * DETAIL_PADDING + DETAIL_HEADING_HEIGHT;\n",
" config.divHeight = config.numLayers * config.thumbnailHeight;\n",
"\n",
" const vis = $(`#${config.rootDivId} #vis`)\n",
" vis.empty();\n",
" vis.attr(\"height\", config.divHeight);\n",
" config.svg = d3.select(`#${config.rootDivId} #vis`)\n",
" .append('svg')\n",
" .attr(\"width\", DIV_WIDTH)\n",
" .attr(\"height\", config.divHeight)\n",
" .attr(\"fill\", getBackgroundColor());\n",
"\n",
" var i;\n",
" var j;\n",
" for (i = 0; i < config.numLayers; i++) {\n",
" for (j = 0; j < config.numHeads; j++) {\n",
" renderThumbnail(i, j);\n",
" }\n",
" }\n",
" }\n",
"\n",
" function renderThumbnail(layerIndex, headIndex) {\n",
" var x = headIndex * config.thumbnailWidth;\n",
" var y = layerIndex * config.thumbnailHeight;\n",
" renderThumbnailAttn(x, y, config.attn[layerIndex][headIndex], layerIndex, headIndex);\n",
" }\n",
"\n",
" function renderDetail(att, layerIndex, headIndex) {\n",
" var xOffset = .8 * config.thumbnailWidth;\n",
" var maxX = DIV_WIDTH;\n",
" var maxY = config.divHeight;\n",
" var leftPos = (headIndex / config.numHeads) * DIV_WIDTH\n",
" var x = leftPos + THUMBNAIL_PADDING + xOffset;\n",
" if (x < MIN_X) {\n",
" x = MIN_X;\n",
" } else if (x + DETAIL_WIDTH > maxX) {\n",
" x = leftPos + THUMBNAIL_PADDING - DETAIL_WIDTH + 8;\n",
" }\n",
" var posLeftText = x;\n",
" var posAttention = posLeftText + DETAIL_BOX_WIDTH;\n",
" var posRightText = posAttention + DETAIL_ATTENTION_WIDTH;\n",
" var thumbnailHeight = Math.max(config.leftText.length, config.rightText.length) * config.thumbnailBoxHeight + 2 * THUMBNAIL_PADDING;\n",
" var yOffset = 20;\n",
" var y = layerIndex * thumbnailHeight + THUMBNAIL_PADDING + yOffset;\n",
" if (y < MIN_Y) {\n",
" y = MIN_Y;\n",
" } else if (y + config.detailHeight > maxY) {\n",
" y = maxY - config.detailHeight;\n",
" }\n",
" renderDetailFrame(x, y, layerIndex);\n",
" renderDetailHeading(x, y + Math.max(config.leftText.length, config.rightText.length) *\n",
" DETAIL_BOX_HEIGHT, layerIndex, headIndex);\n",
" renderDetailText(config.leftText, \"leftText\", posLeftText, y + DETAIL_PADDING, layerIndex);\n",
" renderDetailAttn(posAttention, y + DETAIL_PADDING, att, layerIndex, headIndex);\n",
" renderDetailText(config.rightText, \"rightText\", posRightText, y + DETAIL_PADDING, layerIndex);\n",
" }\n",
"\n",
" function renderDetailHeading(x, y, layerIndex, headIndex) {\n",
" var fillColor = getTextColor();\n",
" config.svg.append(\"text\")\n",
" .classed(\"detail\", true)\n",
" .text('Layer ' + layerIndex + \", Head \" + headIndex)\n",
" .attr(\"font-size\", DETAIL_HEADING_TEXT_SIZE + \"px\")\n",
" .style(\"cursor\", \"default\")\n",
" .style(\"-webkit-user-select\", \"none\")\n",
" .attr(\"fill\", fillColor)\n",
" .attr(\"x\", x + DETAIL_WIDTH / 2)\n",
" .attr(\"text-anchor\", \"middle\")\n",
" .attr(\"y\", y + 40)\n",
" .attr(\"height\", DETAIL_HEADING_HEIGHT)\n",
" .attr(\"width\", DETAIL_WIDTH)\n",
" .attr(\"dy\", DETAIL_HEADING_TEXT_SIZE);\n",
" }\n",
"\n",
" function renderDetailText(text, id, x, y, layerIndex) {\n",
" var tokenContainer = config.svg.append(\"svg:g\")\n",
" .classed(\"detail\", true)\n",
" .selectAll(\"g\")\n",
" .data(text)\n",
" .enter()\n",
" .append(\"g\");\n",
"\n",
" var fillColor = getTextColor();\n",
"\n",
" tokenContainer.append(\"rect\")\n",
" .classed(\"highlight\", true)\n",
" .attr(\"fill\", fillColor)\n",
" .style(\"opacity\", 0.0)\n",
" .attr(\"height\", DETAIL_BOX_HEIGHT)\n",
" .attr(\"width\", DETAIL_BOX_WIDTH)\n",
" .attr(\"x\", x)\n",
" .attr(\"y\", function (d, i) {\n",
" return y + i * DETAIL_BOX_HEIGHT;\n",
" });\n",
"\n",
" var textContainer = tokenContainer.append(\"text\")\n",
" .classed(\"token\", true)\n",
" .text(function (d) {\n",
" return d;\n",
" })\n",
" .attr(\"font-size\", TEXT_SIZE + \"px\")\n",
" .style(\"cursor\", \"default\")\n",
" .style(\"-webkit-user-select\", \"none\")\n",
" .attr(\"fill\", fillColor)\n",
" .attr(\"x\", x)\n",
" .attr(\"y\", function (d, i) {\n",
" return i * DETAIL_BOX_HEIGHT + y;\n",
" })\n",
" .attr(\"height\", DETAIL_BOX_HEIGHT)\n",
" .attr(\"width\", DETAIL_BOX_WIDTH)\n",
" .attr(\"dy\", TEXT_SIZE);\n",
"\n",
" if (id == \"leftText\") {\n",
" textContainer.style(\"text-anchor\", \"end\")\n",
" .attr(\"dx\", DETAIL_BOX_WIDTH - 2);\n",
" tokenContainer.on(\"mouseover\", function (d, index) {\n",
" highlightSelection(index);\n",
" });\n",
" tokenContainer.on(\"mouseleave\", function () {\n",
" unhighlightSelection();\n",
" });\n",
" }\n",
" }\n",
"\n",
" function highlightSelection(index) {\n",
" config.svg.select(\"#leftText\")\n",
" .selectAll(\".highlight\")\n",
" .style(\"opacity\", function (d, i) {\n",
" return i == index ? 1.0 : 0.0;\n",
" });\n",
" config.svg.selectAll(\".attn-line-group\")\n",
" .style(\"opacity\", function (d, i) {\n",
" return i == index ? 1.0 : 0.0;\n",
" });\n",
" }\n",
"\n",
" function unhighlightSelection() {\n",
" config.svg.select(\"#leftText\")\n",
" .selectAll(\".highlight\")\n",
" .style(\"opacity\", 0.0);\n",
" config.svg.selectAll(\".attn-line-group\")\n",
" .style(\"opacity\", 1);\n",
" }\n",
"\n",
" function renderThumbnailAttn(x, y, att, layerIndex, headIndex) {\n",
"\n",
" var attnContainer = config.svg.append(\"svg:g\");\n",
"\n",
" var attnBackground = attnContainer.append(\"rect\")\n",
" .attr(\"id\", 'attn_background_' + layerIndex + \"_\" + headIndex)\n",
" .classed(\"attn_background\", true)\n",
" .attr(\"x\", x)\n",
" .attr(\"y\", y)\n",
" .attr(\"height\", config.thumbnailHeight)\n",
" .attr(\"width\", config.thumbnailWidth)\n",
" .attr(\"stroke-width\", 2)\n",
" .attr(\"stroke\", getLayerColor(layerIndex))\n",
" .attr(\"stroke-opacity\", 0)\n",
" .attr(\"fill\", getBackgroundColor());\n",
" var x1 = x + THUMBNAIL_PADDING;\n",
" var x2 = x1 + config.thumbnailWidth - 14;\n",
" var y1 = y + THUMBNAIL_PADDING;\n",
"\n",
" attnContainer.selectAll(\"g\")\n",
" .data(att)\n",
" .enter()\n",
" .append(\"g\") // Add group for each source token\n",
" .attr(\"source-index\", function (d, i) { // Save index of source token\n",
" return i;\n",
" })\n",
" .selectAll(\"line\")\n",
" .data(function (d) { // Loop over all target tokens\n",
" return d;\n",
" })\n",
" .enter() // When entering\n",
" .append(\"line\")\n",
" .attr(\"x1\", x1)\n",
" .attr(\"y1\", function (d) {\n",
" var sourceIndex = +this.parentNode.getAttribute(\"source-index\");\n",
" return y1 + (sourceIndex + .5) * config.thumbnailBoxHeight;\n",
" })\n",
" .attr(\"x2\", x2)\n",
" .attr(\"y2\", function (d, targetIndex) {\n",
" return y1 + (targetIndex + .5) * config.thumbnailBoxHeight;\n",
" })\n",
" .attr(\"stroke-width\", 2.2)\n",
" .attr(\"stroke\", getLayerColor(layerIndex))\n",
" .attr(\"stroke-opacity\", function (d) {\n",
" return d;\n",
" });\n",
"\n",
" var clickRegion = attnContainer.append(\"rect\")\n",
" .attr(\"x\", x)\n",
" .attr(\"y\", y)\n",
" .attr(\"height\", config.thumbnailHeight)\n",
" .attr(\"width\", config.thumbnailWidth)\n",
" .style(\"opacity\", 0);\n",
"\n",
" clickRegion.on(\"click\", function (d, index) {\n",
" var attnBackgroundOther = config.svg.selectAll(\".attn_background\");\n",
" attnBackgroundOther.attr(\"fill\", getBackgroundColor());\n",
" attnBackgroundOther.attr(\"stroke-opacity\", 0);\n",
"\n",
" config.svg.selectAll(\".detail\").remove();\n",
" if (config.detail_layer != layerIndex || config.detail_head != headIndex) {\n",
" renderDetail(att, layerIndex, headIndex);\n",
" config.detail_layer = layerIndex;\n",
" config.detail_head = headIndex;\n",
" attnBackground.attr(\"fill\", getHighlightColor());\n",
" attnBackground.attr(\"stroke-opacity\", .8);\n",
" } else {\n",
" config.detail_layer = null;\n",
" config.detail_head = null;\n",
" attnBackground.attr(\"fill\", getBackgroundColor());\n",
" attnBackground.attr(\"stroke-opacity\", 0);\n",
" }\n",
" });\n",
"\n",
" clickRegion.on(\"mouseover\", function (d) {\n",
" d3.select(this).style(\"cursor\", \"pointer\");\n",
" });\n",
" }\n",
"\n",
" function renderDetailFrame(x, y, layerIndex) {\n",
" var detailFrame = config.svg.append(\"rect\")\n",
" .classed(\"detail\", true)\n",
" .attr(\"x\", x)\n",
" .attr(\"y\", y)\n",
" .attr(\"height\", config.detailHeight)\n",
" .attr(\"width\", DETAIL_WIDTH)\n",
" .style(\"opacity\", 1)\n",
" .attr(\"stroke-width\", 1.5)\n",
" .attr(\"stroke-opacity\", 0.7)\n",
" .attr(\"stroke\", getLayerColor(layerIndex));\n",
" }\n",
"\n",
" function renderDetailAttn(x, y, att, layerIndex) {\n",
" var attnContainer = config.svg.append(\"svg:g\")\n",
" .classed(\"detail\", true)\n",
" .attr(\"pointer-events\", \"none\");\n",
" attnContainer.selectAll(\"g\")\n",
" .data(att)\n",
" .enter()\n",
" .append(\"g\") // Add group for each source token\n",
" .classed('attn-line-group', true)\n",
" .attr(\"source-index\", function (d, i) { // Save index of source token\n",
" return i;\n",
" })\n",
" .selectAll(\"line\")\n",
" .data(function (d) { // Loop over all target tokens\n",
" return d;\n",
" })\n",
" .enter()\n",
" .append(\"line\")\n",
" .attr(\"x1\", x + ATTN_PADDING)\n",
" .attr(\"y1\", function (d) {\n",
" var sourceIndex = +this.parentNode.getAttribute(\"source-index\");\n",
" return y + (sourceIndex + .5) * DETAIL_BOX_HEIGHT;\n",
" })\n",
" .attr(\"x2\", x + DETAIL_ATTENTION_WIDTH - ATTN_PADDING)\n",
" .attr(\"y2\", function (d, targetIndex) {\n",
" return y + (targetIndex + .5) * DETAIL_BOX_HEIGHT;\n",
" })\n",
" .attr(\"stroke-width\", 2.2)\n",
" .attr(\"stroke\", getLayerColor(layerIndex))\n",
" .attr(\"stroke-opacity\", function (d) {\n",
" return d;\n",
" });\n",
" }\n",
"\n",
" function getLayerColor(layer) {\n",
" return LAYER_COLORS[layer % 10];\n",
" }\n",
"\n",
" function getTextColor() {\n",
" return PALETTE[config.mode]['text']\n",
" }\n",
"\n",
" function getBackgroundColor() {\n",
" return PALETTE[config.mode]['background']\n",
" }\n",
"\n",
" function getHighlightColor() {\n",
" return PALETTE[config.mode]['highlight']\n",
" }\n",
"\n",
" function initialize() {\n",
" config.attention = params['attention'];\n",
" config.filter = params['default_filter'];\n",
" config.mode = params['display_mode'];\n",
" config.rootDivId = params['root_div_id'];\n",
" $(`#${config.rootDivId} #filter`).on('change', function (e) {\n",
" config.filter = e.currentTarget.value;\n",
" render();\n",
" });\n",
" }\n",
"\n",
" initialize();\n",
" render();\n",
"\n",
" });"
],
"text/plain": [
"<IPython.core.display.Javascript object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"model_view(\n",
" encoder_attention=outputs.encoder_attentions,\n",
" decoder_attention=outputs.decoder_attentions,\n",
" cross_attention=outputs.cross_attentions,\n",
" encoder_tokens= encoder_text,\n",
" decoder_tokens = decoder_text\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Zadanie (10 minut)\n",
"\n",
"Za pomocą modelu en-fr przetłumacz dowolne zdanie z angielskiego na język francuski i sprawdź wagi atencji dla tego tłumaczenia"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### PRZYKŁAD: GPT3"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### ZADANIE DOMOWE - POLEVAL"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.3"
}
},
"nbformat": 4,
"nbformat_minor": 4
}