From 240c16b495e4850fabed4060082ef8031fdac816 Mon Sep 17 00:00:00 2001 From: Patryk Bartkowiak Date: Mon, 21 Oct 2024 20:05:56 +0000 Subject: [PATCH] redesigned code, added functions and granularity --- code/.gitignore | 4 + code/data/.gitignore | 2 + code/models/.gitignore | 2 + code/pdm.lock | 2 +- code/pyproject.toml | 14 +- code/src/config.json | 11 + code/src/download_the_stack.ipynb | 262 ------------------ code/src/preprocess_data.py | 58 ---- code/src/the_stack_test.ipynb | 246 ----------------- code/src/train_codebert_mlm.py | 429 +++++++++++++++--------------- 10 files changed, 237 insertions(+), 793 deletions(-) create mode 100644 code/data/.gitignore create mode 100644 code/models/.gitignore create mode 100644 code/src/config.json delete mode 100644 code/src/download_the_stack.ipynb delete mode 100644 code/src/preprocess_data.py delete mode 100644 code/src/the_stack_test.ipynb diff --git a/code/.gitignore b/code/.gitignore index 3a8816c..d51ef0e 100644 --- a/code/.gitignore +++ b/code/.gitignore @@ -160,3 +160,7 @@ cython_debug/ # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ + +# Weights & Biases +wandb/ +outputs/ diff --git a/code/data/.gitignore b/code/data/.gitignore new file mode 100644 index 0000000..d6b7ef3 --- /dev/null +++ b/code/data/.gitignore @@ -0,0 +1,2 @@ +* +!.gitignore diff --git a/code/models/.gitignore b/code/models/.gitignore new file mode 100644 index 0000000..d6b7ef3 --- /dev/null +++ b/code/models/.gitignore @@ -0,0 +1,2 @@ +* +!.gitignore diff --git a/code/pdm.lock b/code/pdm.lock index d359502..5c28202 100644 --- a/code/pdm.lock +++ b/code/pdm.lock @@ -5,7 +5,7 @@ groups = ["default"] strategy = ["inherit_metadata"] lock_version = "4.5.0" -content_hash = "sha256:ac6621f3bd9193d786ab94f80f8b1711100fe418959f2e131ae03afeab616788" +content_hash = "sha256:bf0a0ea826769cf12a84888d394edd8c3c5599c4d369b8b19b75c2fa5e16f5f0" [[metadata.targets]] requires_python = "==3.11.*" diff --git a/code/pyproject.toml b/code/pyproject.toml index cafd29f..ac23f58 100644 --- a/code/pyproject.toml +++ b/code/pyproject.toml @@ -6,13 +6,13 @@ authors = [ {name = "Patryk Bartkowiak", email = "patbar15@st.amu.edu.pl"}, ] dependencies = [ - "wandb>=0.18.5", - "torch>=2.5.0", - "tqdm>=4.66.5", - "tree-sitter>=0.23.1", - "transformers>=4.45.2", - "datasets>=3.0.1", - "huggingface-hub>=0.26.0", + "wandb==0.18.5", + "torch==2.5.0", + "tqdm==4.66.5", + "tree-sitter==0.23.1", + "transformers==4.45.2", + "datasets==3.0.1", + "huggingface-hub==0.26.0", ] requires-python = "==3.11.*" readme = "README.md" diff --git a/code/src/config.json b/code/src/config.json new file mode 100644 index 0000000..7991809 --- /dev/null +++ b/code/src/config.json @@ -0,0 +1,11 @@ +{ + "seed": 42, + "mlm_probability": 0.15, + "batch": 32, + "epochs": 1, + "eval_every": 10000, + "learning_rate": 5e-4, + "weight_decay": 0.01, + "max_grad_norm": 1.0, + "warmup_steps": 10000 +} \ No newline at end of file diff --git a/code/src/download_the_stack.ipynb b/code/src/download_the_stack.ipynb deleted file mode 100644 index 82c518e..0000000 --- a/code/src/download_the_stack.ipynb +++ /dev/null @@ -1,262 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from huggingface_hub import list_repo_files, hf_hub_download\n", - "import os" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Found 5144 files in the repository.\n", - "['data/python/data-00000-of-00144.parquet', 'data/python/data-00001-of-00144.parquet', 'data/python/data-00002-of-00144.parquet', 'data/python/data-00003-of-00144.parquet', 'data/python/data-00004-of-00144.parquet', 'data/python/data-00005-of-00144.parquet', 'data/python/data-00006-of-00144.parquet', 'data/python/data-00007-of-00144.parquet', 'data/python/data-00008-of-00144.parquet', 'data/python/data-00009-of-00144.parquet', 'data/python/data-00010-of-00144.parquet', 'data/python/data-00011-of-00144.parquet', 'data/python/data-00012-of-00144.parquet', 'data/python/data-00013-of-00144.parquet', 'data/python/data-00014-of-00144.parquet', 'data/python/data-00015-of-00144.parquet', 'data/python/data-00016-of-00144.parquet', 'data/python/data-00017-of-00144.parquet', 'data/python/data-00018-of-00144.parquet', 'data/python/data-00019-of-00144.parquet', 'data/python/data-00020-of-00144.parquet', 'data/python/data-00021-of-00144.parquet', 'data/python/data-00022-of-00144.parquet', 'data/python/data-00023-of-00144.parquet', 'data/python/data-00024-of-00144.parquet', 'data/python/data-00025-of-00144.parquet', 'data/python/data-00026-of-00144.parquet', 'data/python/data-00027-of-00144.parquet', 'data/python/data-00028-of-00144.parquet', 'data/python/data-00029-of-00144.parquet', 'data/python/data-00030-of-00144.parquet', 'data/python/data-00031-of-00144.parquet', 'data/python/data-00032-of-00144.parquet', 'data/python/data-00033-of-00144.parquet', 'data/python/data-00034-of-00144.parquet', 'data/python/data-00035-of-00144.parquet', 'data/python/data-00036-of-00144.parquet', 'data/python/data-00037-of-00144.parquet', 'data/python/data-00038-of-00144.parquet', 'data/python/data-00039-of-00144.parquet', 'data/python/data-00040-of-00144.parquet', 'data/python/data-00041-of-00144.parquet', 'data/python/data-00042-of-00144.parquet', 'data/python/data-00043-of-00144.parquet', 'data/python/data-00044-of-00144.parquet', 'data/python/data-00045-of-00144.parquet', 'data/python/data-00046-of-00144.parquet', 'data/python/data-00047-of-00144.parquet', 'data/python/data-00048-of-00144.parquet', 'data/python/data-00049-of-00144.parquet', 'data/python/data-00050-of-00144.parquet', 'data/python/data-00051-of-00144.parquet', 'data/python/data-00052-of-00144.parquet', 'data/python/data-00053-of-00144.parquet', 'data/python/data-00054-of-00144.parquet', 'data/python/data-00055-of-00144.parquet', 'data/python/data-00056-of-00144.parquet', 'data/python/data-00057-of-00144.parquet', 'data/python/data-00058-of-00144.parquet', 'data/python/data-00059-of-00144.parquet', 'data/python/data-00060-of-00144.parquet', 'data/python/data-00061-of-00144.parquet', 'data/python/data-00062-of-00144.parquet', 'data/python/data-00063-of-00144.parquet', 'data/python/data-00064-of-00144.parquet', 'data/python/data-00065-of-00144.parquet', 'data/python/data-00066-of-00144.parquet', 'data/python/data-00067-of-00144.parquet', 'data/python/data-00068-of-00144.parquet', 'data/python/data-00069-of-00144.parquet', 'data/python/data-00070-of-00144.parquet', 'data/python/data-00071-of-00144.parquet', 'data/python/data-00072-of-00144.parquet', 'data/python/data-00073-of-00144.parquet', 'data/python/data-00074-of-00144.parquet', 'data/python/data-00075-of-00144.parquet', 'data/python/data-00076-of-00144.parquet', 'data/python/data-00077-of-00144.parquet', 'data/python/data-00078-of-00144.parquet', 'data/python/data-00079-of-00144.parquet', 'data/python/data-00080-of-00144.parquet', 'data/python/data-00081-of-00144.parquet', 'data/python/data-00082-of-00144.parquet', 'data/python/data-00083-of-00144.parquet', 'data/python/data-00084-of-00144.parquet', 'data/python/data-00085-of-00144.parquet', 'data/python/data-00086-of-00144.parquet', 'data/python/data-00087-of-00144.parquet', 'data/python/data-00088-of-00144.parquet', 'data/python/data-00089-of-00144.parquet', 'data/python/data-00090-of-00144.parquet', 'data/python/data-00091-of-00144.parquet', 'data/python/data-00092-of-00144.parquet', 'data/python/data-00093-of-00144.parquet', 'data/python/data-00094-of-00144.parquet', 'data/python/data-00095-of-00144.parquet', 'data/python/data-00096-of-00144.parquet', 'data/python/data-00097-of-00144.parquet', 'data/python/data-00098-of-00144.parquet', 'data/python/data-00099-of-00144.parquet', 'data/python/data-00100-of-00144.parquet', 'data/python/data-00101-of-00144.parquet', 'data/python/data-00102-of-00144.parquet', 'data/python/data-00103-of-00144.parquet', 'data/python/data-00104-of-00144.parquet', 'data/python/data-00105-of-00144.parquet', 'data/python/data-00106-of-00144.parquet', 'data/python/data-00107-of-00144.parquet', 'data/python/data-00108-of-00144.parquet', 'data/python/data-00109-of-00144.parquet', 'data/python/data-00110-of-00144.parquet', 'data/python/data-00111-of-00144.parquet', 'data/python/data-00112-of-00144.parquet', 'data/python/data-00113-of-00144.parquet', 'data/python/data-00114-of-00144.parquet', 'data/python/data-00115-of-00144.parquet', 'data/python/data-00116-of-00144.parquet', 'data/python/data-00117-of-00144.parquet', 'data/python/data-00118-of-00144.parquet', 'data/python/data-00119-of-00144.parquet', 'data/python/data-00120-of-00144.parquet', 'data/python/data-00121-of-00144.parquet', 'data/python/data-00122-of-00144.parquet', 'data/python/data-00123-of-00144.parquet', 'data/python/data-00124-of-00144.parquet', 'data/python/data-00125-of-00144.parquet', 'data/python/data-00126-of-00144.parquet', 'data/python/data-00127-of-00144.parquet', 'data/python/data-00128-of-00144.parquet', 'data/python/data-00129-of-00144.parquet', 'data/python/data-00130-of-00144.parquet', 'data/python/data-00131-of-00144.parquet', 'data/python/data-00132-of-00144.parquet', 'data/python/data-00133-of-00144.parquet', 'data/python/data-00134-of-00144.parquet', 'data/python/data-00135-of-00144.parquet', 'data/python/data-00136-of-00144.parquet', 'data/python/data-00137-of-00144.parquet', 'data/python/data-00138-of-00144.parquet', 'data/python/data-00139-of-00144.parquet', 'data/python/data-00140-of-00144.parquet', 'data/python/data-00141-of-00144.parquet', 'data/python/data-00142-of-00144.parquet', 'data/python/data-00143-of-00144.parquet']\n", - "Downloading data/python/data-00000-of-00144.parquet...\n", - "Downloading data/python/data-00001-of-00144.parquet...\n", - "Downloading data/python/data-00002-of-00144.parquet...\n", - "Downloading data/python/data-00003-of-00144.parquet...\n", - "Downloading data/python/data-00004-of-00144.parquet...\n", - "Downloading data/python/data-00005-of-00144.parquet...\n", - "Downloading data/python/data-00006-of-00144.parquet...\n", - "Downloading data/python/data-00007-of-00144.parquet...\n", - "Downloading data/python/data-00008-of-00144.parquet...\n", - "Downloading data/python/data-00009-of-00144.parquet...\n", - "Downloading data/python/data-00010-of-00144.parquet...\n", - "Downloading data/python/data-00011-of-00144.parquet...\n", - "Downloading data/python/data-00012-of-00144.parquet...\n", - "Downloading data/python/data-00013-of-00144.parquet...\n", - "Downloading data/python/data-00014-of-00144.parquet...\n", - "Downloading data/python/data-00015-of-00144.parquet...\n", - "Downloading data/python/data-00016-of-00144.parquet...\n", - "Downloading data/python/data-00017-of-00144.parquet...\n", - "Downloading data/python/data-00018-of-00144.parquet...\n", - "Downloading data/python/data-00019-of-00144.parquet...\n", - "Downloading data/python/data-00020-of-00144.parquet...\n", - "Downloading data/python/data-00021-of-00144.parquet...\n", - "Downloading data/python/data-00022-of-00144.parquet...\n", - "Downloading data/python/data-00023-of-00144.parquet...\n", - "Downloading data/python/data-00024-of-00144.parquet...\n", - "Downloading data/python/data-00025-of-00144.parquet...\n", - "Downloading data/python/data-00026-of-00144.parquet...\n", - "Downloading data/python/data-00027-of-00144.parquet...\n", - "Downloading data/python/data-00028-of-00144.parquet...\n", - "Downloading data/python/data-00029-of-00144.parquet...\n", - "Downloading data/python/data-00030-of-00144.parquet...\n", - "Downloading data/python/data-00031-of-00144.parquet...\n", - "Downloading data/python/data-00032-of-00144.parquet...\n", - "Downloading data/python/data-00033-of-00144.parquet...\n", - "Downloading data/python/data-00034-of-00144.parquet...\n", - "Downloading data/python/data-00035-of-00144.parquet...\n", - "Downloading data/python/data-00036-of-00144.parquet...\n", - "Downloading data/python/data-00037-of-00144.parquet...\n", - "Downloading data/python/data-00038-of-00144.parquet...\n", - "Downloading data/python/data-00039-of-00144.parquet...\n", - "Downloading data/python/data-00040-of-00144.parquet...\n", - "Downloading data/python/data-00041-of-00144.parquet...\n", - "Downloading data/python/data-00042-of-00144.parquet...\n", - "Downloading data/python/data-00043-of-00144.parquet...\n", - "Downloading data/python/data-00044-of-00144.parquet...\n", - "Downloading data/python/data-00045-of-00144.parquet...\n", - "Downloading data/python/data-00046-of-00144.parquet...\n", - "Downloading data/python/data-00047-of-00144.parquet...\n", - "Downloading data/python/data-00048-of-00144.parquet...\n", - "Downloading data/python/data-00049-of-00144.parquet...\n", - "Downloading data/python/data-00050-of-00144.parquet...\n", - "Downloading data/python/data-00051-of-00144.parquet...\n", - "Downloading data/python/data-00052-of-00144.parquet...\n", - "Downloading data/python/data-00053-of-00144.parquet...\n", - "Downloading data/python/data-00054-of-00144.parquet...\n", - "Downloading data/python/data-00055-of-00144.parquet...\n", - "Downloading data/python/data-00056-of-00144.parquet...\n", - "Downloading data/python/data-00057-of-00144.parquet...\n", - "Downloading data/python/data-00058-of-00144.parquet...\n", - "Downloading data/python/data-00059-of-00144.parquet...\n", - "Downloading data/python/data-00060-of-00144.parquet...\n", - "Downloading data/python/data-00061-of-00144.parquet...\n", - "Downloading data/python/data-00062-of-00144.parquet...\n", - "Downloading data/python/data-00063-of-00144.parquet...\n", - "Downloading data/python/data-00064-of-00144.parquet...\n", - "Downloading data/python/data-00065-of-00144.parquet...\n", - "Downloading data/python/data-00066-of-00144.parquet...\n", - "Downloading data/python/data-00067-of-00144.parquet...\n", - "Downloading data/python/data-00068-of-00144.parquet...\n", - "Downloading data/python/data-00069-of-00144.parquet...\n", - "Downloading data/python/data-00070-of-00144.parquet...\n", - "Downloading data/python/data-00071-of-00144.parquet...\n", - "Downloading data/python/data-00072-of-00144.parquet...\n", - "Downloading data/python/data-00073-of-00144.parquet...\n", - "Downloading data/python/data-00074-of-00144.parquet...\n", - "Downloading data/python/data-00075-of-00144.parquet...\n", - "Downloading data/python/data-00076-of-00144.parquet...\n", - "Downloading data/python/data-00077-of-00144.parquet...\n", - "Downloading data/python/data-00078-of-00144.parquet...\n", - "Downloading data/python/data-00079-of-00144.parquet...\n", - "Downloading data/python/data-00080-of-00144.parquet...\n", - "Downloading data/python/data-00081-of-00144.parquet...\n", - "Downloading data/python/data-00082-of-00144.parquet...\n", - "Downloading data/python/data-00083-of-00144.parquet...\n", - "Downloading data/python/data-00084-of-00144.parquet...\n", - "Downloading data/python/data-00085-of-00144.parquet...\n", - "Downloading data/python/data-00086-of-00144.parquet...\n", - "Downloading data/python/data-00087-of-00144.parquet...\n", - "Downloading data/python/data-00088-of-00144.parquet...\n", - "Downloading data/python/data-00089-of-00144.parquet...\n", - "Downloading data/python/data-00090-of-00144.parquet...\n", - "Downloading data/python/data-00091-of-00144.parquet...\n", - "Downloading data/python/data-00092-of-00144.parquet...\n", - "Downloading data/python/data-00093-of-00144.parquet...\n", - "Downloading data/python/data-00094-of-00144.parquet...\n", - "Downloading data/python/data-00095-of-00144.parquet...\n", - "Downloading data/python/data-00096-of-00144.parquet...\n", - "Downloading data/python/data-00097-of-00144.parquet...\n", - "Downloading data/python/data-00098-of-00144.parquet...\n", - "Downloading data/python/data-00099-of-00144.parquet...\n", - "Downloading data/python/data-00100-of-00144.parquet...\n", - "Downloading data/python/data-00101-of-00144.parquet...\n", - "Downloading data/python/data-00102-of-00144.parquet...\n", - "Downloading data/python/data-00103-of-00144.parquet...\n", - "Downloading data/python/data-00104-of-00144.parquet...\n", - "Downloading data/python/data-00105-of-00144.parquet...\n", - "Downloading data/python/data-00106-of-00144.parquet...\n", - "Downloading data/python/data-00107-of-00144.parquet...\n", - "Downloading data/python/data-00108-of-00144.parquet...\n", - "Downloading data/python/data-00109-of-00144.parquet...\n", - "Downloading data/python/data-00110-of-00144.parquet...\n", - "Downloading data/python/data-00111-of-00144.parquet...\n", - "Downloading data/python/data-00112-of-00144.parquet...\n", - "Downloading data/python/data-00113-of-00144.parquet...\n", - "Downloading data/python/data-00114-of-00144.parquet...\n", - "Downloading data/python/data-00115-of-00144.parquet...\n", - "Downloading data/python/data-00116-of-00144.parquet...\n", - "Downloading data/python/data-00117-of-00144.parquet...\n", - "Downloading data/python/data-00118-of-00144.parquet...\n", - "Downloading data/python/data-00119-of-00144.parquet...\n", - "Downloading data/python/data-00120-of-00144.parquet...\n", - "Downloading data/python/data-00121-of-00144.parquet...\n", - "Downloading data/python/data-00122-of-00144.parquet...\n", - "Downloading data/python/data-00123-of-00144.parquet...\n", - "Downloading data/python/data-00124-of-00144.parquet...\n", - "Downloading data/python/data-00125-of-00144.parquet...\n", - "Downloading data/python/data-00126-of-00144.parquet...\n", - "Downloading data/python/data-00127-of-00144.parquet...\n", - "Downloading data/python/data-00128-of-00144.parquet...\n", - "Downloading data/python/data-00129-of-00144.parquet...\n", - "Downloading data/python/data-00130-of-00144.parquet...\n", - "Downloading data/python/data-00131-of-00144.parquet...\n", - "Downloading data/python/data-00132-of-00144.parquet...\n", - "Downloading data/python/data-00133-of-00144.parquet...\n", - "Downloading data/python/data-00134-of-00144.parquet...\n", - "Downloading data/python/data-00135-of-00144.parquet...\n", - "Downloading data/python/data-00136-of-00144.parquet...\n", - "Downloading data/python/data-00137-of-00144.parquet...\n", - "Downloading data/python/data-00138-of-00144.parquet...\n", - "Downloading data/python/data-00139-of-00144.parquet...\n", - "Downloading data/python/data-00140-of-00144.parquet...\n", - "Downloading data/python/data-00141-of-00144.parquet...\n", - "Downloading data/python/data-00142-of-00144.parquet...\n", - "Downloading data/python/data-00143-of-00144.parquet...\n", - "All files have been downloaded successfully.\n" - ] - } - ], - "source": [ - "# Define the repository details\n", - "repo_id = \"bigcode/the-stack-dedup\" # Repository ID for the dataset\n", - "subfolder = \"data/python\" # Subfolder path within the repository\n", - "repo_type = \"dataset\" # Specify that it's a dataset\n", - "\n", - "# Specify the local directory to save the files\n", - "local_dir = \"/work/s452638/datasets/the-stack-python\"\n", - "os.makedirs(local_dir, exist_ok=True)\n", - "\n", - "# List all files in the repository's subfolder\n", - "files_list = list_repo_files(repo_id=repo_id, repo_type=repo_type)\n", - "\n", - "# Filter files in the desired subfolder\n", - "files_to_download = [file for file in files_list if file.startswith(f'{subfolder}/')]\n", - "\n", - "print(f\"Found {len(files_to_download)} files in the repository.\")\n", - "\n", - "# Download each file\n", - "for file_name in files_to_download:\n", - " print(f\"Downloading {file_name}...\")\n", - " hf_hub_download(repo_id=repo_id, repo_type=repo_type, filename=file_name, local_dir=local_dir)\n", - "\n", - "print(\"All files have been downloaded successfully.\")" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Found 6 files in the repository.\n", - "Downloading data/Python/train-00004-of-00006.parquet...\n", - "Downloading data/Python/train-00005-of-00006.parquet...\n", - "All files have been downloaded successfully.\n" - ] - } - ], - "source": [ - "# Define the repository details\n", - "repo_id = \"bigcode/the-stack-v2-dedup\" # Repository ID for the dataset\n", - "subfolder = \"data/Python\" # Subfolder path within the repository\n", - "repo_type = \"dataset\" # Specify that it's a dataset\n", - "\n", - "# Specify the local directory to save the files\n", - "local_dir = \"/work/s452638/datasets/the-stack-v2-python\"\n", - "os.makedirs(local_dir, exist_ok=True)\n", - "\n", - "# List all files in the repository's subfolder\n", - "files_list = list_repo_files(repo_id=repo_id, repo_type=repo_type)\n", - "\n", - "# Filter files in the desired subfolder\n", - "files_to_download = [file for file in files_list if file.startswith(f'{subfolder}/')]\n", - "\n", - "print(f\"Found {len(files_to_download)} files in the repository.\")\n", - "\n", - "# Download each file\n", - "for file_name in files_to_download[4:]:\n", - " print(f\"Downloading {file_name}...\")\n", - " hf_hub_download(repo_id=repo_id, repo_type=repo_type, filename=file_name, local_dir=local_dir)\n", - "\n", - "print(\"All files have been downloaded successfully.\")" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.8.19" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/code/src/preprocess_data.py b/code/src/preprocess_data.py deleted file mode 100644 index 19d77d3..0000000 --- a/code/src/preprocess_data.py +++ /dev/null @@ -1,58 +0,0 @@ -from datasets import load_dataset, disable_caching -from transformers import RobertaTokenizer - -disable_caching() - - -def visible_print(text): - print('\n\n') - print('=' * 100) - print(text) - print('=' * 100) - print('\n\n') - - -if __name__ == '__main__': - # Load the dataset - train_data = load_dataset('/work/s452638/datasets/the-stack-python', split='train') - valid_data = load_dataset('json', data_files='/work/s452638/datasets/CodeSearchNet/python/valid.jsonl')['train'] - test_data = load_dataset('json', data_files='/work/s452638/datasets/CodeSearchNet/python/test.jsonl')['train'] - - visible_print('Loaded data') - - # Rename the columns - train_data = train_data.rename_column('content', 'code') - - # Remove all the columns except the code - train_columns = train_data.column_names - valid_columns = valid_data.column_names - test_columns = test_data.column_names - - train_columns.remove('code') - valid_columns.remove('code') - test_columns.remove('code') - - train_data = train_data.remove_columns(train_columns) - valid_data = valid_data.remove_columns(valid_columns) - test_data = test_data.remove_columns(test_columns) - - visible_print('Removed unnecessary columns') - - # Tokenize the data - tokenizer = RobertaTokenizer.from_pretrained('microsoft/codebert-base', clean_up_tokenization_spaces=True) - - def tokenize_function(examples): - return tokenizer(examples['code'], truncation=True, padding='max_length', max_length=512, return_tensors='pt') - - train_data = train_data.map(tokenize_function, batched=True, remove_columns=['code'], desc='[Train] Running tokenizer', num_proc=8) - valid_data = valid_data.map(tokenize_function, batched=True, remove_columns=['code'], desc='[Valid] Running tokenizer', num_proc=8) - test_data = test_data.map(tokenize_function, batched=True, remove_columns=['code'], desc='[Test] Running tokenizer', num_proc=8) - - visible_print('Tokenized data') - - # Save the tokenized data - train_data.save_to_disk('/work/s452638/datasets/the-stack-python-tokenized/train') - valid_data.save_to_disk('/work/s452638/datasets/the-stack-python-tokenized/valid') - test_data.save_to_disk('/work/s452638/datasets/the-stack-python-tokenized/test') - - visible_print('Saved tokenized data') \ No newline at end of file diff --git a/code/src/the_stack_test.ipynb b/code/src/the_stack_test.ipynb deleted file mode 100644 index 08a80be..0000000 --- a/code/src/the_stack_test.ipynb +++ /dev/null @@ -1,246 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/s452638/magisterka/magisterka_env/lib/python3.8/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", - " from .autonotebook import tqdm as notebook_tqdm\n" - ] - } - ], - "source": [ - "import torch\n", - "from torch.utils.data import DataLoader\n", - "from datasets import load_dataset, disable_caching, DatasetDict\n", - "from transformers import RobertaForMaskedLM, RobertaConfig, RobertaTokenizer, DataCollatorForLanguageModeling\n", - "\n", - "disable_caching()" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "Dataset({\n", - " features: ['hexsha', 'size', 'ext', 'lang', 'max_stars_repo_path', 'max_stars_repo_name', 'max_stars_repo_head_hexsha', 'max_stars_repo_licenses', 'max_stars_count', 'max_stars_repo_stars_event_min_datetime', 'max_stars_repo_stars_event_max_datetime', 'max_issues_repo_path', 'max_issues_repo_name', 'max_issues_repo_head_hexsha', 'max_issues_repo_licenses', 'max_issues_count', 'max_issues_repo_issues_event_min_datetime', 'max_issues_repo_issues_event_max_datetime', 'max_forks_repo_path', 'max_forks_repo_name', 'max_forks_repo_head_hexsha', 'max_forks_repo_licenses', 'max_forks_count', 'max_forks_repo_forks_event_min_datetime', 'max_forks_repo_forks_event_max_datetime', 'content', 'avg_line_length', 'max_line_length', 'alphanum_fraction'],\n", - " num_rows: 12962249\n", - "})" - ] - }, - "execution_count": 2, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "train_data = load_dataset(\"/work/s452638/datasets/the-stack-python\", split=\"train\")\n", - "train_data" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "DatasetDict({\n", - " train: Dataset({\n", - " features: ['repo', 'path', 'func_name', 'original_string', 'language', 'code', 'code_tokens', 'docstring', 'docstring_tokens', 'sha', 'url', 'partition'],\n", - " num_rows: 13914\n", - " })\n", - "})" - ] - }, - "execution_count": 3, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "valid_data = load_dataset('json', data_files='/work/s452638/datasets/CodeSearchNet/python/valid.jsonl')\n", - "valid_data" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "DatasetDict({\n", - " train: Dataset({\n", - " features: ['repo', 'path', 'func_name', 'original_string', 'language', 'code', 'code_tokens', 'docstring', 'docstring_tokens', 'sha', 'url', 'partition'],\n", - " num_rows: 14918\n", - " })\n", - "})" - ] - }, - "execution_count": 4, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "test_data = load_dataset('json', data_files='/work/s452638/datasets/CodeSearchNet/python/test.jsonl')\n", - "test_data" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [], - "source": [ - "train_data = train_data.rename_column('content', 'code')" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "RobertaTokenizer(name_or_path='microsoft/codebert-base', vocab_size=50265, model_max_length=512, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'bos_token': '', 'eos_token': '', 'unk_token': '', 'sep_token': '', 'pad_token': '', 'cls_token': '', 'mask_token': ''}, clean_up_tokenization_spaces=True), added_tokens_decoder={\n", - "\t0: AddedToken(\"\", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),\n", - "\t1: AddedToken(\"\", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),\n", - "\t2: AddedToken(\"\", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),\n", - "\t3: AddedToken(\"\", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),\n", - "\t50264: AddedToken(\"\", rstrip=False, lstrip=True, single_word=False, normalized=False, special=True),\n", - "}" - ] - }, - "execution_count": 6, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "tokenizer = RobertaTokenizer.from_pretrained('microsoft/codebert-base', clean_up_tokenization_spaces=True)\n", - "tokenizer" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Running tokenizer: 0%| | 1000/12962249 [00:17<61:47:17, 58.27 examples/s]\n" - ] - }, - { - "ename": "KeyboardInterrupt", - "evalue": "", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[7], line 4\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mtokenize_function\u001b[39m(examples):\n\u001b[1;32m 2\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m tokenizer(examples[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mcode\u001b[39m\u001b[38;5;124m'\u001b[39m], truncation\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m, padding\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mmax_length\u001b[39m\u001b[38;5;124m'\u001b[39m, max_length\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m512\u001b[39m, return_tensors\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mpt\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[0;32m----> 4\u001b[0m train_data \u001b[38;5;241m=\u001b[39m \u001b[43mtrain_data\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmap\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtokenize_function\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbatched\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mremove_columns\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mcode\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdesc\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mRunning tokenizer\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m 5\u001b[0m valid_data \u001b[38;5;241m=\u001b[39m valid_data\u001b[38;5;241m.\u001b[39mmap(tokenize_function, batched\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m, remove_columns\u001b[38;5;241m=\u001b[39m[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mcode\u001b[39m\u001b[38;5;124m'\u001b[39m], desc\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mRunning tokenizer\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[1;32m 6\u001b[0m test_data \u001b[38;5;241m=\u001b[39m test_data\u001b[38;5;241m.\u001b[39mmap(tokenize_function, batched\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m, remove_columns\u001b[38;5;241m=\u001b[39m[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mcode\u001b[39m\u001b[38;5;124m'\u001b[39m], desc\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mRunning tokenizer\u001b[39m\u001b[38;5;124m'\u001b[39m)\n", - "File \u001b[0;32m~/magisterka/magisterka_env/lib/python3.8/site-packages/datasets/arrow_dataset.py:602\u001b[0m, in \u001b[0;36mtransmit_tasks..wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 600\u001b[0m \u001b[38;5;28mself\u001b[39m: \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mDataset\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;241m=\u001b[39m kwargs\u001b[38;5;241m.\u001b[39mpop(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mself\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 601\u001b[0m \u001b[38;5;66;03m# apply actual function\u001b[39;00m\n\u001b[0;32m--> 602\u001b[0m out: Union[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mDataset\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mDatasetDict\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 603\u001b[0m datasets: List[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mDataset\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mlist\u001b[39m(out\u001b[38;5;241m.\u001b[39mvalues()) \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(out, \u001b[38;5;28mdict\u001b[39m) \u001b[38;5;28;01melse\u001b[39;00m [out]\n\u001b[1;32m 604\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m dataset \u001b[38;5;129;01min\u001b[39;00m datasets:\n\u001b[1;32m 605\u001b[0m \u001b[38;5;66;03m# Remove task templates if a column mapping of the template is no longer valid\u001b[39;00m\n", - "File \u001b[0;32m~/magisterka/magisterka_env/lib/python3.8/site-packages/datasets/arrow_dataset.py:567\u001b[0m, in \u001b[0;36mtransmit_format..wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 560\u001b[0m self_format \u001b[38;5;241m=\u001b[39m {\n\u001b[1;32m 561\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtype\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_format_type,\n\u001b[1;32m 562\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mformat_kwargs\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_format_kwargs,\n\u001b[1;32m 563\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcolumns\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_format_columns,\n\u001b[1;32m 564\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124moutput_all_columns\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_output_all_columns,\n\u001b[1;32m 565\u001b[0m }\n\u001b[1;32m 566\u001b[0m \u001b[38;5;66;03m# apply actual function\u001b[39;00m\n\u001b[0;32m--> 567\u001b[0m out: Union[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mDataset\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mDatasetDict\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 568\u001b[0m datasets: List[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mDataset\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mlist\u001b[39m(out\u001b[38;5;241m.\u001b[39mvalues()) \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(out, \u001b[38;5;28mdict\u001b[39m) \u001b[38;5;28;01melse\u001b[39;00m [out]\n\u001b[1;32m 569\u001b[0m \u001b[38;5;66;03m# re-apply format to the output\u001b[39;00m\n", - "File \u001b[0;32m~/magisterka/magisterka_env/lib/python3.8/site-packages/datasets/arrow_dataset.py:3161\u001b[0m, in \u001b[0;36mDataset.map\u001b[0;34m(self, function, with_indices, with_rank, input_columns, batched, batch_size, drop_last_batch, remove_columns, keep_in_memory, load_from_cache_file, cache_file_name, writer_batch_size, features, disable_nullable, fn_kwargs, num_proc, suffix_template, new_fingerprint, desc)\u001b[0m\n\u001b[1;32m 3155\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m transformed_dataset \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 3156\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m hf_tqdm(\n\u001b[1;32m 3157\u001b[0m unit\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m examples\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 3158\u001b[0m total\u001b[38;5;241m=\u001b[39mpbar_total,\n\u001b[1;32m 3159\u001b[0m desc\u001b[38;5;241m=\u001b[39mdesc \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mMap\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 3160\u001b[0m ) \u001b[38;5;28;01mas\u001b[39;00m pbar:\n\u001b[0;32m-> 3161\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m rank, done, content \u001b[38;5;129;01min\u001b[39;00m Dataset\u001b[38;5;241m.\u001b[39m_map_single(\u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mdataset_kwargs):\n\u001b[1;32m 3162\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m done:\n\u001b[1;32m 3163\u001b[0m shards_done \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n", - "File \u001b[0;32m~/magisterka/magisterka_env/lib/python3.8/site-packages/datasets/arrow_dataset.py:3552\u001b[0m, in \u001b[0;36mDataset._map_single\u001b[0;34m(shard, function, with_indices, with_rank, input_columns, batched, batch_size, drop_last_batch, remove_columns, keep_in_memory, cache_file_name, writer_batch_size, features, disable_nullable, fn_kwargs, new_fingerprint, rank, offset)\u001b[0m\n\u001b[1;32m 3548\u001b[0m indices \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mlist\u001b[39m(\n\u001b[1;32m 3549\u001b[0m \u001b[38;5;28mrange\u001b[39m(\u001b[38;5;241m*\u001b[39m(\u001b[38;5;28mslice\u001b[39m(i, i \u001b[38;5;241m+\u001b[39m batch_size)\u001b[38;5;241m.\u001b[39mindices(shard\u001b[38;5;241m.\u001b[39mnum_rows)))\n\u001b[1;32m 3550\u001b[0m ) \u001b[38;5;66;03m# Something simpler?\u001b[39;00m\n\u001b[1;32m 3551\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m-> 3552\u001b[0m batch \u001b[38;5;241m=\u001b[39m \u001b[43mapply_function_on_filtered_inputs\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 3553\u001b[0m \u001b[43m \u001b[49m\u001b[43mbatch\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3554\u001b[0m \u001b[43m \u001b[49m\u001b[43mindices\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3555\u001b[0m \u001b[43m \u001b[49m\u001b[43mcheck_same_num_examples\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mlen\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mshard\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlist_indexes\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m>\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3556\u001b[0m \u001b[43m \u001b[49m\u001b[43moffset\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moffset\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3557\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 3558\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m NumExamplesMismatchError:\n\u001b[1;32m 3559\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m DatasetTransformationNotAllowedError(\n\u001b[1;32m 3560\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mUsing `.map` in batched mode on a dataset with attached indexes is allowed only if it doesn\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mt create or remove existing examples. You can first run `.drop_index() to remove your index and then re-add it.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 3561\u001b[0m ) \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n", - "File \u001b[0;32m~/magisterka/magisterka_env/lib/python3.8/site-packages/datasets/arrow_dataset.py:3421\u001b[0m, in \u001b[0;36mDataset._map_single..apply_function_on_filtered_inputs\u001b[0;34m(pa_inputs, indices, check_same_num_examples, offset)\u001b[0m\n\u001b[1;32m 3419\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m with_rank:\n\u001b[1;32m 3420\u001b[0m additional_args \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m (rank,)\n\u001b[0;32m-> 3421\u001b[0m processed_inputs \u001b[38;5;241m=\u001b[39m \u001b[43mfunction\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mfn_args\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43madditional_args\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mfn_kwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 3422\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(processed_inputs, LazyDict):\n\u001b[1;32m 3423\u001b[0m processed_inputs \u001b[38;5;241m=\u001b[39m {\n\u001b[1;32m 3424\u001b[0m k: v \u001b[38;5;28;01mfor\u001b[39;00m k, v \u001b[38;5;129;01min\u001b[39;00m processed_inputs\u001b[38;5;241m.\u001b[39mdata\u001b[38;5;241m.\u001b[39mitems() \u001b[38;5;28;01mif\u001b[39;00m k \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;129;01min\u001b[39;00m processed_inputs\u001b[38;5;241m.\u001b[39mkeys_to_format\n\u001b[1;32m 3425\u001b[0m }\n", - "Cell \u001b[0;32mIn[7], line 2\u001b[0m, in \u001b[0;36mtokenize_function\u001b[0;34m(examples)\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mtokenize_function\u001b[39m(examples):\n\u001b[0;32m----> 2\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mtokenizer\u001b[49m\u001b[43m(\u001b[49m\u001b[43mexamples\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mcode\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtruncation\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mpadding\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mmax_length\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmax_length\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m512\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mreturn_tensors\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mpt\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/magisterka/magisterka_env/lib/python3.8/site-packages/transformers/tokenization_utils_base.py:3055\u001b[0m, in \u001b[0;36mPreTrainedTokenizerBase.__call__\u001b[0;34m(self, text, text_pair, text_target, text_pair_target, add_special_tokens, padding, truncation, max_length, stride, is_split_into_words, pad_to_multiple_of, return_tensors, return_token_type_ids, return_attention_mask, return_overflowing_tokens, return_special_tokens_mask, return_offsets_mapping, return_length, verbose, **kwargs)\u001b[0m\n\u001b[1;32m 3053\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_in_target_context_manager:\n\u001b[1;32m 3054\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_switch_to_input_mode()\n\u001b[0;32m-> 3055\u001b[0m encodings \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_one\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtext\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtext\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtext_pair\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtext_pair\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mall_kwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 3056\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m text_target \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 3057\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_switch_to_target_mode()\n", - "File \u001b[0;32m~/magisterka/magisterka_env/lib/python3.8/site-packages/transformers/tokenization_utils_base.py:3142\u001b[0m, in \u001b[0;36mPreTrainedTokenizerBase._call_one\u001b[0;34m(self, text, text_pair, add_special_tokens, padding, truncation, max_length, stride, is_split_into_words, pad_to_multiple_of, return_tensors, return_token_type_ids, return_attention_mask, return_overflowing_tokens, return_special_tokens_mask, return_offsets_mapping, return_length, verbose, split_special_tokens, **kwargs)\u001b[0m\n\u001b[1;32m 3137\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[1;32m 3138\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mbatch length of `text`: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mlen\u001b[39m(text)\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m does not match batch length of `text_pair`:\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 3139\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mlen\u001b[39m(text_pair)\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 3140\u001b[0m )\n\u001b[1;32m 3141\u001b[0m batch_text_or_text_pairs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mlist\u001b[39m(\u001b[38;5;28mzip\u001b[39m(text, text_pair)) \u001b[38;5;28;01mif\u001b[39;00m text_pair \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m text\n\u001b[0;32m-> 3142\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbatch_encode_plus\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 3143\u001b[0m \u001b[43m \u001b[49m\u001b[43mbatch_text_or_text_pairs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mbatch_text_or_text_pairs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3144\u001b[0m \u001b[43m \u001b[49m\u001b[43madd_special_tokens\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43madd_special_tokens\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3145\u001b[0m \u001b[43m \u001b[49m\u001b[43mpadding\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpadding\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3146\u001b[0m \u001b[43m \u001b[49m\u001b[43mtruncation\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtruncation\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3147\u001b[0m \u001b[43m \u001b[49m\u001b[43mmax_length\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmax_length\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3148\u001b[0m \u001b[43m \u001b[49m\u001b[43mstride\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mstride\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3149\u001b[0m \u001b[43m \u001b[49m\u001b[43mis_split_into_words\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mis_split_into_words\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3150\u001b[0m \u001b[43m \u001b[49m\u001b[43mpad_to_multiple_of\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpad_to_multiple_of\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3151\u001b[0m \u001b[43m \u001b[49m\u001b[43mreturn_tensors\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mreturn_tensors\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3152\u001b[0m \u001b[43m \u001b[49m\u001b[43mreturn_token_type_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mreturn_token_type_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3153\u001b[0m \u001b[43m \u001b[49m\u001b[43mreturn_attention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mreturn_attention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3154\u001b[0m \u001b[43m \u001b[49m\u001b[43mreturn_overflowing_tokens\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mreturn_overflowing_tokens\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3155\u001b[0m \u001b[43m \u001b[49m\u001b[43mreturn_special_tokens_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mreturn_special_tokens_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3156\u001b[0m \u001b[43m \u001b[49m\u001b[43mreturn_offsets_mapping\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mreturn_offsets_mapping\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3157\u001b[0m \u001b[43m \u001b[49m\u001b[43mreturn_length\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mreturn_length\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3158\u001b[0m \u001b[43m \u001b[49m\u001b[43mverbose\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mverbose\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3159\u001b[0m \u001b[43m \u001b[49m\u001b[43msplit_special_tokens\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43msplit_special_tokens\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3160\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3161\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 3162\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 3163\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mencode_plus(\n\u001b[1;32m 3164\u001b[0m text\u001b[38;5;241m=\u001b[39mtext,\n\u001b[1;32m 3165\u001b[0m text_pair\u001b[38;5;241m=\u001b[39mtext_pair,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 3182\u001b[0m \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs,\n\u001b[1;32m 3183\u001b[0m )\n", - "File \u001b[0;32m~/magisterka/magisterka_env/lib/python3.8/site-packages/transformers/tokenization_utils_base.py:3338\u001b[0m, in \u001b[0;36mPreTrainedTokenizerBase.batch_encode_plus\u001b[0;34m(self, batch_text_or_text_pairs, add_special_tokens, padding, truncation, max_length, stride, is_split_into_words, pad_to_multiple_of, return_tensors, return_token_type_ids, return_attention_mask, return_overflowing_tokens, return_special_tokens_mask, return_offsets_mapping, return_length, verbose, split_special_tokens, **kwargs)\u001b[0m\n\u001b[1;32m 3328\u001b[0m \u001b[38;5;66;03m# Backward compatibility for 'truncation_strategy', 'pad_to_max_length'\u001b[39;00m\n\u001b[1;32m 3329\u001b[0m padding_strategy, truncation_strategy, max_length, kwargs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_get_padding_truncation_strategies(\n\u001b[1;32m 3330\u001b[0m padding\u001b[38;5;241m=\u001b[39mpadding,\n\u001b[1;32m 3331\u001b[0m truncation\u001b[38;5;241m=\u001b[39mtruncation,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 3335\u001b[0m \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs,\n\u001b[1;32m 3336\u001b[0m )\n\u001b[0;32m-> 3338\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_batch_encode_plus\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 3339\u001b[0m \u001b[43m \u001b[49m\u001b[43mbatch_text_or_text_pairs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mbatch_text_or_text_pairs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3340\u001b[0m \u001b[43m \u001b[49m\u001b[43madd_special_tokens\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43madd_special_tokens\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3341\u001b[0m \u001b[43m \u001b[49m\u001b[43mpadding_strategy\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpadding_strategy\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3342\u001b[0m \u001b[43m \u001b[49m\u001b[43mtruncation_strategy\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtruncation_strategy\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3343\u001b[0m \u001b[43m \u001b[49m\u001b[43mmax_length\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmax_length\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3344\u001b[0m \u001b[43m \u001b[49m\u001b[43mstride\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mstride\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3345\u001b[0m \u001b[43m \u001b[49m\u001b[43mis_split_into_words\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mis_split_into_words\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3346\u001b[0m \u001b[43m \u001b[49m\u001b[43mpad_to_multiple_of\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpad_to_multiple_of\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3347\u001b[0m \u001b[43m \u001b[49m\u001b[43mreturn_tensors\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mreturn_tensors\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3348\u001b[0m \u001b[43m \u001b[49m\u001b[43mreturn_token_type_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mreturn_token_type_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3349\u001b[0m \u001b[43m \u001b[49m\u001b[43mreturn_attention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mreturn_attention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3350\u001b[0m \u001b[43m \u001b[49m\u001b[43mreturn_overflowing_tokens\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mreturn_overflowing_tokens\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3351\u001b[0m \u001b[43m \u001b[49m\u001b[43mreturn_special_tokens_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mreturn_special_tokens_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3352\u001b[0m \u001b[43m \u001b[49m\u001b[43mreturn_offsets_mapping\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mreturn_offsets_mapping\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3353\u001b[0m \u001b[43m \u001b[49m\u001b[43mreturn_length\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mreturn_length\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3354\u001b[0m \u001b[43m \u001b[49m\u001b[43mverbose\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mverbose\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3355\u001b[0m \u001b[43m \u001b[49m\u001b[43msplit_special_tokens\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43msplit_special_tokens\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3356\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3357\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/magisterka/magisterka_env/lib/python3.8/site-packages/transformers/tokenization_utils.py:882\u001b[0m, in \u001b[0;36mPreTrainedTokenizer._batch_encode_plus\u001b[0;34m(self, batch_text_or_text_pairs, add_special_tokens, padding_strategy, truncation_strategy, max_length, stride, is_split_into_words, pad_to_multiple_of, return_tensors, return_token_type_ids, return_attention_mask, return_overflowing_tokens, return_special_tokens_mask, return_offsets_mapping, return_length, verbose, split_special_tokens, **kwargs)\u001b[0m\n\u001b[1;32m 879\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 880\u001b[0m ids, pair_ids \u001b[38;5;241m=\u001b[39m ids_or_pair_ids\n\u001b[0;32m--> 882\u001b[0m first_ids \u001b[38;5;241m=\u001b[39m \u001b[43mget_input_ids\u001b[49m\u001b[43m(\u001b[49m\u001b[43mids\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 883\u001b[0m second_ids \u001b[38;5;241m=\u001b[39m get_input_ids(pair_ids) \u001b[38;5;28;01mif\u001b[39;00m pair_ids \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 884\u001b[0m input_ids\u001b[38;5;241m.\u001b[39mappend((first_ids, second_ids))\n", - "File \u001b[0;32m~/magisterka/magisterka_env/lib/python3.8/site-packages/transformers/tokenization_utils.py:849\u001b[0m, in \u001b[0;36mPreTrainedTokenizer._batch_encode_plus..get_input_ids\u001b[0;34m(text)\u001b[0m\n\u001b[1;32m 847\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mget_input_ids\u001b[39m(text):\n\u001b[1;32m 848\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(text, \u001b[38;5;28mstr\u001b[39m):\n\u001b[0;32m--> 849\u001b[0m tokens \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtokenize\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtext\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 850\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconvert_tokens_to_ids(tokens)\n\u001b[1;32m 851\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(text, (\u001b[38;5;28mlist\u001b[39m, \u001b[38;5;28mtuple\u001b[39m)) \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(text) \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m0\u001b[39m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(text[\u001b[38;5;241m0\u001b[39m], \u001b[38;5;28mstr\u001b[39m):\n", - "File \u001b[0;32m~/magisterka/magisterka_env/lib/python3.8/site-packages/transformers/tokenization_utils.py:695\u001b[0m, in \u001b[0;36mPreTrainedTokenizer.tokenize\u001b[0;34m(self, text, **kwargs)\u001b[0m\n\u001b[1;32m 693\u001b[0m tokenized_text\u001b[38;5;241m.\u001b[39mappend(token)\n\u001b[1;32m 694\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 695\u001b[0m tokenized_text\u001b[38;5;241m.\u001b[39mextend(\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_tokenize\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtoken\u001b[49m\u001b[43m)\u001b[49m)\n\u001b[1;32m 696\u001b[0m \u001b[38;5;66;03m# [\"This\", \" is\", \" something\", \"\", \"else\"]\u001b[39;00m\n\u001b[1;32m 697\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m tokenized_text\n", - "File \u001b[0;32m~/magisterka/magisterka_env/lib/python3.8/site-packages/transformers/models/roberta/tokenization_roberta.py:270\u001b[0m, in \u001b[0;36mRobertaTokenizer._tokenize\u001b[0;34m(self, text)\u001b[0m\n\u001b[1;32m 268\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"Tokenize a string.\"\"\"\u001b[39;00m\n\u001b[1;32m 269\u001b[0m bpe_tokens \u001b[38;5;241m=\u001b[39m []\n\u001b[0;32m--> 270\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m token \u001b[38;5;129;01min\u001b[39;00m \u001b[43mre\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfindall\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpat\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtext\u001b[49m\u001b[43m)\u001b[49m:\n\u001b[1;32m 271\u001b[0m token \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;241m.\u001b[39mjoin(\n\u001b[1;32m 272\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbyte_encoder[b] \u001b[38;5;28;01mfor\u001b[39;00m b \u001b[38;5;129;01min\u001b[39;00m token\u001b[38;5;241m.\u001b[39mencode(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mutf-8\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 273\u001b[0m ) \u001b[38;5;66;03m# Maps all our bytes to unicode strings, avoiding control tokens of the BPE (spaces in our case)\u001b[39;00m\n\u001b[1;32m 274\u001b[0m bpe_tokens\u001b[38;5;241m.\u001b[39mextend(bpe_token \u001b[38;5;28;01mfor\u001b[39;00m bpe_token \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbpe(token)\u001b[38;5;241m.\u001b[39msplit(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m \u001b[39m\u001b[38;5;124m\"\u001b[39m))\n", - "File \u001b[0;32m~/magisterka/magisterka_env/lib/python3.8/site-packages/regex/regex.py:338\u001b[0m, in \u001b[0;36mfindall\u001b[0;34m(pattern, string, flags, pos, endpos, overlapped, concurrent, timeout, ignore_unused, **kwargs)\u001b[0m\n\u001b[1;32m 333\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"Return a list of all matches in the string. The matches may be overlapped\u001b[39;00m\n\u001b[1;32m 334\u001b[0m \u001b[38;5;124;03mif overlapped is True. If one or more groups are present in the pattern,\u001b[39;00m\n\u001b[1;32m 335\u001b[0m \u001b[38;5;124;03mreturn a list of groups; this will be a list of tuples if the pattern has\u001b[39;00m\n\u001b[1;32m 336\u001b[0m \u001b[38;5;124;03mmore than one group. Empty matches are included in the result.\"\"\"\u001b[39;00m\n\u001b[1;32m 337\u001b[0m pat \u001b[38;5;241m=\u001b[39m _compile(pattern, flags, ignore_unused, kwargs, \u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[0;32m--> 338\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mpat\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfindall\u001b[49m\u001b[43m(\u001b[49m\u001b[43mstring\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mpos\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mendpos\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43moverlapped\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mconcurrent\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtimeout\u001b[49m\u001b[43m)\u001b[49m\n", - "\u001b[0;31mKeyboardInterrupt\u001b[0m: " - ] - } - ], - "source": [ - "def tokenize_function(examples):\n", - " return tokenizer(examples['code'], truncation=True, padding='max_length', max_length=512, return_tensors='pt')\n", - "\n", - "train_data = train_data.map(tokenize_function, batched=True, remove_columns=['code'], desc='Running tokenizer')\n", - "valid_data = valid_data.map(tokenize_function, batched=True, remove_columns=['code'], desc='Running tokenizer')\n", - "test_data = test_data.map(tokenize_function, batched=True, remove_columns=['code'], desc='Running tokenizer')" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "tokenized_datasets" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=True, mlm_probability=0.15)\n", - "data_collator" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", - "batch_size = 1\n", - "train_dataloader = DataLoader(tokenized_datasets['train'], batch_size=batch_size, shuffle=True, collate_fn=data_collator, generator=torch.Generator(device=device))\n", - "valid_dataloader = DataLoader(tokenized_datasets['valid'], batch_size=batch_size, shuffle=False, collate_fn=data_collator, generator=torch.Generator(device=device))\n", - "test_dataloader = DataLoader(tokenized_datasets['test'], batch_size=batch_size, shuffle=False, collate_fn=data_collator, generator=torch.Generator(device=device))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.8.19" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/code/src/train_codebert_mlm.py b/code/src/train_codebert_mlm.py index 9495ecd..ae84076 100644 --- a/code/src/train_codebert_mlm.py +++ b/code/src/train_codebert_mlm.py @@ -1,254 +1,245 @@ import wandb -import torch -from torch.optim import AdamW -from torch.utils.data import DataLoader import os +import json import random import datetime +import logging +from pathlib import Path +from typing import Dict, Any, Tuple, List + import numpy as np -from datasets import load_dataset, load_from_disk, disable_caching, DatasetDict -from tree_sitter import Language, Parser -from transformers import RobertaForMaskedLM, RobertaConfig, RobertaTokenizer, DataCollatorForLanguageModeling +import torch +from torch import Tensor +from torch.optim import AdamW +from torch.utils.data import DataLoader, Dataset +from datasets import load_dataset, disable_caching, DatasetDict +from huggingface_hub import list_repo_files, hf_hub_download +from transformers import ( + RobertaForMaskedLM, + RobertaConfig, + RobertaTokenizer, + DataCollatorForLanguageModeling, + get_linear_schedule_with_warmup, + PreTrainedTokenizer, + PreTrainedModel +) from tqdm import tqdm -from utils import remove_docstrings_and_comments_from_code +# Configure logging +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') +logger = logging.getLogger(__name__) -# Disable caching for datasets -disable_caching() +class OnTheFlyTokenizationDataset(Dataset): + def __init__(self, dataset: Dataset, tokenizer: PreTrainedTokenizer, max_length: int): + self.dataset = dataset + self.tokenizer = tokenizer + self.max_length = max_length -############################### CONFIG ############################### -dataset_name = 'the-stack-tokenized' # 'the-stack' or 'code-search-net' or 'the-stack-tokenized -remove_comments = False -###################################################################### + def __len__(self) -> int: + return len(self.dataset) -# Initialize Weights & Biases and output directory -curr_time = datetime.datetime.now().strftime('%Y-%m-%d %H:%M') -wandb.init(project='codebert-training', name=curr_time) -output_dir = f'/home/s452638/magisterka/output/{curr_time}/' + def __getitem__(self, idx: int) -> Dict[str, Tensor]: + content: str = self.dataset[idx]['content'] + tokenized = self.tokenizer( + content, + truncation=True, + padding='max_length', + max_length=self.max_length, + return_tensors='pt' + ) + return { + 'input_ids': tokenized['input_ids'].squeeze(0), + 'attention_mask': tokenized['attention_mask'].squeeze(0), + 'labels': tokenized['input_ids'].squeeze(0) + } -# Save this file to Weights & Biases -wandb.save('train_codebert_mlm.py') - -# Create the output directory if it does not exist -if not os.path.exists(output_dir): - os.makedirs(output_dir) - -# Set the seed for reproducibility -SEED = 42 -def set_seed(seed): +def set_seed(seed: int) -> None: random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) -set_seed(SEED) +def setup_wandb(config: Dict[str, Any]) -> None: + curr_time: str = datetime.datetime.now().strftime('%Y-%m-%d %H:%M') + wandb.init(project='codebert-training', name=curr_time, config=config) + wandb.save('train_codebert_mlm.py') -# Set the device for PyTorch (use GPU if available, otherwise CPU) -device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') -torch.set_default_device(device) -print('*' * 10, 'Device', '*' * 10) -print(f'Using device: {device}') -if device.type == 'cuda': - print(f'Device name: {torch.cuda.get_device_name()}') +def setup_directories(current_dir: Path) -> Path: + curr_time: str = datetime.datetime.now().strftime('%Y-%m-%d %H:%M') + output_dir: Path = current_dir.parent.parent / 'outputs' / curr_time + output_dir.mkdir(parents=True, exist_ok=True) + return output_dir -# Load the dataset -if dataset_name == 'the-stack-tokenized': - train_data = load_from_disk('/work/s452638/datasets/the-stack-python-tokenized/train') - valid_data = load_from_disk('/work/s452638/datasets/the-stack-python-tokenized/valid') - test_data = load_from_disk('/work/s452638/datasets/the-stack-python-tokenized/test') -else: - if dataset_name == 'the-stack': - train_data = load_dataset("/work/s452638/datasets/the-stack-python", split="train") - train_data = train_data.rename_column_('content', 'code') - elif dataset_name == 'code-search-net': - train_data = load_dataset('json', data_files='/work/s452638/datasets/CodeSearchNet/python/train.jsonl')['train'] +def load_config(config_file: Path) -> Dict[str, Any]: + with open(config_file, 'r') as f: + return json.load(f) - valid_data = load_dataset('json', data_files='/work/s452638/datasets/CodeSearchNet/python/valid.jsonl')['valid'] - test_data = load_dataset('json', data_files='/work/s452638/datasets/CodeSearchNet/python/test.jsonl')['test'] +def setup_device() -> torch.device: + device: torch.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + torch.set_default_device(device) + logger.info(f'Using device: {device}') + if device.type == 'cuda': + logger.info(f'Device name: {torch.cuda.get_device_name()}') + torch.set_float32_matmul_precision('high') + return device -dataset = DatasetDict({'train': train_data, 'valid': valid_data, 'test': test_data}) -print('\n\n', '*' * 10, 'Dataset', '*' * 10) -print(dataset) +def download_dataset(dataset_dir: Path) -> None: + if not dataset_dir.exists(): + logger.info("Downloading the dataset...") + dataset_dir.mkdir(parents=True, exist_ok=True) + files_list: List[str] = list_repo_files(repo_id='bigcode/the-stack-dedup', repo_type='dataset') + files_to_download: List[str] = [file for file in files_list if file.startswith('data/python/')] + for file_name in files_to_download: + hf_hub_download(repo_id='bigcode/the-stack-dedup', repo_type='dataset', filename=file_name, local_dir=dataset_dir) + logger.info("Dataset downloaded successfully.") -if remove_comments: - # Build the language library if not already built - Language.build_library('/home/s452638/magisterka/build/my-languages.so', ['/home/s452638/magisterka/vendor/tree-sitter-python']) +def load_and_prepare_dataset(dataset_dir: Path, seed: int) -> DatasetDict: + dataset: DatasetDict = load_dataset(str(dataset_dir), split='train') + dataset = dataset.train_test_split(test_size=0.01, seed=seed) + logger.info(f'Dataset loaded: {dataset}') + return dataset - # Load the language - PYTHON_LANGUAGE = Language('/home/s452638/magisterka/build/my-languages.so', 'python') +def create_dataloaders( + dataset: DatasetDict, + tokenizer: PreTrainedTokenizer, + config: Dict[str, Any], + device: torch.device +) -> Tuple[DataLoader, DataLoader]: + dataset['train'] = OnTheFlyTokenizationDataset(dataset['train'], tokenizer, max_length=512) + dataset['test'] = OnTheFlyTokenizationDataset(dataset['test'], tokenizer, max_length=512) + + data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=True, mlm_probability=config['mlm_probability']) + + train_dataloader = DataLoader( + dataset['train'], + batch_size=config['batch'], + shuffle=False, + collate_fn=data_collator, + generator=torch.Generator(device=device) + ) + valid_dataloader = DataLoader( + dataset['test'], + batch_size=config['batch'], + shuffle=False, + collate_fn=data_collator, + generator=torch.Generator(device=device) + ) + return train_dataloader, valid_dataloader - # Initialize the parser - parser = Parser() - parser.set_language(PYTHON_LANGUAGE) +def setup_model_and_optimizer( + config: Dict[str, Any], + current_dir: Path +) -> Tuple[PreTrainedModel, AdamW]: + os.environ['HF_HOME'] = str(current_dir.parent / 'models') + model_config = RobertaConfig.from_pretrained('roberta-base') + model: PreTrainedModel = RobertaForMaskedLM(model_config) + model = torch.compile(model) + wandb.watch(model) + logger.info(f'Model config: {model_config}') + wandb.config.update({'model_config': model_config.to_dict()}) + + optimizer: AdamW = AdamW(model.parameters(), lr=config['learning_rate'], weight_decay=config['weight_decay']) + return model, optimizer - # Remove docstrings and comments from the code - dataset = dataset.map(lambda x: {'code': remove_docstrings_and_comments_from_code(x['code'], parser)}, batched=False, desc='Removing docstrings and comments') - -# Load the tokenizer -tokenizer = RobertaTokenizer.from_pretrained('microsoft/codebert-base', clean_up_tokenization_spaces=True) -print('\n\n', '*' * 10, 'Tokenizer', '*' * 10) -print(tokenizer) - -if dataset_name != 'the-stack-tokenized': - # Tokenize the dataset - def tokenize_function(examples): - return tokenizer(examples['code'], truncation=True, padding='max_length', max_length=512, return_tensors='pt') - - tokenized_datasets = dataset.map(tokenize_function, batched=True, remove_columns=['code'], desc='Running tokenizer') - print('\n\n', '*' * 10, 'Tokenized dataset', '*' * 10) - print(tokenized_datasets) -else: - tokenized_datasets = dataset - -# Set data collator for MLM -data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=True, mlm_probability=0.15) - -# Create DataLoaders -batch_size = 64 -train_dataloader = DataLoader(tokenized_datasets['train'], batch_size=batch_size, shuffle=False, collate_fn=data_collator, generator=torch.Generator(device=device)) -valid_dataloader = DataLoader(tokenized_datasets['valid'], batch_size=batch_size, shuffle=False, collate_fn=data_collator, generator=torch.Generator(device=device)) -test_dataloader = DataLoader(tokenized_datasets['test'], batch_size=batch_size, shuffle=False, collate_fn=data_collator, generator=torch.Generator(device=device)) - -# Initialize a model with random weights based on the configuration for RoBERTa (CodeBERT is based on RoBERTa) -config = RobertaConfig.from_pretrained('roberta-base') -model = RobertaForMaskedLM(config) -model = torch.compile(model) -wandb.watch(model) -print('\n\n', '*' * 10, 'Model', '*' * 10) -print(config) - -# Log the model configuration to wandb -wandb.config.update({'model_config': config.to_dict()}) - -# Set the optimizer and scaler -learning_rate = 5e-4 -optimizer = AdamW(model.parameters(), lr=learning_rate, weight_decay=0.01) -scaler = torch.amp.GradScaler() - -# Training settings -num_epochs = 1 -num_training_steps = num_epochs * len(train_dataloader) -eval_every = 10_000 - -# Log training settings to wandb -wandb.config.update({ - 'training_settings': { - 'num_epochs': num_epochs, - 'num_training_steps': num_training_steps, - 'eval_every': eval_every, - 'batch_size': batch_size, - 'learning_rate': learning_rate, - } -}) - -# Initialize variables to track validation loss, accuracy, and best model path -valid_acc = 0.0 -valid_loss = 0.0 -best_valid_loss = float('inf') - -# Train the model -print('\n\n', '*' * 10, 'Training', '*' * 10) -model.train() -with tqdm(total=num_training_steps, desc='Training') as pbar: - for epoch_idx in range(num_epochs): - for train_idx, train_batch in enumerate(train_dataloader): - - # Forward pass with mixed precision - with torch.autocast(device_type=device.type, dtype=torch.bfloat16): +def train_and_evaluate( + model: PreTrainedModel, + train_dataloader: DataLoader, + valid_dataloader: DataLoader, + optimizer: AdamW, + scheduler: Any, + config: Dict[str, Any], + output_dir: Path +) -> None: + num_training_steps: int = config['epochs'] * len(train_dataloader) + best_valid_loss: float = float('inf') + + with tqdm(total=num_training_steps, desc='Training') as pbar: + for epoch_idx in range(config['epochs']): + model.train() + for train_idx, train_batch in enumerate(train_dataloader): outputs = model(**train_batch) - - train_loss = outputs.loss - scaler.scale(train_loss).backward() - - # Gradient clipping to prevent exploding gradients - norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) - - scaler.step(optimizer) - scaler.update() - optimizer.zero_grad() - - pbar.update(1) - pbar.set_postfix({'norm': norm.item(), 'train_loss': train_loss.item(), 'valid_loss': valid_loss, 'valid_acc': valid_acc}) - - # Log metrics to Weights & Biases - wandb.log({ - 'step': train_idx + len(train_dataloader) * epoch_idx, - 'train_loss': train_loss.item(), - 'gradient_norm': norm.item(), - 'learning_rate': optimizer.param_groups[0]['lr'], - }) - - # Evaluate the model - if train_idx != 0 and train_idx % eval_every == 0: - model.eval() - valid_loss = 0.0 - valid_acc = 0.0 - - with tqdm(total=len(valid_dataloader), desc='Validation') as pbar_valid: - with torch.no_grad(): - for valid_idx, valid_batch in enumerate(valid_dataloader): - - # Forward pass with mixed precision for validation - with torch.autocast(device_type=device.type, dtype=torch.bfloat16): - outputs = model(**valid_batch) - - # Accumulate validation loss and accuracy - valid_loss += outputs.loss.item() - valid_acc += outputs.logits.argmax(dim=-1).eq(valid_batch['labels']).sum().item() - pbar_valid.update(1) - - # Compute average validation loss and accuracy - valid_loss /= len(valid_dataloader) - valid_acc /= len(valid_dataloader.dataset) - model.train() - - # Log validation metrics to Weights & Biases + train_loss: Tensor = outputs.loss + train_loss.backward() + + norm: Tensor = torch.nn.utils.clip_grad_norm_(model.parameters(), config['max_grad_norm']) + + optimizer.step() + scheduler.step() + optimizer.zero_grad() + + pbar.update(1) + pbar.set_postfix({'train_loss': train_loss.item()}) + wandb.log({ - 'valid_loss': valid_loss, - 'valid_acc': valid_acc, 'step': train_idx + len(train_dataloader) * epoch_idx, + 'train_loss': train_loss.item(), + 'gradient_norm': norm.item(), + 'learning_rate': scheduler.get_last_lr()[0], }) + + if train_idx != 0 and train_idx % config['eval_every'] == 0: + valid_loss, valid_acc = evaluate(model, valid_dataloader) + pbar.set_postfix({'train_loss': train_loss.item(), 'valid_loss': valid_loss, 'valid_acc': valid_acc}) + + wandb.log({ + 'valid_loss': valid_loss, + 'valid_acc': valid_acc, + 'step': train_idx + len(train_dataloader) * epoch_idx, + }) + + if valid_loss < best_valid_loss: + best_valid_loss = valid_loss + torch.save(model.state_dict(), output_dir / 'best_model.pt') + + logger.info(f'Best validation loss: {best_valid_loss}') - # Update best model if current validation loss is lower - if valid_loss < best_valid_loss: - best_valid_loss = valid_loss - torch.save(model.state_dict(), output_dir + f'best_model.pt') - -print('\n\n', '*' * 10, 'Training results', '*' * 10) -print(f'Best validation loss: {best_valid_loss}') - -# Load the best model and evaluate on the test set -print('\n\n', '*' * 10, 'Testing', '*' * 10) -model.load_state_dict(torch.load(output_dir + f'best_model.pt', weights_only=True, map_location=device)) -model.eval() -test_loss = 0.0 -test_acc = 0.0 - -with tqdm(total=len(test_dataloader), desc='Testing') as pbar_test: +def evaluate(model: PreTrainedModel, dataloader: DataLoader) -> Tuple[float, float]: + model.eval() + total_loss: float = 0.0 + total_acc: float = 0.0 + with torch.no_grad(): - for test_idx, test_batch in enumerate(test_dataloader): + for batch in tqdm(dataloader, desc='Validation'): + outputs = model(**batch) + total_loss += outputs.loss.item() + total_acc += outputs.logits.argmax(dim=-1).eq(batch['labels']).sum().item() + + avg_loss: float = total_loss / len(dataloader) + avg_acc: float = total_acc / len(dataloader.dataset) + return avg_loss, avg_acc - # Forward pass with mixed precision for testing - with torch.autocast(device_type=device.type, dtype=torch.bfloat16): - outputs = model(**test_batch) +def main() -> None: + disable_caching() + + current_dir: Path = Path(__file__).parent + output_dir: Path = setup_directories(current_dir) + config: Dict[str, Any] = load_config(current_dir / 'config.json') - # Accumulate test loss and accuracy - test_loss += outputs.loss.item() - test_acc += outputs.logits.argmax(dim=-1).eq(test_batch['labels']).sum().item() - pbar_test.update(1) + setup_wandb(config) + set_seed(config['seed']) + device: torch.device = setup_device() + + dataset_dir: Path = current_dir.parent / 'data' / 'the-stack-python' + download_dataset(dataset_dir) + dataset: DatasetDict = load_and_prepare_dataset(dataset_dir, config['seed']) + + tokenizer: PreTrainedTokenizer = RobertaTokenizer.from_pretrained('microsoft/codebert-base', clean_up_tokenization_spaces=True) + logger.info(f'Tokenizer loaded: {tokenizer}') + + train_dataloader, valid_dataloader = create_dataloaders(dataset, tokenizer, config, device) + + model, optimizer = setup_model_and_optimizer(config, current_dir) + + num_training_steps: int = config['epochs'] * len(train_dataloader) + scheduler = get_linear_schedule_with_warmup( + optimizer, + num_warmup_steps=config['warmup_steps'], + num_training_steps=num_training_steps + ) + + train_and_evaluate(model, train_dataloader, valid_dataloader, optimizer, scheduler, config, output_dir) -# Compute average test loss and accuracy -test_loss /= len(test_dataloader) -test_acc /= len(test_dataloader.dataset) - -# Log test metrics to Weights & Biases -wandb.log({ - 'test_loss': test_loss, - 'test_acc': test_acc, -}) - -print('\n\n', '*' * 10, 'Test results', '*' * 10) -print(f'Test loss: {test_loss}') -print(f'Test accuracy: {test_acc}') +if __name__ == "__main__": + main() \ No newline at end of file