standardized utils for training

This commit is contained in:
Patryk Bartkowiak 2024-11-17 21:51:52 +00:00
parent 96fc1041cf
commit 0cd04e6131
12 changed files with 769 additions and 2507 deletions

View File

@ -37,3 +37,4 @@ distribution = true
[tool.pdm.scripts]
run_training = {cmd = "src/train_codebert_mlm.py"}
run_tree_training = {cmd = "src/train_tree_codebert_mlm.py"}
parse_dataset = {cmd = "src/parse_dataset.py"}

View File

@ -8,5 +8,5 @@
"weight_decay": 0.01,
"max_grad_norm": 1.0,
"warmup_steps": 20000,
"pause_on_instability": false
"train_size": 0.95
}

232
code/src/parse_dataset.py Normal file
View File

@ -0,0 +1,232 @@
import ast
from pathlib import Path
import logging
from tqdm import tqdm
from pprint import pprint
from typing import Dict, Any, List
from datasets import load_dataset, Dataset, concatenate_datasets
from transformers import RobertaTokenizer
from dataclasses import dataclass
import numpy as np
import multiprocessing
import warnings
warnings.filterwarnings("ignore", category=SyntaxWarning)
# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
@dataclass
class ASTNodeInfo:
"""Stores structural information about an AST node."""
node_type: str
start_token_idx: int
end_token_idx: int
depth: int
sibling_pos: int
parent_idx: int
def to_dict(self) -> Dict[str, Any]:
return {
'node_type': self.node_type,
'start_token_idx': self.start_token_idx,
'end_token_idx': self.end_token_idx,
'depth': self.depth,
'sibling_pos': self.sibling_pos,
'parent_idx': self.parent_idx
}
def from_dict(data: Dict[str, Any]) -> 'ASTNodeInfo':
return ASTNodeInfo(
node_type=data['node_type'],
start_token_idx=data['start_token_idx'],
end_token_idx=data['end_token_idx'],
depth=data['depth'],
sibling_pos=data['sibling_pos'],
parent_idx=data['parent_idx']
)
def safe_ast_parse(content: str, max_length: int = 50000) -> bool:
"""Safely attempt to parse Python code."""
if not content or not content.strip() or '\0' in content or len(content) > max_length:
return False
try:
tree = ast.parse(content)
return bool(list(ast.walk(tree)))
except:
return False
def process_and_extract_ast_info(content: str, max_ast_size: int = 1000) -> List[Dict[str, Any]]:
"""Process AST and extract node information."""
try:
tree = ast.parse(content)
nodes_info = []
def visit_node(node: ast.AST, depth: int, parent_idx: int, sibling_pos: int):
if len(nodes_info) >= max_ast_size:
return
current_idx = len(nodes_info)
if hasattr(node, 'lineno'):
node_info = ASTNodeInfo(
node_type=type(node).__name__,
start_token_idx=node.lineno,
end_token_idx=getattr(node, 'end_lineno', node.lineno),
depth=min(depth, 31),
sibling_pos=min(sibling_pos, 31),
parent_idx=parent_idx
)
nodes_info.append(node_info.to_dict())
if len(nodes_info) < max_ast_size:
for i, child in enumerate(ast.iter_child_nodes(node)):
visit_node(child, depth + 1, current_idx, i)
visit_node(tree, 0, -1, 0)
return nodes_info
except:
return []
def process_batch(examples: Dict[str, List[Any]], tokenizer: RobertaTokenizer) -> Dict[str, List[Any]]:
"""Process a batch of examples."""
contents = examples['content']
processed_contents = []
processed_ast_nodes = []
processed_input_ids = []
processed_attention_masks = []
for content in contents:
if safe_ast_parse(content):
ast_nodes = process_and_extract_ast_info(content)
if ast_nodes:
try:
encoding = tokenizer(
content,
max_length=512,
truncation=True,
padding='max_length',
return_tensors='pt'
)
processed_contents.append(content)
processed_ast_nodes.append(ast_nodes)
processed_input_ids.append(encoding['input_ids'].squeeze(0).tolist())
processed_attention_masks.append(encoding['attention_mask'].squeeze(0).tolist())
except:
continue
return {
'content': processed_contents,
'ast_nodes': processed_ast_nodes,
'input_ids': processed_input_ids,
'attention_mask': processed_attention_masks,
}
def save_dataset_in_chunks(dataset: Dataset, output_path: str, chunk_size: int = 10000):
"""Save dataset to disk in chunks to manage memory usage."""
num_chunks = (len(dataset) + chunk_size - 1) // chunk_size
# Create directory for chunks
output_dir = Path(output_path).parent
chunks_dir = output_dir / 'chunks'
chunks_dir.mkdir(exist_ok=True)
stats_accumulator = {
'total_ast_nodes': 0,
'total_samples': 0
}
pbar = tqdm(range(num_chunks), desc="Saving chunks", unit="chunk")
for i in pbar:
start_idx = i * chunk_size
end_idx = min((i + 1) * chunk_size, len(dataset))
# Select chunk from dataset
chunk_dataset = dataset.select(range(start_idx, end_idx))
# Update statistics
stats_accumulator['total_ast_nodes'] += sum(len(nodes) for nodes in chunk_dataset['ast_nodes'])
stats_accumulator['total_samples'] += len(chunk_dataset)
# Save chunk using datasets native method
chunk_path = chunks_dir / f'chunk_{i:04d}'
chunk_dataset.save_to_disk(str(chunk_path))
# Update progress bar postfix with current chunk info
pbar.set_postfix({
'samples': len(chunk_dataset),
'path': str(chunk_path.name)
})
return stats_accumulator
def load_all_chunks(chunks_dir: Path) -> Dataset:
"""Load and concatenate all dataset chunks."""
chunks = []
for chunk_path in sorted(chunks_dir.glob('chunk_*')):
chunks.append(load_from_disk(str(chunk_path)))
return concatenate_datasets(chunks)
def main():
current_dir = Path(__file__).parent
input_dir = current_dir.parent / 'data' / 'the-stack-python'
output_dir = current_dir.parent / 'data' / 'processed-python'
output_dir.mkdir(parents=True, exist_ok=True)
tokenizer = RobertaTokenizer.from_pretrained('microsoft/codebert-base')
logger.info("Loading dataset...")
dataset = load_dataset(str(input_dir))['train']
original_dataset_size = len(dataset)
logger.info("Dataset:")
pprint(dataset)
num_proc = min(multiprocessing.cpu_count() - 1, 32)
logger.info(f"Using {num_proc} processes for dataset processing")
logger.info("Processing dataset...")
processed_dataset = dataset.map(
process_batch,
fn_kwargs={'tokenizer': tokenizer},
batched=True,
remove_columns=dataset.column_names,
desc="Processing dataset",
num_proc=num_proc,
load_from_cache_file=False
)
processed_dataset = processed_dataset.filter(
lambda batch: [len(nodes) > 0 for nodes in batch['ast_nodes']],
batched=True,
num_proc=num_proc,
desc="Filtering invalid examples"
)
reduced_dataset_size = len(processed_dataset)
logger.info("Processed dataset:")
pprint(processed_dataset)
logger.info(f"Saving {len(processed_dataset)} processed examples in chunks...")
stats_accumulator = save_dataset_in_chunks(
processed_dataset,
str(output_dir / 'processed_dataset'),
chunk_size=100_000
)
stats = {
'original_dataset_size': original_dataset_size,
'reduced_dataset_size': reduced_dataset_size,
'% samples removed': (1 - reduced_dataset_size / original_dataset_size) * 100,
'avg_ast_nodes': float(stats_accumulator['total_ast_nodes'] / stats_accumulator['total_samples'])
}
# Log stats with pprint
logger.info("Processing completed! Stats:")
pprint(stats)
if __name__ == "__main__":
main()

