diff --git a/code/.gitignore b/code/.gitignore
index 3a8816c..d51ef0e 100644
--- a/code/.gitignore
+++ b/code/.gitignore
@@ -160,3 +160,7 @@ cython_debug/
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
+
+# Weights & Biases
+wandb/
+outputs/
diff --git a/code/data/.gitignore b/code/data/.gitignore
new file mode 100644
index 0000000..d6b7ef3
--- /dev/null
+++ b/code/data/.gitignore
@@ -0,0 +1,2 @@
+*
+!.gitignore
diff --git a/code/models/.gitignore b/code/models/.gitignore
new file mode 100644
index 0000000..d6b7ef3
--- /dev/null
+++ b/code/models/.gitignore
@@ -0,0 +1,2 @@
+*
+!.gitignore
diff --git a/code/pdm.lock b/code/pdm.lock
index d359502..5c28202 100644
--- a/code/pdm.lock
+++ b/code/pdm.lock
@@ -5,7 +5,7 @@
groups = ["default"]
strategy = ["inherit_metadata"]
lock_version = "4.5.0"
-content_hash = "sha256:ac6621f3bd9193d786ab94f80f8b1711100fe418959f2e131ae03afeab616788"
+content_hash = "sha256:bf0a0ea826769cf12a84888d394edd8c3c5599c4d369b8b19b75c2fa5e16f5f0"
[[metadata.targets]]
requires_python = "==3.11.*"
diff --git a/code/pyproject.toml b/code/pyproject.toml
index cafd29f..ac23f58 100644
--- a/code/pyproject.toml
+++ b/code/pyproject.toml
@@ -6,13 +6,13 @@ authors = [
{name = "Patryk Bartkowiak", email = "patbar15@st.amu.edu.pl"},
]
dependencies = [
- "wandb>=0.18.5",
- "torch>=2.5.0",
- "tqdm>=4.66.5",
- "tree-sitter>=0.23.1",
- "transformers>=4.45.2",
- "datasets>=3.0.1",
- "huggingface-hub>=0.26.0",
+ "wandb==0.18.5",
+ "torch==2.5.0",
+ "tqdm==4.66.5",
+ "tree-sitter==0.23.1",
+ "transformers==4.45.2",
+ "datasets==3.0.1",
+ "huggingface-hub==0.26.0",
]
requires-python = "==3.11.*"
readme = "README.md"
diff --git a/code/src/config.json b/code/src/config.json
new file mode 100644
index 0000000..7991809
--- /dev/null
+++ b/code/src/config.json
@@ -0,0 +1,11 @@
+{
+ "seed": 42,
+ "mlm_probability": 0.15,
+ "batch": 32,
+ "epochs": 1,
+ "eval_every": 10000,
+ "learning_rate": 5e-4,
+ "weight_decay": 0.01,
+ "max_grad_norm": 1.0,
+ "warmup_steps": 10000
+}
\ No newline at end of file
diff --git a/code/src/download_the_stack.ipynb b/code/src/download_the_stack.ipynb
deleted file mode 100644
index 82c518e..0000000
--- a/code/src/download_the_stack.ipynb
+++ /dev/null
@@ -1,262 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "from huggingface_hub import list_repo_files, hf_hub_download\n",
- "import os"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 6,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Found 5144 files in the repository.\n",
- "['data/python/data-00000-of-00144.parquet', 'data/python/data-00001-of-00144.parquet', 'data/python/data-00002-of-00144.parquet', 'data/python/data-00003-of-00144.parquet', 'data/python/data-00004-of-00144.parquet', 'data/python/data-00005-of-00144.parquet', 'data/python/data-00006-of-00144.parquet', 'data/python/data-00007-of-00144.parquet', 'data/python/data-00008-of-00144.parquet', 'data/python/data-00009-of-00144.parquet', 'data/python/data-00010-of-00144.parquet', 'data/python/data-00011-of-00144.parquet', 'data/python/data-00012-of-00144.parquet', 'data/python/data-00013-of-00144.parquet', 'data/python/data-00014-of-00144.parquet', 'data/python/data-00015-of-00144.parquet', 'data/python/data-00016-of-00144.parquet', 'data/python/data-00017-of-00144.parquet', 'data/python/data-00018-of-00144.parquet', 'data/python/data-00019-of-00144.parquet', 'data/python/data-00020-of-00144.parquet', 'data/python/data-00021-of-00144.parquet', 'data/python/data-00022-of-00144.parquet', 'data/python/data-00023-of-00144.parquet', 'data/python/data-00024-of-00144.parquet', 'data/python/data-00025-of-00144.parquet', 'data/python/data-00026-of-00144.parquet', 'data/python/data-00027-of-00144.parquet', 'data/python/data-00028-of-00144.parquet', 'data/python/data-00029-of-00144.parquet', 'data/python/data-00030-of-00144.parquet', 'data/python/data-00031-of-00144.parquet', 'data/python/data-00032-of-00144.parquet', 'data/python/data-00033-of-00144.parquet', 'data/python/data-00034-of-00144.parquet', 'data/python/data-00035-of-00144.parquet', 'data/python/data-00036-of-00144.parquet', 'data/python/data-00037-of-00144.parquet', 'data/python/data-00038-of-00144.parquet', 'data/python/data-00039-of-00144.parquet', 'data/python/data-00040-of-00144.parquet', 'data/python/data-00041-of-00144.parquet', 'data/python/data-00042-of-00144.parquet', 'data/python/data-00043-of-00144.parquet', 'data/python/data-00044-of-00144.parquet', 'data/python/data-00045-of-00144.parquet', 'data/python/data-00046-of-00144.parquet', 'data/python/data-00047-of-00144.parquet', 'data/python/data-00048-of-00144.parquet', 'data/python/data-00049-of-00144.parquet', 'data/python/data-00050-of-00144.parquet', 'data/python/data-00051-of-00144.parquet', 'data/python/data-00052-of-00144.parquet', 'data/python/data-00053-of-00144.parquet', 'data/python/data-00054-of-00144.parquet', 'data/python/data-00055-of-00144.parquet', 'data/python/data-00056-of-00144.parquet', 'data/python/data-00057-of-00144.parquet', 'data/python/data-00058-of-00144.parquet', 'data/python/data-00059-of-00144.parquet', 'data/python/data-00060-of-00144.parquet', 'data/python/data-00061-of-00144.parquet', 'data/python/data-00062-of-00144.parquet', 'data/python/data-00063-of-00144.parquet', 'data/python/data-00064-of-00144.parquet', 'data/python/data-00065-of-00144.parquet', 'data/python/data-00066-of-00144.parquet', 'data/python/data-00067-of-00144.parquet', 'data/python/data-00068-of-00144.parquet', 'data/python/data-00069-of-00144.parquet', 'data/python/data-00070-of-00144.parquet', 'data/python/data-00071-of-00144.parquet', 'data/python/data-00072-of-00144.parquet', 'data/python/data-00073-of-00144.parquet', 'data/python/data-00074-of-00144.parquet', 'data/python/data-00075-of-00144.parquet', 'data/python/data-00076-of-00144.parquet', 'data/python/data-00077-of-00144.parquet', 'data/python/data-00078-of-00144.parquet', 'data/python/data-00079-of-00144.parquet', 'data/python/data-00080-of-00144.parquet', 'data/python/data-00081-of-00144.parquet', 'data/python/data-00082-of-00144.parquet', 'data/python/data-00083-of-00144.parquet', 'data/python/data-00084-of-00144.parquet', 'data/python/data-00085-of-00144.parquet', 'data/python/data-00086-of-00144.parquet', 'data/python/data-00087-of-00144.parquet', 'data/python/data-00088-of-00144.parquet', 'data/python/data-00089-of-00144.parquet', 'data/python/data-00090-of-00144.parquet', 'data/python/data-00091-of-00144.parquet', 'data/python/data-00092-of-00144.parquet', 'data/python/data-00093-of-00144.parquet', 'data/python/data-00094-of-00144.parquet', 'data/python/data-00095-of-00144.parquet', 'data/python/data-00096-of-00144.parquet', 'data/python/data-00097-of-00144.parquet', 'data/python/data-00098-of-00144.parquet', 'data/python/data-00099-of-00144.parquet', 'data/python/data-00100-of-00144.parquet', 'data/python/data-00101-of-00144.parquet', 'data/python/data-00102-of-00144.parquet', 'data/python/data-00103-of-00144.parquet', 'data/python/data-00104-of-00144.parquet', 'data/python/data-00105-of-00144.parquet', 'data/python/data-00106-of-00144.parquet', 'data/python/data-00107-of-00144.parquet', 'data/python/data-00108-of-00144.parquet', 'data/python/data-00109-of-00144.parquet', 'data/python/data-00110-of-00144.parquet', 'data/python/data-00111-of-00144.parquet', 'data/python/data-00112-of-00144.parquet', 'data/python/data-00113-of-00144.parquet', 'data/python/data-00114-of-00144.parquet', 'data/python/data-00115-of-00144.parquet', 'data/python/data-00116-of-00144.parquet', 'data/python/data-00117-of-00144.parquet', 'data/python/data-00118-of-00144.parquet', 'data/python/data-00119-of-00144.parquet', 'data/python/data-00120-of-00144.parquet', 'data/python/data-00121-of-00144.parquet', 'data/python/data-00122-of-00144.parquet', 'data/python/data-00123-of-00144.parquet', 'data/python/data-00124-of-00144.parquet', 'data/python/data-00125-of-00144.parquet', 'data/python/data-00126-of-00144.parquet', 'data/python/data-00127-of-00144.parquet', 'data/python/data-00128-of-00144.parquet', 'data/python/data-00129-of-00144.parquet', 'data/python/data-00130-of-00144.parquet', 'data/python/data-00131-of-00144.parquet', 'data/python/data-00132-of-00144.parquet', 'data/python/data-00133-of-00144.parquet', 'data/python/data-00134-of-00144.parquet', 'data/python/data-00135-of-00144.parquet', 'data/python/data-00136-of-00144.parquet', 'data/python/data-00137-of-00144.parquet', 'data/python/data-00138-of-00144.parquet', 'data/python/data-00139-of-00144.parquet', 'data/python/data-00140-of-00144.parquet', 'data/python/data-00141-of-00144.parquet', 'data/python/data-00142-of-00144.parquet', 'data/python/data-00143-of-00144.parquet']\n",
- "Downloading data/python/data-00000-of-00144.parquet...\n",
- "Downloading data/python/data-00001-of-00144.parquet...\n",
- "Downloading data/python/data-00002-of-00144.parquet...\n",
- "Downloading data/python/data-00003-of-00144.parquet...\n",
- "Downloading data/python/data-00004-of-00144.parquet...\n",
- "Downloading data/python/data-00005-of-00144.parquet...\n",
- "Downloading data/python/data-00006-of-00144.parquet...\n",
- "Downloading data/python/data-00007-of-00144.parquet...\n",
- "Downloading data/python/data-00008-of-00144.parquet...\n",
- "Downloading data/python/data-00009-of-00144.parquet...\n",
- "Downloading data/python/data-00010-of-00144.parquet...\n",
- "Downloading data/python/data-00011-of-00144.parquet...\n",
- "Downloading data/python/data-00012-of-00144.parquet...\n",
- "Downloading data/python/data-00013-of-00144.parquet...\n",
- "Downloading data/python/data-00014-of-00144.parquet...\n",
- "Downloading data/python/data-00015-of-00144.parquet...\n",
- "Downloading data/python/data-00016-of-00144.parquet...\n",
- "Downloading data/python/data-00017-of-00144.parquet...\n",
- "Downloading data/python/data-00018-of-00144.parquet...\n",
- "Downloading data/python/data-00019-of-00144.parquet...\n",
- "Downloading data/python/data-00020-of-00144.parquet...\n",
- "Downloading data/python/data-00021-of-00144.parquet...\n",
- "Downloading data/python/data-00022-of-00144.parquet...\n",
- "Downloading data/python/data-00023-of-00144.parquet...\n",
- "Downloading data/python/data-00024-of-00144.parquet...\n",
- "Downloading data/python/data-00025-of-00144.parquet...\n",
- "Downloading data/python/data-00026-of-00144.parquet...\n",
- "Downloading data/python/data-00027-of-00144.parquet...\n",
- "Downloading data/python/data-00028-of-00144.parquet...\n",
- "Downloading data/python/data-00029-of-00144.parquet...\n",
- "Downloading data/python/data-00030-of-00144.parquet...\n",
- "Downloading data/python/data-00031-of-00144.parquet...\n",
- "Downloading data/python/data-00032-of-00144.parquet...\n",
- "Downloading data/python/data-00033-of-00144.parquet...\n",
- "Downloading data/python/data-00034-of-00144.parquet...\n",
- "Downloading data/python/data-00035-of-00144.parquet...\n",
- "Downloading data/python/data-00036-of-00144.parquet...\n",
- "Downloading data/python/data-00037-of-00144.parquet...\n",
- "Downloading data/python/data-00038-of-00144.parquet...\n",
- "Downloading data/python/data-00039-of-00144.parquet...\n",
- "Downloading data/python/data-00040-of-00144.parquet...\n",
- "Downloading data/python/data-00041-of-00144.parquet...\n",
- "Downloading data/python/data-00042-of-00144.parquet...\n",
- "Downloading data/python/data-00043-of-00144.parquet...\n",
- "Downloading data/python/data-00044-of-00144.parquet...\n",
- "Downloading data/python/data-00045-of-00144.parquet...\n",
- "Downloading data/python/data-00046-of-00144.parquet...\n",
- "Downloading data/python/data-00047-of-00144.parquet...\n",
- "Downloading data/python/data-00048-of-00144.parquet...\n",
- "Downloading data/python/data-00049-of-00144.parquet...\n",
- "Downloading data/python/data-00050-of-00144.parquet...\n",
- "Downloading data/python/data-00051-of-00144.parquet...\n",
- "Downloading data/python/data-00052-of-00144.parquet...\n",
- "Downloading data/python/data-00053-of-00144.parquet...\n",
- "Downloading data/python/data-00054-of-00144.parquet...\n",
- "Downloading data/python/data-00055-of-00144.parquet...\n",
- "Downloading data/python/data-00056-of-00144.parquet...\n",
- "Downloading data/python/data-00057-of-00144.parquet...\n",
- "Downloading data/python/data-00058-of-00144.parquet...\n",
- "Downloading data/python/data-00059-of-00144.parquet...\n",
- "Downloading data/python/data-00060-of-00144.parquet...\n",
- "Downloading data/python/data-00061-of-00144.parquet...\n",
- "Downloading data/python/data-00062-of-00144.parquet...\n",
- "Downloading data/python/data-00063-of-00144.parquet...\n",
- "Downloading data/python/data-00064-of-00144.parquet...\n",
- "Downloading data/python/data-00065-of-00144.parquet...\n",
- "Downloading data/python/data-00066-of-00144.parquet...\n",
- "Downloading data/python/data-00067-of-00144.parquet...\n",
- "Downloading data/python/data-00068-of-00144.parquet...\n",
- "Downloading data/python/data-00069-of-00144.parquet...\n",
- "Downloading data/python/data-00070-of-00144.parquet...\n",
- "Downloading data/python/data-00071-of-00144.parquet...\n",
- "Downloading data/python/data-00072-of-00144.parquet...\n",
- "Downloading data/python/data-00073-of-00144.parquet...\n",
- "Downloading data/python/data-00074-of-00144.parquet...\n",
- "Downloading data/python/data-00075-of-00144.parquet...\n",
- "Downloading data/python/data-00076-of-00144.parquet...\n",
- "Downloading data/python/data-00077-of-00144.parquet...\n",
- "Downloading data/python/data-00078-of-00144.parquet...\n",
- "Downloading data/python/data-00079-of-00144.parquet...\n",
- "Downloading data/python/data-00080-of-00144.parquet...\n",
- "Downloading data/python/data-00081-of-00144.parquet...\n",
- "Downloading data/python/data-00082-of-00144.parquet...\n",
- "Downloading data/python/data-00083-of-00144.parquet...\n",
- "Downloading data/python/data-00084-of-00144.parquet...\n",
- "Downloading data/python/data-00085-of-00144.parquet...\n",
- "Downloading data/python/data-00086-of-00144.parquet...\n",
- "Downloading data/python/data-00087-of-00144.parquet...\n",
- "Downloading data/python/data-00088-of-00144.parquet...\n",
- "Downloading data/python/data-00089-of-00144.parquet...\n",
- "Downloading data/python/data-00090-of-00144.parquet...\n",
- "Downloading data/python/data-00091-of-00144.parquet...\n",
- "Downloading data/python/data-00092-of-00144.parquet...\n",
- "Downloading data/python/data-00093-of-00144.parquet...\n",
- "Downloading data/python/data-00094-of-00144.parquet...\n",
- "Downloading data/python/data-00095-of-00144.parquet...\n",
- "Downloading data/python/data-00096-of-00144.parquet...\n",
- "Downloading data/python/data-00097-of-00144.parquet...\n",
- "Downloading data/python/data-00098-of-00144.parquet...\n",
- "Downloading data/python/data-00099-of-00144.parquet...\n",
- "Downloading data/python/data-00100-of-00144.parquet...\n",
- "Downloading data/python/data-00101-of-00144.parquet...\n",
- "Downloading data/python/data-00102-of-00144.parquet...\n",
- "Downloading data/python/data-00103-of-00144.parquet...\n",
- "Downloading data/python/data-00104-of-00144.parquet...\n",
- "Downloading data/python/data-00105-of-00144.parquet...\n",
- "Downloading data/python/data-00106-of-00144.parquet...\n",
- "Downloading data/python/data-00107-of-00144.parquet...\n",
- "Downloading data/python/data-00108-of-00144.parquet...\n",
- "Downloading data/python/data-00109-of-00144.parquet...\n",
- "Downloading data/python/data-00110-of-00144.parquet...\n",
- "Downloading data/python/data-00111-of-00144.parquet...\n",
- "Downloading data/python/data-00112-of-00144.parquet...\n",
- "Downloading data/python/data-00113-of-00144.parquet...\n",
- "Downloading data/python/data-00114-of-00144.parquet...\n",
- "Downloading data/python/data-00115-of-00144.parquet...\n",
- "Downloading data/python/data-00116-of-00144.parquet...\n",
- "Downloading data/python/data-00117-of-00144.parquet...\n",
- "Downloading data/python/data-00118-of-00144.parquet...\n",
- "Downloading data/python/data-00119-of-00144.parquet...\n",
- "Downloading data/python/data-00120-of-00144.parquet...\n",
- "Downloading data/python/data-00121-of-00144.parquet...\n",
- "Downloading data/python/data-00122-of-00144.parquet...\n",
- "Downloading data/python/data-00123-of-00144.parquet...\n",
- "Downloading data/python/data-00124-of-00144.parquet...\n",
- "Downloading data/python/data-00125-of-00144.parquet...\n",
- "Downloading data/python/data-00126-of-00144.parquet...\n",
- "Downloading data/python/data-00127-of-00144.parquet...\n",
- "Downloading data/python/data-00128-of-00144.parquet...\n",
- "Downloading data/python/data-00129-of-00144.parquet...\n",
- "Downloading data/python/data-00130-of-00144.parquet...\n",
- "Downloading data/python/data-00131-of-00144.parquet...\n",
- "Downloading data/python/data-00132-of-00144.parquet...\n",
- "Downloading data/python/data-00133-of-00144.parquet...\n",
- "Downloading data/python/data-00134-of-00144.parquet...\n",
- "Downloading data/python/data-00135-of-00144.parquet...\n",
- "Downloading data/python/data-00136-of-00144.parquet...\n",
- "Downloading data/python/data-00137-of-00144.parquet...\n",
- "Downloading data/python/data-00138-of-00144.parquet...\n",
- "Downloading data/python/data-00139-of-00144.parquet...\n",
- "Downloading data/python/data-00140-of-00144.parquet...\n",
- "Downloading data/python/data-00141-of-00144.parquet...\n",
- "Downloading data/python/data-00142-of-00144.parquet...\n",
- "Downloading data/python/data-00143-of-00144.parquet...\n",
- "All files have been downloaded successfully.\n"
- ]
- }
- ],
- "source": [
- "# Define the repository details\n",
- "repo_id = \"bigcode/the-stack-dedup\" # Repository ID for the dataset\n",
- "subfolder = \"data/python\" # Subfolder path within the repository\n",
- "repo_type = \"dataset\" # Specify that it's a dataset\n",
- "\n",
- "# Specify the local directory to save the files\n",
- "local_dir = \"/work/s452638/datasets/the-stack-python\"\n",
- "os.makedirs(local_dir, exist_ok=True)\n",
- "\n",
- "# List all files in the repository's subfolder\n",
- "files_list = list_repo_files(repo_id=repo_id, repo_type=repo_type)\n",
- "\n",
- "# Filter files in the desired subfolder\n",
- "files_to_download = [file for file in files_list if file.startswith(f'{subfolder}/')]\n",
- "\n",
- "print(f\"Found {len(files_to_download)} files in the repository.\")\n",
- "\n",
- "# Download each file\n",
- "for file_name in files_to_download:\n",
- " print(f\"Downloading {file_name}...\")\n",
- " hf_hub_download(repo_id=repo_id, repo_type=repo_type, filename=file_name, local_dir=local_dir)\n",
- "\n",
- "print(\"All files have been downloaded successfully.\")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 10,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Found 6 files in the repository.\n",
- "Downloading data/Python/train-00004-of-00006.parquet...\n",
- "Downloading data/Python/train-00005-of-00006.parquet...\n",
- "All files have been downloaded successfully.\n"
- ]
- }
- ],
- "source": [
- "# Define the repository details\n",
- "repo_id = \"bigcode/the-stack-v2-dedup\" # Repository ID for the dataset\n",
- "subfolder = \"data/Python\" # Subfolder path within the repository\n",
- "repo_type = \"dataset\" # Specify that it's a dataset\n",
- "\n",
- "# Specify the local directory to save the files\n",
- "local_dir = \"/work/s452638/datasets/the-stack-v2-python\"\n",
- "os.makedirs(local_dir, exist_ok=True)\n",
- "\n",
- "# List all files in the repository's subfolder\n",
- "files_list = list_repo_files(repo_id=repo_id, repo_type=repo_type)\n",
- "\n",
- "# Filter files in the desired subfolder\n",
- "files_to_download = [file for file in files_list if file.startswith(f'{subfolder}/')]\n",
- "\n",
- "print(f\"Found {len(files_to_download)} files in the repository.\")\n",
- "\n",
- "# Download each file\n",
- "for file_name in files_to_download[4:]:\n",
- " print(f\"Downloading {file_name}...\")\n",
- " hf_hub_download(repo_id=repo_id, repo_type=repo_type, filename=file_name, local_dir=local_dir)\n",
- "\n",
- "print(\"All files have been downloaded successfully.\")"
- ]
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python 3",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.8.19"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 2
-}
diff --git a/code/src/preprocess_data.py b/code/src/preprocess_data.py
deleted file mode 100644
index 19d77d3..0000000
--- a/code/src/preprocess_data.py
+++ /dev/null
@@ -1,58 +0,0 @@
-from datasets import load_dataset, disable_caching
-from transformers import RobertaTokenizer
-
-disable_caching()
-
-
-def visible_print(text):
- print('\n\n')
- print('=' * 100)
- print(text)
- print('=' * 100)
- print('\n\n')
-
-
-if __name__ == '__main__':
- # Load the dataset
- train_data = load_dataset('/work/s452638/datasets/the-stack-python', split='train')
- valid_data = load_dataset('json', data_files='/work/s452638/datasets/CodeSearchNet/python/valid.jsonl')['train']
- test_data = load_dataset('json', data_files='/work/s452638/datasets/CodeSearchNet/python/test.jsonl')['train']
-
- visible_print('Loaded data')
-
- # Rename the columns
- train_data = train_data.rename_column('content', 'code')
-
- # Remove all the columns except the code
- train_columns = train_data.column_names
- valid_columns = valid_data.column_names
- test_columns = test_data.column_names
-
- train_columns.remove('code')
- valid_columns.remove('code')
- test_columns.remove('code')
-
- train_data = train_data.remove_columns(train_columns)
- valid_data = valid_data.remove_columns(valid_columns)
- test_data = test_data.remove_columns(test_columns)
-
- visible_print('Removed unnecessary columns')
-
- # Tokenize the data
- tokenizer = RobertaTokenizer.from_pretrained('microsoft/codebert-base', clean_up_tokenization_spaces=True)
-
- def tokenize_function(examples):
- return tokenizer(examples['code'], truncation=True, padding='max_length', max_length=512, return_tensors='pt')
-
- train_data = train_data.map(tokenize_function, batched=True, remove_columns=['code'], desc='[Train] Running tokenizer', num_proc=8)
- valid_data = valid_data.map(tokenize_function, batched=True, remove_columns=['code'], desc='[Valid] Running tokenizer', num_proc=8)
- test_data = test_data.map(tokenize_function, batched=True, remove_columns=['code'], desc='[Test] Running tokenizer', num_proc=8)
-
- visible_print('Tokenized data')
-
- # Save the tokenized data
- train_data.save_to_disk('/work/s452638/datasets/the-stack-python-tokenized/train')
- valid_data.save_to_disk('/work/s452638/datasets/the-stack-python-tokenized/valid')
- test_data.save_to_disk('/work/s452638/datasets/the-stack-python-tokenized/test')
-
- visible_print('Saved tokenized data')
\ No newline at end of file
diff --git a/code/src/the_stack_test.ipynb b/code/src/the_stack_test.ipynb
deleted file mode 100644
index 08a80be..0000000
--- a/code/src/the_stack_test.ipynb
+++ /dev/null
@@ -1,246 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "code",
- "execution_count": 1,
- "metadata": {},
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "/home/s452638/magisterka/magisterka_env/lib/python3.8/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
- " from .autonotebook import tqdm as notebook_tqdm\n"
- ]
- }
- ],
- "source": [
- "import torch\n",
- "from torch.utils.data import DataLoader\n",
- "from datasets import load_dataset, disable_caching, DatasetDict\n",
- "from transformers import RobertaForMaskedLM, RobertaConfig, RobertaTokenizer, DataCollatorForLanguageModeling\n",
- "\n",
- "disable_caching()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 2,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "Dataset({\n",
- " features: ['hexsha', 'size', 'ext', 'lang', 'max_stars_repo_path', 'max_stars_repo_name', 'max_stars_repo_head_hexsha', 'max_stars_repo_licenses', 'max_stars_count', 'max_stars_repo_stars_event_min_datetime', 'max_stars_repo_stars_event_max_datetime', 'max_issues_repo_path', 'max_issues_repo_name', 'max_issues_repo_head_hexsha', 'max_issues_repo_licenses', 'max_issues_count', 'max_issues_repo_issues_event_min_datetime', 'max_issues_repo_issues_event_max_datetime', 'max_forks_repo_path', 'max_forks_repo_name', 'max_forks_repo_head_hexsha', 'max_forks_repo_licenses', 'max_forks_count', 'max_forks_repo_forks_event_min_datetime', 'max_forks_repo_forks_event_max_datetime', 'content', 'avg_line_length', 'max_line_length', 'alphanum_fraction'],\n",
- " num_rows: 12962249\n",
- "})"
- ]
- },
- "execution_count": 2,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "train_data = load_dataset(\"/work/s452638/datasets/the-stack-python\", split=\"train\")\n",
- "train_data"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 3,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "DatasetDict({\n",
- " train: Dataset({\n",
- " features: ['repo', 'path', 'func_name', 'original_string', 'language', 'code', 'code_tokens', 'docstring', 'docstring_tokens', 'sha', 'url', 'partition'],\n",
- " num_rows: 13914\n",
- " })\n",
- "})"
- ]
- },
- "execution_count": 3,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "valid_data = load_dataset('json', data_files='/work/s452638/datasets/CodeSearchNet/python/valid.jsonl')\n",
- "valid_data"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 4,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "DatasetDict({\n",
- " train: Dataset({\n",
- " features: ['repo', 'path', 'func_name', 'original_string', 'language', 'code', 'code_tokens', 'docstring', 'docstring_tokens', 'sha', 'url', 'partition'],\n",
- " num_rows: 14918\n",
- " })\n",
- "})"
- ]
- },
- "execution_count": 4,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "test_data = load_dataset('json', data_files='/work/s452638/datasets/CodeSearchNet/python/test.jsonl')\n",
- "test_data"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 5,
- "metadata": {},
- "outputs": [],
- "source": [
- "train_data = train_data.rename_column('content', 'code')"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 6,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "RobertaTokenizer(name_or_path='microsoft/codebert-base', vocab_size=50265, model_max_length=512, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'bos_token': '', 'eos_token': '', 'unk_token': '', 'sep_token': '', 'pad_token': '', 'cls_token': '', 'mask_token': ''}, clean_up_tokenization_spaces=True), added_tokens_decoder={\n",
- "\t0: AddedToken(\"\", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),\n",
- "\t1: AddedToken(\"\", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),\n",
- "\t2: AddedToken(\"\", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),\n",
- "\t3: AddedToken(\"\", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),\n",
- "\t50264: AddedToken(\"\", rstrip=False, lstrip=True, single_word=False, normalized=False, special=True),\n",
- "}"
- ]
- },
- "execution_count": 6,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "tokenizer = RobertaTokenizer.from_pretrained('microsoft/codebert-base', clean_up_tokenization_spaces=True)\n",
- "tokenizer"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 7,
- "metadata": {},
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "Running tokenizer: 0%| | 1000/12962249 [00:17<61:47:17, 58.27 examples/s]\n"
- ]
- },
- {
- "ename": "KeyboardInterrupt",
- "evalue": "",
- "output_type": "error",
- "traceback": [
- "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
- "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
- "Cell \u001b[0;32mIn[7], line 4\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mtokenize_function\u001b[39m(examples):\n\u001b[1;32m 2\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m tokenizer(examples[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mcode\u001b[39m\u001b[38;5;124m'\u001b[39m], truncation\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m, padding\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mmax_length\u001b[39m\u001b[38;5;124m'\u001b[39m, max_length\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m512\u001b[39m, return_tensors\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mpt\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[0;32m----> 4\u001b[0m train_data \u001b[38;5;241m=\u001b[39m \u001b[43mtrain_data\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mmap\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtokenize_function\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbatched\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mremove_columns\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mcode\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdesc\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mRunning tokenizer\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m 5\u001b[0m valid_data \u001b[38;5;241m=\u001b[39m valid_data\u001b[38;5;241m.\u001b[39mmap(tokenize_function, batched\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m, remove_columns\u001b[38;5;241m=\u001b[39m[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mcode\u001b[39m\u001b[38;5;124m'\u001b[39m], desc\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mRunning tokenizer\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[1;32m 6\u001b[0m test_data \u001b[38;5;241m=\u001b[39m test_data\u001b[38;5;241m.\u001b[39mmap(tokenize_function, batched\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m, remove_columns\u001b[38;5;241m=\u001b[39m[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mcode\u001b[39m\u001b[38;5;124m'\u001b[39m], desc\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mRunning tokenizer\u001b[39m\u001b[38;5;124m'\u001b[39m)\n",
- "File \u001b[0;32m~/magisterka/magisterka_env/lib/python3.8/site-packages/datasets/arrow_dataset.py:602\u001b[0m, in \u001b[0;36mtransmit_tasks..wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 600\u001b[0m \u001b[38;5;28mself\u001b[39m: \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mDataset\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;241m=\u001b[39m kwargs\u001b[38;5;241m.\u001b[39mpop(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mself\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 601\u001b[0m \u001b[38;5;66;03m# apply actual function\u001b[39;00m\n\u001b[0;32m--> 602\u001b[0m out: Union[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mDataset\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mDatasetDict\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 603\u001b[0m datasets: List[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mDataset\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mlist\u001b[39m(out\u001b[38;5;241m.\u001b[39mvalues()) \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(out, \u001b[38;5;28mdict\u001b[39m) \u001b[38;5;28;01melse\u001b[39;00m [out]\n\u001b[1;32m 604\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m dataset \u001b[38;5;129;01min\u001b[39;00m datasets:\n\u001b[1;32m 605\u001b[0m \u001b[38;5;66;03m# Remove task templates if a column mapping of the template is no longer valid\u001b[39;00m\n",
- "File \u001b[0;32m~/magisterka/magisterka_env/lib/python3.8/site-packages/datasets/arrow_dataset.py:567\u001b[0m, in \u001b[0;36mtransmit_format..wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 560\u001b[0m self_format \u001b[38;5;241m=\u001b[39m {\n\u001b[1;32m 561\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtype\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_format_type,\n\u001b[1;32m 562\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mformat_kwargs\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_format_kwargs,\n\u001b[1;32m 563\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcolumns\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_format_columns,\n\u001b[1;32m 564\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124moutput_all_columns\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_output_all_columns,\n\u001b[1;32m 565\u001b[0m }\n\u001b[1;32m 566\u001b[0m \u001b[38;5;66;03m# apply actual function\u001b[39;00m\n\u001b[0;32m--> 567\u001b[0m out: Union[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mDataset\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mDatasetDict\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 568\u001b[0m datasets: List[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mDataset\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mlist\u001b[39m(out\u001b[38;5;241m.\u001b[39mvalues()) \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(out, \u001b[38;5;28mdict\u001b[39m) \u001b[38;5;28;01melse\u001b[39;00m [out]\n\u001b[1;32m 569\u001b[0m \u001b[38;5;66;03m# re-apply format to the output\u001b[39;00m\n",
- "File \u001b[0;32m~/magisterka/magisterka_env/lib/python3.8/site-packages/datasets/arrow_dataset.py:3161\u001b[0m, in \u001b[0;36mDataset.map\u001b[0;34m(self, function, with_indices, with_rank, input_columns, batched, batch_size, drop_last_batch, remove_columns, keep_in_memory, load_from_cache_file, cache_file_name, writer_batch_size, features, disable_nullable, fn_kwargs, num_proc, suffix_template, new_fingerprint, desc)\u001b[0m\n\u001b[1;32m 3155\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m transformed_dataset \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 3156\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m hf_tqdm(\n\u001b[1;32m 3157\u001b[0m unit\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m examples\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 3158\u001b[0m total\u001b[38;5;241m=\u001b[39mpbar_total,\n\u001b[1;32m 3159\u001b[0m desc\u001b[38;5;241m=\u001b[39mdesc \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mMap\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 3160\u001b[0m ) \u001b[38;5;28;01mas\u001b[39;00m pbar:\n\u001b[0;32m-> 3161\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m rank, done, content \u001b[38;5;129;01min\u001b[39;00m Dataset\u001b[38;5;241m.\u001b[39m_map_single(\u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mdataset_kwargs):\n\u001b[1;32m 3162\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m done:\n\u001b[1;32m 3163\u001b[0m shards_done \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n",
- "File \u001b[0;32m~/magisterka/magisterka_env/lib/python3.8/site-packages/datasets/arrow_dataset.py:3552\u001b[0m, in \u001b[0;36mDataset._map_single\u001b[0;34m(shard, function, with_indices, with_rank, input_columns, batched, batch_size, drop_last_batch, remove_columns, keep_in_memory, cache_file_name, writer_batch_size, features, disable_nullable, fn_kwargs, new_fingerprint, rank, offset)\u001b[0m\n\u001b[1;32m 3548\u001b[0m indices \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mlist\u001b[39m(\n\u001b[1;32m 3549\u001b[0m \u001b[38;5;28mrange\u001b[39m(\u001b[38;5;241m*\u001b[39m(\u001b[38;5;28mslice\u001b[39m(i, i \u001b[38;5;241m+\u001b[39m batch_size)\u001b[38;5;241m.\u001b[39mindices(shard\u001b[38;5;241m.\u001b[39mnum_rows)))\n\u001b[1;32m 3550\u001b[0m ) \u001b[38;5;66;03m# Something simpler?\u001b[39;00m\n\u001b[1;32m 3551\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m-> 3552\u001b[0m batch \u001b[38;5;241m=\u001b[39m \u001b[43mapply_function_on_filtered_inputs\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 3553\u001b[0m \u001b[43m \u001b[49m\u001b[43mbatch\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3554\u001b[0m \u001b[43m \u001b[49m\u001b[43mindices\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3555\u001b[0m \u001b[43m \u001b[49m\u001b[43mcheck_same_num_examples\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mlen\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mshard\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlist_indexes\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m>\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3556\u001b[0m \u001b[43m \u001b[49m\u001b[43moffset\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moffset\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3557\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 3558\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m NumExamplesMismatchError:\n\u001b[1;32m 3559\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m DatasetTransformationNotAllowedError(\n\u001b[1;32m 3560\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mUsing `.map` in batched mode on a dataset with attached indexes is allowed only if it doesn\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mt create or remove existing examples. You can first run `.drop_index() to remove your index and then re-add it.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 3561\u001b[0m ) \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n",
- "File \u001b[0;32m~/magisterka/magisterka_env/lib/python3.8/site-packages/datasets/arrow_dataset.py:3421\u001b[0m, in \u001b[0;36mDataset._map_single..apply_function_on_filtered_inputs\u001b[0;34m(pa_inputs, indices, check_same_num_examples, offset)\u001b[0m\n\u001b[1;32m 3419\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m with_rank:\n\u001b[1;32m 3420\u001b[0m additional_args \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m (rank,)\n\u001b[0;32m-> 3421\u001b[0m processed_inputs \u001b[38;5;241m=\u001b[39m \u001b[43mfunction\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mfn_args\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43madditional_args\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mfn_kwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 3422\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(processed_inputs, LazyDict):\n\u001b[1;32m 3423\u001b[0m processed_inputs \u001b[38;5;241m=\u001b[39m {\n\u001b[1;32m 3424\u001b[0m k: v \u001b[38;5;28;01mfor\u001b[39;00m k, v \u001b[38;5;129;01min\u001b[39;00m processed_inputs\u001b[38;5;241m.\u001b[39mdata\u001b[38;5;241m.\u001b[39mitems() \u001b[38;5;28;01mif\u001b[39;00m k \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;129;01min\u001b[39;00m processed_inputs\u001b[38;5;241m.\u001b[39mkeys_to_format\n\u001b[1;32m 3425\u001b[0m }\n",
- "Cell \u001b[0;32mIn[7], line 2\u001b[0m, in \u001b[0;36mtokenize_function\u001b[0;34m(examples)\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mtokenize_function\u001b[39m(examples):\n\u001b[0;32m----> 2\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mtokenizer\u001b[49m\u001b[43m(\u001b[49m\u001b[43mexamples\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mcode\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtruncation\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mpadding\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mmax_length\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmax_length\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m512\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mreturn_tensors\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mpt\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m)\u001b[49m\n",
- "File \u001b[0;32m~/magisterka/magisterka_env/lib/python3.8/site-packages/transformers/tokenization_utils_base.py:3055\u001b[0m, in \u001b[0;36mPreTrainedTokenizerBase.__call__\u001b[0;34m(self, text, text_pair, text_target, text_pair_target, add_special_tokens, padding, truncation, max_length, stride, is_split_into_words, pad_to_multiple_of, return_tensors, return_token_type_ids, return_attention_mask, return_overflowing_tokens, return_special_tokens_mask, return_offsets_mapping, return_length, verbose, **kwargs)\u001b[0m\n\u001b[1;32m 3053\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_in_target_context_manager:\n\u001b[1;32m 3054\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_switch_to_input_mode()\n\u001b[0;32m-> 3055\u001b[0m encodings \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_one\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtext\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtext\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtext_pair\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtext_pair\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mall_kwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 3056\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m text_target \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 3057\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_switch_to_target_mode()\n",
- "File \u001b[0;32m~/magisterka/magisterka_env/lib/python3.8/site-packages/transformers/tokenization_utils_base.py:3142\u001b[0m, in \u001b[0;36mPreTrainedTokenizerBase._call_one\u001b[0;34m(self, text, text_pair, add_special_tokens, padding, truncation, max_length, stride, is_split_into_words, pad_to_multiple_of, return_tensors, return_token_type_ids, return_attention_mask, return_overflowing_tokens, return_special_tokens_mask, return_offsets_mapping, return_length, verbose, split_special_tokens, **kwargs)\u001b[0m\n\u001b[1;32m 3137\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[1;32m 3138\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mbatch length of `text`: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mlen\u001b[39m(text)\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m does not match batch length of `text_pair`:\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 3139\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mlen\u001b[39m(text_pair)\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 3140\u001b[0m )\n\u001b[1;32m 3141\u001b[0m batch_text_or_text_pairs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mlist\u001b[39m(\u001b[38;5;28mzip\u001b[39m(text, text_pair)) \u001b[38;5;28;01mif\u001b[39;00m text_pair \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m text\n\u001b[0;32m-> 3142\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbatch_encode_plus\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 3143\u001b[0m \u001b[43m \u001b[49m\u001b[43mbatch_text_or_text_pairs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mbatch_text_or_text_pairs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3144\u001b[0m \u001b[43m \u001b[49m\u001b[43madd_special_tokens\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43madd_special_tokens\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3145\u001b[0m \u001b[43m \u001b[49m\u001b[43mpadding\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpadding\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3146\u001b[0m \u001b[43m \u001b[49m\u001b[43mtruncation\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtruncation\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3147\u001b[0m \u001b[43m \u001b[49m\u001b[43mmax_length\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmax_length\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3148\u001b[0m \u001b[43m \u001b[49m\u001b[43mstride\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mstride\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3149\u001b[0m \u001b[43m \u001b[49m\u001b[43mis_split_into_words\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mis_split_into_words\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3150\u001b[0m \u001b[43m \u001b[49m\u001b[43mpad_to_multiple_of\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpad_to_multiple_of\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3151\u001b[0m \u001b[43m \u001b[49m\u001b[43mreturn_tensors\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mreturn_tensors\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3152\u001b[0m \u001b[43m \u001b[49m\u001b[43mreturn_token_type_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mreturn_token_type_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3153\u001b[0m \u001b[43m \u001b[49m\u001b[43mreturn_attention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mreturn_attention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3154\u001b[0m \u001b[43m \u001b[49m\u001b[43mreturn_overflowing_tokens\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mreturn_overflowing_tokens\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3155\u001b[0m \u001b[43m \u001b[49m\u001b[43mreturn_special_tokens_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mreturn_special_tokens_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3156\u001b[0m \u001b[43m \u001b[49m\u001b[43mreturn_offsets_mapping\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mreturn_offsets_mapping\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3157\u001b[0m \u001b[43m \u001b[49m\u001b[43mreturn_length\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mreturn_length\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3158\u001b[0m \u001b[43m \u001b[49m\u001b[43mverbose\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mverbose\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3159\u001b[0m \u001b[43m \u001b[49m\u001b[43msplit_special_tokens\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43msplit_special_tokens\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3160\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3161\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 3162\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 3163\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mencode_plus(\n\u001b[1;32m 3164\u001b[0m text\u001b[38;5;241m=\u001b[39mtext,\n\u001b[1;32m 3165\u001b[0m text_pair\u001b[38;5;241m=\u001b[39mtext_pair,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 3182\u001b[0m \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs,\n\u001b[1;32m 3183\u001b[0m )\n",
- "File \u001b[0;32m~/magisterka/magisterka_env/lib/python3.8/site-packages/transformers/tokenization_utils_base.py:3338\u001b[0m, in \u001b[0;36mPreTrainedTokenizerBase.batch_encode_plus\u001b[0;34m(self, batch_text_or_text_pairs, add_special_tokens, padding, truncation, max_length, stride, is_split_into_words, pad_to_multiple_of, return_tensors, return_token_type_ids, return_attention_mask, return_overflowing_tokens, return_special_tokens_mask, return_offsets_mapping, return_length, verbose, split_special_tokens, **kwargs)\u001b[0m\n\u001b[1;32m 3328\u001b[0m \u001b[38;5;66;03m# Backward compatibility for 'truncation_strategy', 'pad_to_max_length'\u001b[39;00m\n\u001b[1;32m 3329\u001b[0m padding_strategy, truncation_strategy, max_length, kwargs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_get_padding_truncation_strategies(\n\u001b[1;32m 3330\u001b[0m padding\u001b[38;5;241m=\u001b[39mpadding,\n\u001b[1;32m 3331\u001b[0m truncation\u001b[38;5;241m=\u001b[39mtruncation,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 3335\u001b[0m \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs,\n\u001b[1;32m 3336\u001b[0m )\n\u001b[0;32m-> 3338\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_batch_encode_plus\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 3339\u001b[0m \u001b[43m \u001b[49m\u001b[43mbatch_text_or_text_pairs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mbatch_text_or_text_pairs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3340\u001b[0m \u001b[43m \u001b[49m\u001b[43madd_special_tokens\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43madd_special_tokens\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3341\u001b[0m \u001b[43m \u001b[49m\u001b[43mpadding_strategy\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpadding_strategy\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3342\u001b[0m \u001b[43m \u001b[49m\u001b[43mtruncation_strategy\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtruncation_strategy\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3343\u001b[0m \u001b[43m \u001b[49m\u001b[43mmax_length\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmax_length\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3344\u001b[0m \u001b[43m \u001b[49m\u001b[43mstride\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mstride\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3345\u001b[0m \u001b[43m \u001b[49m\u001b[43mis_split_into_words\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mis_split_into_words\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3346\u001b[0m \u001b[43m \u001b[49m\u001b[43mpad_to_multiple_of\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpad_to_multiple_of\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3347\u001b[0m \u001b[43m \u001b[49m\u001b[43mreturn_tensors\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mreturn_tensors\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3348\u001b[0m \u001b[43m \u001b[49m\u001b[43mreturn_token_type_ids\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mreturn_token_type_ids\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3349\u001b[0m \u001b[43m \u001b[49m\u001b[43mreturn_attention_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mreturn_attention_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3350\u001b[0m \u001b[43m \u001b[49m\u001b[43mreturn_overflowing_tokens\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mreturn_overflowing_tokens\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3351\u001b[0m \u001b[43m \u001b[49m\u001b[43mreturn_special_tokens_mask\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mreturn_special_tokens_mask\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3352\u001b[0m \u001b[43m \u001b[49m\u001b[43mreturn_offsets_mapping\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mreturn_offsets_mapping\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3353\u001b[0m \u001b[43m \u001b[49m\u001b[43mreturn_length\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mreturn_length\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3354\u001b[0m \u001b[43m \u001b[49m\u001b[43mverbose\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mverbose\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3355\u001b[0m \u001b[43m \u001b[49m\u001b[43msplit_special_tokens\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43msplit_special_tokens\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3356\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3357\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n",
- "File \u001b[0;32m~/magisterka/magisterka_env/lib/python3.8/site-packages/transformers/tokenization_utils.py:882\u001b[0m, in \u001b[0;36mPreTrainedTokenizer._batch_encode_plus\u001b[0;34m(self, batch_text_or_text_pairs, add_special_tokens, padding_strategy, truncation_strategy, max_length, stride, is_split_into_words, pad_to_multiple_of, return_tensors, return_token_type_ids, return_attention_mask, return_overflowing_tokens, return_special_tokens_mask, return_offsets_mapping, return_length, verbose, split_special_tokens, **kwargs)\u001b[0m\n\u001b[1;32m 879\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 880\u001b[0m ids, pair_ids \u001b[38;5;241m=\u001b[39m ids_or_pair_ids\n\u001b[0;32m--> 882\u001b[0m first_ids \u001b[38;5;241m=\u001b[39m \u001b[43mget_input_ids\u001b[49m\u001b[43m(\u001b[49m\u001b[43mids\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 883\u001b[0m second_ids \u001b[38;5;241m=\u001b[39m get_input_ids(pair_ids) \u001b[38;5;28;01mif\u001b[39;00m pair_ids \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 884\u001b[0m input_ids\u001b[38;5;241m.\u001b[39mappend((first_ids, second_ids))\n",
- "File \u001b[0;32m~/magisterka/magisterka_env/lib/python3.8/site-packages/transformers/tokenization_utils.py:849\u001b[0m, in \u001b[0;36mPreTrainedTokenizer._batch_encode_plus..get_input_ids\u001b[0;34m(text)\u001b[0m\n\u001b[1;32m 847\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mget_input_ids\u001b[39m(text):\n\u001b[1;32m 848\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(text, \u001b[38;5;28mstr\u001b[39m):\n\u001b[0;32m--> 849\u001b[0m tokens \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtokenize\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtext\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 850\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mconvert_tokens_to_ids(tokens)\n\u001b[1;32m 851\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(text, (\u001b[38;5;28mlist\u001b[39m, \u001b[38;5;28mtuple\u001b[39m)) \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(text) \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m0\u001b[39m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(text[\u001b[38;5;241m0\u001b[39m], \u001b[38;5;28mstr\u001b[39m):\n",
- "File \u001b[0;32m~/magisterka/magisterka_env/lib/python3.8/site-packages/transformers/tokenization_utils.py:695\u001b[0m, in \u001b[0;36mPreTrainedTokenizer.tokenize\u001b[0;34m(self, text, **kwargs)\u001b[0m\n\u001b[1;32m 693\u001b[0m tokenized_text\u001b[38;5;241m.\u001b[39mappend(token)\n\u001b[1;32m 694\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 695\u001b[0m tokenized_text\u001b[38;5;241m.\u001b[39mextend(\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_tokenize\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtoken\u001b[49m\u001b[43m)\u001b[49m)\n\u001b[1;32m 696\u001b[0m \u001b[38;5;66;03m# [\"This\", \" is\", \" something\", \"\", \"else\"]\u001b[39;00m\n\u001b[1;32m 697\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m tokenized_text\n",
- "File \u001b[0;32m~/magisterka/magisterka_env/lib/python3.8/site-packages/transformers/models/roberta/tokenization_roberta.py:270\u001b[0m, in \u001b[0;36mRobertaTokenizer._tokenize\u001b[0;34m(self, text)\u001b[0m\n\u001b[1;32m 268\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"Tokenize a string.\"\"\"\u001b[39;00m\n\u001b[1;32m 269\u001b[0m bpe_tokens \u001b[38;5;241m=\u001b[39m []\n\u001b[0;32m--> 270\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m token \u001b[38;5;129;01min\u001b[39;00m \u001b[43mre\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfindall\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpat\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtext\u001b[49m\u001b[43m)\u001b[49m:\n\u001b[1;32m 271\u001b[0m token \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;241m.\u001b[39mjoin(\n\u001b[1;32m 272\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbyte_encoder[b] \u001b[38;5;28;01mfor\u001b[39;00m b \u001b[38;5;129;01min\u001b[39;00m token\u001b[38;5;241m.\u001b[39mencode(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mutf-8\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 273\u001b[0m ) \u001b[38;5;66;03m# Maps all our bytes to unicode strings, avoiding control tokens of the BPE (spaces in our case)\u001b[39;00m\n\u001b[1;32m 274\u001b[0m bpe_tokens\u001b[38;5;241m.\u001b[39mextend(bpe_token \u001b[38;5;28;01mfor\u001b[39;00m bpe_token \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbpe(token)\u001b[38;5;241m.\u001b[39msplit(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m \u001b[39m\u001b[38;5;124m\"\u001b[39m))\n",
- "File \u001b[0;32m~/magisterka/magisterka_env/lib/python3.8/site-packages/regex/regex.py:338\u001b[0m, in \u001b[0;36mfindall\u001b[0;34m(pattern, string, flags, pos, endpos, overlapped, concurrent, timeout, ignore_unused, **kwargs)\u001b[0m\n\u001b[1;32m 333\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"Return a list of all matches in the string. The matches may be overlapped\u001b[39;00m\n\u001b[1;32m 334\u001b[0m \u001b[38;5;124;03mif overlapped is True. If one or more groups are present in the pattern,\u001b[39;00m\n\u001b[1;32m 335\u001b[0m \u001b[38;5;124;03mreturn a list of groups; this will be a list of tuples if the pattern has\u001b[39;00m\n\u001b[1;32m 336\u001b[0m \u001b[38;5;124;03mmore than one group. Empty matches are included in the result.\"\"\"\u001b[39;00m\n\u001b[1;32m 337\u001b[0m pat \u001b[38;5;241m=\u001b[39m _compile(pattern, flags, ignore_unused, kwargs, \u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[0;32m--> 338\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mpat\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfindall\u001b[49m\u001b[43m(\u001b[49m\u001b[43mstring\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mpos\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mendpos\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43moverlapped\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mconcurrent\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtimeout\u001b[49m\u001b[43m)\u001b[49m\n",
- "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
- ]
- }
- ],
- "source": [
- "def tokenize_function(examples):\n",
- " return tokenizer(examples['code'], truncation=True, padding='max_length', max_length=512, return_tensors='pt')\n",
- "\n",
- "train_data = train_data.map(tokenize_function, batched=True, remove_columns=['code'], desc='Running tokenizer')\n",
- "valid_data = valid_data.map(tokenize_function, batched=True, remove_columns=['code'], desc='Running tokenizer')\n",
- "test_data = test_data.map(tokenize_function, batched=True, remove_columns=['code'], desc='Running tokenizer')"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "tokenized_datasets"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=True, mlm_probability=0.15)\n",
- "data_collator"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
- "batch_size = 1\n",
- "train_dataloader = DataLoader(tokenized_datasets['train'], batch_size=batch_size, shuffle=True, collate_fn=data_collator, generator=torch.Generator(device=device))\n",
- "valid_dataloader = DataLoader(tokenized_datasets['valid'], batch_size=batch_size, shuffle=False, collate_fn=data_collator, generator=torch.Generator(device=device))\n",
- "test_dataloader = DataLoader(tokenized_datasets['test'], batch_size=batch_size, shuffle=False, collate_fn=data_collator, generator=torch.Generator(device=device))"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": []
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python 3",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.8.19"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 2
-}
diff --git a/code/src/train_codebert_mlm.py b/code/src/train_codebert_mlm.py
index 9495ecd..ae84076 100644
--- a/code/src/train_codebert_mlm.py
+++ b/code/src/train_codebert_mlm.py
@@ -1,254 +1,245 @@
import wandb
-import torch
-from torch.optim import AdamW
-from torch.utils.data import DataLoader
import os
+import json
import random
import datetime
+import logging
+from pathlib import Path
+from typing import Dict, Any, Tuple, List
+
import numpy as np
-from datasets import load_dataset, load_from_disk, disable_caching, DatasetDict
-from tree_sitter import Language, Parser
-from transformers import RobertaForMaskedLM, RobertaConfig, RobertaTokenizer, DataCollatorForLanguageModeling
+import torch
+from torch import Tensor
+from torch.optim import AdamW
+from torch.utils.data import DataLoader, Dataset
+from datasets import load_dataset, disable_caching, DatasetDict
+from huggingface_hub import list_repo_files, hf_hub_download
+from transformers import (
+ RobertaForMaskedLM,
+ RobertaConfig,
+ RobertaTokenizer,
+ DataCollatorForLanguageModeling,
+ get_linear_schedule_with_warmup,
+ PreTrainedTokenizer,
+ PreTrainedModel
+)
from tqdm import tqdm
-from utils import remove_docstrings_and_comments_from_code
+# Configure logging
+logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
+logger = logging.getLogger(__name__)
-# Disable caching for datasets
-disable_caching()
+class OnTheFlyTokenizationDataset(Dataset):
+ def __init__(self, dataset: Dataset, tokenizer: PreTrainedTokenizer, max_length: int):
+ self.dataset = dataset
+ self.tokenizer = tokenizer
+ self.max_length = max_length
-############################### CONFIG ###############################
-dataset_name = 'the-stack-tokenized' # 'the-stack' or 'code-search-net' or 'the-stack-tokenized
-remove_comments = False
-######################################################################
+ def __len__(self) -> int:
+ return len(self.dataset)
-# Initialize Weights & Biases and output directory
-curr_time = datetime.datetime.now().strftime('%Y-%m-%d %H:%M')
-wandb.init(project='codebert-training', name=curr_time)
-output_dir = f'/home/s452638/magisterka/output/{curr_time}/'
+ def __getitem__(self, idx: int) -> Dict[str, Tensor]:
+ content: str = self.dataset[idx]['content']
+ tokenized = self.tokenizer(
+ content,
+ truncation=True,
+ padding='max_length',
+ max_length=self.max_length,
+ return_tensors='pt'
+ )
+ return {
+ 'input_ids': tokenized['input_ids'].squeeze(0),
+ 'attention_mask': tokenized['attention_mask'].squeeze(0),
+ 'labels': tokenized['input_ids'].squeeze(0)
+ }
-# Save this file to Weights & Biases
-wandb.save('train_codebert_mlm.py')
-
-# Create the output directory if it does not exist
-if not os.path.exists(output_dir):
- os.makedirs(output_dir)
-
-# Set the seed for reproducibility
-SEED = 42
-def set_seed(seed):
+def set_seed(seed: int) -> None:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
-set_seed(SEED)
+def setup_wandb(config: Dict[str, Any]) -> None:
+ curr_time: str = datetime.datetime.now().strftime('%Y-%m-%d %H:%M')
+ wandb.init(project='codebert-training', name=curr_time, config=config)
+ wandb.save('train_codebert_mlm.py')
-# Set the device for PyTorch (use GPU if available, otherwise CPU)
-device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
-torch.set_default_device(device)
-print('*' * 10, 'Device', '*' * 10)
-print(f'Using device: {device}')
-if device.type == 'cuda':
- print(f'Device name: {torch.cuda.get_device_name()}')
+def setup_directories(current_dir: Path) -> Path:
+ curr_time: str = datetime.datetime.now().strftime('%Y-%m-%d %H:%M')
+ output_dir: Path = current_dir.parent.parent / 'outputs' / curr_time
+ output_dir.mkdir(parents=True, exist_ok=True)
+ return output_dir
-# Load the dataset
-if dataset_name == 'the-stack-tokenized':
- train_data = load_from_disk('/work/s452638/datasets/the-stack-python-tokenized/train')
- valid_data = load_from_disk('/work/s452638/datasets/the-stack-python-tokenized/valid')
- test_data = load_from_disk('/work/s452638/datasets/the-stack-python-tokenized/test')
-else:
- if dataset_name == 'the-stack':
- train_data = load_dataset("/work/s452638/datasets/the-stack-python", split="train")
- train_data = train_data.rename_column_('content', 'code')
- elif dataset_name == 'code-search-net':
- train_data = load_dataset('json', data_files='/work/s452638/datasets/CodeSearchNet/python/train.jsonl')['train']
+def load_config(config_file: Path) -> Dict[str, Any]:
+ with open(config_file, 'r') as f:
+ return json.load(f)
- valid_data = load_dataset('json', data_files='/work/s452638/datasets/CodeSearchNet/python/valid.jsonl')['valid']
- test_data = load_dataset('json', data_files='/work/s452638/datasets/CodeSearchNet/python/test.jsonl')['test']
+def setup_device() -> torch.device:
+ device: torch.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+ torch.set_default_device(device)
+ logger.info(f'Using device: {device}')
+ if device.type == 'cuda':
+ logger.info(f'Device name: {torch.cuda.get_device_name()}')
+ torch.set_float32_matmul_precision('high')
+ return device
-dataset = DatasetDict({'train': train_data, 'valid': valid_data, 'test': test_data})
-print('\n\n', '*' * 10, 'Dataset', '*' * 10)
-print(dataset)
+def download_dataset(dataset_dir: Path) -> None:
+ if not dataset_dir.exists():
+ logger.info("Downloading the dataset...")
+ dataset_dir.mkdir(parents=True, exist_ok=True)
+ files_list: List[str] = list_repo_files(repo_id='bigcode/the-stack-dedup', repo_type='dataset')
+ files_to_download: List[str] = [file for file in files_list if file.startswith('data/python/')]
+ for file_name in files_to_download:
+ hf_hub_download(repo_id='bigcode/the-stack-dedup', repo_type='dataset', filename=file_name, local_dir=dataset_dir)
+ logger.info("Dataset downloaded successfully.")
-if remove_comments:
- # Build the language library if not already built
- Language.build_library('/home/s452638/magisterka/build/my-languages.so', ['/home/s452638/magisterka/vendor/tree-sitter-python'])
+def load_and_prepare_dataset(dataset_dir: Path, seed: int) -> DatasetDict:
+ dataset: DatasetDict = load_dataset(str(dataset_dir), split='train')
+ dataset = dataset.train_test_split(test_size=0.01, seed=seed)
+ logger.info(f'Dataset loaded: {dataset}')
+ return dataset
- # Load the language
- PYTHON_LANGUAGE = Language('/home/s452638/magisterka/build/my-languages.so', 'python')
+def create_dataloaders(
+ dataset: DatasetDict,
+ tokenizer: PreTrainedTokenizer,
+ config: Dict[str, Any],
+ device: torch.device
+) -> Tuple[DataLoader, DataLoader]:
+ dataset['train'] = OnTheFlyTokenizationDataset(dataset['train'], tokenizer, max_length=512)
+ dataset['test'] = OnTheFlyTokenizationDataset(dataset['test'], tokenizer, max_length=512)
+
+ data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=True, mlm_probability=config['mlm_probability'])
+
+ train_dataloader = DataLoader(
+ dataset['train'],
+ batch_size=config['batch'],
+ shuffle=False,
+ collate_fn=data_collator,
+ generator=torch.Generator(device=device)
+ )
+ valid_dataloader = DataLoader(
+ dataset['test'],
+ batch_size=config['batch'],
+ shuffle=False,
+ collate_fn=data_collator,
+ generator=torch.Generator(device=device)
+ )
+ return train_dataloader, valid_dataloader
- # Initialize the parser
- parser = Parser()
- parser.set_language(PYTHON_LANGUAGE)
+def setup_model_and_optimizer(
+ config: Dict[str, Any],
+ current_dir: Path
+) -> Tuple[PreTrainedModel, AdamW]:
+ os.environ['HF_HOME'] = str(current_dir.parent / 'models')
+ model_config = RobertaConfig.from_pretrained('roberta-base')
+ model: PreTrainedModel = RobertaForMaskedLM(model_config)
+ model = torch.compile(model)
+ wandb.watch(model)
+ logger.info(f'Model config: {model_config}')
+ wandb.config.update({'model_config': model_config.to_dict()})
+
+ optimizer: AdamW = AdamW(model.parameters(), lr=config['learning_rate'], weight_decay=config['weight_decay'])
+ return model, optimizer
- # Remove docstrings and comments from the code
- dataset = dataset.map(lambda x: {'code': remove_docstrings_and_comments_from_code(x['code'], parser)}, batched=False, desc='Removing docstrings and comments')
-
-# Load the tokenizer
-tokenizer = RobertaTokenizer.from_pretrained('microsoft/codebert-base', clean_up_tokenization_spaces=True)
-print('\n\n', '*' * 10, 'Tokenizer', '*' * 10)
-print(tokenizer)
-
-if dataset_name != 'the-stack-tokenized':
- # Tokenize the dataset
- def tokenize_function(examples):
- return tokenizer(examples['code'], truncation=True, padding='max_length', max_length=512, return_tensors='pt')
-
- tokenized_datasets = dataset.map(tokenize_function, batched=True, remove_columns=['code'], desc='Running tokenizer')
- print('\n\n', '*' * 10, 'Tokenized dataset', '*' * 10)
- print(tokenized_datasets)
-else:
- tokenized_datasets = dataset
-
-# Set data collator for MLM
-data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=True, mlm_probability=0.15)
-
-# Create DataLoaders
-batch_size = 64
-train_dataloader = DataLoader(tokenized_datasets['train'], batch_size=batch_size, shuffle=False, collate_fn=data_collator, generator=torch.Generator(device=device))
-valid_dataloader = DataLoader(tokenized_datasets['valid'], batch_size=batch_size, shuffle=False, collate_fn=data_collator, generator=torch.Generator(device=device))
-test_dataloader = DataLoader(tokenized_datasets['test'], batch_size=batch_size, shuffle=False, collate_fn=data_collator, generator=torch.Generator(device=device))
-
-# Initialize a model with random weights based on the configuration for RoBERTa (CodeBERT is based on RoBERTa)
-config = RobertaConfig.from_pretrained('roberta-base')
-model = RobertaForMaskedLM(config)
-model = torch.compile(model)
-wandb.watch(model)
-print('\n\n', '*' * 10, 'Model', '*' * 10)
-print(config)
-
-# Log the model configuration to wandb
-wandb.config.update({'model_config': config.to_dict()})
-
-# Set the optimizer and scaler
-learning_rate = 5e-4
-optimizer = AdamW(model.parameters(), lr=learning_rate, weight_decay=0.01)
-scaler = torch.amp.GradScaler()
-
-# Training settings
-num_epochs = 1
-num_training_steps = num_epochs * len(train_dataloader)
-eval_every = 10_000
-
-# Log training settings to wandb
-wandb.config.update({
- 'training_settings': {
- 'num_epochs': num_epochs,
- 'num_training_steps': num_training_steps,
- 'eval_every': eval_every,
- 'batch_size': batch_size,
- 'learning_rate': learning_rate,
- }
-})
-
-# Initialize variables to track validation loss, accuracy, and best model path
-valid_acc = 0.0
-valid_loss = 0.0
-best_valid_loss = float('inf')
-
-# Train the model
-print('\n\n', '*' * 10, 'Training', '*' * 10)
-model.train()
-with tqdm(total=num_training_steps, desc='Training') as pbar:
- for epoch_idx in range(num_epochs):
- for train_idx, train_batch in enumerate(train_dataloader):
-
- # Forward pass with mixed precision
- with torch.autocast(device_type=device.type, dtype=torch.bfloat16):
+def train_and_evaluate(
+ model: PreTrainedModel,
+ train_dataloader: DataLoader,
+ valid_dataloader: DataLoader,
+ optimizer: AdamW,
+ scheduler: Any,
+ config: Dict[str, Any],
+ output_dir: Path
+) -> None:
+ num_training_steps: int = config['epochs'] * len(train_dataloader)
+ best_valid_loss: float = float('inf')
+
+ with tqdm(total=num_training_steps, desc='Training') as pbar:
+ for epoch_idx in range(config['epochs']):
+ model.train()
+ for train_idx, train_batch in enumerate(train_dataloader):
outputs = model(**train_batch)
-
- train_loss = outputs.loss
- scaler.scale(train_loss).backward()
-
- # Gradient clipping to prevent exploding gradients
- norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
-
- scaler.step(optimizer)
- scaler.update()
- optimizer.zero_grad()
-
- pbar.update(1)
- pbar.set_postfix({'norm': norm.item(), 'train_loss': train_loss.item(), 'valid_loss': valid_loss, 'valid_acc': valid_acc})
-
- # Log metrics to Weights & Biases
- wandb.log({
- 'step': train_idx + len(train_dataloader) * epoch_idx,
- 'train_loss': train_loss.item(),
- 'gradient_norm': norm.item(),
- 'learning_rate': optimizer.param_groups[0]['lr'],
- })
-
- # Evaluate the model
- if train_idx != 0 and train_idx % eval_every == 0:
- model.eval()
- valid_loss = 0.0
- valid_acc = 0.0
-
- with tqdm(total=len(valid_dataloader), desc='Validation') as pbar_valid:
- with torch.no_grad():
- for valid_idx, valid_batch in enumerate(valid_dataloader):
-
- # Forward pass with mixed precision for validation
- with torch.autocast(device_type=device.type, dtype=torch.bfloat16):
- outputs = model(**valid_batch)
-
- # Accumulate validation loss and accuracy
- valid_loss += outputs.loss.item()
- valid_acc += outputs.logits.argmax(dim=-1).eq(valid_batch['labels']).sum().item()
- pbar_valid.update(1)
-
- # Compute average validation loss and accuracy
- valid_loss /= len(valid_dataloader)
- valid_acc /= len(valid_dataloader.dataset)
- model.train()
-
- # Log validation metrics to Weights & Biases
+ train_loss: Tensor = outputs.loss
+ train_loss.backward()
+
+ norm: Tensor = torch.nn.utils.clip_grad_norm_(model.parameters(), config['max_grad_norm'])
+
+ optimizer.step()
+ scheduler.step()
+ optimizer.zero_grad()
+
+ pbar.update(1)
+ pbar.set_postfix({'train_loss': train_loss.item()})
+
wandb.log({
- 'valid_loss': valid_loss,
- 'valid_acc': valid_acc,
'step': train_idx + len(train_dataloader) * epoch_idx,
+ 'train_loss': train_loss.item(),
+ 'gradient_norm': norm.item(),
+ 'learning_rate': scheduler.get_last_lr()[0],
})
+
+ if train_idx != 0 and train_idx % config['eval_every'] == 0:
+ valid_loss, valid_acc = evaluate(model, valid_dataloader)
+ pbar.set_postfix({'train_loss': train_loss.item(), 'valid_loss': valid_loss, 'valid_acc': valid_acc})
+
+ wandb.log({
+ 'valid_loss': valid_loss,
+ 'valid_acc': valid_acc,
+ 'step': train_idx + len(train_dataloader) * epoch_idx,
+ })
+
+ if valid_loss < best_valid_loss:
+ best_valid_loss = valid_loss
+ torch.save(model.state_dict(), output_dir / 'best_model.pt')
+
+ logger.info(f'Best validation loss: {best_valid_loss}')
- # Update best model if current validation loss is lower
- if valid_loss < best_valid_loss:
- best_valid_loss = valid_loss
- torch.save(model.state_dict(), output_dir + f'best_model.pt')
-
-print('\n\n', '*' * 10, 'Training results', '*' * 10)
-print(f'Best validation loss: {best_valid_loss}')
-
-# Load the best model and evaluate on the test set
-print('\n\n', '*' * 10, 'Testing', '*' * 10)
-model.load_state_dict(torch.load(output_dir + f'best_model.pt', weights_only=True, map_location=device))
-model.eval()
-test_loss = 0.0
-test_acc = 0.0
-
-with tqdm(total=len(test_dataloader), desc='Testing') as pbar_test:
+def evaluate(model: PreTrainedModel, dataloader: DataLoader) -> Tuple[float, float]:
+ model.eval()
+ total_loss: float = 0.0
+ total_acc: float = 0.0
+
with torch.no_grad():
- for test_idx, test_batch in enumerate(test_dataloader):
+ for batch in tqdm(dataloader, desc='Validation'):
+ outputs = model(**batch)
+ total_loss += outputs.loss.item()
+ total_acc += outputs.logits.argmax(dim=-1).eq(batch['labels']).sum().item()
+
+ avg_loss: float = total_loss / len(dataloader)
+ avg_acc: float = total_acc / len(dataloader.dataset)
+ return avg_loss, avg_acc
- # Forward pass with mixed precision for testing
- with torch.autocast(device_type=device.type, dtype=torch.bfloat16):
- outputs = model(**test_batch)
+def main() -> None:
+ disable_caching()
+
+ current_dir: Path = Path(__file__).parent
+ output_dir: Path = setup_directories(current_dir)
+ config: Dict[str, Any] = load_config(current_dir / 'config.json')
- # Accumulate test loss and accuracy
- test_loss += outputs.loss.item()
- test_acc += outputs.logits.argmax(dim=-1).eq(test_batch['labels']).sum().item()
- pbar_test.update(1)
+ setup_wandb(config)
+ set_seed(config['seed'])
+ device: torch.device = setup_device()
+
+ dataset_dir: Path = current_dir.parent / 'data' / 'the-stack-python'
+ download_dataset(dataset_dir)
+ dataset: DatasetDict = load_and_prepare_dataset(dataset_dir, config['seed'])
+
+ tokenizer: PreTrainedTokenizer = RobertaTokenizer.from_pretrained('microsoft/codebert-base', clean_up_tokenization_spaces=True)
+ logger.info(f'Tokenizer loaded: {tokenizer}')
+
+ train_dataloader, valid_dataloader = create_dataloaders(dataset, tokenizer, config, device)
+
+ model, optimizer = setup_model_and_optimizer(config, current_dir)
+
+ num_training_steps: int = config['epochs'] * len(train_dataloader)
+ scheduler = get_linear_schedule_with_warmup(
+ optimizer,
+ num_warmup_steps=config['warmup_steps'],
+ num_training_steps=num_training_steps
+ )
+
+ train_and_evaluate(model, train_dataloader, valid_dataloader, optimizer, scheduler, config, output_dir)
-# Compute average test loss and accuracy
-test_loss /= len(test_dataloader)
-test_acc /= len(test_dataloader.dataset)
-
-# Log test metrics to Weights & Biases
-wandb.log({
- 'test_loss': test_loss,
- 'test_acc': test_acc,
-})
-
-print('\n\n', '*' * 10, 'Test results', '*' * 10)
-print(f'Test loss: {test_loss}')
-print(f'Test accuracy: {test_acc}')
+if __name__ == "__main__":
+ main()
\ No newline at end of file