standardized utils for training
This commit is contained in:
parent
96fc1041cf
commit
0cd04e6131
@ -37,3 +37,4 @@ distribution = true
|
||||
[tool.pdm.scripts]
|
||||
run_training = {cmd = "src/train_codebert_mlm.py"}
|
||||
run_tree_training = {cmd = "src/train_tree_codebert_mlm.py"}
|
||||
parse_dataset = {cmd = "src/parse_dataset.py"}
|
||||
|
@ -8,5 +8,5 @@
|
||||
"weight_decay": 0.01,
|
||||
"max_grad_norm": 1.0,
|
||||
"warmup_steps": 20000,
|
||||
"pause_on_instability": false
|
||||
"train_size": 0.95
|
||||
}
|
232
code/src/parse_dataset.py
Normal file
232
code/src/parse_dataset.py
Normal file
@ -0,0 +1,232 @@
|
||||
import ast
|
||||
from pathlib import Path
|
||||
import logging
|
||||
from tqdm import tqdm
|
||||
from pprint import pprint
|
||||
from typing import Dict, Any, List
|
||||
from datasets import load_dataset, Dataset, concatenate_datasets
|
||||
from transformers import RobertaTokenizer
|
||||
from dataclasses import dataclass
|
||||
import numpy as np
|
||||
import multiprocessing
|
||||
|
||||
import warnings
|
||||
warnings.filterwarnings("ignore", category=SyntaxWarning)
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@dataclass
|
||||
class ASTNodeInfo:
|
||||
"""Stores structural information about an AST node."""
|
||||
node_type: str
|
||||
start_token_idx: int
|
||||
end_token_idx: int
|
||||
depth: int
|
||||
sibling_pos: int
|
||||
parent_idx: int
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
'node_type': self.node_type,
|
||||
'start_token_idx': self.start_token_idx,
|
||||
'end_token_idx': self.end_token_idx,
|
||||
'depth': self.depth,
|
||||
'sibling_pos': self.sibling_pos,
|
||||
'parent_idx': self.parent_idx
|
||||
}
|
||||
|
||||
def from_dict(data: Dict[str, Any]) -> 'ASTNodeInfo':
|
||||
return ASTNodeInfo(
|
||||
node_type=data['node_type'],
|
||||
start_token_idx=data['start_token_idx'],
|
||||
end_token_idx=data['end_token_idx'],
|
||||
depth=data['depth'],
|
||||
sibling_pos=data['sibling_pos'],
|
||||
parent_idx=data['parent_idx']
|
||||
)
|
||||
|
||||
def safe_ast_parse(content: str, max_length: int = 50000) -> bool:
|
||||
"""Safely attempt to parse Python code."""
|
||||
if not content or not content.strip() or '\0' in content or len(content) > max_length:
|
||||
return False
|
||||
|
||||
try:
|
||||
tree = ast.parse(content)
|
||||
return bool(list(ast.walk(tree)))
|
||||
except:
|
||||
return False
|
||||
|
||||
def process_and_extract_ast_info(content: str, max_ast_size: int = 1000) -> List[Dict[str, Any]]:
|
||||
"""Process AST and extract node information."""
|
||||
try:
|
||||
tree = ast.parse(content)
|
||||
nodes_info = []
|
||||
|
||||
def visit_node(node: ast.AST, depth: int, parent_idx: int, sibling_pos: int):
|
||||
if len(nodes_info) >= max_ast_size:
|
||||
return
|
||||
|
||||
current_idx = len(nodes_info)
|
||||
if hasattr(node, 'lineno'):
|
||||
node_info = ASTNodeInfo(
|
||||
node_type=type(node).__name__,
|
||||
start_token_idx=node.lineno,
|
||||
end_token_idx=getattr(node, 'end_lineno', node.lineno),
|
||||
depth=min(depth, 31),
|
||||
sibling_pos=min(sibling_pos, 31),
|
||||
parent_idx=parent_idx
|
||||
)
|
||||
nodes_info.append(node_info.to_dict())
|
||||
|
||||
if len(nodes_info) < max_ast_size:
|
||||
for i, child in enumerate(ast.iter_child_nodes(node)):
|
||||
visit_node(child, depth + 1, current_idx, i)
|
||||
|
||||
visit_node(tree, 0, -1, 0)
|
||||
return nodes_info
|
||||
except:
|
||||
return []
|
||||
|
||||
def process_batch(examples: Dict[str, List[Any]], tokenizer: RobertaTokenizer) -> Dict[str, List[Any]]:
|
||||
"""Process a batch of examples."""
|
||||
contents = examples['content']
|
||||
|
||||
processed_contents = []
|
||||
processed_ast_nodes = []
|
||||
processed_input_ids = []
|
||||
processed_attention_masks = []
|
||||
|
||||
for content in contents:
|
||||
if safe_ast_parse(content):
|
||||
ast_nodes = process_and_extract_ast_info(content)
|
||||
if ast_nodes:
|
||||
try:
|
||||
encoding = tokenizer(
|
||||
content,
|
||||
max_length=512,
|
||||
truncation=True,
|
||||
padding='max_length',
|
||||
return_tensors='pt'
|
||||
)
|
||||
|
||||
processed_contents.append(content)
|
||||
processed_ast_nodes.append(ast_nodes)
|
||||
processed_input_ids.append(encoding['input_ids'].squeeze(0).tolist())
|
||||
processed_attention_masks.append(encoding['attention_mask'].squeeze(0).tolist())
|
||||
except:
|
||||
continue
|
||||
|
||||
return {
|
||||
'content': processed_contents,
|
||||
'ast_nodes': processed_ast_nodes,
|
||||
'input_ids': processed_input_ids,
|
||||
'attention_mask': processed_attention_masks,
|
||||
}
|
||||
|
||||
def save_dataset_in_chunks(dataset: Dataset, output_path: str, chunk_size: int = 10000):
|
||||
"""Save dataset to disk in chunks to manage memory usage."""
|
||||
num_chunks = (len(dataset) + chunk_size - 1) // chunk_size
|
||||
|
||||
# Create directory for chunks
|
||||
output_dir = Path(output_path).parent
|
||||
chunks_dir = output_dir / 'chunks'
|
||||
chunks_dir.mkdir(exist_ok=True)
|
||||
|
||||
stats_accumulator = {
|
||||
'total_ast_nodes': 0,
|
||||
'total_samples': 0
|
||||
}
|
||||
|
||||
pbar = tqdm(range(num_chunks), desc="Saving chunks", unit="chunk")
|
||||
for i in pbar:
|
||||
start_idx = i * chunk_size
|
||||
end_idx = min((i + 1) * chunk_size, len(dataset))
|
||||
|
||||
# Select chunk from dataset
|
||||
chunk_dataset = dataset.select(range(start_idx, end_idx))
|
||||
|
||||
# Update statistics
|
||||
stats_accumulator['total_ast_nodes'] += sum(len(nodes) for nodes in chunk_dataset['ast_nodes'])
|
||||
stats_accumulator['total_samples'] += len(chunk_dataset)
|
||||
|
||||
# Save chunk using datasets native method
|
||||
chunk_path = chunks_dir / f'chunk_{i:04d}'
|
||||
chunk_dataset.save_to_disk(str(chunk_path))
|
||||
|
||||
# Update progress bar postfix with current chunk info
|
||||
pbar.set_postfix({
|
||||
'samples': len(chunk_dataset),
|
||||
'path': str(chunk_path.name)
|
||||
})
|
||||
|
||||
return stats_accumulator
|
||||
|
||||
def load_all_chunks(chunks_dir: Path) -> Dataset:
|
||||
"""Load and concatenate all dataset chunks."""
|
||||
chunks = []
|
||||
for chunk_path in sorted(chunks_dir.glob('chunk_*')):
|
||||
chunks.append(load_from_disk(str(chunk_path)))
|
||||
return concatenate_datasets(chunks)
|
||||
|
||||
def main():
|
||||
current_dir = Path(__file__).parent
|
||||
input_dir = current_dir.parent / 'data' / 'the-stack-python'
|
||||
output_dir = current_dir.parent / 'data' / 'processed-python'
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
tokenizer = RobertaTokenizer.from_pretrained('microsoft/codebert-base')
|
||||
|
||||
logger.info("Loading dataset...")
|
||||
dataset = load_dataset(str(input_dir))['train']
|
||||
original_dataset_size = len(dataset)
|
||||
|
||||
logger.info("Dataset:")
|
||||
pprint(dataset)
|
||||
|
||||
num_proc = min(multiprocessing.cpu_count() - 1, 32)
|
||||
logger.info(f"Using {num_proc} processes for dataset processing")
|
||||
|
||||
logger.info("Processing dataset...")
|
||||
processed_dataset = dataset.map(
|
||||
process_batch,
|
||||
fn_kwargs={'tokenizer': tokenizer},
|
||||
batched=True,
|
||||
remove_columns=dataset.column_names,
|
||||
desc="Processing dataset",
|
||||
num_proc=num_proc,
|
||||
load_from_cache_file=False
|
||||
)
|
||||
|
||||
processed_dataset = processed_dataset.filter(
|
||||
lambda batch: [len(nodes) > 0 for nodes in batch['ast_nodes']],
|
||||
batched=True,
|
||||
num_proc=num_proc,
|
||||
desc="Filtering invalid examples"
|
||||
)
|
||||
reduced_dataset_size = len(processed_dataset)
|
||||
|
||||
logger.info("Processed dataset:")
|
||||
pprint(processed_dataset)
|
||||
|
||||
logger.info(f"Saving {len(processed_dataset)} processed examples in chunks...")
|
||||
stats_accumulator = save_dataset_in_chunks(
|
||||
processed_dataset,
|
||||
str(output_dir / 'processed_dataset'),
|
||||
chunk_size=100_000
|
||||
)
|
||||
|
||||
stats = {
|
||||
'original_dataset_size': original_dataset_size,
|
||||
'reduced_dataset_size': reduced_dataset_size,
|
||||
'% samples removed': (1 - reduced_dataset_size / original_dataset_size) * 100,
|
||||
'avg_ast_nodes': float(stats_accumulator['total_ast_nodes'] / stats_accumulator['total_samples'])
|
||||
}
|
||||
|
||||
# Log stats with pprint
|
||||
logger.info("Processing completed! Stats:")
|
||||
pprint(stats)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -1,72 +0,0 @@
|
||||
import torch
|
||||
from transformers import RobertaTokenizer, RobertaForMaskedLM, DataCollatorForLanguageModeling
|
||||
from pathlib import Path
|
||||
|
||||
def load_model_and_tokenizer(model_path: Path, tokenizer_name: str = 'microsoft/codebert-base'):
|
||||
# Load the pre-trained tokenizer
|
||||
tokenizer = RobertaTokenizer.from_pretrained(tokenizer_name)
|
||||
|
||||
# Load the trained model
|
||||
state_dict = torch.load(model_path, weights_only=True)
|
||||
corrected_state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}
|
||||
model = RobertaForMaskedLM.from_pretrained('roberta-base')
|
||||
model.load_state_dict(corrected_state_dict)
|
||||
|
||||
return model, tokenizer
|
||||
|
||||
def run_inference(model, tokenizer, input_text: str, max_length: int = 512):
|
||||
# Tokenize the input text
|
||||
inputs = tokenizer(
|
||||
input_text,
|
||||
return_tensors='pt',
|
||||
padding='max_length',
|
||||
truncation=True,
|
||||
max_length=max_length
|
||||
)
|
||||
|
||||
# Use DataCollatorForLanguageModeling for MLM-style dynamic masking
|
||||
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=True, mlm_probability=0.15)
|
||||
inputs = data_collator([inputs]) # Collate the input and add random masks
|
||||
|
||||
# Squeeze the batch dimension
|
||||
for k, v in inputs.items():
|
||||
inputs[k] = v.squeeze(0)
|
||||
|
||||
print(tokenizer.decode(inputs['input_ids'][0][inputs['attention_mask'][0] == 1], skip_special_tokens=False))
|
||||
|
||||
with torch.no_grad():
|
||||
# Run inference
|
||||
outputs = model(**inputs).logits.squeeze(0)
|
||||
|
||||
# Get predicted token ids (argmax over logits)
|
||||
predicted_token_ids = outputs.argmax(dim=-1)
|
||||
|
||||
# Ignore padding tokens
|
||||
predicted_token_ids = predicted_token_ids[inputs['attention_mask'][0] == 1]
|
||||
|
||||
# Decode predicted token ids to text
|
||||
predicted_text = tokenizer.decode(predicted_token_ids, skip_special_tokens=False)
|
||||
|
||||
return predicted_text
|
||||
|
||||
def main_inference():
|
||||
# Load the trained model and tokenizer
|
||||
model_path = Path('/sql/msc-patryk-bartkowiak/outputs/2024-10-21_20:15/best_model.pt') # Update this to your trained model path
|
||||
model, tokenizer = load_model_and_tokenizer(model_path)
|
||||
|
||||
# Define input string
|
||||
input_string = """def compute_area(radius):
|
||||
# This function calculates the area of a circle
|
||||
pi = 3.14159
|
||||
return pi * radius ** 2"""
|
||||
|
||||
# Run inference
|
||||
print(f"Input:\n{input_string}")
|
||||
print("\n\n")
|
||||
print("Masked input:")
|
||||
output = run_inference(model, tokenizer, input_string)
|
||||
print("\n\n")
|
||||
print(f"Output:\n{output}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main_inference()
|
File diff suppressed because it is too large
Load Diff
@ -1,374 +0,0 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/home/s452638/magisterka/magisterka_env/lib/python3.8/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
||||
" from .autonotebook import tqdm as notebook_tqdm\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from tree_sitter import Language, Parser\n",
|
||||
"from datasets import load_dataset"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Dataset({\n",
|
||||
" features: ['repo', 'path', 'func_name', 'original_string', 'language', 'code', 'code_tokens', 'docstring', 'docstring_tokens', 'sha', 'url', 'partition'],\n",
|
||||
" num_rows: 251820\n",
|
||||
"})\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Load the dataset\n",
|
||||
"dataset = load_dataset('json', data_files={'train': '/work/s452638/datasets/CodeSearchNet/python/train.jsonl'}, split='train')\n",
|
||||
"print(dataset)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/home/s452638/magisterka/magisterka_env/lib/python3.8/site-packages/tree_sitter/__init__.py:36: FutureWarning: Language.build_library is deprecated. Use the new bindings instead.\n",
|
||||
" warn(\"{} is deprecated. Use {} instead.\".format(old, new), FutureWarning)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"False"
|
||||
]
|
||||
},
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Build the language library if not already built\n",
|
||||
"# This should be done only once, and the resulting .so file can be reused\n",
|
||||
"Language.build_library(\n",
|
||||
" '/home/s452638/magisterka/build/my-languages.so', # Output location of compiled language library\n",
|
||||
" [\n",
|
||||
" '/home/s452638/magisterka/vendor/tree-sitter-python' # Replace with the path to the tree-sitter-python grammar\n",
|
||||
" ]\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/home/s452638/magisterka/magisterka_env/lib/python3.8/site-packages/tree_sitter/__init__.py:36: FutureWarning: Language(path, name) is deprecated. Use Language(ptr, name) instead.\n",
|
||||
" warn(\"{} is deprecated. Use {} instead.\".format(old, new), FutureWarning)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Load the language\n",
|
||||
"PYTHON_LANGUAGE = Language('/home/s452638/magisterka/build/my-languages.so', 'python')\n",
|
||||
"\n",
|
||||
"# Initialize the parser\n",
|
||||
"parser = Parser()\n",
|
||||
"parser.set_language(PYTHON_LANGUAGE)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def remove_docstrings_and_comments_from_code(code):\n",
|
||||
" # Parse the code\n",
|
||||
" tree = parser.parse(bytes(code, \"utf8\"))\n",
|
||||
" cursor = tree.walk()\n",
|
||||
"\n",
|
||||
" # Traverse the tree and collect all docstrings\n",
|
||||
" to_remove = []\n",
|
||||
"\n",
|
||||
" def traverse_tree(cursor, prev_node_type=None):\n",
|
||||
" node_type = cursor.node.type\n",
|
||||
" node_text = cursor.node.text.decode(\"utf-8\")\n",
|
||||
" # Check if the current node is a function or class definition\n",
|
||||
" if node_type == \"string\" and node_text.startswith('\"\"\"') and node_text.endswith('\"\"\"') and prev_node_type == \"expression_statement\":\n",
|
||||
" to_remove.append((cursor.node.start_byte, cursor.node.end_byte))\n",
|
||||
" if cursor.node.type == \"comment\":\n",
|
||||
" to_remove.append((cursor.node.start_byte, cursor.node.end_byte))\n",
|
||||
"\n",
|
||||
" # Traverse children\n",
|
||||
" if cursor.goto_first_child():\n",
|
||||
" while True:\n",
|
||||
" traverse_tree(cursor, node_type)\n",
|
||||
" if not cursor.goto_next_sibling():\n",
|
||||
" break\n",
|
||||
" cursor.goto_parent()\n",
|
||||
"\n",
|
||||
" return node_type\n",
|
||||
"\n",
|
||||
" # Start traversing from the root\n",
|
||||
" traverse_tree(cursor)\n",
|
||||
"\n",
|
||||
" # Remove docstrings from code\n",
|
||||
" code_without_docstrings = code\n",
|
||||
" for start, end in sorted(to_remove, reverse=True):\n",
|
||||
" code_without_docstrings = code_without_docstrings[:start] + code_without_docstrings[end:]\n",
|
||||
"\n",
|
||||
" return code_without_docstrings"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"def gather_categories(imap, header, categories=None):\n",
|
||||
" \"\"\"\n",
|
||||
" Find the user specified categories in the map and create a dictionary to contain the\n",
|
||||
" relevant data for each type within the categories. Multiple categories will have their\n",
|
||||
" types combined such that each possible combination will have its own entry in the\n",
|
||||
" dictionary.\n",
|
||||
"\n",
|
||||
" :type imap: dict\n",
|
||||
" :param imap: The input mapping file data keyed by SampleID\n",
|
||||
" :type header: list\n",
|
||||
" :param header: The header line from the input mapping file. This will be searched for\n",
|
||||
" the user-specified categories\n",
|
||||
" :type categories: list\n",
|
||||
" :param categories: The list of user-specified category column name from mapping file\n",
|
||||
" :rtype: dict\n",
|
||||
" :return: A sorted dictionary keyed on the combinations of all the types found within\n",
|
||||
" the user-specified categories. Each entry will contain an empty DataCategory\n",
|
||||
" namedtuple. If no categories are specified, a single entry with the key\n",
|
||||
" 'default' will be returned\n",
|
||||
" \"\"\"\n",
|
||||
" # If no categories provided, return all SampleIDs\n",
|
||||
" if categories is None:\n",
|
||||
" return {\"default\": DataCategory(set(imap.keys()), {})}\n",
|
||||
"\n",
|
||||
" cat_ids = [header.index(cat)\n",
|
||||
" for cat in categories if cat in header and \"=\" not in cat]\n",
|
||||
"\n",
|
||||
" table = OrderedDict()\n",
|
||||
" conditions = defaultdict(set)\n",
|
||||
" for i, cat in enumerate(categories):\n",
|
||||
" if \"=\" in cat and cat.split(\"=\")[0] in header:\n",
|
||||
" cat_name = header[header.index(cat.split(\"=\")[0])]\n",
|
||||
" conditions[cat_name].add(cat.split(\"=\")[1])\n",
|
||||
"\n",
|
||||
" # If invalid categories or conditions identified, return all SampleIDs\n",
|
||||
" if not cat_ids and not conditions:\n",
|
||||
" return {\"default\": DataCategory(set(imap.keys()), {})}\n",
|
||||
"\n",
|
||||
" #If only category column given, return column-wise SampleIDs\n",
|
||||
" if cat_ids and not conditions:\n",
|
||||
" for sid, row in imap.items():\n",
|
||||
" cat_name = \"_\".join([row[cid] for cid in cat_ids])\n",
|
||||
" if cat_name not in table:\n",
|
||||
" table[cat_name] = DataCategory(set(), {})\n",
|
||||
" table[cat_name].sids.add(sid)\n",
|
||||
" return table\n",
|
||||
"\n",
|
||||
" # Collect all condition names\n",
|
||||
" cond_ids = set()\n",
|
||||
" for k in conditions:\n",
|
||||
" try:\n",
|
||||
" cond_ids.add(header.index(k))\n",
|
||||
" except ValueError:\n",
|
||||
" continue\n",
|
||||
" idx_to_test = set(cat_ids).union(cond_ids)\n",
|
||||
"\n",
|
||||
" # If column name and condition given, return overlapping SampleIDs of column and\n",
|
||||
" # condition combinations\n",
|
||||
" for sid, row in imap.items():\n",
|
||||
" if all([row[header.index(c)] in conditions[c] for c in conditions]):\n",
|
||||
" key = \"_\".join([row[idx] for idx in idx_to_test])\n",
|
||||
" try:\n",
|
||||
" assert key in table.keys()\n",
|
||||
" except AssertionError:\n",
|
||||
" table[key] = DataCategory(set(), {})\n",
|
||||
" table[key].sids.add(sid)\n",
|
||||
" try:\n",
|
||||
" assert len(table) > 0\n",
|
||||
" except AssertionError:\n",
|
||||
" return {\"default\": DataCategory(set(imap.keys()), {})}\n",
|
||||
" else:\n",
|
||||
" return table\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"idx = 3\n",
|
||||
"print(dataset[idx]['code'])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"def gather_categories(imap, header, categories=None):\n",
|
||||
" \n",
|
||||
" \n",
|
||||
" if categories is None:\n",
|
||||
" return {\"default\": DataCategory(set(imap.keys()), {})}\n",
|
||||
"\n",
|
||||
" cat_ids = [header.index(cat)\n",
|
||||
" for cat in categories if cat in header and \"=\" not in cat]\n",
|
||||
"\n",
|
||||
" table = OrderedDict()\n",
|
||||
" conditions = defaultdict(set)\n",
|
||||
" for i, cat in enumerate(categories):\n",
|
||||
" if \"=\" in cat and cat.split(\"=\")[0] in header:\n",
|
||||
" cat_name = header[header.index(cat.split(\"=\")[0])]\n",
|
||||
" conditions[cat_name].add(cat.split(\"=\")[1])\n",
|
||||
"\n",
|
||||
" \n",
|
||||
" if not cat_ids and not conditions:\n",
|
||||
" return {\"default\": DataCategory(set(imap.keys()), {})}\n",
|
||||
"\n",
|
||||
" \n",
|
||||
" if cat_ids and not conditions:\n",
|
||||
" for sid, row in imap.items():\n",
|
||||
" cat_name = \"_\".join([row[cid] for cid in cat_ids])\n",
|
||||
" if cat_name not in table:\n",
|
||||
" table[cat_name] = DataCategory(set(), {})\n",
|
||||
" table[cat_name].sids.add(sid)\n",
|
||||
" return table\n",
|
||||
"\n",
|
||||
" \n",
|
||||
" cond_ids = set()\n",
|
||||
" for k in conditions:\n",
|
||||
" try:\n",
|
||||
" cond_ids.add(header.index(k))\n",
|
||||
" except ValueError:\n",
|
||||
" continue\n",
|
||||
" idx_to_test = set(cat_ids).union(cond_ids)\n",
|
||||
"\n",
|
||||
" \n",
|
||||
" \n",
|
||||
" for sid, row in imap.items():\n",
|
||||
" if all([row[header.index(c)] in conditions[c] for c in conditions]):\n",
|
||||
" key = \"_\".join([row[idx] for idx in idx_to_test])\n",
|
||||
" try:\n",
|
||||
" assert key in table.keys()\n",
|
||||
" except AssertionError:\n",
|
||||
" table[key] = DataCategory(set(), {})\n",
|
||||
" table[key].sids.add(sid)\n",
|
||||
" try:\n",
|
||||
" assert len(table) > 0\n",
|
||||
" except AssertionError:\n",
|
||||
" return {\"default\": DataCategory(set(imap.keys()), {})}\n",
|
||||
" else:\n",
|
||||
" return table\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(remove_docstrings_and_comments_from_code(dataset[idx]['code']))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n",
|
||||
"def test():\n",
|
||||
" \n",
|
||||
" str_variable = \"\"\"\n",
|
||||
" This is a string\n",
|
||||
" \"\"\"\n",
|
||||
" print(\"Hello World\")\n",
|
||||
"\n",
|
||||
"class Test:\n",
|
||||
" \n",
|
||||
" def __init__(self):\n",
|
||||
" pass\n",
|
||||
"\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"test_code = '''\n",
|
||||
"def test():\n",
|
||||
" \"\"\"\n",
|
||||
" This is a test function\n",
|
||||
" \"\"\"\n",
|
||||
" str_variable = \"\"\"\n",
|
||||
" This is a string\n",
|
||||
" \"\"\"\n",
|
||||
" print(\"Hello World\")\n",
|
||||
"\n",
|
||||
"class Test:\n",
|
||||
" \"\"\"\n",
|
||||
" This is a test class\n",
|
||||
" \"\"\"\n",
|
||||
" def __init__(self):\n",
|
||||
" pass\n",
|
||||
"'''\n",
|
||||
"\n",
|
||||
"print(remove_docstrings_and_comments_from_code(test_code))"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.19"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
12
code/src/tmp_config.json
Normal file
12
code/src/tmp_config.json
Normal file
@ -0,0 +1,12 @@
|
||||
{
|
||||
"seed": 42,
|
||||
"mlm_probability": 0.15,
|
||||
"batch": 16,
|
||||
"epochs": 1,
|
||||
"eval_every": 5000,
|
||||
"learning_rate": 5e-4,
|
||||
"weight_decay": 0.01,
|
||||
"max_grad_norm": 1.0,
|
||||
"warmup_steps": 5000,
|
||||
"train_size": 0.95
|
||||
}
|
@ -1,261 +1,50 @@
|
||||
import wandb
|
||||
|
||||
import os
|
||||
import json
|
||||
import random
|
||||
import datetime
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, Tuple, List
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.optim import AdamW
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
from datasets import load_dataset, DatasetDict
|
||||
from huggingface_hub import list_repo_files, hf_hub_download
|
||||
from transformers import (
|
||||
RobertaForMaskedLM,
|
||||
RobertaConfig,
|
||||
RobertaTokenizer,
|
||||
DataCollatorForLanguageModeling,
|
||||
get_linear_schedule_with_warmup,
|
||||
PreTrainedTokenizer,
|
||||
PreTrainedModel
|
||||
from transformers import RobertaForMaskedLM, RobertaConfig, RobertaTokenizer
|
||||
|
||||
from training_utils import (
|
||||
set_seed, setup_wandb, setup_directories, load_config,
|
||||
setup_device, create_optimizer_and_scheduler,
|
||||
train_and_evaluate, create_base_dataloaders, set_deterministic_mode
|
||||
)
|
||||
from tqdm import tqdm
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class OnTheFlyTokenizationDataset(Dataset):
|
||||
def __init__(self, dataset: Dataset, tokenizer: PreTrainedTokenizer, max_length: int):
|
||||
self.dataset = dataset
|
||||
self.tokenizer = tokenizer
|
||||
self.max_length = max_length
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.dataset)
|
||||
|
||||
def __getitem__(self, idx: int) -> Dict[str, Tensor]:
|
||||
content: str = self.dataset[idx]['content']
|
||||
tokenized = self.tokenizer(
|
||||
content,
|
||||
truncation=True,
|
||||
padding='max_length',
|
||||
max_length=self.max_length,
|
||||
return_tensors='pt'
|
||||
)
|
||||
return {
|
||||
'input_ids': tokenized['input_ids'].squeeze(0),
|
||||
'attention_mask': tokenized['attention_mask'].squeeze(0),
|
||||
'labels': tokenized['input_ids'].squeeze(0)
|
||||
}
|
||||
|
||||
def set_seed(seed: int) -> None:
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
def setup_wandb(config: Dict[str, Any], exec_file: str = 'train_codebert_mlm.py') -> None:
|
||||
curr_time: str = datetime.datetime.now().strftime('%Y-%m-%d %H:%M')
|
||||
wandb.init(project='codebert-training', name=f'tree_{curr_time}', config=config)
|
||||
wandb.save('train_codebert_mlm.py')
|
||||
|
||||
def setup_directories(current_dir: Path) -> Path:
|
||||
curr_time: str = datetime.datetime.now().strftime('%Y-%m-%d %H:%M')
|
||||
output_dir: Path = current_dir.parent.parent / 'outputs' / f'tree_{curr_time}'
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
return output_dir
|
||||
|
||||
def load_config(config_file: Path) -> Dict[str, Any]:
|
||||
with open(config_file, 'r') as f:
|
||||
return json.load(f)
|
||||
|
||||
def setup_device() -> torch.device:
|
||||
device: torch.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
torch.set_default_device(device)
|
||||
logger.info(f'Using device: {device}')
|
||||
if device.type == 'cuda':
|
||||
logger.info(f'Device name: {torch.cuda.get_device_name()}')
|
||||
torch.set_float32_matmul_precision('high')
|
||||
return device
|
||||
|
||||
def download_dataset(dataset_dir: Path) -> None:
|
||||
if not dataset_dir.exists():
|
||||
logger.info("Downloading the dataset...")
|
||||
dataset_dir.mkdir(parents=True, exist_ok=True)
|
||||
files_list: List[str] = list_repo_files(repo_id='bigcode/the-stack-dedup', repo_type='dataset')
|
||||
files_to_download: List[str] = [file for file in files_list if file.startswith('data/python/')]
|
||||
for file_name in files_to_download:
|
||||
hf_hub_download(repo_id='bigcode/the-stack-dedup', repo_type='dataset', filename=file_name, local_dir=dataset_dir)
|
||||
logger.info("Dataset downloaded successfully.")
|
||||
|
||||
def load_and_prepare_dataset(dataset_dir: Path, seed: int) -> DatasetDict:
|
||||
dataset: DatasetDict = load_dataset(str(dataset_dir), split='train')
|
||||
dataset = dataset.train_test_split(test_size=0.01, seed=seed)
|
||||
logger.info(f'Dataset loaded: {dataset}')
|
||||
return dataset
|
||||
|
||||
def create_dataloaders(
|
||||
dataset: DatasetDict,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
config: Dict[str, Any],
|
||||
device: torch.device
|
||||
) -> Tuple[DataLoader, DataLoader]:
|
||||
dataset['train'] = OnTheFlyTokenizationDataset(dataset['train'], tokenizer, max_length=512)
|
||||
dataset['test'] = OnTheFlyTokenizationDataset(dataset['test'], tokenizer, max_length=512)
|
||||
|
||||
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=True, mlm_probability=config['mlm_probability'])
|
||||
|
||||
train_dataloader = DataLoader(
|
||||
dataset['train'],
|
||||
batch_size=config['batch'],
|
||||
shuffle=False,
|
||||
collate_fn=data_collator,
|
||||
generator=torch.Generator(device=device)
|
||||
)
|
||||
valid_dataloader = DataLoader(
|
||||
dataset['test'],
|
||||
batch_size=config['batch'],
|
||||
shuffle=False,
|
||||
collate_fn=data_collator,
|
||||
generator=torch.Generator(device=device)
|
||||
)
|
||||
return train_dataloader, valid_dataloader
|
||||
|
||||
def setup_model_and_optimizer(
|
||||
config: Dict[str, Any],
|
||||
models_dir: Path
|
||||
) -> Tuple[PreTrainedModel, AdamW]:
|
||||
if not models_dir.exists():
|
||||
logger.info("Downloading the model...")
|
||||
model_config = RobertaConfig.from_pretrained('roberta-base', cache_dir=models_dir)
|
||||
model: PreTrainedModel = RobertaForMaskedLM(model_config)
|
||||
model = torch.compile(model)
|
||||
wandb.watch(model)
|
||||
logger.info(f'Model config: {model_config}')
|
||||
wandb.config.update({'model_config': model_config.to_dict()})
|
||||
|
||||
optimizer: AdamW = AdamW(model.parameters(), lr=config['learning_rate'], weight_decay=config['weight_decay'])
|
||||
return model, optimizer
|
||||
|
||||
def train_and_evaluate(
|
||||
model: PreTrainedModel,
|
||||
train_dataloader: DataLoader,
|
||||
valid_dataloader: DataLoader,
|
||||
optimizer: AdamW,
|
||||
scheduler: Any,
|
||||
config: Dict[str, Any],
|
||||
output_dir: Path
|
||||
) -> None:
|
||||
num_training_steps: int = config['epochs'] * len(train_dataloader)
|
||||
best_valid_loss: float = float('inf')
|
||||
|
||||
with tqdm(total=num_training_steps, desc='Training') as pbar:
|
||||
for epoch_idx in range(config['epochs']):
|
||||
model.train()
|
||||
for train_idx, train_batch in enumerate(train_dataloader):
|
||||
outputs = model(**train_batch)
|
||||
train_loss: Tensor = outputs.loss
|
||||
train_loss.backward()
|
||||
|
||||
norm: Tensor = torch.nn.utils.clip_grad_norm_(model.parameters(), config['max_grad_norm'])
|
||||
|
||||
optimizer.step()
|
||||
scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
pbar.update(1)
|
||||
pbar.set_postfix({'train_loss': train_loss.item()})
|
||||
|
||||
wandb.log({
|
||||
'step': train_idx + len(train_dataloader) * epoch_idx,
|
||||
'train_loss': train_loss.item(),
|
||||
'gradient_norm': norm.item(),
|
||||
'learning_rate': scheduler.get_last_lr()[0],
|
||||
})
|
||||
|
||||
if train_idx != 0 and train_idx % config['eval_every'] == 0:
|
||||
valid_loss, valid_acc = evaluate(model, valid_dataloader)
|
||||
pbar.set_postfix({'train_loss': train_loss.item(), 'valid_loss': valid_loss, 'valid_acc': valid_acc})
|
||||
|
||||
wandb.log({
|
||||
'valid_loss': valid_loss,
|
||||
'valid_acc': valid_acc,
|
||||
'step': train_idx + len(train_dataloader) * epoch_idx,
|
||||
})
|
||||
|
||||
if valid_loss < best_valid_loss:
|
||||
best_valid_loss = valid_loss
|
||||
torch.save(model.state_dict(), output_dir / 'best_model.pt')
|
||||
|
||||
logger.info(f'Best validation loss: {best_valid_loss}')
|
||||
|
||||
def evaluate(model: PreTrainedModel, dataloader: DataLoader) -> Tuple[float, float]:
|
||||
model.eval()
|
||||
total_loss: float = 0.0
|
||||
total_acc: float = 0.0
|
||||
|
||||
with torch.no_grad():
|
||||
for batch in tqdm(dataloader, desc='Validation'):
|
||||
outputs = model(**batch)
|
||||
total_loss += outputs.loss.item()
|
||||
total_acc += outputs.logits.argmax(dim=-1).eq(batch['labels']).sum().item()
|
||||
|
||||
avg_loss: float = total_loss / len(dataloader)
|
||||
avg_acc: float = total_acc / len(dataloader.dataset)
|
||||
model.train()
|
||||
return avg_loss, avg_acc
|
||||
|
||||
def main() -> None:
|
||||
current_dir: Path = Path(__file__).parent
|
||||
output_dir: Path = setup_directories(current_dir)
|
||||
config: Dict[str, Any] = load_config(current_dir / 'config.json')
|
||||
set_deterministic_mode()
|
||||
|
||||
current_dir = Path(__file__).parent
|
||||
output_dir = setup_directories(current_dir, model_name='base')
|
||||
config = load_config(current_dir / 'tmp_config.json')
|
||||
|
||||
setup_wandb(config)
|
||||
setup_wandb(config, model_name='base')
|
||||
set_seed(config['seed'])
|
||||
device: torch.device = setup_device()
|
||||
device = setup_device()
|
||||
|
||||
dataset_dir: Path = current_dir.parent / 'data' / 'the-stack-python'
|
||||
download_dataset(dataset_dir)
|
||||
dataset: DatasetDict = load_and_prepare_dataset(dataset_dir, config['seed'])
|
||||
# Load tokenizer
|
||||
tokenizer = RobertaTokenizer.from_pretrained('microsoft/codebert-base')
|
||||
|
||||
tokenizer: PreTrainedTokenizer = RobertaTokenizer.from_pretrained('microsoft/codebert-base', clean_up_tokenization_spaces=True)
|
||||
logger.info(f'Tokenizer loaded: {tokenizer}')
|
||||
|
||||
######################## Reproducing last training here ########################
|
||||
# Remove first 186513 batches
|
||||
dataset['train'] = dataset['train'].select(range(186_513 * 32, len(dataset['train'])))
|
||||
################################################################################
|
||||
|
||||
train_dataloader, valid_dataloader = create_dataloaders(dataset, tokenizer, config, device)
|
||||
|
||||
models_dir: Path = current_dir.parent / 'models' / 'roberta-base'
|
||||
model, optimizer = setup_model_and_optimizer(config, models_dir)
|
||||
|
||||
num_training_steps: int = config['epochs'] * len(train_dataloader)
|
||||
scheduler = get_linear_schedule_with_warmup(
|
||||
optimizer,
|
||||
num_warmup_steps=config['warmup_steps'],
|
||||
num_training_steps=num_training_steps
|
||||
# Create dataloaders
|
||||
processed_data_dir = current_dir.parent / 'data' / 'processed-python'
|
||||
train_dataloader, valid_dataloader = create_base_dataloaders(
|
||||
processed_data_dir,
|
||||
tokenizer,
|
||||
config,
|
||||
device
|
||||
)
|
||||
|
||||
######################## Reproducing last training here ########################
|
||||
# Change opitmizer learning rate to 0.00021575814536340852
|
||||
optimizer = AdamW(model.parameters(), lr=0.00021575814536340852, weight_decay=config['weight_decay'])
|
||||
# Set warmup_steps to 0
|
||||
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=num_training_steps)
|
||||
# Load the best model weights
|
||||
state_dict = torch.load('/sql/msc-patryk-bartkowiak/outputs/2024-10-21_20:15/best_model.pt', weights_only=True, map_location=device)
|
||||
model.load_state_dict(state_dict)
|
||||
################################################################################
|
||||
|
||||
train_and_evaluate(model, train_dataloader, valid_dataloader, optimizer, scheduler, config, output_dir)
|
||||
# Model setup
|
||||
model_config = RobertaConfig.from_pretrained('microsoft/codebert-base')
|
||||
model = RobertaForMaskedLM(model_config)
|
||||
model = model.to(device)
|
||||
|
||||
# Training setup and run
|
||||
num_training_steps = config['epochs'] * len(train_dataloader)
|
||||
optimizer, scheduler = create_optimizer_and_scheduler(model, config, num_training_steps)
|
||||
|
||||
train_and_evaluate(
|
||||
model, train_dataloader, valid_dataloader,
|
||||
optimizer, scheduler, config, output_dir,
|
||||
log_weights=False, device=device
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -1,78 +1,18 @@
|
||||
import wandb
|
||||
|
||||
import sys
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
from torch.optim import AdamW
|
||||
from dataclasses import dataclass
|
||||
from tqdm import tqdm
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Any, Optional, Tuple
|
||||
import ast
|
||||
import numpy as np
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from datasets import DatasetDict
|
||||
from transformers import (
|
||||
RobertaForMaskedLM,
|
||||
RobertaConfig,
|
||||
RobertaTokenizer,
|
||||
get_linear_schedule_with_warmup,
|
||||
PreTrainedTokenizer,
|
||||
PreTrainedModel,
|
||||
)
|
||||
from typing import Dict, List, Optional
|
||||
from transformers import RobertaConfig, RobertaTokenizer, RobertaForMaskedLM
|
||||
|
||||
sys.setrecursionlimit(3000) # Increase recursion limit
|
||||
|
||||
# Import existing training functionality
|
||||
from train_codebert_mlm import (
|
||||
from training_utils import (
|
||||
set_seed, setup_wandb, setup_directories, load_config,
|
||||
setup_device, download_dataset, load_and_prepare_dataset, logger
|
||||
setup_device, create_optimizer_and_scheduler,
|
||||
train_and_evaluate, create_base_dataloaders, PreprocessedTreeDataset,
|
||||
set_deterministic_mode
|
||||
)
|
||||
|
||||
@dataclass
|
||||
class ASTNodeInfo:
|
||||
"""Stores structural information about an AST node."""
|
||||
node_type: str
|
||||
start_token_idx: int
|
||||
end_token_idx: int
|
||||
depth: int
|
||||
sibling_pos: int
|
||||
parent_idx: int
|
||||
|
||||
def ast_collate_fn(batch):
|
||||
"""Custom collate function with improved error handling."""
|
||||
# Remove failed parses (where attention_mask is all zeros)
|
||||
valid_batch = [
|
||||
item for item in batch
|
||||
if item['attention_mask'].sum() > 0 and len(item['ast_nodes']) > 0
|
||||
]
|
||||
from parse_dataset import ASTNodeInfo
|
||||
|
||||
if not valid_batch:
|
||||
# Return minimal batch if no valid items
|
||||
return {
|
||||
'input_ids': torch.zeros((1, 512), dtype=torch.long),
|
||||
'attention_mask': torch.zeros((1, 512), dtype=torch.long),
|
||||
'labels': torch.zeros((1, 512), dtype=torch.long),
|
||||
'ast_nodes': [[]]
|
||||
}
|
||||
|
||||
# Stack tensors
|
||||
input_ids = torch.stack([item['input_ids'] for item in valid_batch])
|
||||
attention_mask = torch.stack([item['attention_mask'] for item in valid_batch])
|
||||
labels = torch.stack([item['labels'] for item in valid_batch])
|
||||
|
||||
# Collect AST nodes
|
||||
ast_nodes = [item['ast_nodes'] for item in valid_batch]
|
||||
|
||||
return {
|
||||
'input_ids': input_ids,
|
||||
'attention_mask': attention_mask,
|
||||
'labels': labels,
|
||||
'ast_nodes': ast_nodes
|
||||
}
|
||||
|
||||
class TreePositionalEmbedding(nn.Module):
|
||||
"""Generates tree-aware positional embeddings for code tokens."""
|
||||
|
||||
@ -126,7 +66,7 @@ class TreePositionalEmbedding(nn.Module):
|
||||
embeddings[batch_idx, i] = self.combine(torch.cat([depth_emb, sibling_emb]))
|
||||
|
||||
return embeddings
|
||||
|
||||
|
||||
class TreeCodeBERTForPreTraining(RobertaForMaskedLM):
|
||||
"""CodeBERT model enhanced with normalized embedding weights for stable training."""
|
||||
|
||||
@ -244,253 +184,44 @@ class TreeCodeBERTForPreTraining(RobertaForMaskedLM):
|
||||
}
|
||||
}
|
||||
|
||||
class TreeEnhancedDataset(Dataset):
|
||||
"""Dataset that processes code into tokens and AST nodes with improved error handling."""
|
||||
def __init__(self, dataset: Dataset, tokenizer: PreTrainedTokenizer, max_length: int):
|
||||
self.dataset = dataset
|
||||
self.tokenizer = tokenizer
|
||||
self.max_length = max_length
|
||||
self.max_ast_size = 1000 # Limit maximum AST size to prevent memory issues
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.dataset)
|
||||
|
||||
def create_dummy_data(self) -> Dict[str, Any]:
|
||||
"""Create dummy data for invalid/problematic code samples."""
|
||||
pad_token_id = self.tokenizer.pad_token_id or 0
|
||||
dummy_tensor = torch.full((self.max_length,), pad_token_id, dtype=torch.long)
|
||||
return {
|
||||
'input_ids': dummy_tensor,
|
||||
'attention_mask': torch.zeros(self.max_length, dtype=torch.long),
|
||||
'labels': dummy_tensor.clone(),
|
||||
'ast_nodes': []
|
||||
}
|
||||
|
||||
def __getitem__(self, idx: int) -> Dict[str, Any]:
|
||||
try:
|
||||
content: str = self.dataset[idx]['content']
|
||||
|
||||
# Skip extremely long files
|
||||
if len(content) > 50000: # ~1000 lines
|
||||
return self.create_dummy_data()
|
||||
|
||||
# Basic code validation
|
||||
if not content.strip() or '\0' in content:
|
||||
return self.create_dummy_data()
|
||||
|
||||
# Tokenize first
|
||||
encoding = self.tokenizer(
|
||||
content,
|
||||
max_length=self.max_length,
|
||||
truncation=True,
|
||||
padding='max_length',
|
||||
return_tensors='pt'
|
||||
)
|
||||
|
||||
try:
|
||||
# Parse AST with timeout
|
||||
tree = ast.parse(content)
|
||||
|
||||
# Get AST nodes info
|
||||
nodes_info = []
|
||||
def visit_node(node: ast.AST, depth: int, parent_idx: int, sibling_pos: int):
|
||||
if len(nodes_info) >= self.max_ast_size:
|
||||
return
|
||||
|
||||
current_idx = len(nodes_info)
|
||||
if hasattr(node, 'lineno'):
|
||||
nodes_info.append(ASTNodeInfo(
|
||||
node_type=type(node).__name__,
|
||||
start_token_idx=node.lineno,
|
||||
end_token_idx=getattr(node, 'end_lineno', node.lineno),
|
||||
depth=min(depth, 31), # Limit depth to prevent issues
|
||||
sibling_pos=min(sibling_pos, 31),
|
||||
parent_idx=parent_idx
|
||||
))
|
||||
|
||||
# Only process children if we haven't hit the limit
|
||||
if len(nodes_info) < self.max_ast_size:
|
||||
for i, child in enumerate(ast.iter_child_nodes(node)):
|
||||
visit_node(child, depth + 1, current_idx, i)
|
||||
|
||||
visit_node(tree, 0, -1, 0)
|
||||
|
||||
# If we hit the AST size limit, use dummy data
|
||||
if len(nodes_info) >= self.max_ast_size:
|
||||
return self.create_dummy_data()
|
||||
|
||||
return {
|
||||
'input_ids': encoding['input_ids'].squeeze(0),
|
||||
'attention_mask': encoding['attention_mask'].squeeze(0),
|
||||
'labels': encoding['input_ids'].squeeze(0).clone(),
|
||||
'ast_nodes': nodes_info
|
||||
}
|
||||
|
||||
except (SyntaxError, ValueError, RecursionError, TimeoutError, MemoryError):
|
||||
return self.create_dummy_data()
|
||||
|
||||
except Exception as e:
|
||||
return self.create_dummy_data()
|
||||
|
||||
def create_tree_dataloaders(
|
||||
dataset: DatasetDict,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
config: Dict[str, Any],
|
||||
device: torch.device
|
||||
) -> Tuple[DataLoader, DataLoader]:
|
||||
"""Create dataloaders with tree-enhanced datasets."""
|
||||
train_dataset = TreeEnhancedDataset(dataset['train'], tokenizer, max_length=512)
|
||||
valid_dataset = TreeEnhancedDataset(dataset['test'], tokenizer, max_length=512)
|
||||
def main() -> None:
|
||||
set_deterministic_mode()
|
||||
|
||||
train_dataloader = DataLoader(
|
||||
train_dataset,
|
||||
batch_size=config['batch'],
|
||||
shuffle=True,
|
||||
collate_fn=ast_collate_fn,
|
||||
generator=torch.Generator(device=device)
|
||||
)
|
||||
valid_dataloader = DataLoader(
|
||||
valid_dataset,
|
||||
batch_size=config['batch'],
|
||||
shuffle=False,
|
||||
collate_fn=ast_collate_fn,
|
||||
generator=torch.Generator(device=device)
|
||||
)
|
||||
|
||||
return train_dataloader, valid_dataloader
|
||||
|
||||
def train_and_evaluate(
|
||||
model: PreTrainedModel,
|
||||
train_dataloader: DataLoader,
|
||||
valid_dataloader: DataLoader,
|
||||
optimizer: AdamW,
|
||||
scheduler: Any,
|
||||
config: Dict[str, Any],
|
||||
output_dir: Path
|
||||
) -> None:
|
||||
"""Training loop with explicit tracking of alpha, beta, gamma weights."""
|
||||
num_training_steps: int = config['epochs'] * len(train_dataloader)
|
||||
best_valid_loss: float = float('inf')
|
||||
|
||||
with tqdm(total=num_training_steps, desc='Training') as pbar:
|
||||
for epoch_idx in range(config['epochs']):
|
||||
model.train()
|
||||
|
||||
for train_idx, train_batch in enumerate(train_dataloader):
|
||||
outputs = model(**train_batch)
|
||||
train_loss = outputs["loss"]
|
||||
|
||||
if train_loss is not None:
|
||||
train_loss.backward()
|
||||
norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config['max_grad_norm'])
|
||||
optimizer.step()
|
||||
scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Get current metrics
|
||||
current_loss = train_loss.item()
|
||||
weights = outputs["embedding_weights"]
|
||||
|
||||
# Update progress bar with all three weights
|
||||
pbar.update(1)
|
||||
pbar.set_postfix({
|
||||
'train_loss': f"{current_loss:.3f}",
|
||||
'α': f"{weights['token']:.2f}",
|
||||
'β': f"{weights['tree']:.2f}",
|
||||
'γ': f"{weights['sequential']:.2f}"
|
||||
})
|
||||
|
||||
# Log all three weights separately
|
||||
step = train_idx + len(train_dataloader) * epoch_idx
|
||||
wandb.log({
|
||||
'train_loss': current_loss,
|
||||
'token_weight': weights['token'],
|
||||
'tree_weight': weights['tree'],
|
||||
'sequential_weight': weights['sequential'],
|
||||
'gradient_norm': norm.item(),
|
||||
'learning_rate': scheduler.get_last_lr()[0],
|
||||
'step': step,
|
||||
})
|
||||
|
||||
# Periodic evaluation
|
||||
if train_idx != 0 and train_idx % config['eval_every'] == 0:
|
||||
valid_loss, valid_acc = evaluate(model, valid_dataloader)
|
||||
|
||||
wandb.log({
|
||||
'valid_loss': valid_loss,
|
||||
'valid_acc': valid_acc,
|
||||
'step': step,
|
||||
})
|
||||
|
||||
if valid_loss < best_valid_loss:
|
||||
best_valid_loss = valid_loss
|
||||
torch.save(model.state_dict(), output_dir / 'best_model.pt')
|
||||
else:
|
||||
pbar.update(1)
|
||||
pbar.set_postfix({'train_loss': 'N/A'})
|
||||
|
||||
logger.info(f'Best validation loss: {best_valid_loss}')
|
||||
|
||||
def evaluate(model: PreTrainedModel, dataloader: DataLoader) -> Tuple[float, float]:
|
||||
"""Evaluation function with simple metrics."""
|
||||
model.eval()
|
||||
total_loss: float = 0.0
|
||||
total_acc: float = 0.0
|
||||
num_batches: int = 0
|
||||
|
||||
with torch.no_grad():
|
||||
for batch in dataloader:
|
||||
outputs = model(**batch)
|
||||
if outputs["loss"] is not None:
|
||||
total_loss += outputs["loss"].item()
|
||||
logits = outputs["logits"]
|
||||
labels = batch['labels']
|
||||
predictions = logits.argmax(dim=-1)
|
||||
total_acc += (predictions == labels).float().mean().item()
|
||||
num_batches += 1
|
||||
|
||||
avg_loss: float = total_loss / max(num_batches, 1)
|
||||
avg_acc: float = total_acc / max(num_batches, 1)
|
||||
|
||||
model.train()
|
||||
return avg_loss, avg_acc
|
||||
|
||||
def main():
|
||||
# Setup identical to original training script
|
||||
current_dir = Path(__file__).parent
|
||||
output_dir = setup_directories(current_dir)
|
||||
config = load_config(current_dir / 'config.json')
|
||||
|
||||
setup_wandb(config, exec_file='src/train_tree_codebert_mlm.py')
|
||||
output_dir = setup_directories(current_dir, model_name='tree')
|
||||
config = load_config(current_dir / 'tmp_config.json')
|
||||
|
||||
setup_wandb(config, model_name='tree')
|
||||
set_seed(config['seed'])
|
||||
device = setup_device()
|
||||
|
||||
dataset_dir = current_dir.parent / 'data' / 'the-stack-python'
|
||||
download_dataset(dataset_dir)
|
||||
dataset = load_and_prepare_dataset(dataset_dir, config['seed'])
|
||||
|
||||
# Load tokenizer
|
||||
tokenizer = RobertaTokenizer.from_pretrained('microsoft/codebert-base')
|
||||
|
||||
# Create tree-enhanced dataloaders
|
||||
train_dataloader, valid_dataloader = create_tree_dataloaders(dataset, tokenizer, config, device)
|
||||
|
||||
# Initialize tree-enhanced model
|
||||
model_config = RobertaConfig.from_pretrained('microsoft/codebert-base')
|
||||
model = TreeCodeBERTForPreTraining(model_config)
|
||||
model = model.to(device) # Just move to device without compilation
|
||||
# model = torch.compile(model)
|
||||
|
||||
# Optimizer and scheduler setup identical to original
|
||||
optimizer = AdamW(model.parameters(), lr=config['learning_rate'], weight_decay=config['weight_decay'])
|
||||
num_training_steps = config['epochs'] * len(train_dataloader)
|
||||
scheduler = get_linear_schedule_with_warmup(
|
||||
optimizer,
|
||||
num_warmup_steps=config['warmup_steps'],
|
||||
num_training_steps=num_training_steps
|
||||
# Create dataloaders with tree dataset
|
||||
processed_data_dir = current_dir.parent / 'data' / 'processed-python'
|
||||
train_dataloader, valid_dataloader = create_base_dataloaders(
|
||||
processed_data_dir,
|
||||
tokenizer,
|
||||
config,
|
||||
device,
|
||||
dataset_class=PreprocessedTreeDataset
|
||||
)
|
||||
|
||||
# Training loop (using original train_and_evaluate function)
|
||||
train_and_evaluate(model, train_dataloader, valid_dataloader, optimizer, scheduler, config, output_dir)
|
||||
# Model setup
|
||||
model_config = RobertaConfig.from_pretrained('microsoft/codebert-base')
|
||||
model = TreeCodeBERTForPreTraining(model_config)
|
||||
model = model.to(device)
|
||||
|
||||
# Training setup and run
|
||||
num_training_steps = config['epochs'] * len(train_dataloader)
|
||||
optimizer, scheduler = create_optimizer_and_scheduler(model, config, num_training_steps)
|
||||
|
||||
train_and_evaluate(
|
||||
model, train_dataloader, valid_dataloader,
|
||||
optimizer, scheduler, config, output_dir,
|
||||
log_weights=True, device=device
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
450
code/src/training_utils.py
Normal file
450
code/src/training_utils.py
Normal file
@ -0,0 +1,450 @@
|
||||
import os
|
||||
import json
|
||||
import random
|
||||
import datetime
|
||||
import platform
|
||||
import logging
|
||||
import wandb
|
||||
import numpy as np
|
||||
import torch
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, Tuple, Type, Optional
|
||||
from datasets import load_from_disk, concatenate_datasets
|
||||
from torch.optim import AdamW
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
from transformers import (
|
||||
PreTrainedModel,
|
||||
get_linear_schedule_with_warmup,
|
||||
DataCollatorForLanguageModeling,
|
||||
PreTrainedTokenizer
|
||||
)
|
||||
from tqdm import tqdm
|
||||
|
||||
from parse_dataset import ASTNodeInfo
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def set_deterministic_mode() -> None:
|
||||
"""Enable deterministic mode for reproducibility."""
|
||||
# Set Python random seed
|
||||
random.seed(42)
|
||||
|
||||
# Set NumPy random seed
|
||||
np.random.seed(42)
|
||||
|
||||
# Set PyTorch random seed
|
||||
torch.manual_seed(42)
|
||||
|
||||
# Set CUDA random seed
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(42)
|
||||
|
||||
# Enable deterministic operations for PyTorch 2.x
|
||||
torch.use_deterministic_algorithms(True)
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
|
||||
# Set environment variables for reproducibility
|
||||
os.environ['PYTHONHASHSEED'] = '42'
|
||||
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8' # Required for CUDA >= 10.2
|
||||
|
||||
def get_device_info() -> Dict[str, Any]:
|
||||
"""Get detailed information about the computing environment."""
|
||||
info = {
|
||||
'python_version': platform.python_version(),
|
||||
'torch_version': torch.__version__,
|
||||
'cuda_available': torch.cuda.is_available(),
|
||||
'deterministic_algorithms': torch.are_deterministic_algorithms_enabled(),
|
||||
'cudnn_deterministic': torch.backends.cudnn.deterministic,
|
||||
'cudnn_benchmark': torch.backends.cudnn.benchmark,
|
||||
}
|
||||
|
||||
if torch.cuda.is_available():
|
||||
info.update({
|
||||
'cuda_version': torch.version.cuda,
|
||||
'gpu_name': torch.cuda.get_device_name(0),
|
||||
'gpu_count': torch.cuda.device_count(),
|
||||
})
|
||||
|
||||
return info
|
||||
|
||||
class TreeDataCollator(DataCollatorForLanguageModeling):
|
||||
"""Custom data collator that handles both MLM and optional AST node information."""
|
||||
|
||||
def torch_call(self, examples):
|
||||
# Check if we have AST nodes
|
||||
has_ast = 'ast_nodes' in examples[0]
|
||||
|
||||
# Extract AST nodes before MLM processing if they exist
|
||||
ast_nodes = None
|
||||
if has_ast:
|
||||
ast_nodes = [e.pop('ast_nodes') for e in examples]
|
||||
|
||||
# Process normal MLM features
|
||||
batch = super().torch_call(examples)
|
||||
|
||||
# Add AST nodes back to batch if they existed
|
||||
if has_ast:
|
||||
batch['ast_nodes'] = ast_nodes
|
||||
|
||||
return batch
|
||||
|
||||
def __call__(self, examples):
|
||||
return self.torch_call(examples)
|
||||
|
||||
class PreprocessedBaseDataset(Dataset):
|
||||
"""Dataset that uses pre-processed data without AST information."""
|
||||
def __init__(self, dataset: Any):
|
||||
self.dataset = dataset
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.dataset)
|
||||
|
||||
def __getitem__(self, idx: int) -> Dict[str, Any]:
|
||||
item = self.dataset[idx]
|
||||
return {
|
||||
'input_ids': torch.tensor(item['input_ids']),
|
||||
'attention_mask': torch.tensor(item['attention_mask']),
|
||||
'labels': torch.tensor(item['input_ids']).clone()
|
||||
}
|
||||
|
||||
class PreprocessedTreeDataset(PreprocessedBaseDataset):
|
||||
"""Dataset that includes AST information."""
|
||||
def __getitem__(self, idx: int) -> Dict[str, Any]:
|
||||
item = super().__getitem__(idx)
|
||||
item['ast_nodes'] = [ASTNodeInfo.from_dict(node) for node in self.dataset[idx]['ast_nodes']]
|
||||
return item
|
||||
|
||||
def set_seed(seed: int) -> None:
|
||||
"""Set random seeds for reproducibility."""
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
def setup_wandb(
|
||||
config: Dict[str, Any],
|
||||
model_name: str = 'base',
|
||||
device_info: Optional[Dict[str, Any]] = None
|
||||
) -> None:
|
||||
"""Initialize W&B logging with reproducibility information."""
|
||||
curr_time: str = datetime.datetime.now().strftime('%Y-%m-%d_%H:%M')
|
||||
|
||||
# Add reproducibility settings to config
|
||||
full_config = {
|
||||
**config,
|
||||
'random_seed': 42,
|
||||
'deterministic_mode': True,
|
||||
'device_info': device_info or get_device_info(),
|
||||
}
|
||||
|
||||
wandb.init(
|
||||
project='codebert-training-test',
|
||||
name=f'{model_name}_{curr_time}',
|
||||
config=full_config
|
||||
)
|
||||
|
||||
def setup_directories(current_dir: Path, model_name: str = 'base') -> Path:
|
||||
"""Create output directories for model artifacts."""
|
||||
curr_time: str = datetime.datetime.now().strftime('%Y-%m-%d_%H:%M')
|
||||
output_dir: Path = current_dir.parent.parent / 'outputs' / f'{model_name}_{curr_time}'
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
return output_dir
|
||||
|
||||
def load_config(config_file: Path) -> Dict[str, Any]:
|
||||
"""Load training configuration from JSON file."""
|
||||
with open(config_file, 'r') as f:
|
||||
return json.load(f)
|
||||
|
||||
def setup_device() -> torch.device:
|
||||
"""Setup and configure training device."""
|
||||
device: torch.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
torch.set_default_device(device)
|
||||
logger.info(f'Using device: {device}')
|
||||
if device.type == 'cuda':
|
||||
logger.info(f'Device name: {torch.cuda.get_device_name()}')
|
||||
torch.set_float32_matmul_precision('high')
|
||||
return device
|
||||
|
||||
def create_optimizer_and_scheduler(
|
||||
model: PreTrainedModel,
|
||||
config: Dict[str, Any],
|
||||
num_training_steps: int
|
||||
) -> Tuple[AdamW, Any]:
|
||||
"""Create optimizer and learning rate scheduler."""
|
||||
optimizer = AdamW(
|
||||
model.parameters(),
|
||||
lr=config['learning_rate'],
|
||||
weight_decay=config['weight_decay']
|
||||
)
|
||||
|
||||
scheduler = get_linear_schedule_with_warmup(
|
||||
optimizer,
|
||||
num_warmup_steps=config['warmup_steps'],
|
||||
num_training_steps=num_training_steps
|
||||
)
|
||||
|
||||
return optimizer, scheduler
|
||||
|
||||
def evaluate(model: PreTrainedModel, dataloader: DataLoader) -> Tuple[float, float]:
|
||||
"""Evaluate model on validation set."""
|
||||
model.eval()
|
||||
total_loss: float = 0.0
|
||||
total_acc: float = 0.0
|
||||
|
||||
with torch.no_grad():
|
||||
for batch in tqdm(dataloader, desc='Validation'):
|
||||
outputs = model(**batch)
|
||||
total_loss += outputs.loss.item() if hasattr(outputs, 'loss') else outputs['loss'].item()
|
||||
logits = outputs.logits if hasattr(outputs, 'logits') else outputs['logits']
|
||||
total_acc += logits.argmax(dim=-1).eq(batch['labels']).sum().item()
|
||||
|
||||
avg_loss: float = total_loss / len(dataloader)
|
||||
avg_acc: float = total_acc / len(dataloader.dataset)
|
||||
model.train()
|
||||
return avg_loss, avg_acc
|
||||
|
||||
def train_and_evaluate(
|
||||
model: PreTrainedModel,
|
||||
train_dataloader: DataLoader,
|
||||
valid_dataloader: DataLoader,
|
||||
optimizer: AdamW,
|
||||
scheduler: Any,
|
||||
config: Dict[str, Any],
|
||||
output_dir: Path,
|
||||
log_weights: bool = False,
|
||||
device: torch.device = torch.device('cpu')
|
||||
) -> None:
|
||||
"""Train and evaluate model with deterministic behavior and comprehensive logging."""
|
||||
# Enable deterministic algorithms for PyTorch 2.5
|
||||
torch.use_deterministic_algorithms(True)
|
||||
|
||||
num_training_steps: int = config['epochs'] * len(train_dataloader)
|
||||
best_valid_loss: float = float('inf')
|
||||
|
||||
# Save initial model state for reproducibility
|
||||
torch.save({
|
||||
'model_state': model.state_dict(),
|
||||
'optimizer_state': optimizer.state_dict(),
|
||||
'scheduler_state': scheduler.state_dict(),
|
||||
'rng_state': torch.get_rng_state(),
|
||||
'cuda_rng_state': torch.cuda.get_rng_state() if torch.cuda.is_available() else None
|
||||
}, output_dir / 'initial_state.pt')
|
||||
|
||||
with tqdm(total=num_training_steps, desc='Training') as pbar:
|
||||
for epoch_idx in range(config['epochs']):
|
||||
model.train()
|
||||
epoch_loss = 0.0
|
||||
epoch_steps = 0
|
||||
|
||||
for train_idx, train_batch in enumerate(train_dataloader):
|
||||
# Move batch to device
|
||||
train_batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v
|
||||
for k, v in train_batch.items()}
|
||||
|
||||
# Forward pass
|
||||
outputs = model(**train_batch)
|
||||
train_loss = outputs.loss if hasattr(outputs, 'loss') else outputs['loss']
|
||||
|
||||
# Backward pass with deterministic behavior
|
||||
train_loss.backward()
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
model.parameters(),
|
||||
config['max_grad_norm']
|
||||
)
|
||||
|
||||
optimizer.step()
|
||||
scheduler.step()
|
||||
optimizer.zero_grad(set_to_none=False) # Use False for determinism
|
||||
|
||||
# Update metrics
|
||||
current_loss = train_loss.item()
|
||||
epoch_loss += current_loss
|
||||
epoch_steps += 1
|
||||
|
||||
# Calculate global step
|
||||
step = train_idx + len(train_dataloader) * epoch_idx
|
||||
|
||||
# Prepare logging dictionary
|
||||
log_dict = {
|
||||
'step': step,
|
||||
'epoch': epoch_idx,
|
||||
'train_loss': current_loss,
|
||||
'gradient_norm': grad_norm.item(),
|
||||
'learning_rate': scheduler.get_last_lr()[0],
|
||||
}
|
||||
|
||||
# Add embedding weights if using tree model
|
||||
if log_weights and 'embedding_weights' in outputs:
|
||||
weights = outputs['embedding_weights']
|
||||
log_dict.update({
|
||||
'token_weight': weights['token'],
|
||||
'tree_weight': weights['tree'],
|
||||
'sequential_weight': weights['sequential']
|
||||
})
|
||||
pbar_dict = {
|
||||
'epoch': f"{epoch_idx + 1}/{config['epochs']}",
|
||||
'train_loss': f"{current_loss:.3f}",
|
||||
'α': f"{weights['token']:.2f}",
|
||||
'β': f"{weights['tree']:.2f}",
|
||||
'γ': f"{weights['sequential']:.2f}"
|
||||
}
|
||||
else:
|
||||
pbar_dict = {
|
||||
'epoch': f"{epoch_idx + 1}/{config['epochs']}",
|
||||
'train_loss': f"{current_loss:.3f}"
|
||||
}
|
||||
|
||||
# Log to wandb
|
||||
wandb.log(log_dict)
|
||||
|
||||
# Update progress bar
|
||||
pbar.update(1)
|
||||
pbar.set_postfix(pbar_dict)
|
||||
|
||||
# Periodic evaluation
|
||||
if (train_idx != 0 and train_idx % config['eval_every'] == 0) or train_idx == len(train_dataloader) - 1:
|
||||
model.eval()
|
||||
valid_loss, valid_acc = evaluate(model, valid_dataloader)
|
||||
model.train()
|
||||
|
||||
# Log validation metrics
|
||||
wandb.log({
|
||||
'valid_loss': valid_loss,
|
||||
'valid_acc': valid_acc,
|
||||
'step': step,
|
||||
'epoch': epoch_idx,
|
||||
})
|
||||
|
||||
# Save checkpoint if best model
|
||||
if valid_loss < best_valid_loss:
|
||||
best_valid_loss = valid_loss
|
||||
|
||||
# Save complete state
|
||||
torch.save({
|
||||
'epoch': epoch_idx,
|
||||
'step': step,
|
||||
'model_state': model.state_dict(),
|
||||
'optimizer_state': optimizer.state_dict(),
|
||||
'scheduler_state': scheduler.state_dict(),
|
||||
'rng_state': torch.get_rng_state(),
|
||||
'cuda_rng_state': torch.cuda.get_rng_state() if torch.cuda.is_available() else None,
|
||||
'best_valid_loss': best_valid_loss,
|
||||
'config': config,
|
||||
'device_info': get_device_info()
|
||||
}, output_dir / 'best_model.pt')
|
||||
|
||||
# Log best metrics
|
||||
wandb.run.summary['best_valid_loss'] = best_valid_loss
|
||||
wandb.run.summary['best_valid_acc'] = valid_acc
|
||||
wandb.run.summary['best_epoch'] = epoch_idx
|
||||
wandb.run.summary['best_step'] = step
|
||||
|
||||
# End of epoch logging
|
||||
avg_epoch_loss = epoch_loss / epoch_steps
|
||||
wandb.log({
|
||||
'epoch': epoch_idx,
|
||||
'epoch_avg_loss': avg_epoch_loss,
|
||||
})
|
||||
|
||||
# Save end of epoch checkpoint
|
||||
torch.save({
|
||||
'epoch': epoch_idx,
|
||||
'model_state': model.state_dict(),
|
||||
'optimizer_state': optimizer.state_dict(),
|
||||
'scheduler_state': scheduler.state_dict(),
|
||||
'rng_state': torch.get_rng_state(),
|
||||
'cuda_rng_state': torch.cuda.get_rng_state() if torch.cuda.is_available() else None,
|
||||
'train_loss': avg_epoch_loss,
|
||||
'config': config,
|
||||
'device_info': get_device_info()
|
||||
}, output_dir / f'checkpoint_epoch_{epoch_idx}.pt')
|
||||
|
||||
# End of training logging
|
||||
wandb.run.summary['final_epoch'] = config['epochs'] - 1
|
||||
wandb.run.summary['total_steps'] = num_training_steps
|
||||
|
||||
logger.info(f'Training completed. Best validation loss: {best_valid_loss:.4f}')
|
||||
logger.info(f'Model checkpoints saved in {output_dir}')
|
||||
|
||||
def load_all_chunks(chunks_dir: Path) -> Any:
|
||||
"""Load and concatenate all dataset chunks."""
|
||||
logger.info(f"Loading dataset chunks from {chunks_dir}")
|
||||
chunks = []
|
||||
for chunk_path in sorted(chunks_dir.glob('chunk_*'))[:5]:
|
||||
chunks.append(load_from_disk(str(chunk_path)))
|
||||
dataset = concatenate_datasets(chunks)
|
||||
logger.info(f"Loaded {len(dataset)} examples from {len(chunks)} chunks")
|
||||
return dataset
|
||||
|
||||
def create_base_dataloaders(
|
||||
processed_data_dir: Path,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
config: Dict[str, Any],
|
||||
device: torch.device,
|
||||
dataset_class: Type[Dataset] = PreprocessedBaseDataset,
|
||||
) -> Tuple[DataLoader, DataLoader]:
|
||||
"""Create reproducible dataloaders from pre-processed data."""
|
||||
# Load chunks
|
||||
chunks_dir = processed_data_dir / 'chunks'
|
||||
full_dataset = load_all_chunks(chunks_dir)
|
||||
|
||||
# Calculate split sizes
|
||||
dataset_size = len(full_dataset)
|
||||
train_size = int(config['train_size'] * dataset_size)
|
||||
val_size = dataset_size - train_size
|
||||
|
||||
# Create splits using indices to avoid device issues
|
||||
indices = torch.arange(dataset_size, device=device)
|
||||
|
||||
# Create a deterministic generator on the correct device
|
||||
generator = torch.Generator(device=device)
|
||||
generator.manual_seed(42)
|
||||
|
||||
# Shuffle indices deterministically
|
||||
shuffled_indices = indices[torch.randperm(len(indices), device=device, generator=generator)]
|
||||
train_indices = shuffled_indices[:train_size].cpu()
|
||||
valid_indices = shuffled_indices[train_size:].cpu()
|
||||
|
||||
# Create train/validation splits using the shuffled indices
|
||||
train_dataset = torch.utils.data.Subset(full_dataset, train_indices.tolist())
|
||||
valid_dataset = torch.utils.data.Subset(full_dataset, valid_indices.tolist())
|
||||
|
||||
# Create dataset wrappers
|
||||
train_dataset = dataset_class(train_dataset)
|
||||
valid_dataset = dataset_class(valid_dataset)
|
||||
|
||||
logger.info(f"Created train dataset with {len(train_dataset)} samples")
|
||||
logger.info(f"Created validation dataset with {len(valid_dataset)} samples")
|
||||
|
||||
# Use TreeDataCollator for both base and tree models
|
||||
data_collator = TreeDataCollator(
|
||||
tokenizer=tokenizer,
|
||||
mlm=True,
|
||||
mlm_probability=config['mlm_probability']
|
||||
)
|
||||
|
||||
# Create train dataloader without generator
|
||||
train_dataloader = DataLoader(
|
||||
train_dataset,
|
||||
batch_size=config['batch'],
|
||||
shuffle=False, # We already shuffled the data
|
||||
collate_fn=data_collator,
|
||||
num_workers=0, # Use single worker for reproducibility
|
||||
drop_last=True # Ensure consistent batch sizes
|
||||
)
|
||||
|
||||
# Create validation dataloader
|
||||
valid_dataloader = DataLoader(
|
||||
valid_dataset,
|
||||
batch_size=config['batch'],
|
||||
shuffle=False,
|
||||
collate_fn=data_collator,
|
||||
num_workers=0,
|
||||
drop_last=True
|
||||
)
|
||||
|
||||
return train_dataloader, valid_dataloader
|
@ -1,36 +0,0 @@
|
||||
def remove_docstrings_and_comments_from_code(code, parser):
|
||||
# Parse the code
|
||||
tree = parser.parse(bytes(code, "utf8"))
|
||||
cursor = tree.walk()
|
||||
|
||||
# Traverse the tree and collect all docstrings
|
||||
to_remove = []
|
||||
|
||||
def traverse_tree(cursor, prev_node_type=None):
|
||||
node_type = cursor.node.type
|
||||
node_text = cursor.node.text.decode("utf-8")
|
||||
# Check if the current node is a function or class definition
|
||||
if node_type == "string" and node_text.startswith('"""') and node_text.endswith('"""') and prev_node_type == "expression_statement":
|
||||
to_remove.append((cursor.node.start_byte, cursor.node.end_byte))
|
||||
if cursor.node.type == "comment":
|
||||
to_remove.append((cursor.node.start_byte, cursor.node.end_byte))
|
||||
|
||||
# Traverse children
|
||||
if cursor.goto_first_child():
|
||||
while True:
|
||||
traverse_tree(cursor, node_type)
|
||||
if not cursor.goto_next_sibling():
|
||||
break
|
||||
cursor.goto_parent()
|
||||
|
||||
return node_type
|
||||
|
||||
# Start traversing from the root
|
||||
traverse_tree(cursor)
|
||||
|
||||
# Remove docstrings from code
|
||||
code_without_docstrings = code
|
||||
for start, end in sorted(to_remove, reverse=True):
|
||||
code_without_docstrings = code_without_docstrings[:start] + code_without_docstrings[end:]
|
||||
|
||||
return code_without_docstrings
|
Loading…
Reference in New Issue
Block a user