View File

@ -1,72 +0,0 @@
import torch
from transformers import RobertaTokenizer, RobertaForMaskedLM, DataCollatorForLanguageModeling
from pathlib import Path
def load_model_and_tokenizer(model_path: Path, tokenizer_name: str = 'microsoft/codebert-base'):
# Load the pre-trained tokenizer
tokenizer = RobertaTokenizer.from_pretrained(tokenizer_name)
# Load the trained model
state_dict = torch.load(model_path, weights_only=True)
corrected_state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}
model = RobertaForMaskedLM.from_pretrained('roberta-base')
model.load_state_dict(corrected_state_dict)
return model, tokenizer
def run_inference(model, tokenizer, input_text: str, max_length: int = 512):
# Tokenize the input text
inputs = tokenizer(
input_text,
return_tensors='pt',
padding='max_length',
truncation=True,
max_length=max_length
)
# Use DataCollatorForLanguageModeling for MLM-style dynamic masking
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=True, mlm_probability=0.15)
inputs = data_collator([inputs]) # Collate the input and add random masks
# Squeeze the batch dimension
for k, v in inputs.items():
inputs[k] = v.squeeze(0)
print(tokenizer.decode(inputs['input_ids'][0][inputs['attention_mask'][0] == 1], skip_special_tokens=False))
with torch.no_grad():
# Run inference
outputs = model(**inputs).logits.squeeze(0)
# Get predicted token ids (argmax over logits)
predicted_token_ids = outputs.argmax(dim=-1)
# Ignore padding tokens
predicted_token_ids = predicted_token_ids[inputs['attention_mask'][0] == 1]
# Decode predicted token ids to text
predicted_text = tokenizer.decode(predicted_token_ids, skip_special_tokens=False)
return predicted_text
def main_inference():
# Load the trained model and tokenizer
model_path = Path('/sql/msc-patryk-bartkowiak/outputs/2024-10-21_20:15/best_model.pt') # Update this to your trained model path
model, tokenizer = load_model_and_tokenizer(model_path)
# Define input string
input_string = """def compute_area(radius):
# This function calculates the area of a circle
pi = 3.14159
return pi * radius ** 2"""
# Run inference
print(f"Input:\n{input_string}")
print("\n\n")
print("Masked input:")
output = run_inference(model, tokenizer, input_string)
print("\n\n")
print(f"Output:\n{output}")
if __name__ == "__main__":
main_inference()

File diff suppressed because it is too large Load Diff

View File

