diff --git a/code/pyproject.toml b/code/pyproject.toml index 75c4f34..43c9ff8 100644 --- a/code/pyproject.toml +++ b/code/pyproject.toml @@ -7,15 +7,18 @@ authors = [ ] dependencies = [ "wandb==0.18.5", - "torch==2.5.0", + "torch==2.5.1", "tqdm==4.66.5", "tree-sitter==0.23.1", - "transformers==4.45.2", + "transformers[torch]>=4.46.3", "datasets==3.0.1", "huggingface-hub==0.26.0", "matplotlib==3.9.2", "scikit-learn==1.5.2", "seaborn==0.13.2", + "tree-sitter-python==0.23.4", + "ipykernel>=6.29.5", + "ipywidgets>=8.1.5", ] requires-python = "==3.11.*" readme = "README.md" @@ -35,6 +38,5 @@ build-backend = "pdm.backend" distribution = true [tool.pdm.scripts] -run_training = {cmd = "src/train_codebert_mlm.py"} -run_tree_training = {cmd = "src/train_tree_codebert_mlm.py"} parse_dataset = {cmd = "src/parse_dataset.py"} +run_training = {cmd = "src/training.py"} diff --git a/code/src/config.json b/code/src/config.json index 17425ff..370198d 100644 --- a/code/src/config.json +++ b/code/src/config.json @@ -1,12 +1,13 @@ { - "seed": 42, + "seed": 420, "mlm_probability": 0.15, - "batch": 16, - "epochs": 1, - "eval_every": 20000, + "batch_size": 32, + "epochs": 5, + "eval_every": 1000, "learning_rate": 5e-4, - "weight_decay": 0.01, + "weight_decay": 0.1, "max_grad_norm": 1.0, - "warmup_steps": 20000, - "train_size": 0.95 + "warmup_steps": 1000, + "train_size": 0.95, + "fp16": true } \ No newline at end of file diff --git a/code/src/dataset_parsing_test.ipynb b/code/src/dataset_parsing_test.ipynb new file mode 100644 index 0000000..aa5b341 --- /dev/null +++ b/code/src/dataset_parsing_test.ipynb @@ -0,0 +1,4507 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 183, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import random\n", + "import pandas as pd\n", + "import multiprocessing\n", + "import tree_sitter_python as tspython\n", + "from typing import Dict, Any, List\n", + "from datasets import load_dataset\n", + "from tree_sitter import Language, Parser\n", + "from transformers import AutoTokenizer" + ] + }, + { + "cell_type": "code", + "execution_count": 184, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "32" + ] + }, + "execution_count": 184, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "num_proc = min(multiprocessing.cpu_count() - 1, 32)\n", + "num_proc" + ] + }, + { + "cell_type": "code", + "execution_count": 185, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "RobertaTokenizerFast(name_or_path='microsoft/codebert-base', vocab_size=50265, model_max_length=512, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '', 'eos_token': '', 'unk_token': '', 'sep_token': '', 'pad_token': '', 'cls_token': '', 'mask_token': ''}, clean_up_tokenization_spaces=False), added_tokens_decoder={\n", + "\t0: AddedToken(\"\", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),\n", + "\t1: AddedToken(\"\", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),\n", + "\t2: AddedToken(\"\", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),\n", + "\t3: AddedToken(\"\", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),\n", + "\t50264: AddedToken(\"\", rstrip=False, lstrip=True, single_word=False, normalized=False, special=True),\n", + "}" + ] + }, + "execution_count": 185, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "tokenizer = AutoTokenizer.from_pretrained('microsoft/codebert-base', use_fast=True)\n", + "tokenizer" + ] + }, + { + "cell_type": "code", + "execution_count": 186, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Dataset({\n", + " features: ['hexsha', 'size', 'ext', 'lang', 'max_stars_repo_path', 'max_stars_repo_name', 'max_stars_repo_head_hexsha', 'max_stars_repo_licenses', 'max_stars_count', 'max_stars_repo_stars_event_min_datetime', 'max_stars_repo_stars_event_max_datetime', 'max_issues_repo_path', 'max_issues_repo_name', 'max_issues_repo_head_hexsha', 'max_issues_repo_licenses', 'max_issues_count', 'max_issues_repo_issues_event_min_datetime', 'max_issues_repo_issues_event_max_datetime', 'max_forks_repo_path', 'max_forks_repo_name', 'max_forks_repo_head_hexsha', 'max_forks_repo_licenses', 'max_forks_count', 'max_forks_repo_forks_event_min_datetime', 'max_forks_repo_forks_event_max_datetime', 'content', 'avg_line_length', 'max_line_length', 'alphanum_fraction'],\n", + " num_rows: 1000\n", + "})" + ] + }, + "execution_count": 186, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "input_dir = os.path.join('..', 'data', 'the-stack-python')\n", + "dataset = load_dataset(str(input_dir))['train']\n", + "dataset = dataset.select(range(1000))\n", + "dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 187, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['hexsha',\n", + " 'ext',\n", + " 'lang',\n", + " 'max_stars_repo_path',\n", + " 'max_stars_repo_name',\n", + " 'max_stars_repo_head_hexsha',\n", + " 'max_stars_repo_licenses',\n", + " 'max_stars_repo_stars_event_min_datetime',\n", + " 'max_stars_repo_stars_event_max_datetime',\n", + " 'max_issues_repo_path',\n", + " 'max_issues_repo_name',\n", + " 'max_issues_repo_head_hexsha',\n", + " 'max_issues_repo_licenses',\n", + " 'max_issues_count',\n", + " 'max_issues_repo_issues_event_min_datetime',\n", + " 'max_issues_repo_issues_event_max_datetime',\n", + " 'max_forks_repo_path',\n", + " 'max_forks_repo_name',\n", + " 'max_forks_repo_head_hexsha',\n", + " 'max_forks_repo_licenses',\n", + " 'max_forks_count',\n", + " 'max_forks_repo_forks_event_min_datetime',\n", + " 'max_forks_repo_forks_event_max_datetime']" + ] + }, + "execution_count": 187, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "columns_to_remove = dataset.column_names\n", + "columns_to_remove.remove('content')\n", + "columns_to_remove.remove('size')\n", + "columns_to_remove.remove('avg_line_length')\n", + "columns_to_remove.remove('max_line_length')\n", + "columns_to_remove.remove('alphanum_fraction')\n", + "columns_to_remove.remove('max_stars_count')\n", + "columns_to_remove" + ] + }, + { + "cell_type": "code", + "execution_count": 188, + "metadata": {}, + "outputs": [], + "source": [ + "def analyze_code_with_codebert_and_treesitter(code_snippet: str, tokenizer: AutoTokenizer):\n", + " \"\"\"\n", + " Map tokens to their original text and Tree-sitter AST position\n", + " \"\"\"\n", + " # Initialize Tree-sitter\n", + " PY_LANGUAGE = Language(tspython.language())\n", + " parser = Parser(PY_LANGUAGE)\n", + " \n", + " # Parse with Tree-sitter\n", + " tree = parser.parse(bytes(code_snippet, \"utf8\"))\n", + " \n", + " encoded = tokenizer(\n", + " code_snippet,\n", + " add_special_tokens=True,\n", + " return_offsets_mapping=True,\n", + " return_tensors='pt',\n", + " padding='max_length',\n", + " truncation=True,\n", + " max_length=512,\n", + " )\n", + " \n", + " tokens = tokenizer.convert_ids_to_tokens(encoded['input_ids'][0])\n", + " offset_mapping = encoded['offset_mapping'][0].tolist()\n", + " \n", + " def get_node_position(node, depth=0, idx=0):\n", + " \"\"\"Get depth and sibling index for a node\"\"\"\n", + " return (depth, idx)\n", + " \n", + " def find_node_at_position(node, start_byte, depth=0, sibling_idx=0):\n", + " \"\"\"Find the most specific node containing the given position\"\"\"\n", + " if start_byte >= node.start_byte and start_byte < node.end_byte:\n", + " # Check children first for more specific nodes\n", + " for idx, child in enumerate(node.children):\n", + " result = find_node_at_position(child, start_byte, depth + 1, idx)\n", + " if result:\n", + " return result\n", + " # Return current node's position if no child contains the position\n", + " return get_node_position(node, depth, sibling_idx)\n", + " return None\n", + " \n", + " depths = []\n", + " sibling_idxs = []\n", + " node_texts = []\n", + " for token, (start, end) in zip(tokens, offset_mapping):\n", + " if token not in ['', '', '', '']:\n", + " text = code_snippet[start:end] if start < len(code_snippet) else \"\"\n", + " # Tree-sitter works with bytes, so convert position\n", + " start_byte = len(code_snippet[:start].encode('utf8')) if start < len(code_snippet) else 0\n", + " position = find_node_at_position(tree.root_node, start_byte) or (-1, -1)\n", + " \n", + " depths.append(position[0])\n", + " sibling_idxs.append(position[1])\n", + " node_texts.append(text)\n", + " else:\n", + " depths.append(-1)\n", + " sibling_idxs.append(-1)\n", + " node_texts.append(None)\n", + " \n", + " return tokens, depths, sibling_idxs, node_texts" + ] + }, + { + "cell_type": "code", + "execution_count": 189, + "metadata": {}, + "outputs": [], + "source": [ + "def process_batch(examples, tokenizer) -> Dict[str, List[Any]]:\n", + " \"\"\"Process a batch of examples.\"\"\"\n", + " contents = examples['content']\n", + "\n", + " processed_tokens = []\n", + " processed_depths = []\n", + " processed_sibling_idxs = []\n", + " processed_node_texts = []\n", + "\n", + " for content in contents:\n", + " tokens, depths, sibling_idxs, node_texts = analyze_code_with_codebert_and_treesitter(content, tokenizer)\n", + " processed_tokens.append(tokens)\n", + " processed_depths.append(depths)\n", + " processed_sibling_idxs.append(sibling_idxs)\n", + " processed_node_texts.append(node_texts)\n", + "\n", + " return {\n", + " 'tokens': processed_tokens,\n", + " 'depths': processed_depths,\n", + " 'sibling_idxs': processed_sibling_idxs,\n", + " 'node_texts': processed_node_texts,\n", + " }" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Processing dataset (num_proc=32): 100%|██████████| 1000/1000 [00:03<00:00, 290.50 examples/s]\n" + ] + }, + { + "data": { + "text/plain": [ + "Dataset({\n", + " features: ['size', 'max_stars_count', 'content', 'avg_line_length', 'max_line_length', 'alphanum_fraction', 'tokens', 'depths', 'sibling_idxs', 'node_texts'],\n", + " num_rows: 1000\n", + "})" + ] + }, + "execution_count": 190, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "dataset = dataset.map(\n", + " process_batch,\n", + " fn_kwargs={'tokenizer': tokenizer},\n", + " batched=True,\n", + " remove_columns=columns_to_remove,\n", + " desc=\"Processing dataset\",\n", + " num_proc=num_proc,\n", + " load_from_cache_file=False\n", + ")\n", + "dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 191, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "#!/usr/bin/env python\n", + "# -*- coding: utf-8 -*-\n", + "import os\n", + "from setuptools import find_packages, setup\n", + "\n", + "from app import __version__\n", + "\n", + "# get the dependencies and installs\n", + "here = os.path.abspath(os.path.dirname(__file__))\n", + "with open(os.path.join(here, 'requirements.txt')) as f:\n", + " all_requirements = f.read().split('\\n')\n", + "\n", + "setup(\n", + " name='webspider',\n", + " version=__version__,\n", + " license='MIT',\n", + " author='heguozhu',\n", + " author_email='heguozhu@zhihu.com',\n", + " description='lagou.com spider',\n", + " url='git@github.com:GuozhuHe/webspider.git',\n", + " packages=find_packages(exclude=['tests']),\n", + " package_data={'webspider': ['README.md']},\n", + " zip_safe=False,\n", + " install_requires=all_requirements,\n", + " entry_points={\n", + " 'console_scripts': [\n", + " 'web = app.web_app:main',\n", + " 'production_web = app.quickly_cmd:run_web_app_by_gunicorn',\n", + " 'crawl_lagou_data = app.tasks:crawl_lagou_data',\n", + " 'crawl_jobs_count = app.tasks.jobs_count:crawl_lagou_jobs_count',\n", + " 'celery_jobs_count_worker = app.quickly_cmd:run_celery_jobs_count_worker',\n", + " 'celery_lagou_data_worker = app.quickly_cmd:run_celery_lagou_data_worker',\n", + " 'celery_beat = app.quickly_cmd:run_celery_beat',\n", + " 'celery_flower = app.quickly_cmd.py:run_celery_flower',\n", + " ],\n", + " }\n", + ")\n", + "\n" + ] + } + ], + "source": [ + "random_sample = random.choice(dataset)\n", + "print(random_sample['content'])" + ] + }, + { + "cell_type": "code", + "execution_count": 192, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
tokensdepthssibling_idxsnode_texts
0<s>-1-1None
1#10#
2!/10!/
3usr10usr
4/10/
5bin10bin
6/10/
7env10env
8Ġpython10python
9Ċ00\\n
10#11#
11Ġ-11-
12*11*
13-11-
14Ġcoding11coding
15:11:
16Ġut11ut
17f11f
18-11-
198118
20Ġ-11-
21*11*
22-11-
23Ċ00\\n
24import20import
25Ġos30os
26Ċ00\\n
27from20from
28Ġset30set
29upt30upt
30ools30ools
31Ġimport22import
32Ġfind30find
33_30_
34packages30packages
35,24,
36Ġsetup30setup
37Ċ00\\n
38Ċ00\\n
39from20from
40Ġapp30app
41Ġimport22import
42Ġ__30__
43version30version
44__30__
45Ċ00\\n
46Ċ00\\n
47#15#
48Ġget15get
49Ġthe15the
50Ġdependencies15dependencies
51Ġand15and
52Ġinstalls15installs
53Ċ00\\n
54here30here
55Ġ=31=
56Ġos60os
57.61.
58path62path
59.51.
60ab52ab
61sp52sp
62ath52ath
63(50(
64os80os
65.81.
66path82path
67.71.
68dir72dir
69name72name
70(70(
71__71__
72file71file
73__71__
74))72))
75Ċ00\\n
76with20with
77Ġopen60open
78(70(
79os100os
80.101.
81path102path
82.91.
83join92join
84(90(
85here91here
86,92,
87Ġ'100'
88requ101requ
89irements101irements
90.101.
91txt101txt
92'102'
93))94))
94Ġas51as
95Ġf60f
96:22:
97Ċ17\\n
98Ġ17
99Ġ17
100Ġ17
101Ġall50all
102_50_
103requ50requ
104irements50irements
105Ġ=51=
106Ġf90f
107.91.
108read92read
109().90().
110split72split
111('70('
112\\90\\
113n90n
114')82')
115Ċ00\\n
116Ċ00\\n
117setup30setup
118(40(
119Ċ31\\n
120Ġ31
121Ġ31
122Ġ31
123Ġname50name
124='51='
125we61we
126bsp61bsp
127ider61ider
128',62',
129Ċ31\\n
130Ġ31
131Ġ31
132Ġ31
133Ġversion50version
134=51=
135__52__
136version52version
137__52__
138,44,
139Ċ31\\n
140Ġ31
141Ġ31
142Ġ31
143Ġlicense50license
144='51='
145MIT61MIT
146',62',
147Ċ31\\n
148Ġ31
149Ġ31
150Ġ31
151Ġauthor50author
152='51='
153he61he
154gu61gu
155oz61oz
156hu61hu
157',62',
158Ċ31\\n
159Ġ31
160Ġ31
161Ġ31
162Ġauthor50author
163_50_
164email50email
165='51='
166he61he
167gu61gu
168oz61oz
169hu61hu
170@61@
171zh61zh
172ihu61ihu
173.61.
174com61com
175',62',
176Ċ31\\n
177Ġ31
178Ġ31
179Ġ31
180Ġdescription50description
181='51='
182lag61lag
183ou61ou
184.61.
185com61com
186Ġspider61spider
187',62',
188Ċ31\\n
189Ġ31
190Ġ31
191Ġ31
192Ġurl50url
193='51='
194git61git
195@61@
196github61github
197.61.
198com61com
199:61:
200Gu61Gu
201oz61oz
202hu61hu
203He61He
204/61/
205we61we
206bsp61bsp
207ider61ider
208.61.
209git61git
210',62',
211Ċ31\\n
212Ġ31
213Ġ31
214Ġ31
215Ġpackages50packages
216=51=
217find60find
218_60_
219packages60packages
220(70(
221ex80ex
222clude80clude
223=81=
224['90['
225tests101tests
226']102']
227),72),
228Ċ31\\n
229Ġ31
230Ġ31
231Ġ31
232Ġpackage50package
233_50_
234data50data
235={51={
236'80'
237we81we
238bsp81bsp
239ider81ider
240':82':
241Ġ['80['
242READ91READ
243ME91ME
244.91.
245md91md
246']92']
247},62},
248Ċ31\\n
249Ġ31
250Ġ31
251Ġ31
252Ġzip50zip
253_50_
254safe50safe
255=51=
256False52False
257,420,
258Ċ31\\n
259Ġ31
260Ġ31
261Ġ31
262Ġinstall50install
263_50_
264requires50requires
265=51=
266all52all
267_52_
268requ52requ
269irements52irements
270,422,
271Ċ31\\n
272Ġ31
273Ġ31
274Ġ31
275Ġentry50entry
276_50_
277points50points
278={51={
279Ċ52\\n
280Ġ52
281Ġ52
282Ġ52
283Ġ52
284Ġ52
285Ġ52
286Ġ52
287Ġ'80'
288console81console
289_81_
290scripts81scripts
291':82':
292Ġ[80[
293Ċ72\\n
294Ġ72
295Ġ72
296Ġ72
297Ġ72
298Ġ72
299Ġ72
300Ġ72
301Ġ72
302Ġ72
303Ġ72
304Ġ72
305Ġ'90'
306web91web
307Ġ=91=
308Ġapp91app
309.91.
310web91web
311_91_
312app91app
313:91:
314main91main
315',92',
316Ċ72\\n
317Ġ72
318Ġ72
319Ġ72
320Ġ72
321Ġ72
322Ġ72
323Ġ72
324Ġ72
325Ġ72
326Ġ72
327Ġ72
328Ġ'90'
329production91production
330_91_
331web91web
332Ġ=91=
333Ġapp91app
334.91.
335quick91quick
336ly91ly
337_91_
338cmd91cmd
339:91:
340run91run
341_91_
342web91web
343_91_
344app91app
345_91_
346by91by
347_91_
348gun91gun
349ic91ic
350orn91orn
351',92',
352Ċ72\\n
353Ġ72
354Ġ72
355Ġ72
356Ġ72
357Ġ72
358Ġ72
359Ġ72
360Ġ72
361Ġ72
362Ġ72
363Ġ72
364Ġ'90'
365c91c
366rawl91rawl
367_91_
368lag91lag
369ou91ou
370_91_
371data91data
372Ġ=91=
373Ġapp91app
374.91.
375t91t
376asks91asks
377:91:
378c91c
379rawl91rawl
380_91_
381lag91lag
382ou91ou
383_91_
384data91data
385',92',
386Ċ72\\n
387Ġ72
388Ġ72
389Ġ72
390Ġ72
391Ġ72
392Ġ72
393Ġ72
394Ġ72
395Ġ72
396Ġ72
397Ġ72
398Ġ'90'
399c91c
400rawl91rawl
401_91_
402jobs91jobs
403_91_
404count91count
405Ġ=91=
406Ġapp91app
407.91.
408t91t
409asks91asks
410.91.
411jobs91jobs
412_91_
413count91count
414:91:
415c91c
416rawl91rawl
417_91_
418lag91lag
419ou91ou
420_91_
421jobs91jobs
422_91_
423count91count
424',92',
425Ċ72\\n
426Ġ72
427Ġ72
428Ġ72
429Ġ72
430Ġ72
431Ġ72
432Ġ72
433Ġ72
434Ġ72
435Ġ72
436Ġ72
437Ġ'90'
438celer91celer
439y91y
440_91_
441jobs91jobs
442_91_
443count91count
444_91_
445worker91worker
446Ġ=91=
447Ġapp91app
448.91.
449quick91quick
450ly91ly
451_91_
452cmd91cmd
453:91:
454run91run
455_91_
456celer91celer
457y91y
458_91_
459jobs91jobs
460_91_
461count91count
462_91_
463worker91worker
464',92',
465Ċ72\\n
466Ġ72
467Ġ72
468Ġ72
469Ġ72
470Ġ72
471Ġ72
472Ġ72
473Ġ72
474Ġ72
475Ġ72
476Ġ72
477Ġ'90'
478celer91celer
479y91y
480_91_
481lag91lag
482ou91ou
483_91_
484data91data
485_91_
486worker91worker
487Ġ=91=
488Ġapp91app
489.91.
490quick91quick
491ly91ly
492_91_
493cmd91cmd
494:91:
495run91run
496_91_
497celer91celer
498y91y
499_91_
500lag91lag
501ou91ou
502_91_
503data91data
504_91_
505worker91worker
506',92',
507Ċ72\\n
508Ġ72
509Ġ72
510Ġ72
511</s>-1-1None
\n", + "
" + ], + "text/plain": [ + " tokens depths sibling_idxs node_texts\n", + "0 -1 -1 None\n", + "1 # 1 0 #\n", + "2 !/ 1 0 !/\n", + "3 usr 1 0 usr\n", + "4 / 1 0 /\n", + "5 bin 1 0 bin\n", + "6 / 1 0 /\n", + "7 env 1 0 env\n", + "8 Ġpython 1 0 python\n", + "9 Ċ 0 0 \\n\n", + "10 # 1 1 #\n", + "11 Ġ- 1 1 -\n", + "12 * 1 1 *\n", + "13 - 1 1 -\n", + "14 Ġcoding 1 1 coding\n", + "15 : 1 1 :\n", + "16 Ġut 1 1 ut\n", + "17 f 1 1 f\n", + "18 - 1 1 -\n", + "19 8 1 1 8\n", + "20 Ġ- 1 1 -\n", + "21 * 1 1 *\n", + "22 - 1 1 -\n", + "23 Ċ 0 0 \\n\n", + "24 import 2 0 import\n", + "25 Ġos 3 0 os\n", + "26 Ċ 0 0 \\n\n", + "27 from 2 0 from\n", + "28 Ġset 3 0 set\n", + "29 upt 3 0 upt\n", + "30 ools 3 0 ools\n", + "31 Ġimport 2 2 import\n", + "32 Ġfind 3 0 find\n", + "33 _ 3 0 _\n", + "34 packages 3 0 packages\n", + "35 , 2 4 ,\n", + "36 Ġsetup 3 0 setup\n", + "37 Ċ 0 0 \\n\n", + "38 Ċ 0 0 \\n\n", + "39 from 2 0 from\n", + "40 Ġapp 3 0 app\n", + "41 Ġimport 2 2 import\n", + "42 Ġ__ 3 0 __\n", + "43 version 3 0 version\n", + "44 __ 3 0 __\n", + "45 Ċ 0 0 \\n\n", + "46 Ċ 0 0 \\n\n", + "47 # 1 5 #\n", + "48 Ġget 1 5 get\n", + "49 Ġthe 1 5 the\n", + "50 Ġdependencies 1 5 dependencies\n", + "51 Ġand 1 5 and\n", + "52 Ġinstalls 1 5 installs\n", + "53 Ċ 0 0 \\n\n", + "54 here 3 0 here\n", + "55 Ġ= 3 1 =\n", + "56 Ġos 6 0 os\n", + "57 . 6 1 .\n", + "58 path 6 2 path\n", + "59 . 5 1 .\n", + "60 ab 5 2 ab\n", + "61 sp 5 2 sp\n", + "62 ath 5 2 ath\n", + "63 ( 5 0 (\n", + "64 os 8 0 os\n", + "65 . 8 1 .\n", + "66 path 8 2 path\n", + "67 . 7 1 .\n", + "68 dir 7 2 dir\n", + "69 name 7 2 name\n", + "70 ( 7 0 (\n", + "71 __ 7 1 __\n", + "72 file 7 1 file\n", + "73 __ 7 1 __\n", + "74 )) 7 2 ))\n", + "75 Ċ 0 0 \\n\n", + "76 with 2 0 with\n", + "77 Ġopen 6 0 open\n", + "78 ( 7 0 (\n", + "79 os 10 0 os\n", + "80 . 10 1 .\n", + "81 path 10 2 path\n", + "82 . 9 1 .\n", + "83 join 9 2 join\n", + "84 ( 9 0 (\n", + "85 here 9 1 here\n", + "86 , 9 2 ,\n", + "87 Ġ' 10 0 '\n", + "88 requ 10 1 requ\n", + "89 irements 10 1 irements\n", + "90 . 10 1 .\n", + "91 txt 10 1 txt\n", + "92 ' 10 2 '\n", + "93 )) 9 4 ))\n", + "94 Ġas 5 1 as\n", + "95 Ġf 6 0 f\n", + "96 : 2 2 :\n", + "97 Ċ 1 7 \\n\n", + "98 Ġ 1 7 \n", + "99 Ġ 1 7 \n", + "100 Ġ 1 7 \n", + "101 Ġall 5 0 all\n", + "102 _ 5 0 _\n", + "103 requ 5 0 requ\n", + "104 irements 5 0 irements\n", + "105 Ġ= 5 1 =\n", + "106 Ġf 9 0 f\n", + "107 . 9 1 .\n", + "108 read 9 2 read\n", + "109 (). 9 0 ().\n", + "110 split 7 2 split\n", + "111 (' 7 0 ('\n", + "112 \\ 9 0 \\\n", + "113 n 9 0 n\n", + "114 ') 8 2 ')\n", + "115 Ċ 0 0 \\n\n", + "116 Ċ 0 0 \\n\n", + "117 setup 3 0 setup\n", + "118 ( 4 0 (\n", + "119 Ċ 3 1 \\n\n", + "120 Ġ 3 1 \n", + "121 Ġ 3 1 \n", + "122 Ġ 3 1 \n", + "123 Ġname 5 0 name\n", + "124 =' 5 1 ='\n", + "125 we 6 1 we\n", + "126 bsp 6 1 bsp\n", + "127 ider 6 1 ider\n", + "128 ', 6 2 ',\n", + "129 Ċ 3 1 \\n\n", + "130 Ġ 3 1 \n", + "131 Ġ 3 1 \n", + "132 Ġ 3 1 \n", + "133 Ġversion 5 0 version\n", + "134 = 5 1 =\n", + "135 __ 5 2 __\n", + "136 version 5 2 version\n", + "137 __ 5 2 __\n", + "138 , 4 4 ,\n", + "139 Ċ 3 1 \\n\n", + "140 Ġ 3 1 \n", + "141 Ġ 3 1 \n", + "142 Ġ 3 1 \n", + "143 Ġlicense 5 0 license\n", + "144 =' 5 1 ='\n", + "145 MIT 6 1 MIT\n", + "146 ', 6 2 ',\n", + "147 Ċ 3 1 \\n\n", + "148 Ġ 3 1 \n", + "149 Ġ 3 1 \n", + "150 Ġ 3 1 \n", + "151 Ġauthor 5 0 author\n", + "152 =' 5 1 ='\n", + "153 he 6 1 he\n", + "154 gu 6 1 gu\n", + "155 oz 6 1 oz\n", + "156 hu 6 1 hu\n", + "157 ', 6 2 ',\n", + "158 Ċ 3 1 \\n\n", + "159 Ġ 3 1 \n", + "160 Ġ 3 1 \n", + "161 Ġ 3 1 \n", + "162 Ġauthor 5 0 author\n", + "163 _ 5 0 _\n", + "164 email 5 0 email\n", + "165 =' 5 1 ='\n", + "166 he 6 1 he\n", + "167 gu 6 1 gu\n", + "168 oz 6 1 oz\n", + "169 hu 6 1 hu\n", + "170 @ 6 1 @\n", + "171 zh 6 1 zh\n", + "172 ihu 6 1 ihu\n", + "173 . 6 1 .\n", + "174 com 6 1 com\n", + "175 ', 6 2 ',\n", + "176 Ċ 3 1 \\n\n", + "177 Ġ 3 1 \n", + "178 Ġ 3 1 \n", + "179 Ġ 3 1 \n", + "180 Ġdescription 5 0 description\n", + "181 =' 5 1 ='\n", + "182 lag 6 1 lag\n", + "183 ou 6 1 ou\n", + "184 . 6 1 .\n", + "185 com 6 1 com\n", + "186 Ġspider 6 1 spider\n", + "187 ', 6 2 ',\n", + "188 Ċ 3 1 \\n\n", + "189 Ġ 3 1 \n", + "190 Ġ 3 1 \n", + "191 Ġ 3 1 \n", + "192 Ġurl 5 0 url\n", + "193 =' 5 1 ='\n", + "194 git 6 1 git\n", + "195 @ 6 1 @\n", + "196 github 6 1 github\n", + "197 . 6 1 .\n", + "198 com 6 1 com\n", + "199 : 6 1 :\n", + "200 Gu 6 1 Gu\n", + "201 oz 6 1 oz\n", + "202 hu 6 1 hu\n", + "203 He 6 1 He\n", + "204 / 6 1 /\n", + "205 we 6 1 we\n", + "206 bsp 6 1 bsp\n", + "207 ider 6 1 ider\n", + "208 . 6 1 .\n", + "209 git 6 1 git\n", + "210 ', 6 2 ',\n", + "211 Ċ 3 1 \\n\n", + "212 Ġ 3 1 \n", + "213 Ġ 3 1 \n", + "214 Ġ 3 1 \n", + "215 Ġpackages 5 0 packages\n", + "216 = 5 1 =\n", + "217 find 6 0 find\n", + "218 _ 6 0 _\n", + "219 packages 6 0 packages\n", + "220 ( 7 0 (\n", + "221 ex 8 0 ex\n", + "222 clude 8 0 clude\n", + "223 = 8 1 =\n", + "224 [' 9 0 ['\n", + "225 tests 10 1 tests\n", + "226 '] 10 2 ']\n", + "227 ), 7 2 ),\n", + "228 Ċ 3 1 \\n\n", + "229 Ġ 3 1 \n", + "230 Ġ 3 1 \n", + "231 Ġ 3 1 \n", + "232 Ġpackage 5 0 package\n", + "233 _ 5 0 _\n", + "234 data 5 0 data\n", + "235 ={ 5 1 ={\n", + "236 ' 8 0 '\n", + "237 we 8 1 we\n", + "238 bsp 8 1 bsp\n", + "239 ider 8 1 ider\n", + "240 ': 8 2 ':\n", + "241 Ġ[' 8 0 ['\n", + "242 READ 9 1 READ\n", + "243 ME 9 1 ME\n", + "244 . 9 1 .\n", + "245 md 9 1 md\n", + "246 '] 9 2 ']\n", + "247 }, 6 2 },\n", + "248 Ċ 3 1 \\n\n", + "249 Ġ 3 1 \n", + "250 Ġ 3 1 \n", + "251 Ġ 3 1 \n", + "252 Ġzip 5 0 zip\n", + "253 _ 5 0 _\n", + "254 safe 5 0 safe\n", + "255 = 5 1 =\n", + "256 False 5 2 False\n", + "257 , 4 20 ,\n", + "258 Ċ 3 1 \\n\n", + "259 Ġ 3 1 \n", + "260 Ġ 3 1 \n", + "261 Ġ 3 1 \n", + "262 Ġinstall 5 0 install\n", + "263 _ 5 0 _\n", + "264 requires 5 0 requires\n", + "265 = 5 1 =\n", + "266 all 5 2 all\n", + "267 _ 5 2 _\n", + "268 requ 5 2 requ\n", + "269 irements 5 2 irements\n", + "270 , 4 22 ,\n", + "271 Ċ 3 1 \\n\n", + "272 Ġ 3 1 \n", + "273 Ġ 3 1 \n", + "274 Ġ 3 1 \n", + "275 Ġentry 5 0 entry\n", + "276 _ 5 0 _\n", + "277 points 5 0 points\n", + "278 ={ 5 1 ={\n", + "279 Ċ 5 2 \\n\n", + "280 Ġ 5 2 \n", + "281 Ġ 5 2 \n", + "282 Ġ 5 2 \n", + "283 Ġ 5 2 \n", + "284 Ġ 5 2 \n", + "285 Ġ 5 2 \n", + "286 Ġ 5 2 \n", + "287 Ġ' 8 0 '\n", + "288 console 8 1 console\n", + "289 _ 8 1 _\n", + "290 scripts 8 1 scripts\n", + "291 ': 8 2 ':\n", + "292 Ġ[ 8 0 [\n", + "293 Ċ 7 2 \\n\n", + "294 Ġ 7 2 \n", + "295 Ġ 7 2 \n", + "296 Ġ 7 2 \n", + "297 Ġ 7 2 \n", + "298 Ġ 7 2 \n", + "299 Ġ 7 2 \n", + "300 Ġ 7 2 \n", + "301 Ġ 7 2 \n", + "302 Ġ 7 2 \n", + "303 Ġ 7 2 \n", + "304 Ġ 7 2 \n", + "305 Ġ' 9 0 '\n", + "306 web 9 1 web\n", + "307 Ġ= 9 1 =\n", + "308 Ġapp 9 1 app\n", + "309 . 9 1 .\n", + "310 web 9 1 web\n", + "311 _ 9 1 _\n", + "312 app 9 1 app\n", + "313 : 9 1 :\n", + "314 main 9 1 main\n", + "315 ', 9 2 ',\n", + "316 Ċ 7 2 \\n\n", + "317 Ġ 7 2 \n", + "318 Ġ 7 2 \n", + "319 Ġ 7 2 \n", + "320 Ġ 7 2 \n", + "321 Ġ 7 2 \n", + "322 Ġ 7 2 \n", + "323 Ġ 7 2 \n", + "324 Ġ 7 2 \n", + "325 Ġ 7 2 \n", + "326 Ġ 7 2 \n", + "327 Ġ 7 2 \n", + "328 Ġ' 9 0 '\n", + "329 production 9 1 production\n", + "330 _ 9 1 _\n", + "331 web 9 1 web\n", + "332 Ġ= 9 1 =\n", + "333 Ġapp 9 1 app\n", + "334 . 9 1 .\n", + "335 quick 9 1 quick\n", + "336 ly 9 1 ly\n", + "337 _ 9 1 _\n", + "338 cmd 9 1 cmd\n", + "339 : 9 1 :\n", + "340 run 9 1 run\n", + "341 _ 9 1 _\n", + "342 web 9 1 web\n", + "343 _ 9 1 _\n", + "344 app 9 1 app\n", + "345 _ 9 1 _\n", + "346 by 9 1 by\n", + "347 _ 9 1 _\n", + "348 gun 9 1 gun\n", + "349 ic 9 1 ic\n", + "350 orn 9 1 orn\n", + "351 ', 9 2 ',\n", + "352 Ċ 7 2 \\n\n", + "353 Ġ 7 2 \n", + "354 Ġ 7 2 \n", + "355 Ġ 7 2 \n", + "356 Ġ 7 2 \n", + "357 Ġ 7 2 \n", + "358 Ġ 7 2 \n", + "359 Ġ 7 2 \n", + "360 Ġ 7 2 \n", + "361 Ġ 7 2 \n", + "362 Ġ 7 2 \n", + "363 Ġ 7 2 \n", + "364 Ġ' 9 0 '\n", + "365 c 9 1 c\n", + "366 rawl 9 1 rawl\n", + "367 _ 9 1 _\n", + "368 lag 9 1 lag\n", + "369 ou 9 1 ou\n", + "370 _ 9 1 _\n", + "371 data 9 1 data\n", + "372 Ġ= 9 1 =\n", + "373 Ġapp 9 1 app\n", + "374 . 9 1 .\n", + "375 t 9 1 t\n", + "376 asks 9 1 asks\n", + "377 : 9 1 :\n", + "378 c 9 1 c\n", + "379 rawl 9 1 rawl\n", + "380 _ 9 1 _\n", + "381 lag 9 1 lag\n", + "382 ou 9 1 ou\n", + "383 _ 9 1 _\n", + "384 data 9 1 data\n", + "385 ', 9 2 ',\n", + "386 Ċ 7 2 \\n\n", + "387 Ġ 7 2 \n", + "388 Ġ 7 2 \n", + "389 Ġ 7 2 \n", + "390 Ġ 7 2 \n", + "391 Ġ 7 2 \n", + "392 Ġ 7 2 \n", + "393 Ġ 7 2 \n", + "394 Ġ 7 2 \n", + "395 Ġ 7 2 \n", + "396 Ġ 7 2 \n", + "397 Ġ 7 2 \n", + "398 Ġ' 9 0 '\n", + "399 c 9 1 c\n", + "400 rawl 9 1 rawl\n", + "401 _ 9 1 _\n", + "402 jobs 9 1 jobs\n", + "403 _ 9 1 _\n", + "404 count 9 1 count\n", + "405 Ġ= 9 1 =\n", + "406 Ġapp 9 1 app\n", + "407 . 9 1 .\n", + "408 t 9 1 t\n", + "409 asks 9 1 asks\n", + "410 . 9 1 .\n", + "411 jobs 9 1 jobs\n", + "412 _ 9 1 _\n", + "413 count 9 1 count\n", + "414 : 9 1 :\n", + "415 c 9 1 c\n", + "416 rawl 9 1 rawl\n", + "417 _ 9 1 _\n", + "418 lag 9 1 lag\n", + "419 ou 9 1 ou\n", + "420 _ 9 1 _\n", + "421 jobs 9 1 jobs\n", + "422 _ 9 1 _\n", + "423 count 9 1 count\n", + "424 ', 9 2 ',\n", + "425 Ċ 7 2 \\n\n", + "426 Ġ 7 2 \n", + "427 Ġ 7 2 \n", + "428 Ġ 7 2 \n", + "429 Ġ 7 2 \n", + "430 Ġ 7 2 \n", + "431 Ġ 7 2 \n", + "432 Ġ 7 2 \n", + "433 Ġ 7 2 \n", + "434 Ġ 7 2 \n", + "435 Ġ 7 2 \n", + "436 Ġ 7 2 \n", + "437 Ġ' 9 0 '\n", + "438 celer 9 1 celer\n", + "439 y 9 1 y\n", + "440 _ 9 1 _\n", + "441 jobs 9 1 jobs\n", + "442 _ 9 1 _\n", + "443 count 9 1 count\n", + "444 _ 9 1 _\n", + "445 worker 9 1 worker\n", + "446 Ġ= 9 1 =\n", + "447 Ġapp 9 1 app\n", + "448 . 9 1 .\n", + "449 quick 9 1 quick\n", + "450 ly 9 1 ly\n", + "451 _ 9 1 _\n", + "452 cmd 9 1 cmd\n", + "453 : 9 1 :\n", + "454 run 9 1 run\n", + "455 _ 9 1 _\n", + "456 celer 9 1 celer\n", + "457 y 9 1 y\n", + "458 _ 9 1 _\n", + "459 jobs 9 1 jobs\n", + "460 _ 9 1 _\n", + "461 count 9 1 count\n", + "462 _ 9 1 _\n", + "463 worker 9 1 worker\n", + "464 ', 9 2 ',\n", + "465 Ċ 7 2 \\n\n", + "466 Ġ 7 2 \n", + "467 Ġ 7 2 \n", + "468 Ġ 7 2 \n", + "469 Ġ 7 2 \n", + "470 Ġ 7 2 \n", + "471 Ġ 7 2 \n", + "472 Ġ 7 2 \n", + "473 Ġ 7 2 \n", + "474 Ġ 7 2 \n", + "475 Ġ 7 2 \n", + "476 Ġ 7 2 \n", + "477 Ġ' 9 0 '\n", + "478 celer 9 1 celer\n", + "479 y 9 1 y\n", + "480 _ 9 1 _\n", + "481 lag 9 1 lag\n", + "482 ou 9 1 ou\n", + "483 _ 9 1 _\n", + "484 data 9 1 data\n", + "485 _ 9 1 _\n", + "486 worker 9 1 worker\n", + "487 Ġ= 9 1 =\n", + "488 Ġapp 9 1 app\n", + "489 . 9 1 .\n", + "490 quick 9 1 quick\n", + "491 ly 9 1 ly\n", + "492 _ 9 1 _\n", + "493 cmd 9 1 cmd\n", + "494 : 9 1 :\n", + "495 run 9 1 run\n", + "496 _ 9 1 _\n", + "497 celer 9 1 celer\n", + "498 y 9 1 y\n", + "499 _ 9 1 _\n", + "500 lag 9 1 lag\n", + "501 ou 9 1 ou\n", + "502 _ 9 1 _\n", + "503 data 9 1 data\n", + "504 _ 9 1 _\n", + "505 worker 9 1 worker\n", + "506 ', 9 2 ',\n", + "507 Ċ 7 2 \\n\n", + "508 Ġ 7 2 \n", + "509 Ġ 7 2 \n", + "510 Ġ 7 2 \n", + "511 -1 -1 None" + ] + }, + "execution_count": 192, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pd.DataFrame({\n", + " 'tokens': random_sample['tokens'],\n", + " 'depths': random_sample['depths'],\n", + " 'sibling_idxs': random_sample['sibling_idxs'],\n", + " 'node_texts': random_sample['node_texts'],\n", + "})" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "base", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.10" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/code/src/download_dataset.py b/code/src/download_dataset.py new file mode 100644 index 0000000..efbc6e9 --- /dev/null +++ b/code/src/download_dataset.py @@ -0,0 +1,22 @@ +import logging +from pathlib import Path +from typing import List +from huggingface_hub import list_repo_files, hf_hub_download + +# Setup logging +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') +logger = logging.getLogger(__name__) + +def download_dataset(dataset_dir: Path) -> None: + if not dataset_dir.exists(): + logger.info("Downloading the dataset...") + dataset_dir.mkdir(parents=True, exist_ok=True) + files_list: List[str] = list_repo_files(repo_id='bigcode/the-stack-dedup', repo_type='dataset') + files_to_download: List[str] = [file for file in files_list if file.startswith('data/python/')] + for file_name in files_to_download: + hf_hub_download(repo_id='bigcode/the-stack-dedup', repo_type='dataset', filename=file_name, local_dir=dataset_dir) + logger.info("Dataset downloaded successfully.") + +if __name__ == '__main__': + dataset_dir = Path('data/the-stack-python') + download_dataset(dataset_dir) diff --git a/code/src/parse_dataset.py b/code/src/parse_dataset.py index f362e28..8883afe 100644 --- a/code/src/parse_dataset.py +++ b/code/src/parse_dataset.py @@ -1,128 +1,123 @@ -import ast -from pathlib import Path import logging +import multiprocessing +import tree_sitter_python as tspython from tqdm import tqdm +from pathlib import Path from pprint import pprint from typing import Dict, Any, List -from datasets import load_dataset, Dataset, concatenate_datasets -from transformers import RobertaTokenizer -from dataclasses import dataclass -import numpy as np -import multiprocessing +from datasets import load_dataset, Dataset +from tree_sitter import Language, Parser +from transformers import AutoTokenizer import warnings warnings.filterwarnings("ignore", category=SyntaxWarning) -# Setup logging -logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') +# Configure logging +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(levelname)s - %(message)s', + datefmt='%Y-%m-%d %H:%M:%S' +) logger = logging.getLogger(__name__) -@dataclass -class ASTNodeInfo: - """Stores structural information about an AST node.""" - node_type: str - start_token_idx: int - end_token_idx: int - depth: int - sibling_pos: int - parent_idx: int - - def to_dict(self) -> Dict[str, Any]: - return { - 'node_type': self.node_type, - 'start_token_idx': self.start_token_idx, - 'end_token_idx': self.end_token_idx, - 'depth': self.depth, - 'sibling_pos': self.sibling_pos, - 'parent_idx': self.parent_idx - } +def analyze_code_with_codebert_and_treesitter(code_snippet: str, tokenizer: AutoTokenizer): + """ + Map tokens to their original text and Tree-sitter AST position + """ + # Initialize Tree-sitter + PY_LANGUAGE = Language(tspython.language()) + parser = Parser(PY_LANGUAGE) - def from_dict(data: Dict[str, Any]) -> 'ASTNodeInfo': - return ASTNodeInfo( - node_type=data['node_type'], - start_token_idx=data['start_token_idx'], - end_token_idx=data['end_token_idx'], - depth=data['depth'], - sibling_pos=data['sibling_pos'], - parent_idx=data['parent_idx'] - ) - -def safe_ast_parse(content: str, max_length: int = 50000) -> bool: - """Safely attempt to parse Python code.""" - if not content or not content.strip() or '\0' in content or len(content) > max_length: - return False - - try: - tree = ast.parse(content) - return bool(list(ast.walk(tree))) - except: - return False - -def process_and_extract_ast_info(content: str, max_ast_size: int = 1000) -> List[Dict[str, Any]]: - """Process AST and extract node information.""" - try: - tree = ast.parse(content) - nodes_info = [] - - def visit_node(node: ast.AST, depth: int, parent_idx: int, sibling_pos: int): - if len(nodes_info) >= max_ast_size: - return + # Parse with Tree-sitter + tree = parser.parse(bytes(code_snippet, "utf8")) + + encoded = tokenizer( + code_snippet, + add_special_tokens=True, + return_offsets_mapping=True, + return_tensors='pt', + padding='max_length', + truncation=True, + max_length=512, + ) + + tokens = tokenizer.convert_ids_to_tokens(encoded['input_ids'][0]) + offset_mapping = encoded['offset_mapping'][0].tolist() + + def get_node_position(node, depth=0, idx=0): + """Get depth and sibling index for a node""" + return (depth, idx) + + def find_node_at_position(node, start_byte, depth=0, sibling_idx=0): + """Find the most specific node containing the given position""" + if start_byte >= node.start_byte and start_byte < node.end_byte: + # Check children first for more specific nodes + for idx, child in enumerate(node.children): + result = find_node_at_position(child, start_byte, depth + 1, idx) + if result: + return result + # Return current node's position if no child contains the position + return get_node_position(node, depth, sibling_idx) + return None + + depths = [] + sibling_idxs = [] + node_texts = [] + for token, (start, end) in zip(tokens, offset_mapping): + if token not in ['', '', '', '']: + text = code_snippet[start:end] if start < len(code_snippet) else "" + # Tree-sitter works with bytes, so convert position + start_byte = len(code_snippet[:start].encode('utf8')) if start < len(code_snippet) else 0 + position = find_node_at_position(tree.root_node, start_byte) or (-1, -1) - current_idx = len(nodes_info) - if hasattr(node, 'lineno'): - node_info = ASTNodeInfo( - node_type=type(node).__name__, - start_token_idx=node.lineno, - end_token_idx=getattr(node, 'end_lineno', node.lineno), - depth=min(depth, 31), - sibling_pos=min(sibling_pos, 31), - parent_idx=parent_idx - ) - nodes_info.append(node_info.to_dict()) - - if len(nodes_info) < max_ast_size: - for i, child in enumerate(ast.iter_child_nodes(node)): - visit_node(child, depth + 1, current_idx, i) - - visit_node(tree, 0, -1, 0) - return nodes_info - except: - return [] + depths.append(position[0]) + sibling_idxs.append(position[1]) + node_texts.append(text) + else: + depths.append(-1) + sibling_idxs.append(-1) + node_texts.append(None) + + return encoded['input_ids'], encoded['attention_mask'], tokens, depths, sibling_idxs, node_texts -def process_batch(examples: Dict[str, List[Any]], tokenizer: RobertaTokenizer) -> Dict[str, List[Any]]: +def process_batch(examples, tokenizer) -> Dict[str, List[Any]]: """Process a batch of examples.""" contents = examples['content'] - - processed_contents = [] - processed_ast_nodes = [] + processed_input_ids = [] - processed_attention_masks = [] - + processed_attention_mask = [] + processed_tokens = [] + processed_depths = [] + processed_sibling_idxs = [] + processed_node_texts = [] + for content in contents: - if safe_ast_parse(content): - ast_nodes = process_and_extract_ast_info(content) - if ast_nodes: - try: - encoding = tokenizer( - content, - max_length=512, - truncation=True, - padding='max_length', - return_tensors='pt' - ) - - processed_contents.append(content) - processed_ast_nodes.append(ast_nodes) - processed_input_ids.append(encoding['input_ids'].squeeze(0).tolist()) - processed_attention_masks.append(encoding['attention_mask'].squeeze(0).tolist()) - except: - continue - + try: + input_ids, attention_mask, tokens, depths, sibling_idxs, node_texts = analyze_code_with_codebert_and_treesitter(content, tokenizer) + processed_input_ids.append(input_ids[0]) + processed_attention_mask.append(attention_mask[0]) + processed_tokens.append(tokens) + processed_depths.append(depths) + processed_sibling_idxs.append(sibling_idxs) + processed_node_texts.append(node_texts) + + except Exception as e: + logger.error(f"Error processing example: {e}") + # Return empty lists that will be filtered out later + processed_input_ids.append([]) + processed_attention_mask.append([]) + processed_tokens.append([]) + processed_depths.append([]) + processed_sibling_idxs.append([]) + processed_node_texts.append([]) + return { - 'content': processed_contents, - 'ast_nodes': processed_ast_nodes, 'input_ids': processed_input_ids, - 'attention_mask': processed_attention_masks, + 'attention_mask': processed_attention_mask, + 'tokens': processed_tokens, + 'depths': processed_depths, + 'sibling_idxs': processed_sibling_idxs, + 'node_texts': processed_node_texts, } def save_dataset_in_chunks(dataset: Dataset, output_path: str, chunk_size: int = 10000): @@ -131,14 +126,9 @@ def save_dataset_in_chunks(dataset: Dataset, output_path: str, chunk_size: int = # Create directory for chunks output_dir = Path(output_path).parent - chunks_dir = output_dir / 'chunks' + chunks_dir = output_dir / 'data' chunks_dir.mkdir(exist_ok=True) - - stats_accumulator = { - 'total_ast_nodes': 0, - 'total_samples': 0 - } - + pbar = tqdm(range(num_chunks), desc="Saving chunks", unit="chunk") for i in pbar: start_idx = i * chunk_size @@ -147,28 +137,10 @@ def save_dataset_in_chunks(dataset: Dataset, output_path: str, chunk_size: int = # Select chunk from dataset chunk_dataset = dataset.select(range(start_idx, end_idx)) - # Update statistics - stats_accumulator['total_ast_nodes'] += sum(len(nodes) for nodes in chunk_dataset['ast_nodes']) - stats_accumulator['total_samples'] += len(chunk_dataset) - - # Save chunk using datasets native method - chunk_path = chunks_dir / f'chunk_{i:04d}' - chunk_dataset.save_to_disk(str(chunk_path)) - - # Update progress bar postfix with current chunk info - pbar.set_postfix({ - 'samples': len(chunk_dataset), - 'path': str(chunk_path.name) - }) - - return stats_accumulator - -def load_all_chunks(chunks_dir: Path) -> Dataset: - """Load and concatenate all dataset chunks.""" - chunks = [] - for chunk_path in sorted(chunks_dir.glob('chunk_*')): - chunks.append(load_from_disk(str(chunk_path))) - return concatenate_datasets(chunks) + if len(chunk_dataset) > 0: # Only process if chunk has data + # Save chunk using datasets native method + chunk_path = chunks_dir / f'data-{i:05d}-of-{num_chunks:05d}.parquet' + chunk_dataset.to_parquet(str(chunk_path)) def main(): current_dir = Path(__file__).parent @@ -176,14 +148,23 @@ def main(): output_dir = current_dir.parent / 'data' / 'processed-python' output_dir.mkdir(parents=True, exist_ok=True) - tokenizer = RobertaTokenizer.from_pretrained('microsoft/codebert-base') + tokenizer = AutoTokenizer.from_pretrained('microsoft/codebert-base', use_fast=True) logger.info("Loading dataset...") dataset = load_dataset(str(input_dir))['train'] + # dataset = dataset.select(range(200_000)) # Limit dataset size for testing + original_dataset_size = len(dataset) + logger.info(f"Original dataset size: {original_dataset_size}") - logger.info("Dataset:") - pprint(dataset) + columns_to_remove = dataset.column_names + columns_to_remove.remove('content') + columns_to_remove.remove('size') + columns_to_remove.remove('avg_line_length') + columns_to_remove.remove('max_line_length') + columns_to_remove.remove('alphanum_fraction') + columns_to_remove.remove('max_stars_count') + logging.info(f"Columns to remove: {columns_to_remove}") num_proc = min(multiprocessing.cpu_count() - 1, 32) logger.info(f"Using {num_proc} processes for dataset processing") @@ -193,35 +174,38 @@ def main(): process_batch, fn_kwargs={'tokenizer': tokenizer}, batched=True, - remove_columns=dataset.column_names, + remove_columns=columns_to_remove, desc="Processing dataset", num_proc=num_proc, load_from_cache_file=False ) processed_dataset = processed_dataset.filter( - lambda batch: [len(nodes) > 0 for nodes in batch['ast_nodes']], + lambda batch: [len(tokens) > 0 for tokens in batch['tokens']], batched=True, num_proc=num_proc, desc="Filtering invalid examples" ) - reduced_dataset_size = len(processed_dataset) - - logger.info("Processed dataset:") - pprint(processed_dataset) - logger.info(f"Saving {len(processed_dataset)} processed examples in chunks...") - stats_accumulator = save_dataset_in_chunks( + reduced_dataset_size = len(processed_dataset) + logger.info(f"Processed dataset size: {reduced_dataset_size}") + + if reduced_dataset_size == 0: + logger.error("No valid examples found in dataset!") + return + + logger.info(f"Saving {reduced_dataset_size} processed examples in chunks...") + save_dataset_in_chunks( processed_dataset, str(output_dir / 'processed_dataset'), chunk_size=100_000 ) + # Add derived statistics stats = { 'original_dataset_size': original_dataset_size, 'reduced_dataset_size': reduced_dataset_size, - '% samples removed': (1 - reduced_dataset_size / original_dataset_size) * 100, - 'avg_ast_nodes': float(stats_accumulator['total_ast_nodes'] / stats_accumulator['total_samples']) + 'samples_removed_pct': (1 - reduced_dataset_size / original_dataset_size) * 100, } # Log stats with pprint diff --git a/code/src/reduce_dataset.py b/code/src/reduce_dataset.py new file mode 100644 index 0000000..e1c4351 --- /dev/null +++ b/code/src/reduce_dataset.py @@ -0,0 +1,50 @@ +import logging +import multiprocessing +from pathlib import Path +from datasets import load_dataset + +from parse_dataset import save_dataset_in_chunks + +# Configure logging +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(levelname)s - %(message)s', + datefmt='%Y-%m-%d %H:%M:%S' +) +logger = logging.getLogger(__name__) + +def main(): + current_dir = Path(__file__).parent + input_dir = current_dir.parent / 'data' / 'processed-python' + output_dir = current_dir.parent / 'data' / 'filtered-python' + output_dir.mkdir(parents=True, exist_ok=True) + + num_proc = min(multiprocessing.cpu_count() - 1, 32) + logger.info(f"Using {num_proc} processes for dataset processing") + + logger.info("Loading dataset...") + dataset = load_dataset(str(input_dir))['data'] + logging.info(f"Dataset:\n{dataset}") + + logger.info("Filtering dataset by max_stars_count > 3...") + filtered_dataset = dataset.filter( + lambda batch: [example and example > 3 for example in batch['max_stars_count']], + num_proc=num_proc, + batched=True, + desc="Filtering dataset" + ) + + filtered_dataset_size = len(filtered_dataset) + logger.info(f"Filtered dataset size: {filtered_dataset_size}") + + if filtered_dataset_size == 0: + logger.error("No examples found with max_stars_count > 3!") + return + + logger.info(f"Saving {filtered_dataset_size} filtered examples...") + save_dataset_in_chunks(filtered_dataset, output_dir, chunk_size=100_000) + + logger.info("Filtering and saving completed!") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/code/src/tmp_config.json b/code/src/tmp_config.json deleted file mode 100644 index 0e97216..0000000 --- a/code/src/tmp_config.json +++ /dev/null @@ -1,12 +0,0 @@ -{ - "seed": 42, - "mlm_probability": 0.15, - "batch": 16, - "epochs": 1, - "eval_every": 5000, - "learning_rate": 5e-4, - "weight_decay": 0.01, - "max_grad_norm": 1.0, - "warmup_steps": 5000, - "train_size": 0.95 -} \ No newline at end of file diff --git a/code/src/train_codebert_mlm.py b/code/src/train_codebert_mlm.py deleted file mode 100644 index 13ac329..0000000 --- a/code/src/train_codebert_mlm.py +++ /dev/null @@ -1,50 +0,0 @@ -from pathlib import Path -import torch -from transformers import RobertaForMaskedLM, RobertaConfig, RobertaTokenizer - -from training_utils import ( - set_seed, setup_wandb, setup_directories, load_config, - setup_device, create_optimizer_and_scheduler, - train_and_evaluate, create_base_dataloaders, set_deterministic_mode -) - -def main() -> None: - set_deterministic_mode() - - current_dir = Path(__file__).parent - output_dir = setup_directories(current_dir, model_name='base') - config = load_config(current_dir / 'tmp_config.json') - - setup_wandb(config, model_name='base') - set_seed(config['seed']) - device = setup_device() - - # Load tokenizer - tokenizer = RobertaTokenizer.from_pretrained('microsoft/codebert-base') - - # Create dataloaders - processed_data_dir = current_dir.parent / 'data' / 'processed-python' - train_dataloader, valid_dataloader = create_base_dataloaders( - processed_data_dir, - tokenizer, - config, - device - ) - - # Model setup - model_config = RobertaConfig.from_pretrained('microsoft/codebert-base') - model = RobertaForMaskedLM(model_config) - model = model.to(device) - - # Training setup and run - num_training_steps = config['epochs'] * len(train_dataloader) - optimizer, scheduler = create_optimizer_and_scheduler(model, config, num_training_steps) - - train_and_evaluate( - model, train_dataloader, valid_dataloader, - optimizer, scheduler, config, output_dir, - log_weights=False, device=device - ) - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/code/src/train_tree_codebert_mlm.py b/code/src/train_tree_codebert_mlm.py deleted file mode 100644 index d4989ee..0000000 --- a/code/src/train_tree_codebert_mlm.py +++ /dev/null @@ -1,227 +0,0 @@ -import math -import torch -import torch.nn as nn -from pathlib import Path -from typing import Dict, List, Optional -from transformers import RobertaConfig, RobertaTokenizer, RobertaForMaskedLM - -from training_utils import ( - set_seed, setup_wandb, setup_directories, load_config, - setup_device, create_optimizer_and_scheduler, - train_and_evaluate, create_base_dataloaders, PreprocessedTreeDataset, - set_deterministic_mode -) -from parse_dataset import ASTNodeInfo - -class TreePositionalEmbedding(nn.Module): - """Generates tree-aware positional embeddings for code tokens.""" - - def __init__(self, d_model: int = 768, max_depth: int = 32): - super().__init__() - self.d_model = d_model - self.max_depth = max_depth - - self.depth_embedding = nn.Embedding(max_depth, d_model) - self.sibling_embedding = nn.Embedding(max_depth, d_model) - self.combine = nn.Linear(d_model * 2, d_model) - - self._initialize_embeddings() - - def _initialize_embeddings(self): - position = torch.arange(self.max_depth).unsqueeze(1).float() - div_term = torch.exp(torch.arange(0, self.d_model, 2).float() * - (-math.log(10000.0) / self.d_model)) - - pe = torch.zeros(self.max_depth, self.d_model) - pe[:, 0::2] = torch.sin(position * div_term) - pe[:, 1::2] = torch.cos(position * div_term) - - with torch.no_grad(): - self.depth_embedding.weight.copy_(pe) - self.sibling_embedding.weight.copy_(pe) - - def forward(self, input_ids: torch.Tensor, ast_nodes_batch: List[List[ASTNodeInfo]]) -> torch.Tensor: - """Process batched input with corresponding AST nodes.""" - batch_size, seq_len = input_ids.shape - device = input_ids.device - embeddings = torch.zeros((batch_size, seq_len, self.d_model), device=device) - - # Process each item in the batch - for batch_idx in range(batch_size): - ast_nodes = ast_nodes_batch[batch_idx] - # Process each position in the sequence - for i in range(seq_len): - containing_nodes = [ - node for node in ast_nodes - if node.start_token_idx <= i < node.end_token_idx - ] - - if containing_nodes: - node = max(containing_nodes, key=lambda n: n.depth) - depth = min(node.depth, self.max_depth - 1) - sibling_pos = min(node.sibling_pos, self.max_depth - 1) - - depth_emb = self.depth_embedding(torch.tensor(depth, device=device)) - sibling_emb = self.sibling_embedding(torch.tensor(sibling_pos, device=device)) - embeddings[batch_idx, i] = self.combine(torch.cat([depth_emb, sibling_emb])) - - return embeddings - -class TreeCodeBERTForPreTraining(RobertaForMaskedLM): - """CodeBERT model enhanced with normalized embedding weights for stable training.""" - - def __init__(self, config: RobertaConfig, max_depth: int = 32, max_seq_length: int = 512): - super().__init__(config) - - self.tree_pos_embeddings = TreePositionalEmbedding( - d_model=config.hidden_size, - max_depth=max_depth - ) - - self.seq_pos_embeddings = nn.Embedding(max_seq_length, config.hidden_size) - - # Initialize sequential position embeddings with sinusoidal pattern - position = torch.arange(max_seq_length).unsqueeze(1) - div_term = torch.exp(torch.arange(0, config.hidden_size, 2) * (-math.log(10000.0) / config.hidden_size)) - pe = torch.zeros(max_seq_length, config.hidden_size) - pe[:, 0::2] = torch.sin(position * div_term) - pe[:, 1::2] = torch.cos(position * div_term) - self.seq_pos_embeddings.weight.data.copy_(pe) - - # Initialize weights with small random values around 0 - self.alpha = nn.Parameter(torch.randn(1) * 0.02) - self.beta = nn.Parameter(torch.randn(1) * 0.02) - self.gamma = nn.Parameter(torch.randn(1) * 0.02) - - self.embedding_combination_layer_norm = nn.LayerNorm(config.hidden_size) - self.final_layer_norm = nn.LayerNorm(config.hidden_size) - - # Add dropout for regularization - self.embedding_dropout = nn.Dropout(config.hidden_dropout_prob) - - def get_normalized_weights(self) -> torch.Tensor: - """ - Compute softmax-normalized weights for embedding combination. - Returns tensor of shape (3,) containing normalized [alpha, beta, gamma]. - """ - weights = torch.stack([self.alpha, self.beta, self.gamma]) - return torch.softmax(weights, dim=0) - - def forward( - self, - input_ids: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - labels: Optional[torch.Tensor] = None, - ast_nodes: Optional[List[List[ASTNodeInfo]]] = None, - output_attentions: bool = False, - **kwargs - ) -> Dict[str, torch.Tensor]: - # Move tensors to device - device = input_ids.device - - # Get embeddings - token_embeddings = self.roberta.embeddings.word_embeddings(input_ids) - - seq_positions = torch.arange(input_ids.size(1), device=device) - seq_embeddings = self.seq_pos_embeddings(seq_positions).unsqueeze(0).expand(input_ids.size(0), -1, -1) - - # Get normalized weights - norm_weights = self.get_normalized_weights() - - # Combine embeddings based on presence of AST nodes - if ast_nodes is not None: - tree_embeddings = self.tree_pos_embeddings(input_ids, ast_nodes) - combined_embeddings = ( - norm_weights[0] * token_embeddings + - norm_weights[1] * tree_embeddings + - norm_weights[2] * seq_embeddings - ) - else: - # Redistribute tree weight to other components when no AST available - token_seq_weights = torch.softmax(torch.stack([self.alpha, self.gamma]), dim=0) - combined_embeddings = ( - token_seq_weights[0] * token_embeddings + - token_seq_weights[1] * seq_embeddings - ) - - # Apply layer normalization and dropout - combined_embeddings = self.embedding_combination_layer_norm(combined_embeddings) - combined_embeddings = self.embedding_dropout(combined_embeddings) - combined_embeddings = self.final_layer_norm(combined_embeddings) - - # Forward pass through transformer - outputs = self.roberta( - inputs_embeds=combined_embeddings, - attention_mask=attention_mask, - output_attentions=output_attentions, - **kwargs - ) - - sequence_output = outputs[0] - prediction_scores = self.lm_head(sequence_output) - - # Calculate loss if labels provided - masked_lm_loss = None - if labels is not None: - loss_fct = nn.CrossEntropyLoss() - masked_lm_loss = loss_fct( - prediction_scores.view(-1, self.config.vocab_size), - labels.view(-1) - ) - - # Get normalized weights for logging - norm_weights_cpu = norm_weights.detach().cpu() - - return { - "loss": masked_lm_loss, - "logits": prediction_scores, - "hidden_states": outputs.hidden_states if hasattr(outputs, "hidden_states") else None, - "attentions": outputs.attentions, - "embedding_weights": { - "token": norm_weights_cpu[0].item(), - "tree": norm_weights_cpu[1].item(), - "sequential": norm_weights_cpu[2].item() - } - } - -def main() -> None: - set_deterministic_mode() - - current_dir = Path(__file__).parent - output_dir = setup_directories(current_dir, model_name='tree') - config = load_config(current_dir / 'tmp_config.json') - - setup_wandb(config, model_name='tree') - set_seed(config['seed']) - device = setup_device() - - # Load tokenizer - tokenizer = RobertaTokenizer.from_pretrained('microsoft/codebert-base') - - # Create dataloaders with tree dataset - processed_data_dir = current_dir.parent / 'data' / 'processed-python' - train_dataloader, valid_dataloader = create_base_dataloaders( - processed_data_dir, - tokenizer, - config, - device, - dataset_class=PreprocessedTreeDataset - ) - - # Model setup - model_config = RobertaConfig.from_pretrained('microsoft/codebert-base') - model = TreeCodeBERTForPreTraining(model_config) - model = model.to(device) - - # Training setup and run - num_training_steps = config['epochs'] * len(train_dataloader) - optimizer, scheduler = create_optimizer_and_scheduler(model, config, num_training_steps) - - train_and_evaluate( - model, train_dataloader, valid_dataloader, - optimizer, scheduler, config, output_dir, - log_weights=True, device=device - ) - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/code/src/training.py b/code/src/training.py new file mode 100644 index 0000000..168845d --- /dev/null +++ b/code/src/training.py @@ -0,0 +1,306 @@ +import math +import torch +import torch.nn as nn +from pathlib import Path +from typing import Dict, Optional +from transformers import RobertaConfig, AutoTokenizer, RobertaForMaskedLM + +from training_utils import * + +MODEL = 'base-custom' # 'base' or 'tree' or 'base-custom' +DATASET = 'filtered-python' # 'processed-python-small' or 'processed-python' + +class TreePositionalEmbedding(nn.Module): + """Improved tree-aware positional embeddings that work directly with depth and sibling tensors.""" + + def __init__(self, d_model: int = 768, max_depth: int = 32, dropout: float = 0.1): + super().__init__() + self.d_model = d_model + self.max_depth = max_depth + + # Separate embeddings for different features + self.depth_embedding = nn.Embedding(max_depth, d_model) + self.sibling_embedding = nn.Embedding(max_depth, d_model) + + # Improved projection layers + self.node_projection = nn.Sequential( + nn.Linear(d_model * 2, d_model * 2), + nn.GELU(), + nn.Linear(d_model * 2, d_model), + nn.Dropout(dropout) + ) + + # Layer norm for stability + self.layer_norm = nn.LayerNorm(d_model) + + self._initialize_embeddings() + + def _initialize_embeddings(self): + std = 0.02 + for embedding in [self.depth_embedding, self.sibling_embedding]: + nn.init.normal_(embedding.weight, mean=0.0, std=std) + + # Initialize projection layers + for layer in self.node_projection: + if isinstance(layer, nn.Linear): + nn.init.normal_(layer.weight, mean=0.0, std=std) + nn.init.zeros_(layer.bias) + + def forward(self, depths: torch.Tensor, sibling_idxs: torch.Tensor) -> torch.Tensor: + """ + Args: + depths: Tensor of shape [batch_size, seq_len] containing depth values + sibling_idxs: Tensor of shape [batch_size, seq_len] containing sibling positions + Returns: + Tensor of shape [batch_size, seq_len, d_model] containing tree-aware embeddings + """ + # Clamp values to max_depth + depths = torch.clamp(depths, 0, self.max_depth - 1) + sibling_idxs = torch.clamp(sibling_idxs, 0, self.max_depth - 1) + + # Get embeddings for each feature + depth_embeddings = self.depth_embedding(depths) # [batch, seq_len, d_model] + sibling_embeddings = self.sibling_embedding(sibling_idxs) # [batch, seq_len, d_model] + + # Combine features + combined = torch.cat([depth_embeddings, sibling_embeddings], dim=-1) + embeddings = self.node_projection(combined) + + # Apply layer norm + normalized_embeddings = self.layer_norm(embeddings) + + return normalized_embeddings + +class TreeCodeBERTForPreTraining(RobertaForMaskedLM): + """CodeBERT model enhanced with tree-structural information.""" + + def __init__(self, config: RobertaConfig, max_depth: int = 32, max_seq_length: int = 512, base_custom=False): + super().__init__(config) + self.base_custom = base_custom + + self.tree_pos_embeddings = TreePositionalEmbedding( + d_model=config.hidden_size, + max_depth=max_depth, + dropout=config.hidden_dropout_prob + ) + + self.seq_pos_embeddings = nn.Embedding(max_seq_length, config.hidden_size) + + # Initialize sequential position embeddings with sinusoidal pattern + position = torch.arange(max_seq_length).unsqueeze(1) + div_term = torch.exp(torch.arange(0, config.hidden_size, 2) * (-math.log(10000.0) / config.hidden_size)) + pe = torch.zeros(max_seq_length, config.hidden_size) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + self.seq_pos_embeddings.weight.data.copy_(pe) + + # Initialize embedding weights equally + if base_custom: + initial_weight = math.log(1/2) + self.embedding_weights = nn.Parameter(torch.full((2,), initial_weight)) + else: + initial_weight = math.log(1/3) # log(1/3) because we use softmax later + self.embedding_weights = nn.Parameter(torch.full((3,), initial_weight)) + + # Layer norms for embedding combination + self.pre_combination_norm = nn.LayerNorm(config.hidden_size) + self.post_combination_norm = nn.LayerNorm(config.hidden_size) + + def get_normalized_weights(self): + """Get softmaxed weights for embedding combination.""" + return torch.softmax(self.embedding_weights, dim=0) + + def forward( + self, + input_ids: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + depths: Optional[torch.Tensor] = None, + sibling_idxs: Optional[torch.Tensor] = None, + output_attentions: bool = False, + **kwargs + ) -> Dict[str, torch.Tensor]: + device = input_ids.device + + # Get normalized weights for embedding combination and calculate regularization + weights = self.get_normalized_weights() + + # Calculate weight variance regularization + # We want weights to remain somewhat balanced, so penalize high variance + weight_variance = torch.var(weights) + weight_reg_loss = 0.1 * weight_variance # Adjustable coefficient + + # Add L2 regularization to prevent any weight from getting too close to 1 + # This helps maintain a more balanced contribution from each embedding type + max_weight_penalty = torch.sum(torch.relu(weights - 0.8) ** 2) # Penalize weights > 0.8 + l2_reg_loss = 0.05 * max_weight_penalty # Adjustable coefficient + + # Get token embeddings + token_embeddings = self.roberta.embeddings.word_embeddings(input_ids) + token_embeddings = self.pre_combination_norm(token_embeddings) + + # Get sequential position embeddings + seq_positions = torch.arange(input_ids.size(1), device=device) + seq_embeddings = self.seq_pos_embeddings(seq_positions).unsqueeze(0).expand(input_ids.size(0), -1, -1) + + # Get tree positional embeddings if tree information is provided + if depths is not None and sibling_idxs is not None and not self.base_custom: + tree_embeddings = self.tree_pos_embeddings(depths, sibling_idxs) + else: + tree_embeddings = torch.zeros_like(token_embeddings) + + # Combine all embeddings using learned weights + if self.base_custom: + combined_embeddings = ( + weights[0] * token_embeddings + + weights[1] * seq_embeddings + ) + else: + combined_embeddings = ( + weights[0] * token_embeddings + + weights[1] * tree_embeddings + + weights[2] * seq_embeddings + ) + + combined_embeddings = self.post_combination_norm(combined_embeddings) + + # Forward pass through base model + outputs = self.roberta( + inputs_embeds=combined_embeddings, + attention_mask=attention_mask, + output_attentions=output_attentions, + **kwargs + ) + + sequence_output = outputs[0] + prediction_scores = self.lm_head(sequence_output) + + # Calculate MLM loss if labels are provided + masked_lm_loss = None + if labels is not None: + loss_fct = nn.CrossEntropyLoss() + masked_lm_loss = loss_fct( + prediction_scores.view(-1, self.config.vocab_size), + labels.view(-1) + ) + + # Add regularization losses to final loss + if masked_lm_loss is not None: + final_loss = masked_lm_loss + weight_reg_loss + l2_reg_loss + else: + final_loss = weight_reg_loss + l2_reg_loss + else: + final_loss = None + + # Prepare embedding weights for logging + weights_cpu = weights.detach().cpu() + if self.base_custom: + embedding_weights = { + "token": weights_cpu[0].item(), + "sequential": weights_cpu[1].item() + } + reg_metrics = { + "weight_variance": weight_variance.item(), + "max_weight_penalty": max_weight_penalty.item(), + "weight_reg_loss": weight_reg_loss.item(), + "l2_reg_loss": l2_reg_loss.item() + } + else: + embedding_weights = { + "token": weights_cpu[0].item(), + "tree": weights_cpu[1].item(), + "sequential": weights_cpu[2].item() + } + reg_metrics = { + "weight_variance": weight_variance.item(), + "max_weight_penalty": max_weight_penalty.item(), + "weight_reg_loss": weight_reg_loss.item(), + "l2_reg_loss": l2_reg_loss.item() + } + + return { + "loss": final_loss, + "logits": prediction_scores, + "hidden_states": outputs.hidden_states if hasattr(outputs, "hidden_states") else None, + "attentions": outputs.attentions if output_attentions else None, + "embedding_weights": embedding_weights, + "regularization_metrics": reg_metrics + } + +def main() -> None: + current_dir = Path(__file__).parent + output_dir = setup_directories(current_dir, model_name=MODEL) + config = load_config(current_dir / 'config.json') + + set_deterministic_mode(config['seed']) + + setup_wandb(config, model_name=MODEL, script_path=__file__) + set_seed(config['seed']) + device = setup_device() + + # Load tokenizer + logger.info("Loading tokenizer...") + tokenizer = AutoTokenizer.from_pretrained('microsoft/codebert-base', use_fast=True) + + # Create dataloaders with tree dataset + logger.info("Creating dataloaders...") + processed_data_dir = current_dir.parent / 'data' / DATASET + train_dataloader, valid_dataloader = create_base_dataloaders( + processed_data_dir, + tokenizer, + config, + base_training=True if MODEL == 'base' else False + ) + + # Initialize model config + logger.info("Initializing model...") + model_config = RobertaConfig.from_pretrained('microsoft/codebert-base') + + # Update W&B + wandb.config.update({'model_config': model_config.__dict__}) + + # Initialize model + if MODEL == 'tree': + model = TreeCodeBERTForPreTraining(model_config, max_depth=32, max_seq_length=512) + elif MODEL == 'base': + model = RobertaForMaskedLM(model_config) + elif MODEL == 'base-custom': + model = TreeCodeBERTForPreTraining(model_config, max_depth=32, max_seq_length=512, base_custom=True) + else: + raise ValueError(f"Invalid model type: {MODEL}") + model = model.to(device) + logging.info(f"Model initialized: {MODEL}") + + # Model precision + if device.type == 'cuda' and config['fp16']: + model = model.half() + logging.info("Model set to half precision.") + for param in model.parameters(): + logging.info(f"Parameter dtype: {param.dtype}") + break + + # Training setup + logger.info("Setting up optimizer and scheduler...") + optimizer, scheduler = create_optimizer_and_scheduler(model, config) + + # Train and evaluate model + logger.info("Starting training...") + train_and_evaluate( + model=model, + train_dataloader=train_dataloader, + valid_dataloader=valid_dataloader, + optimizer=optimizer, + scheduler=scheduler, + config=config, + device=device, + output_dir=output_dir + ) + + logger.info("Training completed!") + final_output_dir = output_dir / 'final-model' + model.save_pretrained(final_output_dir) + tokenizer.save_pretrained(final_output_dir) + logger.info(f"Model saved to {final_output_dir}") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/code/src/training_utils.py b/code/src/training_utils.py index 0fd07ef..6a71b3a 100644 --- a/code/src/training_utils.py +++ b/code/src/training_utils.py @@ -8,38 +8,41 @@ import wandb import numpy as np import torch from pathlib import Path -from typing import Dict, Any, Tuple, Type, Optional -from datasets import load_from_disk, concatenate_datasets +from typing import Dict, Any, Tuple, Optional +from datasets import load_dataset from torch.optim import AdamW -from torch.utils.data import DataLoader, Dataset +from torch.utils.data import DataLoader from transformers import ( PreTrainedModel, - get_linear_schedule_with_warmup, + get_constant_schedule_with_warmup, DataCollatorForLanguageModeling, - PreTrainedTokenizer + PreTrainedTokenizer, + RobertaForMaskedLM ) from tqdm import tqdm -from parse_dataset import ASTNodeInfo - # Configure logging -logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(levelname)s - %(message)s', + datefmt='%Y-%m-%d %H:%M:%S' +) logger = logging.getLogger(__name__) -def set_deterministic_mode() -> None: +def set_deterministic_mode(seed: int) -> None: """Enable deterministic mode for reproducibility.""" # Set Python random seed - random.seed(42) + random.seed(seed) # Set NumPy random seed - np.random.seed(42) + np.random.seed(seed) # Set PyTorch random seed - torch.manual_seed(42) + torch.manual_seed(seed) # Set CUDA random seed if torch.cuda.is_available(): - torch.cuda.manual_seed_all(42) + torch.cuda.manual_seed_all(seed) # Enable deterministic operations for PyTorch 2.x torch.use_deterministic_algorithms(True) @@ -47,7 +50,7 @@ def set_deterministic_mode() -> None: torch.backends.cudnn.benchmark = False # Set environment variables for reproducibility - os.environ['PYTHONHASHSEED'] = '42' + os.environ['PYTHONHASHSEED'] = f'{seed}' os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8' # Required for CUDA >= 10.2 def get_device_info() -> Dict[str, Any]: @@ -70,53 +73,6 @@ def get_device_info() -> Dict[str, Any]: return info -class TreeDataCollator(DataCollatorForLanguageModeling): - """Custom data collator that handles both MLM and optional AST node information.""" - - def torch_call(self, examples): - # Check if we have AST nodes - has_ast = 'ast_nodes' in examples[0] - - # Extract AST nodes before MLM processing if they exist - ast_nodes = None - if has_ast: - ast_nodes = [e.pop('ast_nodes') for e in examples] - - # Process normal MLM features - batch = super().torch_call(examples) - - # Add AST nodes back to batch if they existed - if has_ast: - batch['ast_nodes'] = ast_nodes - - return batch - - def __call__(self, examples): - return self.torch_call(examples) - -class PreprocessedBaseDataset(Dataset): - """Dataset that uses pre-processed data without AST information.""" - def __init__(self, dataset: Any): - self.dataset = dataset - - def __len__(self) -> int: - return len(self.dataset) - - def __getitem__(self, idx: int) -> Dict[str, Any]: - item = self.dataset[idx] - return { - 'input_ids': torch.tensor(item['input_ids']), - 'attention_mask': torch.tensor(item['attention_mask']), - 'labels': torch.tensor(item['input_ids']).clone() - } - -class PreprocessedTreeDataset(PreprocessedBaseDataset): - """Dataset that includes AST information.""" - def __getitem__(self, idx: int) -> Dict[str, Any]: - item = super().__getitem__(idx) - item['ast_nodes'] = [ASTNodeInfo.from_dict(node) for node in self.dataset[idx]['ast_nodes']] - return item - def set_seed(seed: int) -> None: """Set random seeds for reproducibility.""" random.seed(seed) @@ -128,6 +84,7 @@ def set_seed(seed: int) -> None: def setup_wandb( config: Dict[str, Any], model_name: str = 'base', + script_path: Path = None, device_info: Optional[Dict[str, Any]] = None ) -> None: """Initialize W&B logging with reproducibility information.""" @@ -136,17 +93,21 @@ def setup_wandb( # Add reproducibility settings to config full_config = { **config, - 'random_seed': 42, 'deterministic_mode': True, 'device_info': device_info or get_device_info(), } wandb.init( - project='codebert-training-test', + project='new-codebert-tree-training', name=f'{model_name}_{curr_time}', config=full_config ) + # Upload the training script + wandb.save(script_path) + wandb.save(Path(script_path).parent / 'training_utils.py') + logger.info(f'Saving script {script_path} to W&B') + def setup_directories(current_dir: Path, model_name: str = 'base') -> Path: """Create output directories for model artifacts.""" curr_time: str = datetime.datetime.now().strftime('%Y-%m-%d_%H:%M') @@ -172,7 +133,6 @@ def setup_device() -> torch.device: def create_optimizer_and_scheduler( model: PreTrainedModel, config: Dict[str, Any], - num_training_steps: int ) -> Tuple[AdamW, Any]: """Create optimizer and learning rate scheduler.""" optimizer = AdamW( @@ -181,270 +141,215 @@ def create_optimizer_and_scheduler( weight_decay=config['weight_decay'] ) - scheduler = get_linear_schedule_with_warmup( + scheduler = get_constant_schedule_with_warmup( optimizer, num_warmup_steps=config['warmup_steps'], - num_training_steps=num_training_steps ) return optimizer, scheduler -def evaluate(model: PreTrainedModel, dataloader: DataLoader) -> Tuple[float, float]: - """Evaluate model on validation set.""" +def evaluate( + model: RobertaForMaskedLM, + dataloader: DataLoader, + device: torch.device +) -> Tuple[float, float]: + """Evaluate the model on the validation set. + + Returns: + Tuple of (average loss, accuracy on masked tokens) + """ model.eval() - total_loss: float = 0.0 - total_acc: float = 0.0 + total_loss = 0 + total_correct = 0 + total_predictions = 0 with torch.no_grad(): - for batch in tqdm(dataloader, desc='Validation'): + for batch in tqdm(dataloader, desc='Evaluating'): + batch = { + k: v.to(device) if isinstance(v, torch.Tensor) else v + for k, v in batch.items() + } outputs = model(**batch) - total_loss += outputs.loss.item() if hasattr(outputs, 'loss') else outputs['loss'].item() - logits = outputs.logits if hasattr(outputs, 'logits') else outputs['logits'] - total_acc += logits.argmax(dim=-1).eq(batch['labels']).sum().item() + + # Get loss + total_loss += outputs['loss'].item() + + # Calculate accuracy only on masked tokens + predictions = outputs['logits'].argmax(dim=-1) # [batch_size, seq_length] + labels = batch['labels'] # [batch_size, seq_length] + + # Create mask for tokens we actually want to predict (ignore padding and unmasked tokens) + predict_mask = labels != -100 # -100 is the ignore index + + # Calculate accuracy + correct = predictions[predict_mask] == labels[predict_mask] + total_correct += correct.sum().item() + total_predictions += predict_mask.sum().item() - avg_loss: float = total_loss / len(dataloader) - avg_acc: float = total_acc / len(dataloader.dataset) - model.train() - return avg_loss, avg_acc + avg_loss = total_loss / len(dataloader) + accuracy = total_correct / total_predictions if total_predictions > 0 else 0 + + return avg_loss, accuracy def train_and_evaluate( - model: PreTrainedModel, - train_dataloader: DataLoader, - valid_dataloader: DataLoader, - optimizer: AdamW, - scheduler: Any, - config: Dict[str, Any], + model: RobertaForMaskedLM, + train_dataloader: DataLoader, + valid_dataloader: DataLoader, + optimizer: AdamW, + scheduler: Any, + config: Dict[str, Any], + device: torch.device, output_dir: Path, - log_weights: bool = False, - device: torch.device = torch.device('cpu') ) -> None: - """Train and evaluate model with deterministic behavior and comprehensive logging.""" - # Enable deterministic algorithms for PyTorch 2.5 - torch.use_deterministic_algorithms(True) + """Train and evaluate the model.""" - num_training_steps: int = config['epochs'] * len(train_dataloader) - best_valid_loss: float = float('inf') + model.train() + best_valid_loss = float('inf') - # Save initial model state for reproducibility - torch.save({ - 'model_state': model.state_dict(), - 'optimizer_state': optimizer.state_dict(), - 'scheduler_state': scheduler.state_dict(), - 'rng_state': torch.get_rng_state(), - 'cuda_rng_state': torch.cuda.get_rng_state() if torch.cuda.is_available() else None - }, output_dir / 'initial_state.pt') - - with tqdm(total=num_training_steps, desc='Training') as pbar: - for epoch_idx in range(config['epochs']): - model.train() - epoch_loss = 0.0 - epoch_steps = 0 - - for train_idx, train_batch in enumerate(train_dataloader): + for epoch in range(config['epochs']): + total_loss = 0 + + # Training loop + with tqdm(train_dataloader, desc=f'Epoch {epoch+1}/{config["epochs"]}') as pbar: + for step, batch in enumerate(pbar): # Move batch to device - train_batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v - for k, v in train_batch.items()} - - # Forward pass - outputs = model(**train_batch) - train_loss = outputs.loss if hasattr(outputs, 'loss') else outputs['loss'] - - # Backward pass with deterministic behavior - train_loss.backward() - grad_norm = torch.nn.utils.clip_grad_norm_( - model.parameters(), - config['max_grad_norm'] - ) - - optimizer.step() - scheduler.step() - optimizer.zero_grad(set_to_none=False) # Use False for determinism - - # Update metrics - current_loss = train_loss.item() - epoch_loss += current_loss - epoch_steps += 1 - - # Calculate global step - step = train_idx + len(train_dataloader) * epoch_idx - - # Prepare logging dictionary - log_dict = { - 'step': step, - 'epoch': epoch_idx, - 'train_loss': current_loss, - 'gradient_norm': grad_norm.item(), - 'learning_rate': scheduler.get_last_lr()[0], + batch = { + k: v.to(device) if isinstance(v, torch.Tensor) else v + for k, v in batch.items() } - # Add embedding weights if using tree model - if log_weights and 'embedding_weights' in outputs: - weights = outputs['embedding_weights'] - log_dict.update({ - 'token_weight': weights['token'], - 'tree_weight': weights['tree'], - 'sequential_weight': weights['sequential'] - }) - pbar_dict = { - 'epoch': f"{epoch_idx + 1}/{config['epochs']}", - 'train_loss': f"{current_loss:.3f}", - 'α': f"{weights['token']:.2f}", - 'β': f"{weights['tree']:.2f}", - 'γ': f"{weights['sequential']:.2f}" - } - else: - pbar_dict = { - 'epoch': f"{epoch_idx + 1}/{config['epochs']}", - 'train_loss': f"{current_loss:.3f}" - } + # Forward pass + outputs = model(**batch) + loss = outputs['loss'] - # Log to wandb - wandb.log(log_dict) + # Backward pass + loss.backward() + + # Clip gradients + grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config['max_grad_norm']) + + optimizer.step() + scheduler.step() + optimizer.zero_grad() + + # Update metrics + total_loss += loss.item() # Update progress bar - pbar.update(1) - pbar.set_postfix(pbar_dict) + pbar.set_postfix({'loss': f'{loss.item():.4f}'}) - # Periodic evaluation - if (train_idx != 0 and train_idx % config['eval_every'] == 0) or train_idx == len(train_dataloader) - 1: - model.eval() - valid_loss, valid_acc = evaluate(model, valid_dataloader) - model.train() + # Log to wandb if configured + log_values = { + 'train_loss': loss.item(), + 'grad_norm': grad_norm.item(), + 'learning_rate': scheduler.get_last_lr()[0] + } + if 'embedding_weights' in outputs and 'tree' in outputs['embedding_weights']: + log_values.update({ + 'token_weight': outputs['embedding_weights']['token'], + 'tree_weight': outputs['embedding_weights']['tree'], + 'seq_weight': outputs['embedding_weights']['sequential'], + }) + elif 'embedding_weights' in outputs: + log_values.update({ + 'token_weight': outputs['embedding_weights']['token'], + 'seq_weight': outputs['embedding_weights']['sequential'], + }) + if 'regularization_metrics' in outputs: + log_values.update({f"regularization/{k}": v for k, v in outputs['regularization_metrics'].items()}) + wandb.log(log_values) + + # Evaluate periodically + if step % config['eval_every'] == 0 and step > 0: + valid_loss, valid_acc = evaluate(model, valid_dataloader, device) # Log validation metrics - wandb.log({ - 'valid_loss': valid_loss, - 'valid_acc': valid_acc, - 'step': step, - 'epoch': epoch_idx, - }) + wandb.log({'valid_loss': valid_loss, 'valid_accuracy': valid_acc}) - # Save checkpoint if best model + # Print validation metrics + print(f"\nValidation - Loss: {valid_loss:.4f}, Accuracy: {valid_acc:.4f}") + + # Save best model if valid_loss < best_valid_loss: best_valid_loss = valid_loss - - # Save complete state - torch.save({ - 'epoch': epoch_idx, - 'step': step, - 'model_state': model.state_dict(), - 'optimizer_state': optimizer.state_dict(), - 'scheduler_state': scheduler.state_dict(), - 'rng_state': torch.get_rng_state(), - 'cuda_rng_state': torch.cuda.get_rng_state() if torch.cuda.is_available() else None, - 'best_valid_loss': best_valid_loss, - 'config': config, - 'device_info': get_device_info() - }, output_dir / 'best_model.pt') - - # Log best metrics - wandb.run.summary['best_valid_loss'] = best_valid_loss - wandb.run.summary['best_valid_acc'] = valid_acc - wandb.run.summary['best_epoch'] = epoch_idx - wandb.run.summary['best_step'] = step - - # End of epoch logging - avg_epoch_loss = epoch_loss / epoch_steps - wandb.log({ - 'epoch': epoch_idx, - 'epoch_avg_loss': avg_epoch_loss, - }) - - # Save end of epoch checkpoint - torch.save({ - 'epoch': epoch_idx, - 'model_state': model.state_dict(), - 'optimizer_state': optimizer.state_dict(), - 'scheduler_state': scheduler.state_dict(), - 'rng_state': torch.get_rng_state(), - 'cuda_rng_state': torch.cuda.get_rng_state() if torch.cuda.is_available() else None, - 'train_loss': avg_epoch_loss, - 'config': config, - 'device_info': get_device_info() - }, output_dir / f'checkpoint_epoch_{epoch_idx}.pt') - - # End of training logging - wandb.run.summary['final_epoch'] = config['epochs'] - 1 - wandb.run.summary['total_steps'] = num_training_steps - - logger.info(f'Training completed. Best validation loss: {best_valid_loss:.4f}') - logger.info(f'Model checkpoints saved in {output_dir}') - -def load_all_chunks(chunks_dir: Path) -> Any: - """Load and concatenate all dataset chunks.""" - logger.info(f"Loading dataset chunks from {chunks_dir}") - chunks = [] - for chunk_path in sorted(chunks_dir.glob('chunk_*'))[:5]: - chunks.append(load_from_disk(str(chunk_path))) - dataset = concatenate_datasets(chunks) - logger.info(f"Loaded {len(dataset)} examples from {len(chunks)} chunks") - return dataset + model.save_pretrained(output_dir / f'checkpoint-{epoch}-{step}') + + model.train() # Resume training mode def create_base_dataloaders( processed_data_dir: Path, tokenizer: PreTrainedTokenizer, config: Dict[str, Any], - device: torch.device, - dataset_class: Type[Dataset] = PreprocessedBaseDataset, + base_training=False, ) -> Tuple[DataLoader, DataLoader]: - """Create reproducible dataloaders from pre-processed data.""" - # Load chunks - chunks_dir = processed_data_dir / 'chunks' - full_dataset = load_all_chunks(chunks_dir) + """Create dataloaders from pre-processed parquet data.""" + + # Load chunks using datasets library + chunks_dir = processed_data_dir / 'data' + dataset = load_dataset( + "parquet", + data_files=str(chunks_dir / "*.parquet"), + split="train" + ) + + contents = dataset['content'][:config['batch_size']] + with open("contents.txt", "w") as f: + for content in contents: + f.write(content) + f.write("\n\n\n") + f.write("#" * 80) + f.write("\n\n\n") + + # Remove columns that are not needed + columns_to_remove = dataset.column_names + columns_to_remove.remove('input_ids') + columns_to_remove.remove('attention_mask') + + if not base_training: + columns_to_remove.remove('depths') + columns_to_remove.remove('sibling_idxs') + + dataset = dataset.remove_columns(columns_to_remove) + + logging.info(f"Loaded dataset:\n{dataset}") # Calculate split sizes - dataset_size = len(full_dataset) + dataset_size = len(dataset) train_size = int(config['train_size'] * dataset_size) val_size = dataset_size - train_size - # Create splits using indices to avoid device issues - indices = torch.arange(dataset_size, device=device) + # Create splits + splits = dataset.train_test_split( + train_size=train_size, + test_size=val_size, + seed=config['seed'] + ) + train_dataset = splits['train'] + valid_dataset = splits['test'] - # Create a deterministic generator on the correct device - generator = torch.Generator(device=device) - generator.manual_seed(42) - - # Shuffle indices deterministically - shuffled_indices = indices[torch.randperm(len(indices), device=device, generator=generator)] - train_indices = shuffled_indices[:train_size].cpu() - valid_indices = shuffled_indices[train_size:].cpu() - - # Create train/validation splits using the shuffled indices - train_dataset = torch.utils.data.Subset(full_dataset, train_indices.tolist()) - valid_dataset = torch.utils.data.Subset(full_dataset, valid_indices.tolist()) - - # Create dataset wrappers - train_dataset = dataset_class(train_dataset) - valid_dataset = dataset_class(valid_dataset) - - logger.info(f"Created train dataset with {len(train_dataset)} samples") - logger.info(f"Created validation dataset with {len(valid_dataset)} samples") - - # Use TreeDataCollator for both base and tree models - data_collator = TreeDataCollator( - tokenizer=tokenizer, - mlm=True, + # Create data collator for MLM + data_collator = DataCollatorForLanguageModeling( + tokenizer=tokenizer, + mlm=True, mlm_probability=config['mlm_probability'] ) - - # Create train dataloader without generator + train_dataloader = DataLoader( train_dataset, - batch_size=config['batch'], - shuffle=False, # We already shuffled the data + batch_size=config['batch_size'], + shuffle=False, # We'll let the dataset handle shuffling collate_fn=data_collator, - num_workers=0, # Use single worker for reproducibility - drop_last=True # Ensure consistent batch sizes + num_workers=0, + drop_last=True, ) - # Create validation dataloader valid_dataloader = DataLoader( valid_dataset, - batch_size=config['batch'], + batch_size=config['batch_size'], shuffle=False, collate_fn=data_collator, num_workers=0, - drop_last=True + drop_last=True, ) return train_dataloader, valid_dataloader