@ -1,374 +0,0 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/s452638/magisterka/magisterka_env/lib/python3.8/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
}
],
"source": [
"from tree_sitter import Language, Parser\n",
"from datasets import load_dataset"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Dataset({\n",
" features: ['repo', 'path', 'func_name', 'original_string', 'language', 'code', 'code_tokens', 'docstring', 'docstring_tokens', 'sha', 'url', 'partition'],\n",
" num_rows: 251820\n",
"})\n"
]
}
],
"source": [
"# Load the dataset\n",
"dataset = load_dataset('json', data_files={'train': '/work/s452638/datasets/CodeSearchNet/python/train.jsonl'}, split='train')\n",
"print(dataset)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/s452638/magisterka/magisterka_env/lib/python3.8/site-packages/tree_sitter/__init__.py:36: FutureWarning: Language.build_library is deprecated. Use the new bindings instead.\n",
" warn(\"{} is deprecated. Use {} instead.\".format(old, new), FutureWarning)\n"
]
},
{
"data": {
"text/plain": [
"False"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Build the language library if not already built\n",
"# This should be done only once, and the resulting .so file can be reused\n",
"Language.build_library(\n",
" '/home/s452638/magisterka/build/my-languages.so', # Output location of compiled language library\n",
" [\n",
" '/home/s452638/magisterka/vendor/tree-sitter-python' # Replace with the path to the tree-sitter-python grammar\n",
" ]\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/s452638/magisterka/magisterka_env/lib/python3.8/site-packages/tree_sitter/__init__.py:36: FutureWarning: Language(path, name) is deprecated. Use Language(ptr, name) instead.\n",
" warn(\"{} is deprecated. Use {} instead.\".format(old, new), FutureWarning)\n"
]
}
],
"source": [
"# Load the language\n",
"PYTHON_LANGUAGE = Language('/home/s452638/magisterka/build/my-languages.so', 'python')\n",
"\n",
"# Initialize the parser\n",
"parser = Parser()\n",
"parser.set_language(PYTHON_LANGUAGE)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"def remove_docstrings_and_comments_from_code(code):\n",
" # Parse the code\n",
" tree = parser.parse(bytes(code, \"utf8\"))\n",
" cursor = tree.walk()\n",
"\n",
" # Traverse the tree and collect all docstrings\n",
" to_remove = []\n",
"\n",
" def traverse_tree(cursor, prev_node_type=None):\n",
" node_type = cursor.node.type\n",
" node_text = cursor.node.text.decode(\"utf-8\")\n",
" # Check if the current node is a function or class definition\n",
" if node_type == \"string\" and node_text.startswith('\"\"\"') and node_text.endswith('\"\"\"') and prev_node_type == \"expression_statement\":\n",
" to_remove.append((cursor.node.start_byte, cursor.node.end_byte))\n",
" if cursor.node.type == \"comment\":\n",
" to_remove.append((cursor.node.start_byte, cursor.node.end_byte))\n",
"\n",
" # Traverse children\n",
" if cursor.goto_first_child():\n",
" while True:\n",
" traverse_tree(cursor, node_type)\n",
" if not cursor.goto_next_sibling():\n",
" break\n",
" cursor.goto_parent()\n",
"\n",
" return node_type\n",
"\n",
" # Start traversing from the root\n",
" traverse_tree(cursor)\n",
"\n",
" # Remove docstrings from code\n",
" code_without_docstrings = code\n",
" for start, end in sorted(to_remove, reverse=True):\n",
" code_without_docstrings = code_without_docstrings[:start] + code_without_docstrings[end:]\n",
"\n",
" return code_without_docstrings"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"def gather_categories(imap, header, categories=None):\n",
" \"\"\"\n",
" Find the user specified categories in the map and create a dictionary to contain the\n",
" relevant data for each type within the categories. Multiple categories will have their\n",
" types combined such that each possible combination will have its own entry in the\n",
" dictionary.\n",
"\n",
" :type imap: dict\n",
" :param imap: The input mapping file data keyed by SampleID\n",
" :type header: list\n",
" :param header: The header line from the input mapping file. This will be searched for\n",
" the user-specified categories\n",
" :type categories: list\n",
" :param categories: The list of user-specified category column name from mapping file\n",
" :rtype: dict\n",
" :return: A sorted dictionary keyed on the combinations of all the types found within\n",
" the user-specified categories. Each entry will contain an empty DataCategory\n",
" namedtuple. If no categories are specified, a single entry with the key\n",
" 'default' will be returned\n",
" \"\"\"\n",
" # If no categories provided, return all SampleIDs\n",
" if categories is None:\n",
" return {\"default\": DataCategory(set(imap.keys()), {})}\n",
"\n",
" cat_ids = [header.index(cat)\n",
" for cat in categories if cat in header and \"=\" not in cat]\n",
"\n",
" table = OrderedDict()\n",
" conditions = defaultdict(set)\n",
" for i, cat in enumerate(categories):\n",
" if \"=\" in cat and cat.split(\"=\")[0] in header:\n",
" cat_name = header[header.index(cat.split(\"=\")[0])]\n",
" conditions[cat_name].add(cat.split(\"=\")[1])\n",
"\n",
" # If invalid categories or conditions identified, return all SampleIDs\n",
" if not cat_ids and not conditions:\n",
" return {\"default\": DataCategory(set(imap.keys()), {})}\n",
"\n",
" #If only category column given, return column-wise SampleIDs\n",
" if cat_ids and not conditions:\n",
" for sid, row in imap.items():\n",
" cat_name = \"_\".join([row[cid] for cid in cat_ids])\n",
" if cat_name not in table:\n",
" table[cat_name] = DataCategory(set(), {})\n",
" table[cat_name].sids.add(sid)\n",
" return table\n",
"\n",
" # Collect all condition names\n",
" cond_ids = set()\n",
" for k in conditions:\n",
" try:\n",
" cond_ids.add(header.index(k))\n",
" except ValueError:\n",
" continue\n",
" idx_to_test = set(cat_ids).union(cond_ids)\n",
"\n",
" # If column name and condition given, return overlapping SampleIDs of column and\n",
" # condition combinations\n",
" for sid, row in imap.items():\n",
" if all([row[header.index(c)] in conditions[c] for c in conditions]):\n",
" key = \"_\".join([row[idx] for idx in idx_to_test])\n",
" try:\n",
" assert key in table.keys()\n",
" except AssertionError:\n",
" table[key] = DataCategory(set(), {})\n",
" table[key].sids.add(sid)\n",
" try:\n",
" assert len(table) > 0\n",
" except AssertionError:\n",
" return {\"default\": DataCategory(set(imap.keys()), {})}\n",
" else:\n",
" return table\n"
]
}
],
"source": [
"idx = 3\n",
"print(dataset[idx]['code'])"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"def gather_categories(imap, header, categories=None):\n",
" \n",
" \n",
" if categories is None:\n",
" return {\"default\": DataCategory(set(imap.keys()), {})}\n",
"\n",
" cat_ids = [header.index(cat)\n",
" for cat in categories if cat in header and \"=\" not in cat]\n",
"\n",
" table = OrderedDict()\n",
" conditions = defaultdict(set)\n",
" for i, cat in enumerate(categories):\n",
" if \"=\" in cat and cat.split(\"=\")[0] in header:\n",
" cat_name = header[header.index(cat.split(\"=\")[0])]\n",
" conditions[cat_name].add(cat.split(\"=\")[1])\n",
"\n",
" \n",
" if not cat_ids and not conditions:\n",
" return {\"default\": DataCategory(set(imap.keys()), {})}\n",
"\n",
" \n",
" if cat_ids and not conditions:\n",
" for sid, row in imap.items():\n",
" cat_name = \"_\".join([row[cid] for cid in cat_ids])\n",
" if cat_name not in table:\n",
" table[cat_name] = DataCategory(set(), {})\n",
" table[cat_name].sids.add(sid)\n",
" return table\n",
"\n",
" \n",
" cond_ids = set()\n",
" for k in conditions:\n",
" try:\n",
" cond_ids.add(header.index(k))\n",
" except ValueError:\n",
" continue\n",
" idx_to_test = set(cat_ids).union(cond_ids)\n",
"\n",
" \n",
" \n",
" for sid, row in imap.items():\n",
" if all([row[header.index(c)] in conditions[c] for c in conditions]):\n",
" key = \"_\".join([row[idx] for idx in idx_to_test])\n",
" try:\n",
" assert key in table.keys()\n",
" except AssertionError:\n",
" table[key] = DataCategory(set(), {})\n",
" table[key].sids.add(sid)\n",
" try:\n",
" assert len(table) > 0\n",
" except AssertionError:\n",
" return {\"default\": DataCategory(set(imap.keys()), {})}\n",
" else:\n",
" return table\n"
]
}
],
"source": [
"print(remove_docstrings_and_comments_from_code(dataset[idx]['code']))"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"def test():\n",
" \n",
" str_variable = \"\"\"\n",
" This is a string\n",
" \"\"\"\n",
" print(\"Hello World\")\n",
"\n",
"class Test:\n",
" \n",
" def __init__(self):\n",
" pass\n",
"\n"
]
}
],
"source": [
"test_code = '''\n",
"def test():\n",
" \"\"\"\n",
" This is a test function\n",
" \"\"\"\n",
" str_variable = \"\"\"\n",
" This is a string\n",
" \"\"\"\n",
" print(\"Hello World\")\n",
"\n",
"class Test:\n",
" \"\"\"\n",
" This is a test class\n",
" \"\"\"\n",
" def __init__(self):\n",
" pass\n",
"'''\n",
"\n",
"print(remove_docstrings_and_comments_from_code(test_code))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.19"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

12
code/src/tmp_config.json Normal file
View File

@ -0,0 +1,12 @@
{
"seed": 42,
"mlm_probability": 0.15,
"batch": 16,
"epochs": 1,
"eval_every": 5000,
"learning_rate": 5e-4,
"weight_decay": 0.01,
"max_grad_norm": 1.0,
"warmup_steps": 5000,
"train_size": 0.95
}

View File

@ -1,261 +1,50 @@
import wandb
import os
import json
import random
import datetime
import logging
from pathlib import Path
from typing import Dict, Any, Tuple, List
import numpy as np
import torch
from torch import Tensor
from torch.optim import AdamW
from torch.utils.data import DataLoader, Dataset
from datasets import load_dataset, DatasetDict
from huggingface_hub import list_repo_files, hf_hub_download
from transformers import (
RobertaForMaskedLM,
RobertaConfig,
RobertaTokenizer,
DataCollatorForLanguageModeling,
get_linear_schedule_with_warmup,
PreTrainedTokenizer,
PreTrainedModel
from transformers import RobertaForMaskedLM, RobertaConfig, RobertaTokenizer
from training_utils import (
set_seed, setup_wandb, setup_directories, load_config,
setup_device, create_optimizer_and_scheduler,
train_and_evaluate, create_base_dataloaders, set_deterministic_mode
)
from tqdm import tqdm
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
class OnTheFlyTokenizationDataset(Dataset):
def __init__(self, dataset: Dataset, tokenizer: PreTrainedTokenizer, max_length: int):
self.dataset = dataset
self.tokenizer = tokenizer
self.max_length = max_length
def __len__(self) -> int:
return len(self.dataset)
def __getitem__(self, idx: int) -> Dict[str, Tensor]:
content: str = self.dataset[idx]['content']
tokenized = self.tokenizer(
content,
truncation=True,
padding='max_length',
max_length=self.max_length,
return_tensors='pt'
)
return {
'input_ids': tokenized['input_ids'].squeeze(0),
'attention_mask': tokenized['attention_mask'].squeeze(0),
'labels': tokenized['input_ids'].squeeze(0)
}
def set_seed(seed: int) -> None:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
def setup_wandb(config: Dict[str, Any], exec_file: str = 'train_codebert_mlm.py') -> None:
curr_time: str = datetime.datetime.now().strftime('%Y-%m-%d %H:%M')
wandb.init(project='codebert-training', name=f'tree_{curr_time}', config=config)
wandb.save('train_codebert_mlm.py')
def setup_directories(current_dir: Path) -> Path:
curr_time: str = datetime.datetime.now().strftime('%Y-%m-%d %H:%M')
output_dir: Path = current_dir.parent.parent / 'outputs' / f'tree_{curr_time}'
output_dir.mkdir(parents=True, exist_ok=True)
return output_dir
def load_config(config_file: Path) -> Dict[str, Any]:
with open(config_file, 'r') as f:
return json.load(f)
def setup_device() -> torch.device:
device: torch.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.set_default_device(device)
logger.info(f'Using device: {device}')
if device.type == 'cuda':
logger.info(f'Device name: {torch.cuda.get_device_name()}')
torch.set_float32_matmul_precision('high')
return device
def download_dataset(dataset_dir: Path) -> None:
if not dataset_dir.exists():
logger.info("Downloading the dataset...")
dataset_dir.mkdir(parents=True, exist_ok=True)
files_list: List[str] = list_repo_files(repo_id='bigcode/the-stack-dedup', repo_type='dataset')
files_to_download: List[str] = [file for file in files_list if file.startswith('data/python/')]
for file_name in files_to_download:
hf_hub_download(repo_id='bigcode/the-stack-dedup', repo_type='dataset', filename=file_name, local_dir=dataset_dir)
logger.info("Dataset downloaded successfully.")
def load_and_prepare_dataset(dataset_dir: Path, seed: int) -> DatasetDict:
dataset: DatasetDict = load_dataset(str(dataset_dir), split='train')
dataset = dataset.train_test_split(test_size=0.01, seed=seed)
logger.info(f'Dataset loaded: {dataset}')
return dataset
def create_dataloaders(
dataset: DatasetDict,
tokenizer: PreTrainedTokenizer,
config: Dict[str, Any],
device: torch.device
) -> Tuple[DataLoader, DataLoader]:
dataset['train'] = OnTheFlyTokenizationDataset(dataset['train'], tokenizer, max_length=512)
dataset['test'] = OnTheFlyTokenizationDataset(dataset['test'], tokenizer, max_length=512)
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=True, mlm_probability=config['mlm_probability'])
train_dataloader = DataLoader(
dataset['train'],
batch_size=config['batch'],
shuffle=False,
collate_fn=data_collator,
generator=torch.Generator(device=device)
)
valid_dataloader = DataLoader(
dataset['test'],
batch_size=config['batch'],
shuffle=False,
collate_fn=data_collator,
generator=torch.Generator(device=device)
)
return train_dataloader, valid_dataloader
def setup_model_and_optimizer(
config: Dict[str, Any],
models_dir: Path
) -> Tuple[PreTrainedModel, AdamW]:
if not models_dir.exists():
logger.info("Downloading the model...")
model_config = RobertaConfig.from_pretrained('roberta-base', cache_dir=models_dir)
model: PreTrainedModel = RobertaForMaskedLM(model_config)
model = torch.compile(model)
wandb.watch(model)
logger.info(f'Model config: {model_config}')
wandb.config.update({'model_config': model_config.to_dict()})
optimizer: AdamW = AdamW(model.parameters(), lr=config['learning_rate'], weight_decay=config['weight_decay'])
return model, optimizer
def train_and_evaluate(
model: PreTrainedModel,
train_dataloader: DataLoader,
valid_dataloader: DataLoader,
optimizer: AdamW,
scheduler: Any,
config: Dict[str, Any],
output_dir: Path
) -> None:
num_training_steps: int = config['epochs'] * len(train_dataloader)
best_valid_loss: float = float('inf')
with tqdm(total=num_training_steps, desc='Training') as pbar:
for epoch_idx in range(config['epochs']):
model.train()
for train_idx, train_batch in enumerate(train_dataloader):
outputs = model(**train_batch)
train_loss: Tensor = outputs.loss
train_loss.backward()
norm: Tensor = torch.nn.utils.clip_grad_norm_(model.parameters(), config['max_grad_norm'])
optimizer.step()
scheduler.step()
optimizer.zero_grad()
pbar.update(1)
pbar.set_postfix({'train_loss': train_loss.item()})
wandb.log({
'step': train_idx + len(train_dataloader) * epoch_idx,
'train_loss': train_loss.item(),
'gradient_norm': norm.item(),
'learning_rate': scheduler.get_last_lr()[0],
})
if train_idx != 0 and train_idx % config['eval_every'] == 0:
valid_loss, valid_acc = evaluate(model, valid_dataloader)
pbar.set_postfix({'train_loss': train_loss.item(), 'valid_loss': valid_loss, 'valid_acc': valid_acc})
wandb.log({
'valid_loss': valid_loss,
'valid_acc': valid_acc,
'step': train_idx + len(train_dataloader) * epoch_idx,
})
if valid_loss < best_valid_loss:
best_valid_loss = valid_loss
torch.save(model.state_dict(), output_dir / 'best_model.pt')
logger.info(f'Best validation loss: {best_valid_loss}')
def evaluate(model: PreTrainedModel, dataloader: DataLoader) -> Tuple[float, float]:
model.eval()
total_loss: float = 0.0
total_acc: float = 0.0
with torch.no_grad():
for batch in tqdm(dataloader, desc='Validation'):
outputs = model(**batch)
total_loss += outputs.loss.item()
total_acc += outputs.logits.argmax(dim=-1).eq(batch['labels']).sum().item()
avg_loss: float = total_loss / len(dataloader)
avg_acc: float = total_acc / len(dataloader.dataset)
model.train()
return avg_loss, avg_acc
def main() -> None:
current_dir: Path = Path(__file__).parent
output_dir: Path = setup_directories(current_dir)
config: Dict[str, Any] = load_config(current_dir / 'config.json')
set_deterministic_mode()
current_dir = Path(__file__).parent
output_dir = setup_directories(current_dir, model_name='base')
config = load_config(current_dir / 'tmp_config.json')
setup_wandb(config)
setup_wandb(config, model_name='base')
set_seed(config['seed'])
device: torch.device = setup_device()
device = setup_device()
dataset_dir: Path = current_dir.parent / 'data' / 'the-stack-python'
download_dataset(dataset_dir)
dataset: DatasetDict = load_and_prepare_dataset(dataset_dir, config['seed'])
# Load tokenizer
tokenizer = RobertaTokenizer.from_pretrained('microsoft/codebert-base')
tokenizer: PreTrainedTokenizer = RobertaTokenizer.from_pretrained('microsoft/codebert-base', clean_up_tokenization_spaces=True)
logger.info(f'Tokenizer loaded: {tokenizer}')
######################## Reproducing last training here ########################
# Remove first 186513 batches
dataset['train'] = dataset['train'].select(range(186_513 * 32, len(dataset['train'])))
################################################################################
train_dataloader, valid_dataloader = create_dataloaders(dataset, tokenizer, config, device)
models_dir: Path = current_dir.parent / 'models' / 'roberta-base'
model, optimizer = setup_model_and_optimizer(config, models_dir)
num_training_steps: int = config['epochs'] * len(train_dataloader)
scheduler = get_linear_schedule_with_warmup(
optimizer,
num_warmup_steps=config['warmup_steps'],
num_training_steps=num_training_steps
# Create dataloaders
processed_data_dir = current_dir.parent / 'data' / 'processed-python'
train_dataloader, valid_dataloader = create_base_dataloaders(
processed_data_dir,
tokenizer,
config,
device
)
######################## Reproducing last training here ########################
# Change opitmizer learning rate to 0.00021575814536340852
optimizer = AdamW(model.parameters(), lr=0.00021575814536340852, weight_decay=config['weight_decay'])
# Set warmup_steps to 0
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=num_training_steps)
# Load the best model weights
state_dict = torch.load('/sql/msc-patryk-bartkowiak/outputs/2024-10-21_20:15/best_model.pt', weights_only=True, map_location=device)
model.load_state_dict(state_dict)
################################################################################
train_and_evaluate(model, train_dataloader, valid_dataloader, optimizer, scheduler, config, output_dir)
# Model setup
model_config = RobertaConfig.from_pretrained('microsoft/codebert-base')
model = RobertaForMaskedLM(model_config)
model = model.to(device)
# Training setup and run
num_training_steps = config['epochs'] * len(train_dataloader)
optimizer, scheduler = create_optimizer_and_scheduler(model, config, num_training_steps)
train_and_evaluate(
model, train_dataloader, valid_dataloader,
optimizer, scheduler, config, output_dir,
log_weights=False, device=device
)
if __name__ == "__main__":
main()

View File

@ -1,78 +1,18 @@
import wandb
import sys
import math
import torch
import torch.nn as nn
from torch import Tensor
from torch.optim import AdamW
from dataclasses import dataclass
from tqdm import tqdm
from pathlib import Path
from typing import List, Dict, Any, Optional, Tuple
import ast
import numpy as np
from torch.utils.data import Dataset, DataLoader
from datasets import DatasetDict
from transformers import (
RobertaForMaskedLM,
RobertaConfig,
RobertaTokenizer,
get_linear_schedule_with_warmup,
PreTrainedTokenizer,
PreTrainedModel,
)
from typing import Dict, List, Optional
from transformers import RobertaConfig, RobertaTokenizer, RobertaForMaskedLM
sys.setrecursionlimit(3000) # Increase recursion limit
# Import existing training functionality
from train_codebert_mlm import (
from training_utils import (
set_seed, setup_wandb, setup_directories, load_config,
setup_device, download_dataset, load_and_prepare_dataset, logger
setup_device, create_optimizer_and_scheduler,
train_and_evaluate, create_base_dataloaders, PreprocessedTreeDataset,
set_deterministic_mode
)
@dataclass
class ASTNodeInfo:
"""Stores structural information about an AST node."""
node_type: str
start_token_idx: int
end_token_idx: int
depth: int
sibling_pos: int
parent_idx: int
def ast_collate_fn(batch):
"""Custom collate function with improved error handling."""
# Remove failed parses (where attention_mask is all zeros)
valid_batch = [
item for item in batch
if item['attention_mask'].sum() > 0 and len(item['ast_nodes']) > 0
]
from parse_dataset import ASTNodeInfo
if not valid_batch:
# Return minimal batch if no valid items
return {
'input_ids': torch.zeros((1, 512), dtype=torch.long),
'attention_mask': torch.zeros((1, 512), dtype=torch.long),
'labels': torch.zeros((1, 512), dtype=torch.long),
'ast_nodes': [[]]
}
# Stack tensors
input_ids = torch.stack([item['input_ids'] for item in valid_batch])
attention_mask = torch.stack([item['attention_mask'] for item in valid_batch])
labels = torch.stack([item['labels'] for item in valid_batch])
# Collect AST nodes
ast_nodes = [item['ast_nodes'] for item in valid_batch]
return {
'input_ids': input_ids,
'attention_mask': attention_mask,
'labels': labels,
'ast_nodes': ast_nodes
}
class TreePositionalEmbedding(nn.Module):
"""Generates tree-aware positional embeddings for code tokens."""
@ -126,7 +66,7 @@ class TreePositionalEmbedding(nn.Module):
embeddings[batch_idx, i] = self.combine(torch.cat([depth_emb, sibling_emb]))
return embeddings
class TreeCodeBERTForPreTraining(RobertaForMaskedLM):
"""CodeBERT model enhanced with normalized embedding weights for stable training."""
@ -244,253 +184,44 @@ class TreeCodeBERTForPreTraining(RobertaForMaskedLM):
}
}
class TreeEnhancedDataset(Dataset):
"""Dataset that processes code into tokens and AST nodes with improved error handling."""
def __init__(self, dataset: Dataset, tokenizer: PreTrainedTokenizer, max_length: int):
self.dataset = dataset
self.tokenizer = tokenizer
self.max_length = max_length
self.max_ast_size = 1000 # Limit maximum AST size to prevent memory issues
def __len__(self) -> int:
return len(self.dataset)
def create_dummy_data(self) -> Dict[str, Any]:
"""Create dummy data for invalid/problematic code samples."""
pad_token_id = self.tokenizer.pad_token_id or 0
dummy_tensor = torch.full((self.max_length,), pad_token_id, dtype=torch.long)
return {
'input_ids': dummy_tensor,
'attention_mask': torch.zeros(self.max_length, dtype=torch.long),
'labels': dummy_tensor.clone(),
'ast_nodes': []
}
def __getitem__(self, idx: int) -> Dict[str, Any]:
try:
content: str = self.dataset[idx]['content']
# Skip extremely long files
if len(content) > 50000: # ~1000 lines
return self.create_dummy_data()
# Basic code validation
if not content.strip() or '\0' in content:
return self.create_dummy_data()
# Tokenize first
encoding = self.tokenizer(
content,
max_length=self.max_length,
truncation=True,
padding='max_length',
return_tensors='pt'
)
try:
# Parse AST with timeout
tree = ast.parse(content)
# Get AST nodes info
nodes_info = []
def visit_node(node: ast.AST, depth: int, parent_idx: int, sibling_pos: int):
if len(nodes_info) >= self.max_ast_size:
return
current_idx = len(nodes_info)
if hasattr(node, 'lineno'):
nodes_info.append(ASTNodeInfo(
node_type=type(node).__name__,
start_token_idx=node.lineno,
end_token_idx=getattr(node, 'end_lineno', node.lineno),
depth=min(depth, 31), # Limit depth to prevent issues
sibling_pos=min(sibling_pos, 31),
parent_idx=parent_idx
))
# Only process children if we haven't hit the limit
if len(nodes_info) < self.max_ast_size:
for i, child in enumerate(ast.iter_child_nodes(node)):
visit_node(child, depth + 1, current_idx, i)
visit_node(tree, 0, -1, 0)
# If we hit the AST size limit, use dummy data
if len(nodes_info) >= self.max_ast_size:
return self.create_dummy_data()
return {
'input_ids': encoding['input_ids'].squeeze(0),
'attention_mask': encoding['attention_mask'].squeeze(0),
'labels': encoding['input_ids'].squeeze(0).clone(),
'ast_nodes': nodes_info
}
except (SyntaxError, ValueError, RecursionError, TimeoutError, MemoryError):
return self.create_dummy_data()
except Exception as e:
return self.create_dummy_data()
def create_tree_dataloaders(
dataset: DatasetDict,
tokenizer: PreTrainedTokenizer,
config: Dict[str, Any],
device: torch.device
) -> Tuple[DataLoader, DataLoader]:
"""Create dataloaders with tree-enhanced datasets."""
train_dataset = TreeEnhancedDataset(dataset['train'], tokenizer, max_length=512)
valid_dataset = TreeEnhancedDataset(dataset['test'], tokenizer, max_length=512)
def main() -> None:
set_deterministic_mode()
train_dataloader = DataLoader(
train_dataset,
batch_size=config['batch'],
shuffle=True,
collate_fn=ast_collate_fn,
generator=torch.Generator(device=device)
)
valid_dataloader = DataLoader(
valid_dataset,
batch_size=config['batch'],
shuffle=False,
collate_fn=ast_collate_fn,
generator=torch.Generator(device=device)
)
return train_dataloader, valid_dataloader
def train_and_evaluate(
model: PreTrainedModel,
train_dataloader: DataLoader,
valid_dataloader: DataLoader,
optimizer: AdamW,
scheduler: Any,
config: Dict[str, Any],
output_dir: Path
) -> None:
"""Training loop with explicit tracking of alpha, beta, gamma weights."""
num_training_steps: int = config['epochs'] * len(train_dataloader)
best_valid_loss: float = float('inf')
with tqdm(total=num_training_steps, desc='Training') as pbar:
for epoch_idx in range(config['epochs']):
model.train()
for train_idx, train_batch in enumerate(train_dataloader):
outputs = model(**train_batch)
train_loss = outputs["loss"]
if train_loss is not None:
train_loss.backward()
norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config['max_grad_norm'])
optimizer.step()
scheduler.step()
optimizer.zero_grad()
# Get current metrics
current_loss = train_loss.item()
weights = outputs["embedding_weights"]
# Update progress bar with all three weights
pbar.update(1)
pbar.set_postfix({
'train_loss': f"{current_loss:.3f}",
'α': f"{weights['token']:.2f}",
'β': f"{weights['tree']:.2f}",
'γ': f"{weights['sequential']:.2f}"
})
# Log all three weights separately
step = train_idx + len(train_dataloader) * epoch_idx
wandb.log({
'train_loss': current_loss,
'token_weight': weights['token'],
'tree_weight': weights['tree'],
'sequential_weight': weights['sequential'],
'gradient_norm': norm.item(),
'learning_rate': scheduler.get_last_lr()[0],
'step': step,
})
# Periodic evaluation
if train_idx != 0 and train_idx % config['eval_every'] == 0:
valid_loss, valid_acc = evaluate(model, valid_dataloader)
wandb.log({
'valid_loss': valid_loss,
'valid_acc': valid_acc,
'step': step,
})
if valid_loss < best_valid_loss:
best_valid_loss = valid_loss
torch.save(model.state_dict(), output_dir / 'best_model.pt')
else:
pbar.update(1)
pbar.set_postfix({'train_loss': 'N/A'})
logger.info(f'Best validation loss: {best_valid_loss}')
def evaluate(model: PreTrainedModel, dataloader: DataLoader) -> Tuple[float, float]:
"""Evaluation function with simple metrics."""
model.eval()
total_loss: float = 0.0
total_acc: float = 0.0
num_batches: int = 0
with torch.no_grad():
for batch in dataloader:
outputs = model(**batch)
if outputs["loss"] is not None:
total_loss += outputs["loss"].item()
logits = outputs["logits"]
labels = batch['labels']
predictions = logits.argmax(dim=-1)
total_acc += (predictions == labels).float().mean().item()
num_batches += 1
avg_loss: float = total_loss / max(num_batches, 1)
avg_acc: float = total_acc / max(num_batches, 1)
model.train()
return avg_loss, avg_acc
def main():
# Setup identical to original training script
current_dir = Path(__file__).parent
output_dir = setup_directories(current_dir)
config = load_config(current_dir / 'config.json')
setup_wandb(config, exec_file='src/train_tree_codebert_mlm.py')
output_dir = setup_directories(current_dir, model_name='tree')
config = load_config(current_dir / 'tmp_config.json')
setup_wandb(config, model_name='tree')
set_seed(config['seed'])
device = setup_device()
dataset_dir = current_dir.parent / 'data' / 'the-stack-python'
download_dataset(dataset_dir)
dataset = load_and_prepare_dataset(dataset_dir, config['seed'])
# Load tokenizer
tokenizer = RobertaTokenizer.from_pretrained('microsoft/codebert-base')
# Create tree-enhanced dataloaders
train_dataloader, valid_dataloader = create_tree_dataloaders(dataset, tokenizer, config, device)
# Initialize tree-enhanced model
model_config = RobertaConfig.from_pretrained('microsoft/codebert-base')
model = TreeCodeBERTForPreTraining(model_config)
model = model.to(device) # Just move to device without compilation
# model = torch.compile(model)
# Optimizer and scheduler setup identical to original
optimizer = AdamW(model.parameters(), lr=config['learning_rate'], weight_decay=config['weight_decay'])
num_training_steps = config['epochs'] * len(train_dataloader)
scheduler = get_linear_schedule_with_warmup(
optimizer,
num_warmup_steps=config['warmup_steps'],
num_training_steps=num_training_steps
# Create dataloaders with tree dataset
processed_data_dir = current_dir.parent / 'data' / 'processed-python'
train_dataloader, valid_dataloader = create_base_dataloaders(
processed_data_dir,
tokenizer,
config,
device,
dataset_class=PreprocessedTreeDataset
)
# Training loop (using original train_and_evaluate function)
train_and_evaluate(model, train_dataloader, valid_dataloader, optimizer, scheduler, config, output_dir)
# Model setup
model_config = RobertaConfig.from_pretrained('microsoft/codebert-base')
model = TreeCodeBERTForPreTraining(model_config)
model = model.to(device)
# Training setup and run
num_training_steps = config['epochs'] * len(train_dataloader)
optimizer, scheduler = create_optimizer_and_scheduler(model, config, num_training_steps)
train_and_evaluate(
model, train_dataloader, valid_dataloader,
optimizer, scheduler, config, output_dir,
log_weights=True, device=device
)
if __name__ == "__main__":
main()

450
code/src/training_utils.py Normal file
View File

@ -0,0 +1,450 @@
import os
import json
import random
import datetime
import platform
import logging
import wandb
import numpy as np
import torch
from pathlib import Path
from typing import Dict, Any, Tuple, Type, Optional
from datasets import load_from_disk, concatenate_datasets
from torch.optim import AdamW
from torch.utils.data import DataLoader, Dataset
from transformers import (
PreTrainedModel,
get_linear_schedule_with_warmup,
DataCollatorForLanguageModeling,
PreTrainedTokenizer
)
from tqdm import tqdm
from parse_dataset import ASTNodeInfo
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
def set_deterministic_mode() -> None:
"""Enable deterministic mode for reproducibility."""
# Set Python random seed
random.seed(42)
# Set NumPy random seed
np.random.seed(42)
# Set PyTorch random seed
torch.manual_seed(42)
# Set CUDA random seed
if torch.cuda.is_available():
torch.cuda.manual_seed_all(42)
# Enable deterministic operations for PyTorch 2.x
torch.use_deterministic_algorithms(True)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# Set environment variables for reproducibility
os.environ['PYTHONHASHSEED'] = '42'
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8' # Required for CUDA >= 10.2
def get_device_info() -> Dict[str, Any]:
"""Get detailed information about the computing environment."""
info = {
'python_version': platform.python_version(),
'torch_version': torch.__version__,
'cuda_available': torch.cuda.is_available(),
'deterministic_algorithms': torch.are_deterministic_algorithms_enabled(),
'cudnn_deterministic': torch.backends.cudnn.deterministic,
'cudnn_benchmark': torch.backends.cudnn.benchmark,
}
if torch.cuda.is_available():
info.update({
'cuda_version': torch.version.cuda,
'gpu_name': torch.cuda.get_device_name(0),
'gpu_count': torch.cuda.device_count(),
})
return info
class TreeDataCollator(DataCollatorForLanguageModeling):
"""Custom data collator that handles both MLM and optional AST node information."""
def torch_call(self, examples):
# Check if we have AST nodes
has_ast = 'ast_nodes' in examples[0]
# Extract AST nodes before MLM processing if they exist
ast_nodes = None
if has_ast:
ast_nodes = [e.pop('ast_nodes') for e in examples]
# Process normal MLM features
batch = super().torch_call(examples)
# Add AST nodes back to batch if they existed
if has_ast:
batch['ast_nodes'] = ast_nodes
return batch
def __call__(self, examples):
return self.torch_call(examples)
class PreprocessedBaseDataset(Dataset):
"""Dataset that uses pre-processed data without AST information."""
def __init__(self, dataset: Any):
self.dataset = dataset
def __len__(self) -> int:
return len(self.dataset)
def __getitem__(self, idx: int) -> Dict[str, Any]:
item = self.dataset[idx]
return {
'input_ids': torch.tensor(item['input_ids']),
'attention_mask': torch.tensor(item['attention_mask']),
'labels': torch.tensor(item['input_ids']).clone()
}
class PreprocessedTreeDataset(PreprocessedBaseDataset):
"""Dataset that includes AST information."""
def __getitem__(self, idx: int) -> Dict[str, Any]:
item = super().__getitem__(idx)
item['ast_nodes'] = [ASTNodeInfo.from_dict(node) for node in self.dataset[idx]['ast_nodes']]
return item
def set_seed(seed: int) -> None:
"""Set random seeds for reproducibility."""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
def setup_wandb(
config: Dict[str, Any],
model_name: str = 'base',
device_info: Optional[Dict[str, Any]] = None
) -> None:
"""Initialize W&B logging with reproducibility information."""
curr_time: str = datetime.datetime.now().strftime('%Y-%m-%d_%H:%M')
# Add reproducibility settings to config
full_config = {
**config,
'random_seed': 42,
'deterministic_mode': True,
'device_info': device_info or get_device_info(),
}
wandb.init(
project='codebert-training-test',
name=f'{model_name}_{curr_time}',
config=full_config
)
def setup_directories(current_dir: Path, model_name: str = 'base') -> Path:
"""Create output directories for model artifacts."""
curr_time: str = datetime.datetime.now().strftime('%Y-%m-%d_%H:%M')
output_dir: Path = current_dir.parent.parent / 'outputs' / f'{model_name}_{curr_time}'
output_dir.mkdir(parents=True, exist_ok=True)
return output_dir
def load_config(config_file: Path) -> Dict[str, Any]:
"""Load training configuration from JSON file."""
with open(config_file, 'r') as f:
return json.load(f)
def setup_device() -> torch.device:
"""Setup and configure training device."""
device: torch.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.set_default_device(device)
logger.info(f'Using device: {device}')
if device.type == 'cuda':
logger.info(f'Device name: {torch.cuda.get_device_name()}')
torch.set_float32_matmul_precision('high')
return device
def create_optimizer_and_scheduler(
model: PreTrainedModel,
config: Dict[str, Any],
num_training_steps: int
) -> Tuple[AdamW, Any]:
"""Create optimizer and learning rate scheduler."""
optimizer = AdamW(
model.parameters(),
lr=config['learning_rate'],
weight_decay=config['weight_decay']
)
scheduler = get_linear_schedule_with_warmup(
optimizer,
num_warmup_steps=config['warmup_steps'],
num_training_steps=num_training_steps
)
return optimizer, scheduler
def evaluate(model: PreTrainedModel, dataloader: DataLoader) -> Tuple[float, float]:
"""Evaluate model on validation set."""
model.eval()
total_loss: float = 0.0
total_acc: float = 0.0
with torch.no_grad():
for batch in tqdm(dataloader, desc='Validation'):
outputs = model(**batch)
total_loss += outputs.loss.item() if hasattr(outputs, 'loss') else outputs['loss'].item()
logits = outputs.logits if hasattr(outputs, 'logits') else outputs['logits']
total_acc += logits.argmax(dim=-1).eq(batch['labels']).sum().item()
avg_loss: float = total_loss / len(dataloader)
avg_acc: float = total_acc / len(dataloader.dataset)
model.train()
return avg_loss, avg_acc
def train_and_evaluate(
model: PreTrainedModel,
train_dataloader: DataLoader,
valid_dataloader: DataLoader,
optimizer: AdamW,
scheduler: Any,
config: Dict[str, Any],
output_dir: Path,
log_weights: bool = False,
device: torch.device = torch.device('cpu')
) -> None:
"""Train and evaluate model with deterministic behavior and comprehensive logging."""
# Enable deterministic algorithms for PyTorch 2.5
torch.use_deterministic_algorithms(True)
num_training_steps: int = config['epochs'] * len(train_dataloader)
best_valid_loss: float = float('inf')
# Save initial model state for reproducibility
torch.save({
'model_state': model.state_dict(),
'optimizer_state': optimizer.state_dict(),
'scheduler_state': scheduler.state_dict(),
'rng_state': torch.get_rng_state(),
'cuda_rng_state': torch.cuda.get_rng_state() if torch.cuda.is_available() else None
}, output_dir / 'initial_state.pt')
with tqdm(total=num_training_steps, desc='Training') as pbar:
for epoch_idx in range(config['epochs']):
model.train()
epoch_loss = 0.0
epoch_steps = 0
for train_idx, train_batch in enumerate(train_dataloader):
# Move batch to device
train_batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v
for k, v in train_batch.items()}
# Forward pass
outputs = model(**train_batch)
train_loss = outputs.loss if hasattr(outputs, 'loss') else outputs['loss']
# Backward pass with deterministic behavior
train_loss.backward()
grad_norm = torch.nn.utils.clip_grad_norm_(
model.parameters(),
config['max_grad_norm']
)
optimizer.step()
scheduler.step()
optimizer.zero_grad(set_to_none=False) # Use False for determinism
# Update metrics
current_loss = train_loss.item()
epoch_loss += current_loss
epoch_steps += 1
# Calculate global step
step = train_idx + len(train_dataloader) * epoch_idx
# Prepare logging dictionary
log_dict = {
'step': step,
'epoch': epoch_idx,
'train_loss': current_loss,
'gradient_norm': grad_norm.item(),
'learning_rate': scheduler.get_last_lr()[0],
}
# Add embedding weights if using tree model
if log_weights and 'embedding_weights' in outputs:
weights = outputs['embedding_weights']
log_dict.update({
'token_weight': weights['token'],
'tree_weight': weights['tree'],
'sequential_weight': weights['sequential']
})
pbar_dict = {
'epoch': f"{epoch_idx + 1}/{config['epochs']}",
'train_loss': f"{current_loss:.3f}",
'α': f"{weights['token']:.2f}",
'β': f"{weights['tree']:.2f}",
'γ': f"{weights['sequential']:.2f}"
}
else:
pbar_dict = {
'epoch': f"{epoch_idx + 1}/{config['epochs']}",
'train_loss': f"{current_loss:.3f}"
}
# Log to wandb
wandb.log(log_dict)
# Update progress bar
pbar.update(1)
pbar.set_postfix(pbar_dict)
# Periodic evaluation
if (train_idx != 0 and train_idx % config['eval_every'] == 0) or train_idx == len(train_dataloader) - 1:
model.eval()
valid_loss, valid_acc = evaluate(model, valid_dataloader)
model.train()
# Log validation metrics
wandb.log({
'valid_loss': valid_loss,
'valid_acc': valid_acc,
'step': step,
'epoch': epoch_idx,
})
# Save checkpoint if best model
if valid_loss < best_valid_loss:
best_valid_loss = valid_loss
# Save complete state
torch.save({
'epoch': epoch_idx,
'step': step,
'model_state': model.state_dict(),
'optimizer_state': optimizer.state_dict(),
'scheduler_state': scheduler.state_dict(),
'rng_state': torch.get_rng_state(),
'cuda_rng_state': torch.cuda.get_rng_state() if torch.cuda.is_available() else None,
'best_valid_loss': best_valid_loss,
'config': config,
'device_info': get_device_info()
}, output_dir / 'best_model.pt')
# Log best metrics
wandb.run.summary['best_valid_loss'] = best_valid_loss
wandb.run.summary['best_valid_acc'] = valid_acc
wandb.run.summary['best_epoch'] = epoch_idx
wandb.run.summary['best_step'] = step
# End of epoch logging
avg_epoch_loss = epoch_loss / epoch_steps
wandb.log({
'epoch': epoch_idx,
'epoch_avg_loss': avg_epoch_loss,
})
# Save end of epoch checkpoint
torch.save({
'epoch': epoch_idx,
'model_state': model.state_dict(),
'optimizer_state': optimizer.state_dict(),
'scheduler_state': scheduler.state_dict(),
'rng_state': torch.get_rng_state(),
'cuda_rng_state': torch.cuda.get_rng_state() if torch.cuda.is_available() else None,
'train_loss': avg_epoch_loss,
'config': config,
'device_info': get_device_info()
}, output_dir / f'checkpoint_epoch_{epoch_idx}.pt')
# End of training logging
wandb.run.summary['final_epoch'] = config['epochs'] - 1
wandb.run.summary['total_steps'] = num_training_steps
logger.info(f'Training completed. Best validation loss: {best_valid_loss:.4f}')
logger.info(f'Model checkpoints saved in {output_dir}')
def load_all_chunks(chunks_dir: Path) -> Any:
"""Load and concatenate all dataset chunks."""
logger.info(f"Loading dataset chunks from {chunks_dir}")
chunks = []
for chunk_path in sorted(chunks_dir.glob('chunk_*'))[:5]:
chunks.append(load_from_disk(str(chunk_path)))
dataset = concatenate_datasets(chunks)
logger.info(f"Loaded {len(dataset)} examples from {len(chunks)} chunks")
return dataset
def create_base_dataloaders(
processed_data_dir: Path,
tokenizer: PreTrainedTokenizer,
config: Dict[str, Any],
device: torch.device,
dataset_class: Type[Dataset] = PreprocessedBaseDataset,
) -> Tuple[DataLoader, DataLoader]:
"""Create reproducible dataloaders from pre-processed data."""
# Load chunks
chunks_dir = processed_data_dir / 'chunks'
full_dataset = load_all_chunks(chunks_dir)
# Calculate split sizes
dataset_size = len(full_dataset)
train_size = int(config['train_size'] * dataset_size)
val_size = dataset_size - train_size
# Create splits using indices to avoid device issues
indices = torch.arange(dataset_size, device=device)
# Create a deterministic generator on the correct device
generator = torch.Generator(device=device)
generator.manual_seed(42)
# Shuffle indices deterministically
shuffled_indices = indices[torch.randperm(len(indices), device=device, generator=generator)]
train_indices = shuffled_indices[:train_size].cpu()
valid_indices = shuffled_indices[train_size:].cpu()
# Create train/validation splits using the shuffled indices
train_dataset = torch.utils.data.Subset(full_dataset, train_indices.tolist())
valid_dataset = torch.utils.data.Subset(full_dataset, valid_indices.tolist())
# Create dataset wrappers
train_dataset = dataset_class(train_dataset)
valid_dataset = dataset_class(valid_dataset)
logger.info(f"Created train dataset with {len(train_dataset)} samples")
logger.info(f"Created validation dataset with {len(valid_dataset)} samples")
# Use TreeDataCollator for both base and tree models
data_collator = TreeDataCollator(
tokenizer=tokenizer,
mlm=True,
mlm_probability=config['mlm_probability']
)
# Create train dataloader without generator
train_dataloader = DataLoader(
train_dataset,
batch_size=config['batch'],
shuffle=False, # We already shuffled the data
collate_fn=data_collator,
num_workers=0, # Use single worker for reproducibility
drop_last=True # Ensure consistent batch sizes
)
# Create validation dataloader
valid_dataloader = DataLoader(
valid_dataset,
batch_size=config['batch'],
shuffle=False,
collate_fn=data_collator,
num_workers=0,
drop_last=True
)
return train_dataloader, valid_dataloader

View File

@ -1,36 +0,0 @@
def remove_docstrings_and_comments_from_code(code, parser):
# Parse the code
tree = parser.parse(bytes(code, "utf8"))
cursor = tree.walk()
# Traverse the tree and collect all docstrings
to_remove = []
def traverse_tree(cursor, prev_node_type=None):
node_type = cursor.node.type
node_text = cursor.node.text.decode("utf-8")
# Check if the current node is a function or class definition
if node_type == "string" and node_text.startswith('"""') and node_text.endswith('"""') and prev_node_type == "expression_statement":
to_remove.append((cursor.node.start_byte, cursor.node.end_byte))
if cursor.node.type == "comment":
to_remove.append((cursor.node.start_byte, cursor.node.end_byte))
# Traverse children
if cursor.goto_first_child():
while True:
traverse_tree(cursor, node_type)
if not cursor.goto_next_sibling():
break
cursor.goto_parent()
return node_type
# Start traversing from the root
traverse_tree(cursor)
# Remove docstrings from code
code_without_docstrings = code
for start, end in sorted(to_remove, reverse=True):
code_without_docstrings = code_without_docstrings[:start] + code_without_docstrings[end:]
return code_without_docstrings