on this commit there is a weird loss spike (unresolved)
This commit is contained in:
parent
0cd04e6131
commit
b2d65059da
@ -7,15 +7,18 @@ authors = [
|
|||||||
]
|
]
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"wandb==0.18.5",
|
"wandb==0.18.5",
|
||||||
"torch==2.5.0",
|
"torch==2.5.1",
|
||||||
"tqdm==4.66.5",
|
"tqdm==4.66.5",
|
||||||
"tree-sitter==0.23.1",
|
"tree-sitter==0.23.1",
|
||||||
"transformers==4.45.2",
|
"transformers[torch]>=4.46.3",
|
||||||
"datasets==3.0.1",
|
"datasets==3.0.1",
|
||||||
"huggingface-hub==0.26.0",
|
"huggingface-hub==0.26.0",
|
||||||
"matplotlib==3.9.2",
|
"matplotlib==3.9.2",
|
||||||
"scikit-learn==1.5.2",
|
"scikit-learn==1.5.2",
|
||||||
"seaborn==0.13.2",
|
"seaborn==0.13.2",
|
||||||
|
"tree-sitter-python==0.23.4",
|
||||||
|
"ipykernel>=6.29.5",
|
||||||
|
"ipywidgets>=8.1.5",
|
||||||
]
|
]
|
||||||
requires-python = "==3.11.*"
|
requires-python = "==3.11.*"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
@ -35,6 +38,5 @@ build-backend = "pdm.backend"
|
|||||||
distribution = true
|
distribution = true
|
||||||
|
|
||||||
[tool.pdm.scripts]
|
[tool.pdm.scripts]
|
||||||
run_training = {cmd = "src/train_codebert_mlm.py"}
|
|
||||||
run_tree_training = {cmd = "src/train_tree_codebert_mlm.py"}
|
|
||||||
parse_dataset = {cmd = "src/parse_dataset.py"}
|
parse_dataset = {cmd = "src/parse_dataset.py"}
|
||||||
|
run_training = {cmd = "src/training.py"}
|
||||||
|
@ -1,12 +1,13 @@
|
|||||||
{
|
{
|
||||||
"seed": 42,
|
"seed": 420,
|
||||||
"mlm_probability": 0.15,
|
"mlm_probability": 0.15,
|
||||||
"batch": 16,
|
"batch_size": 32,
|
||||||
"epochs": 1,
|
"epochs": 5,
|
||||||
"eval_every": 20000,
|
"eval_every": 1000,
|
||||||
"learning_rate": 5e-4,
|
"learning_rate": 5e-4,
|
||||||
"weight_decay": 0.01,
|
"weight_decay": 0.1,
|
||||||
"max_grad_norm": 1.0,
|
"max_grad_norm": 1.0,
|
||||||
"warmup_steps": 20000,
|
"warmup_steps": 1000,
|
||||||
"train_size": 0.95
|
"train_size": 0.95,
|
||||||
|
"fp16": true
|
||||||
}
|
}
|
4507
code/src/dataset_parsing_test.ipynb
Normal file
4507
code/src/dataset_parsing_test.ipynb
Normal file
File diff suppressed because it is too large
Load Diff
22
code/src/download_dataset.py
Normal file
22
code/src/download_dataset.py
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List
|
||||||
|
from huggingface_hub import list_repo_files, hf_hub_download
|
||||||
|
|
||||||
|
# Setup logging
|
||||||
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
def download_dataset(dataset_dir: Path) -> None:
|
||||||
|
if not dataset_dir.exists():
|
||||||
|
logger.info("Downloading the dataset...")
|
||||||
|
dataset_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
files_list: List[str] = list_repo_files(repo_id='bigcode/the-stack-dedup', repo_type='dataset')
|
||||||
|
files_to_download: List[str] = [file for file in files_list if file.startswith('data/python/')]
|
||||||
|
for file_name in files_to_download:
|
||||||
|
hf_hub_download(repo_id='bigcode/the-stack-dedup', repo_type='dataset', filename=file_name, local_dir=dataset_dir)
|
||||||
|
logger.info("Dataset downloaded successfully.")
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
dataset_dir = Path('data/the-stack-python')
|
||||||
|
download_dataset(dataset_dir)
|
@ -1,128 +1,123 @@
|
|||||||
import ast
|
|
||||||
from pathlib import Path
|
|
||||||
import logging
|
import logging
|
||||||
|
import multiprocessing
|
||||||
|
import tree_sitter_python as tspython
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
from pathlib import Path
|
||||||
from pprint import pprint
|
from pprint import pprint
|
||||||
from typing import Dict, Any, List
|
from typing import Dict, Any, List
|
||||||
from datasets import load_dataset, Dataset, concatenate_datasets
|
from datasets import load_dataset, Dataset
|
||||||
from transformers import RobertaTokenizer
|
from tree_sitter import Language, Parser
|
||||||
from dataclasses import dataclass
|
from transformers import AutoTokenizer
|
||||||
import numpy as np
|
|
||||||
import multiprocessing
|
|
||||||
|
|
||||||
import warnings
|
import warnings
|
||||||
warnings.filterwarnings("ignore", category=SyntaxWarning)
|
warnings.filterwarnings("ignore", category=SyntaxWarning)
|
||||||
|
|
||||||
# Setup logging
|
# Configure logging
|
||||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format='%(asctime)s - %(levelname)s - %(message)s',
|
||||||
|
datefmt='%Y-%m-%d %H:%M:%S'
|
||||||
|
)
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@dataclass
|
def analyze_code_with_codebert_and_treesitter(code_snippet: str, tokenizer: AutoTokenizer):
|
||||||
class ASTNodeInfo:
|
"""
|
||||||
"""Stores structural information about an AST node."""
|
Map tokens to their original text and Tree-sitter AST position
|
||||||
node_type: str
|
"""
|
||||||
start_token_idx: int
|
# Initialize Tree-sitter
|
||||||
end_token_idx: int
|
PY_LANGUAGE = Language(tspython.language())
|
||||||
depth: int
|
parser = Parser(PY_LANGUAGE)
|
||||||
sibling_pos: int
|
|
||||||
parent_idx: int
|
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, Any]:
|
# Parse with Tree-sitter
|
||||||
return {
|
tree = parser.parse(bytes(code_snippet, "utf8"))
|
||||||
'node_type': self.node_type,
|
|
||||||
'start_token_idx': self.start_token_idx,
|
|
||||||
'end_token_idx': self.end_token_idx,
|
|
||||||
'depth': self.depth,
|
|
||||||
'sibling_pos': self.sibling_pos,
|
|
||||||
'parent_idx': self.parent_idx
|
|
||||||
}
|
|
||||||
|
|
||||||
def from_dict(data: Dict[str, Any]) -> 'ASTNodeInfo':
|
encoded = tokenizer(
|
||||||
return ASTNodeInfo(
|
code_snippet,
|
||||||
node_type=data['node_type'],
|
add_special_tokens=True,
|
||||||
start_token_idx=data['start_token_idx'],
|
return_offsets_mapping=True,
|
||||||
end_token_idx=data['end_token_idx'],
|
return_tensors='pt',
|
||||||
depth=data['depth'],
|
padding='max_length',
|
||||||
sibling_pos=data['sibling_pos'],
|
truncation=True,
|
||||||
parent_idx=data['parent_idx']
|
max_length=512,
|
||||||
)
|
)
|
||||||
|
|
||||||
def safe_ast_parse(content: str, max_length: int = 50000) -> bool:
|
tokens = tokenizer.convert_ids_to_tokens(encoded['input_ids'][0])
|
||||||
"""Safely attempt to parse Python code."""
|
offset_mapping = encoded['offset_mapping'][0].tolist()
|
||||||
if not content or not content.strip() or '\0' in content or len(content) > max_length:
|
|
||||||
return False
|
|
||||||
|
|
||||||
try:
|
def get_node_position(node, depth=0, idx=0):
|
||||||
tree = ast.parse(content)
|
"""Get depth and sibling index for a node"""
|
||||||
return bool(list(ast.walk(tree)))
|
return (depth, idx)
|
||||||
except:
|
|
||||||
return False
|
|
||||||
|
|
||||||
def process_and_extract_ast_info(content: str, max_ast_size: int = 1000) -> List[Dict[str, Any]]:
|
def find_node_at_position(node, start_byte, depth=0, sibling_idx=0):
|
||||||
"""Process AST and extract node information."""
|
"""Find the most specific node containing the given position"""
|
||||||
try:
|
if start_byte >= node.start_byte and start_byte < node.end_byte:
|
||||||
tree = ast.parse(content)
|
# Check children first for more specific nodes
|
||||||
nodes_info = []
|
for idx, child in enumerate(node.children):
|
||||||
|
result = find_node_at_position(child, start_byte, depth + 1, idx)
|
||||||
|
if result:
|
||||||
|
return result
|
||||||
|
# Return current node's position if no child contains the position
|
||||||
|
return get_node_position(node, depth, sibling_idx)
|
||||||
|
return None
|
||||||
|
|
||||||
def visit_node(node: ast.AST, depth: int, parent_idx: int, sibling_pos: int):
|
depths = []
|
||||||
if len(nodes_info) >= max_ast_size:
|
sibling_idxs = []
|
||||||
return
|
node_texts = []
|
||||||
|
for token, (start, end) in zip(tokens, offset_mapping):
|
||||||
|
if token not in ['<s>', '</s>', '<pad>', '<unk>']:
|
||||||
|
text = code_snippet[start:end] if start < len(code_snippet) else ""
|
||||||
|
# Tree-sitter works with bytes, so convert position
|
||||||
|
start_byte = len(code_snippet[:start].encode('utf8')) if start < len(code_snippet) else 0
|
||||||
|
position = find_node_at_position(tree.root_node, start_byte) or (-1, -1)
|
||||||
|
|
||||||
current_idx = len(nodes_info)
|
depths.append(position[0])
|
||||||
if hasattr(node, 'lineno'):
|
sibling_idxs.append(position[1])
|
||||||
node_info = ASTNodeInfo(
|
node_texts.append(text)
|
||||||
node_type=type(node).__name__,
|
else:
|
||||||
start_token_idx=node.lineno,
|
depths.append(-1)
|
||||||
end_token_idx=getattr(node, 'end_lineno', node.lineno),
|
sibling_idxs.append(-1)
|
||||||
depth=min(depth, 31),
|
node_texts.append(None)
|
||||||
sibling_pos=min(sibling_pos, 31),
|
|
||||||
parent_idx=parent_idx
|
|
||||||
)
|
|
||||||
nodes_info.append(node_info.to_dict())
|
|
||||||
|
|
||||||
if len(nodes_info) < max_ast_size:
|
return encoded['input_ids'], encoded['attention_mask'], tokens, depths, sibling_idxs, node_texts
|
||||||
for i, child in enumerate(ast.iter_child_nodes(node)):
|
|
||||||
visit_node(child, depth + 1, current_idx, i)
|
|
||||||
|
|
||||||
visit_node(tree, 0, -1, 0)
|
def process_batch(examples, tokenizer) -> Dict[str, List[Any]]:
|
||||||
return nodes_info
|
|
||||||
except:
|
|
||||||
return []
|
|
||||||
|
|
||||||
def process_batch(examples: Dict[str, List[Any]], tokenizer: RobertaTokenizer) -> Dict[str, List[Any]]:
|
|
||||||
"""Process a batch of examples."""
|
"""Process a batch of examples."""
|
||||||
contents = examples['content']
|
contents = examples['content']
|
||||||
|
|
||||||
processed_contents = []
|
|
||||||
processed_ast_nodes = []
|
|
||||||
processed_input_ids = []
|
processed_input_ids = []
|
||||||
processed_attention_masks = []
|
processed_attention_mask = []
|
||||||
|
processed_tokens = []
|
||||||
|
processed_depths = []
|
||||||
|
processed_sibling_idxs = []
|
||||||
|
processed_node_texts = []
|
||||||
|
|
||||||
for content in contents:
|
for content in contents:
|
||||||
if safe_ast_parse(content):
|
try:
|
||||||
ast_nodes = process_and_extract_ast_info(content)
|
input_ids, attention_mask, tokens, depths, sibling_idxs, node_texts = analyze_code_with_codebert_and_treesitter(content, tokenizer)
|
||||||
if ast_nodes:
|
processed_input_ids.append(input_ids[0])
|
||||||
try:
|
processed_attention_mask.append(attention_mask[0])
|
||||||
encoding = tokenizer(
|
processed_tokens.append(tokens)
|
||||||
content,
|
processed_depths.append(depths)
|
||||||
max_length=512,
|
processed_sibling_idxs.append(sibling_idxs)
|
||||||
truncation=True,
|
processed_node_texts.append(node_texts)
|
||||||
padding='max_length',
|
|
||||||
return_tensors='pt'
|
|
||||||
)
|
|
||||||
|
|
||||||
processed_contents.append(content)
|
except Exception as e:
|
||||||
processed_ast_nodes.append(ast_nodes)
|
logger.error(f"Error processing example: {e}")
|
||||||
processed_input_ids.append(encoding['input_ids'].squeeze(0).tolist())
|
# Return empty lists that will be filtered out later
|
||||||
processed_attention_masks.append(encoding['attention_mask'].squeeze(0).tolist())
|
processed_input_ids.append([])
|
||||||
except:
|
processed_attention_mask.append([])
|
||||||
continue
|
processed_tokens.append([])
|
||||||
|
processed_depths.append([])
|
||||||
|
processed_sibling_idxs.append([])
|
||||||
|
processed_node_texts.append([])
|
||||||
|
|
||||||
return {
|
return {
|
||||||
'content': processed_contents,
|
|
||||||
'ast_nodes': processed_ast_nodes,
|
|
||||||
'input_ids': processed_input_ids,
|
'input_ids': processed_input_ids,
|
||||||
'attention_mask': processed_attention_masks,
|
'attention_mask': processed_attention_mask,
|
||||||
|
'tokens': processed_tokens,
|
||||||
|
'depths': processed_depths,
|
||||||
|
'sibling_idxs': processed_sibling_idxs,
|
||||||
|
'node_texts': processed_node_texts,
|
||||||
}
|
}
|
||||||
|
|
||||||
def save_dataset_in_chunks(dataset: Dataset, output_path: str, chunk_size: int = 10000):
|
def save_dataset_in_chunks(dataset: Dataset, output_path: str, chunk_size: int = 10000):
|
||||||
@ -131,14 +126,9 @@ def save_dataset_in_chunks(dataset: Dataset, output_path: str, chunk_size: int =
|
|||||||
|
|
||||||
# Create directory for chunks
|
# Create directory for chunks
|
||||||
output_dir = Path(output_path).parent
|
output_dir = Path(output_path).parent
|
||||||
chunks_dir = output_dir / 'chunks'
|
chunks_dir = output_dir / 'data'
|
||||||
chunks_dir.mkdir(exist_ok=True)
|
chunks_dir.mkdir(exist_ok=True)
|
||||||
|
|
||||||
stats_accumulator = {
|
|
||||||
'total_ast_nodes': 0,
|
|
||||||
'total_samples': 0
|
|
||||||
}
|
|
||||||
|
|
||||||
pbar = tqdm(range(num_chunks), desc="Saving chunks", unit="chunk")
|
pbar = tqdm(range(num_chunks), desc="Saving chunks", unit="chunk")
|
||||||
for i in pbar:
|
for i in pbar:
|
||||||
start_idx = i * chunk_size
|
start_idx = i * chunk_size
|
||||||
@ -147,28 +137,10 @@ def save_dataset_in_chunks(dataset: Dataset, output_path: str, chunk_size: int =
|
|||||||
# Select chunk from dataset
|
# Select chunk from dataset
|
||||||
chunk_dataset = dataset.select(range(start_idx, end_idx))
|
chunk_dataset = dataset.select(range(start_idx, end_idx))
|
||||||
|
|
||||||
# Update statistics
|
if len(chunk_dataset) > 0: # Only process if chunk has data
|
||||||
stats_accumulator['total_ast_nodes'] += sum(len(nodes) for nodes in chunk_dataset['ast_nodes'])
|
# Save chunk using datasets native method
|
||||||
stats_accumulator['total_samples'] += len(chunk_dataset)
|
chunk_path = chunks_dir / f'data-{i:05d}-of-{num_chunks:05d}.parquet'
|
||||||
|
chunk_dataset.to_parquet(str(chunk_path))
|
||||||
# Save chunk using datasets native method
|
|
||||||
chunk_path = chunks_dir / f'chunk_{i:04d}'
|
|
||||||
chunk_dataset.save_to_disk(str(chunk_path))
|
|
||||||
|
|
||||||
# Update progress bar postfix with current chunk info
|
|
||||||
pbar.set_postfix({
|
|
||||||
'samples': len(chunk_dataset),
|
|
||||||
'path': str(chunk_path.name)
|
|
||||||
})
|
|
||||||
|
|
||||||
return stats_accumulator
|
|
||||||
|
|
||||||
def load_all_chunks(chunks_dir: Path) -> Dataset:
|
|
||||||
"""Load and concatenate all dataset chunks."""
|
|
||||||
chunks = []
|
|
||||||
for chunk_path in sorted(chunks_dir.glob('chunk_*')):
|
|
||||||
chunks.append(load_from_disk(str(chunk_path)))
|
|
||||||
return concatenate_datasets(chunks)
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
current_dir = Path(__file__).parent
|
current_dir = Path(__file__).parent
|
||||||
@ -176,14 +148,23 @@ def main():
|
|||||||
output_dir = current_dir.parent / 'data' / 'processed-python'
|
output_dir = current_dir.parent / 'data' / 'processed-python'
|
||||||
output_dir.mkdir(parents=True, exist_ok=True)
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
tokenizer = RobertaTokenizer.from_pretrained('microsoft/codebert-base')
|
tokenizer = AutoTokenizer.from_pretrained('microsoft/codebert-base', use_fast=True)
|
||||||
|
|
||||||
logger.info("Loading dataset...")
|
logger.info("Loading dataset...")
|
||||||
dataset = load_dataset(str(input_dir))['train']
|
dataset = load_dataset(str(input_dir))['train']
|
||||||
original_dataset_size = len(dataset)
|
# dataset = dataset.select(range(200_000)) # Limit dataset size for testing
|
||||||
|
|
||||||
logger.info("Dataset:")
|
original_dataset_size = len(dataset)
|
||||||
pprint(dataset)
|
logger.info(f"Original dataset size: {original_dataset_size}")
|
||||||
|
|
||||||
|
columns_to_remove = dataset.column_names
|
||||||
|
columns_to_remove.remove('content')
|
||||||
|
columns_to_remove.remove('size')
|
||||||
|
columns_to_remove.remove('avg_line_length')
|
||||||
|
columns_to_remove.remove('max_line_length')
|
||||||
|
columns_to_remove.remove('alphanum_fraction')
|
||||||
|
columns_to_remove.remove('max_stars_count')
|
||||||
|
logging.info(f"Columns to remove: {columns_to_remove}")
|
||||||
|
|
||||||
num_proc = min(multiprocessing.cpu_count() - 1, 32)
|
num_proc = min(multiprocessing.cpu_count() - 1, 32)
|
||||||
logger.info(f"Using {num_proc} processes for dataset processing")
|
logger.info(f"Using {num_proc} processes for dataset processing")
|
||||||
@ -193,35 +174,38 @@ def main():
|
|||||||
process_batch,
|
process_batch,
|
||||||
fn_kwargs={'tokenizer': tokenizer},
|
fn_kwargs={'tokenizer': tokenizer},
|
||||||
batched=True,
|
batched=True,
|
||||||
remove_columns=dataset.column_names,
|
remove_columns=columns_to_remove,
|
||||||
desc="Processing dataset",
|
desc="Processing dataset",
|
||||||
num_proc=num_proc,
|
num_proc=num_proc,
|
||||||
load_from_cache_file=False
|
load_from_cache_file=False
|
||||||
)
|
)
|
||||||
|
|
||||||
processed_dataset = processed_dataset.filter(
|
processed_dataset = processed_dataset.filter(
|
||||||
lambda batch: [len(nodes) > 0 for nodes in batch['ast_nodes']],
|
lambda batch: [len(tokens) > 0 for tokens in batch['tokens']],
|
||||||
batched=True,
|
batched=True,
|
||||||
num_proc=num_proc,
|
num_proc=num_proc,
|
||||||
desc="Filtering invalid examples"
|
desc="Filtering invalid examples"
|
||||||
)
|
)
|
||||||
|
|
||||||
reduced_dataset_size = len(processed_dataset)
|
reduced_dataset_size = len(processed_dataset)
|
||||||
|
logger.info(f"Processed dataset size: {reduced_dataset_size}")
|
||||||
|
|
||||||
logger.info("Processed dataset:")
|
if reduced_dataset_size == 0:
|
||||||
pprint(processed_dataset)
|
logger.error("No valid examples found in dataset!")
|
||||||
|
return
|
||||||
|
|
||||||
logger.info(f"Saving {len(processed_dataset)} processed examples in chunks...")
|
logger.info(f"Saving {reduced_dataset_size} processed examples in chunks...")
|
||||||
stats_accumulator = save_dataset_in_chunks(
|
save_dataset_in_chunks(
|
||||||
processed_dataset,
|
processed_dataset,
|
||||||
str(output_dir / 'processed_dataset'),
|
str(output_dir / 'processed_dataset'),
|
||||||
chunk_size=100_000
|
chunk_size=100_000
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Add derived statistics
|
||||||
stats = {
|
stats = {
|
||||||
'original_dataset_size': original_dataset_size,
|
'original_dataset_size': original_dataset_size,
|
||||||
'reduced_dataset_size': reduced_dataset_size,
|
'reduced_dataset_size': reduced_dataset_size,
|
||||||
'% samples removed': (1 - reduced_dataset_size / original_dataset_size) * 100,
|
'samples_removed_pct': (1 - reduced_dataset_size / original_dataset_size) * 100,
|
||||||
'avg_ast_nodes': float(stats_accumulator['total_ast_nodes'] / stats_accumulator['total_samples'])
|
|
||||||
}
|
}
|
||||||
|
|
||||||
# Log stats with pprint
|
# Log stats with pprint
|
||||||
|
50
code/src/reduce_dataset.py
Normal file
50
code/src/reduce_dataset.py
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
import logging
|
||||||
|
import multiprocessing
|
||||||
|
from pathlib import Path
|
||||||
|
from datasets import load_dataset
|
||||||
|
|
||||||
|
from parse_dataset import save_dataset_in_chunks
|
||||||
|
|
||||||
|
# Configure logging
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format='%(asctime)s - %(levelname)s - %(message)s',
|
||||||
|
datefmt='%Y-%m-%d %H:%M:%S'
|
||||||
|
)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
def main():
|
||||||
|
current_dir = Path(__file__).parent
|
||||||
|
input_dir = current_dir.parent / 'data' / 'processed-python'
|
||||||
|
output_dir = current_dir.parent / 'data' / 'filtered-python'
|
||||||
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
num_proc = min(multiprocessing.cpu_count() - 1, 32)
|
||||||
|
logger.info(f"Using {num_proc} processes for dataset processing")
|
||||||
|
|
||||||
|
logger.info("Loading dataset...")
|
||||||
|
dataset = load_dataset(str(input_dir))['data']
|
||||||
|
logging.info(f"Dataset:\n{dataset}")
|
||||||
|
|
||||||
|
logger.info("Filtering dataset by max_stars_count > 3...")
|
||||||
|
filtered_dataset = dataset.filter(
|
||||||
|
lambda batch: [example and example > 3 for example in batch['max_stars_count']],
|
||||||
|
num_proc=num_proc,
|
||||||
|
batched=True,
|
||||||
|
desc="Filtering dataset"
|
||||||
|
)
|
||||||
|
|
||||||
|
filtered_dataset_size = len(filtered_dataset)
|
||||||
|
logger.info(f"Filtered dataset size: {filtered_dataset_size}")
|
||||||
|
|
||||||
|
if filtered_dataset_size == 0:
|
||||||
|
logger.error("No examples found with max_stars_count > 3!")
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info(f"Saving {filtered_dataset_size} filtered examples...")
|
||||||
|
save_dataset_in_chunks(filtered_dataset, output_dir, chunk_size=100_000)
|
||||||
|
|
||||||
|
logger.info("Filtering and saving completed!")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
@ -1,12 +0,0 @@
|
|||||||
{
|
|
||||||
"seed": 42,
|
|
||||||
"mlm_probability": 0.15,
|
|
||||||
"batch": 16,
|
|
||||||
"epochs": 1,
|
|
||||||
"eval_every": 5000,
|
|
||||||
"learning_rate": 5e-4,
|
|
||||||
"weight_decay": 0.01,
|
|
||||||
"max_grad_norm": 1.0,
|
|
||||||
"warmup_steps": 5000,
|
|
||||||
"train_size": 0.95
|
|
||||||
}
|
|
@ -1,50 +0,0 @@
|
|||||||
from pathlib import Path
|
|
||||||
import torch
|
|
||||||
from transformers import RobertaForMaskedLM, RobertaConfig, RobertaTokenizer
|
|
||||||
|
|
||||||
from training_utils import (
|
|
||||||
set_seed, setup_wandb, setup_directories, load_config,
|
|
||||||
setup_device, create_optimizer_and_scheduler,
|
|
||||||
train_and_evaluate, create_base_dataloaders, set_deterministic_mode
|
|
||||||
)
|
|
||||||
|
|
||||||
def main() -> None:
|
|
||||||
set_deterministic_mode()
|
|
||||||
|
|
||||||
current_dir = Path(__file__).parent
|
|
||||||
output_dir = setup_directories(current_dir, model_name='base')
|
|
||||||
config = load_config(current_dir / 'tmp_config.json')
|
|
||||||
|
|
||||||
setup_wandb(config, model_name='base')
|
|
||||||
set_seed(config['seed'])
|
|
||||||
device = setup_device()
|
|
||||||
|
|
||||||
# Load tokenizer
|
|
||||||
tokenizer = RobertaTokenizer.from_pretrained('microsoft/codebert-base')
|
|
||||||
|
|
||||||
# Create dataloaders
|
|
||||||
processed_data_dir = current_dir.parent / 'data' / 'processed-python'
|
|
||||||
train_dataloader, valid_dataloader = create_base_dataloaders(
|
|
||||||
processed_data_dir,
|
|
||||||
tokenizer,
|
|
||||||
config,
|
|
||||||
device
|
|
||||||
)
|
|
||||||
|
|
||||||
# Model setup
|
|
||||||
model_config = RobertaConfig.from_pretrained('microsoft/codebert-base')
|
|
||||||
model = RobertaForMaskedLM(model_config)
|
|
||||||
model = model.to(device)
|
|
||||||
|
|
||||||
# Training setup and run
|
|
||||||
num_training_steps = config['epochs'] * len(train_dataloader)
|
|
||||||
optimizer, scheduler = create_optimizer_and_scheduler(model, config, num_training_steps)
|
|
||||||
|
|
||||||
train_and_evaluate(
|
|
||||||
model, train_dataloader, valid_dataloader,
|
|
||||||
optimizer, scheduler, config, output_dir,
|
|
||||||
log_weights=False, device=device
|
|
||||||
)
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
@ -1,227 +0,0 @@
|
|||||||
import math
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Dict, List, Optional
|
|
||||||
from transformers import RobertaConfig, RobertaTokenizer, RobertaForMaskedLM
|
|
||||||
|
|
||||||
from training_utils import (
|
|
||||||
set_seed, setup_wandb, setup_directories, load_config,
|
|
||||||
setup_device, create_optimizer_and_scheduler,
|
|
||||||
train_and_evaluate, create_base_dataloaders, PreprocessedTreeDataset,
|
|
||||||
set_deterministic_mode
|
|
||||||
)
|
|
||||||
from parse_dataset import ASTNodeInfo
|
|
||||||
|
|
||||||
class TreePositionalEmbedding(nn.Module):
|
|
||||||
"""Generates tree-aware positional embeddings for code tokens."""
|
|
||||||
|
|
||||||
def __init__(self, d_model: int = 768, max_depth: int = 32):
|
|
||||||
super().__init__()
|
|
||||||
self.d_model = d_model
|
|
||||||
self.max_depth = max_depth
|
|
||||||
|
|
||||||
self.depth_embedding = nn.Embedding(max_depth, d_model)
|
|
||||||
self.sibling_embedding = nn.Embedding(max_depth, d_model)
|
|
||||||
self.combine = nn.Linear(d_model * 2, d_model)
|
|
||||||
|
|
||||||
self._initialize_embeddings()
|
|
||||||
|
|
||||||
def _initialize_embeddings(self):
|
|
||||||
position = torch.arange(self.max_depth).unsqueeze(1).float()
|
|
||||||
div_term = torch.exp(torch.arange(0, self.d_model, 2).float() *
|
|
||||||
(-math.log(10000.0) / self.d_model))
|
|
||||||
|
|
||||||
pe = torch.zeros(self.max_depth, self.d_model)
|
|
||||||
pe[:, 0::2] = torch.sin(position * div_term)
|
|
||||||
pe[:, 1::2] = torch.cos(position * div_term)
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
self.depth_embedding.weight.copy_(pe)
|
|
||||||
self.sibling_embedding.weight.copy_(pe)
|
|
||||||
|
|
||||||
def forward(self, input_ids: torch.Tensor, ast_nodes_batch: List[List[ASTNodeInfo]]) -> torch.Tensor:
|
|
||||||
"""Process batched input with corresponding AST nodes."""
|
|
||||||
batch_size, seq_len = input_ids.shape
|
|
||||||
device = input_ids.device
|
|
||||||
embeddings = torch.zeros((batch_size, seq_len, self.d_model), device=device)
|
|
||||||
|
|
||||||
# Process each item in the batch
|
|
||||||
for batch_idx in range(batch_size):
|
|
||||||
ast_nodes = ast_nodes_batch[batch_idx]
|
|
||||||
# Process each position in the sequence
|
|
||||||
for i in range(seq_len):
|
|
||||||
containing_nodes = [
|
|
||||||
node for node in ast_nodes
|
|
||||||
if node.start_token_idx <= i < node.end_token_idx
|
|
||||||
]
|
|
||||||
|
|
||||||
if containing_nodes:
|
|
||||||
node = max(containing_nodes, key=lambda n: n.depth)
|
|
||||||
depth = min(node.depth, self.max_depth - 1)
|
|
||||||
sibling_pos = min(node.sibling_pos, self.max_depth - 1)
|
|
||||||
|
|
||||||
depth_emb = self.depth_embedding(torch.tensor(depth, device=device))
|
|
||||||
sibling_emb = self.sibling_embedding(torch.tensor(sibling_pos, device=device))
|
|
||||||
embeddings[batch_idx, i] = self.combine(torch.cat([depth_emb, sibling_emb]))
|
|
||||||
|
|
||||||
return embeddings
|
|
||||||
|
|
||||||
class TreeCodeBERTForPreTraining(RobertaForMaskedLM):
|
|
||||||
"""CodeBERT model enhanced with normalized embedding weights for stable training."""
|
|
||||||
|
|
||||||
def __init__(self, config: RobertaConfig, max_depth: int = 32, max_seq_length: int = 512):
|
|
||||||
super().__init__(config)
|
|
||||||
|
|
||||||
self.tree_pos_embeddings = TreePositionalEmbedding(
|
|
||||||
d_model=config.hidden_size,
|
|
||||||
max_depth=max_depth
|
|
||||||
)
|
|
||||||
|
|
||||||
self.seq_pos_embeddings = nn.Embedding(max_seq_length, config.hidden_size)
|
|
||||||
|
|
||||||
# Initialize sequential position embeddings with sinusoidal pattern
|
|
||||||
position = torch.arange(max_seq_length).unsqueeze(1)
|
|
||||||
div_term = torch.exp(torch.arange(0, config.hidden_size, 2) * (-math.log(10000.0) / config.hidden_size))
|
|
||||||
pe = torch.zeros(max_seq_length, config.hidden_size)
|
|
||||||
pe[:, 0::2] = torch.sin(position * div_term)
|
|
||||||
pe[:, 1::2] = torch.cos(position * div_term)
|
|
||||||
self.seq_pos_embeddings.weight.data.copy_(pe)
|
|
||||||
|
|
||||||
# Initialize weights with small random values around 0
|
|
||||||
self.alpha = nn.Parameter(torch.randn(1) * 0.02)
|
|
||||||
self.beta = nn.Parameter(torch.randn(1) * 0.02)
|
|
||||||
self.gamma = nn.Parameter(torch.randn(1) * 0.02)
|
|
||||||
|
|
||||||
self.embedding_combination_layer_norm = nn.LayerNorm(config.hidden_size)
|
|
||||||
self.final_layer_norm = nn.LayerNorm(config.hidden_size)
|
|
||||||
|
|
||||||
# Add dropout for regularization
|
|
||||||
self.embedding_dropout = nn.Dropout(config.hidden_dropout_prob)
|
|
||||||
|
|
||||||
def get_normalized_weights(self) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
Compute softmax-normalized weights for embedding combination.
|
|
||||||
Returns tensor of shape (3,) containing normalized [alpha, beta, gamma].
|
|
||||||
"""
|
|
||||||
weights = torch.stack([self.alpha, self.beta, self.gamma])
|
|
||||||
return torch.softmax(weights, dim=0)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
input_ids: torch.Tensor,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
labels: Optional[torch.Tensor] = None,
|
|
||||||
ast_nodes: Optional[List[List[ASTNodeInfo]]] = None,
|
|
||||||
output_attentions: bool = False,
|
|
||||||
**kwargs
|
|
||||||
) -> Dict[str, torch.Tensor]:
|
|
||||||
# Move tensors to device
|
|
||||||
device = input_ids.device
|
|
||||||
|
|
||||||
# Get embeddings
|
|
||||||
token_embeddings = self.roberta.embeddings.word_embeddings(input_ids)
|
|
||||||
|
|
||||||
seq_positions = torch.arange(input_ids.size(1), device=device)
|
|
||||||
seq_embeddings = self.seq_pos_embeddings(seq_positions).unsqueeze(0).expand(input_ids.size(0), -1, -1)
|
|
||||||
|
|
||||||
# Get normalized weights
|
|
||||||
norm_weights = self.get_normalized_weights()
|
|
||||||
|
|
||||||
# Combine embeddings based on presence of AST nodes
|
|
||||||
if ast_nodes is not None:
|
|
||||||
tree_embeddings = self.tree_pos_embeddings(input_ids, ast_nodes)
|
|
||||||
combined_embeddings = (
|
|
||||||
norm_weights[0] * token_embeddings +
|
|
||||||
norm_weights[1] * tree_embeddings +
|
|
||||||
norm_weights[2] * seq_embeddings
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Redistribute tree weight to other components when no AST available
|
|
||||||
token_seq_weights = torch.softmax(torch.stack([self.alpha, self.gamma]), dim=0)
|
|
||||||
combined_embeddings = (
|
|
||||||
token_seq_weights[0] * token_embeddings +
|
|
||||||
token_seq_weights[1] * seq_embeddings
|
|
||||||
)
|
|
||||||
|
|
||||||
# Apply layer normalization and dropout
|
|
||||||
combined_embeddings = self.embedding_combination_layer_norm(combined_embeddings)
|
|
||||||
combined_embeddings = self.embedding_dropout(combined_embeddings)
|
|
||||||
combined_embeddings = self.final_layer_norm(combined_embeddings)
|
|
||||||
|
|
||||||
# Forward pass through transformer
|
|
||||||
outputs = self.roberta(
|
|
||||||
inputs_embeds=combined_embeddings,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
**kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
sequence_output = outputs[0]
|
|
||||||
prediction_scores = self.lm_head(sequence_output)
|
|
||||||
|
|
||||||
# Calculate loss if labels provided
|
|
||||||
masked_lm_loss = None
|
|
||||||
if labels is not None:
|
|
||||||
loss_fct = nn.CrossEntropyLoss()
|
|
||||||
masked_lm_loss = loss_fct(
|
|
||||||
prediction_scores.view(-1, self.config.vocab_size),
|
|
||||||
labels.view(-1)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get normalized weights for logging
|
|
||||||
norm_weights_cpu = norm_weights.detach().cpu()
|
|
||||||
|
|
||||||
return {
|
|
||||||
"loss": masked_lm_loss,
|
|
||||||
"logits": prediction_scores,
|
|
||||||
"hidden_states": outputs.hidden_states if hasattr(outputs, "hidden_states") else None,
|
|
||||||
"attentions": outputs.attentions,
|
|
||||||
"embedding_weights": {
|
|
||||||
"token": norm_weights_cpu[0].item(),
|
|
||||||
"tree": norm_weights_cpu[1].item(),
|
|
||||||
"sequential": norm_weights_cpu[2].item()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
def main() -> None:
|
|
||||||
set_deterministic_mode()
|
|
||||||
|
|
||||||
current_dir = Path(__file__).parent
|
|
||||||
output_dir = setup_directories(current_dir, model_name='tree')
|
|
||||||
config = load_config(current_dir / 'tmp_config.json')
|
|
||||||
|
|
||||||
setup_wandb(config, model_name='tree')
|
|
||||||
set_seed(config['seed'])
|
|
||||||
device = setup_device()
|
|
||||||
|
|
||||||
# Load tokenizer
|
|
||||||
tokenizer = RobertaTokenizer.from_pretrained('microsoft/codebert-base')
|
|
||||||
|
|
||||||
# Create dataloaders with tree dataset
|
|
||||||
processed_data_dir = current_dir.parent / 'data' / 'processed-python'
|
|
||||||
train_dataloader, valid_dataloader = create_base_dataloaders(
|
|
||||||
processed_data_dir,
|
|
||||||
tokenizer,
|
|
||||||
config,
|
|
||||||
device,
|
|
||||||
dataset_class=PreprocessedTreeDataset
|
|
||||||
)
|
|
||||||
|
|
||||||
# Model setup
|
|
||||||
model_config = RobertaConfig.from_pretrained('microsoft/codebert-base')
|
|
||||||
model = TreeCodeBERTForPreTraining(model_config)
|
|
||||||
model = model.to(device)
|
|
||||||
|
|
||||||
# Training setup and run
|
|
||||||
num_training_steps = config['epochs'] * len(train_dataloader)
|
|
||||||
optimizer, scheduler = create_optimizer_and_scheduler(model, config, num_training_steps)
|
|
||||||
|
|
||||||
train_and_evaluate(
|
|
||||||
model, train_dataloader, valid_dataloader,
|
|
||||||
optimizer, scheduler, config, output_dir,
|
|
||||||
log_weights=True, device=device
|
|
||||||
)
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
306
code/src/training.py
Normal file
306
code/src/training.py
Normal file
@ -0,0 +1,306 @@
|
|||||||
|
import math
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, Optional
|
||||||
|
from transformers import RobertaConfig, AutoTokenizer, RobertaForMaskedLM
|
||||||
|
|
||||||
|
from training_utils import *
|
||||||
|
|
||||||
|
MODEL = 'base-custom' # 'base' or 'tree' or 'base-custom'
|
||||||
|
DATASET = 'filtered-python' # 'processed-python-small' or 'processed-python'
|
||||||
|
|
||||||
|
class TreePositionalEmbedding(nn.Module):
|
||||||
|
"""Improved tree-aware positional embeddings that work directly with depth and sibling tensors."""
|
||||||
|
|
||||||
|
def __init__(self, d_model: int = 768, max_depth: int = 32, dropout: float = 0.1):
|
||||||
|
super().__init__()
|
||||||
|
self.d_model = d_model
|
||||||
|
self.max_depth = max_depth
|
||||||
|
|
||||||
|
# Separate embeddings for different features
|
||||||
|
self.depth_embedding = nn.Embedding(max_depth, d_model)
|
||||||
|
self.sibling_embedding = nn.Embedding(max_depth, d_model)
|
||||||
|
|
||||||
|
# Improved projection layers
|
||||||
|
self.node_projection = nn.Sequential(
|
||||||
|
nn.Linear(d_model * 2, d_model * 2),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.Linear(d_model * 2, d_model),
|
||||||
|
nn.Dropout(dropout)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Layer norm for stability
|
||||||
|
self.layer_norm = nn.LayerNorm(d_model)
|
||||||
|
|
||||||
|
self._initialize_embeddings()
|
||||||
|
|
||||||
|
def _initialize_embeddings(self):
|
||||||
|
std = 0.02
|
||||||
|
for embedding in [self.depth_embedding, self.sibling_embedding]:
|
||||||
|
nn.init.normal_(embedding.weight, mean=0.0, std=std)
|
||||||
|
|
||||||
|
# Initialize projection layers
|
||||||
|
for layer in self.node_projection:
|
||||||
|
if isinstance(layer, nn.Linear):
|
||||||
|
nn.init.normal_(layer.weight, mean=0.0, std=std)
|
||||||
|
nn.init.zeros_(layer.bias)
|
||||||
|
|
||||||
|
def forward(self, depths: torch.Tensor, sibling_idxs: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
depths: Tensor of shape [batch_size, seq_len] containing depth values
|
||||||
|
sibling_idxs: Tensor of shape [batch_size, seq_len] containing sibling positions
|
||||||
|
Returns:
|
||||||
|
Tensor of shape [batch_size, seq_len, d_model] containing tree-aware embeddings
|
||||||
|
"""
|
||||||
|
# Clamp values to max_depth
|
||||||
|
depths = torch.clamp(depths, 0, self.max_depth - 1)
|
||||||
|
sibling_idxs = torch.clamp(sibling_idxs, 0, self.max_depth - 1)
|
||||||
|
|
||||||
|
# Get embeddings for each feature
|
||||||
|
depth_embeddings = self.depth_embedding(depths) # [batch, seq_len, d_model]
|
||||||
|
sibling_embeddings = self.sibling_embedding(sibling_idxs) # [batch, seq_len, d_model]
|
||||||
|
|
||||||
|
# Combine features
|
||||||
|
combined = torch.cat([depth_embeddings, sibling_embeddings], dim=-1)
|
||||||
|
embeddings = self.node_projection(combined)
|
||||||
|
|
||||||
|
# Apply layer norm
|
||||||
|
normalized_embeddings = self.layer_norm(embeddings)
|
||||||
|
|
||||||
|
return normalized_embeddings
|
||||||
|
|
||||||
|
class TreeCodeBERTForPreTraining(RobertaForMaskedLM):
|
||||||
|
"""CodeBERT model enhanced with tree-structural information."""
|
||||||
|
|
||||||
|
def __init__(self, config: RobertaConfig, max_depth: int = 32, max_seq_length: int = 512, base_custom=False):
|
||||||
|
super().__init__(config)
|
||||||
|
self.base_custom = base_custom
|
||||||
|
|
||||||
|
self.tree_pos_embeddings = TreePositionalEmbedding(
|
||||||
|
d_model=config.hidden_size,
|
||||||
|
max_depth=max_depth,
|
||||||
|
dropout=config.hidden_dropout_prob
|
||||||
|
)
|
||||||
|
|
||||||
|
self.seq_pos_embeddings = nn.Embedding(max_seq_length, config.hidden_size)
|
||||||
|
|
||||||
|
# Initialize sequential position embeddings with sinusoidal pattern
|
||||||
|
position = torch.arange(max_seq_length).unsqueeze(1)
|
||||||
|
div_term = torch.exp(torch.arange(0, config.hidden_size, 2) * (-math.log(10000.0) / config.hidden_size))
|
||||||
|
pe = torch.zeros(max_seq_length, config.hidden_size)
|
||||||
|
pe[:, 0::2] = torch.sin(position * div_term)
|
||||||
|
pe[:, 1::2] = torch.cos(position * div_term)
|
||||||
|
self.seq_pos_embeddings.weight.data.copy_(pe)
|
||||||
|
|
||||||
|
# Initialize embedding weights equally
|
||||||
|
if base_custom:
|
||||||
|
initial_weight = math.log(1/2)
|
||||||
|
self.embedding_weights = nn.Parameter(torch.full((2,), initial_weight))
|
||||||
|
else:
|
||||||
|
initial_weight = math.log(1/3) # log(1/3) because we use softmax later
|
||||||
|
self.embedding_weights = nn.Parameter(torch.full((3,), initial_weight))
|
||||||
|
|
||||||
|
# Layer norms for embedding combination
|
||||||
|
self.pre_combination_norm = nn.LayerNorm(config.hidden_size)
|
||||||
|
self.post_combination_norm = nn.LayerNorm(config.hidden_size)
|
||||||
|
|
||||||
|
def get_normalized_weights(self):
|
||||||
|
"""Get softmaxed weights for embedding combination."""
|
||||||
|
return torch.softmax(self.embedding_weights, dim=0)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
labels: Optional[torch.Tensor] = None,
|
||||||
|
depths: Optional[torch.Tensor] = None,
|
||||||
|
sibling_idxs: Optional[torch.Tensor] = None,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
**kwargs
|
||||||
|
) -> Dict[str, torch.Tensor]:
|
||||||
|
device = input_ids.device
|
||||||
|
|
||||||
|
# Get normalized weights for embedding combination and calculate regularization
|
||||||
|
weights = self.get_normalized_weights()
|
||||||
|
|
||||||
|
# Calculate weight variance regularization
|
||||||
|
# We want weights to remain somewhat balanced, so penalize high variance
|
||||||
|
weight_variance = torch.var(weights)
|
||||||
|
weight_reg_loss = 0.1 * weight_variance # Adjustable coefficient
|
||||||
|
|
||||||
|
# Add L2 regularization to prevent any weight from getting too close to 1
|
||||||
|
# This helps maintain a more balanced contribution from each embedding type
|
||||||
|
max_weight_penalty = torch.sum(torch.relu(weights - 0.8) ** 2) # Penalize weights > 0.8
|
||||||
|
l2_reg_loss = 0.05 * max_weight_penalty # Adjustable coefficient
|
||||||
|
|
||||||
|
# Get token embeddings
|
||||||
|
token_embeddings = self.roberta.embeddings.word_embeddings(input_ids)
|
||||||
|
token_embeddings = self.pre_combination_norm(token_embeddings)
|
||||||
|
|
||||||
|
# Get sequential position embeddings
|
||||||
|
seq_positions = torch.arange(input_ids.size(1), device=device)
|
||||||
|
seq_embeddings = self.seq_pos_embeddings(seq_positions).unsqueeze(0).expand(input_ids.size(0), -1, -1)
|
||||||
|
|
||||||
|
# Get tree positional embeddings if tree information is provided
|
||||||
|
if depths is not None and sibling_idxs is not None and not self.base_custom:
|
||||||
|
tree_embeddings = self.tree_pos_embeddings(depths, sibling_idxs)
|
||||||
|
else:
|
||||||
|
tree_embeddings = torch.zeros_like(token_embeddings)
|
||||||
|
|
||||||
|
# Combine all embeddings using learned weights
|
||||||
|
if self.base_custom:
|
||||||
|
combined_embeddings = (
|
||||||
|
weights[0] * token_embeddings +
|
||||||
|
weights[1] * seq_embeddings
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
combined_embeddings = (
|
||||||
|
weights[0] * token_embeddings +
|
||||||
|
weights[1] * tree_embeddings +
|
||||||
|
weights[2] * seq_embeddings
|
||||||
|
)
|
||||||
|
|
||||||
|
combined_embeddings = self.post_combination_norm(combined_embeddings)
|
||||||
|
|
||||||
|
# Forward pass through base model
|
||||||
|
outputs = self.roberta(
|
||||||
|
inputs_embeds=combined_embeddings,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
sequence_output = outputs[0]
|
||||||
|
prediction_scores = self.lm_head(sequence_output)
|
||||||
|
|
||||||
|
# Calculate MLM loss if labels are provided
|
||||||
|
masked_lm_loss = None
|
||||||
|
if labels is not None:
|
||||||
|
loss_fct = nn.CrossEntropyLoss()
|
||||||
|
masked_lm_loss = loss_fct(
|
||||||
|
prediction_scores.view(-1, self.config.vocab_size),
|
||||||
|
labels.view(-1)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add regularization losses to final loss
|
||||||
|
if masked_lm_loss is not None:
|
||||||
|
final_loss = masked_lm_loss + weight_reg_loss + l2_reg_loss
|
||||||
|
else:
|
||||||
|
final_loss = weight_reg_loss + l2_reg_loss
|
||||||
|
else:
|
||||||
|
final_loss = None
|
||||||
|
|
||||||
|
# Prepare embedding weights for logging
|
||||||
|
weights_cpu = weights.detach().cpu()
|
||||||
|
if self.base_custom:
|
||||||
|
embedding_weights = {
|
||||||
|
"token": weights_cpu[0].item(),
|
||||||
|
"sequential": weights_cpu[1].item()
|
||||||
|
}
|
||||||
|
reg_metrics = {
|
||||||
|
"weight_variance": weight_variance.item(),
|
||||||
|
"max_weight_penalty": max_weight_penalty.item(),
|
||||||
|
"weight_reg_loss": weight_reg_loss.item(),
|
||||||
|
"l2_reg_loss": l2_reg_loss.item()
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
embedding_weights = {
|
||||||
|
"token": weights_cpu[0].item(),
|
||||||
|
"tree": weights_cpu[1].item(),
|
||||||
|
"sequential": weights_cpu[2].item()
|
||||||
|
}
|
||||||
|
reg_metrics = {
|
||||||
|
"weight_variance": weight_variance.item(),
|
||||||
|
"max_weight_penalty": max_weight_penalty.item(),
|
||||||
|
"weight_reg_loss": weight_reg_loss.item(),
|
||||||
|
"l2_reg_loss": l2_reg_loss.item()
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
"loss": final_loss,
|
||||||
|
"logits": prediction_scores,
|
||||||
|
"hidden_states": outputs.hidden_states if hasattr(outputs, "hidden_states") else None,
|
||||||
|
"attentions": outputs.attentions if output_attentions else None,
|
||||||
|
"embedding_weights": embedding_weights,
|
||||||
|
"regularization_metrics": reg_metrics
|
||||||
|
}
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
current_dir = Path(__file__).parent
|
||||||
|
output_dir = setup_directories(current_dir, model_name=MODEL)
|
||||||
|
config = load_config(current_dir / 'config.json')
|
||||||
|
|
||||||
|
set_deterministic_mode(config['seed'])
|
||||||
|
|
||||||
|
setup_wandb(config, model_name=MODEL, script_path=__file__)
|
||||||
|
set_seed(config['seed'])
|
||||||
|
device = setup_device()
|
||||||
|
|
||||||
|
# Load tokenizer
|
||||||
|
logger.info("Loading tokenizer...")
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained('microsoft/codebert-base', use_fast=True)
|
||||||
|
|
||||||
|
# Create dataloaders with tree dataset
|
||||||
|
logger.info("Creating dataloaders...")
|
||||||
|
processed_data_dir = current_dir.parent / 'data' / DATASET
|
||||||
|
train_dataloader, valid_dataloader = create_base_dataloaders(
|
||||||
|
processed_data_dir,
|
||||||
|
tokenizer,
|
||||||
|
config,
|
||||||
|
base_training=True if MODEL == 'base' else False
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize model config
|
||||||
|
logger.info("Initializing model...")
|
||||||
|
model_config = RobertaConfig.from_pretrained('microsoft/codebert-base')
|
||||||
|
|
||||||
|
# Update W&B
|
||||||
|
wandb.config.update({'model_config': model_config.__dict__})
|
||||||
|
|
||||||
|
# Initialize model
|
||||||
|
if MODEL == 'tree':
|
||||||
|
model = TreeCodeBERTForPreTraining(model_config, max_depth=32, max_seq_length=512)
|
||||||
|
elif MODEL == 'base':
|
||||||
|
model = RobertaForMaskedLM(model_config)
|
||||||
|
elif MODEL == 'base-custom':
|
||||||
|
model = TreeCodeBERTForPreTraining(model_config, max_depth=32, max_seq_length=512, base_custom=True)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid model type: {MODEL}")
|
||||||
|
model = model.to(device)
|
||||||
|
logging.info(f"Model initialized: {MODEL}")
|
||||||
|
|
||||||
|
# Model precision
|
||||||
|
if device.type == 'cuda' and config['fp16']:
|
||||||
|
model = model.half()
|
||||||
|
logging.info("Model set to half precision.")
|
||||||
|
for param in model.parameters():
|
||||||
|
logging.info(f"Parameter dtype: {param.dtype}")
|
||||||
|
break
|
||||||
|
|
||||||
|
# Training setup
|
||||||
|
logger.info("Setting up optimizer and scheduler...")
|
||||||
|
optimizer, scheduler = create_optimizer_and_scheduler(model, config)
|
||||||
|
|
||||||
|
# Train and evaluate model
|
||||||
|
logger.info("Starting training...")
|
||||||
|
train_and_evaluate(
|
||||||
|
model=model,
|
||||||
|
train_dataloader=train_dataloader,
|
||||||
|
valid_dataloader=valid_dataloader,
|
||||||
|
optimizer=optimizer,
|
||||||
|
scheduler=scheduler,
|
||||||
|
config=config,
|
||||||
|
device=device,
|
||||||
|
output_dir=output_dir
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info("Training completed!")
|
||||||
|
final_output_dir = output_dir / 'final-model'
|
||||||
|
model.save_pretrained(final_output_dir)
|
||||||
|
tokenizer.save_pretrained(final_output_dir)
|
||||||
|
logger.info(f"Model saved to {final_output_dir}")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
@ -8,38 +8,41 @@ import wandb
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Any, Tuple, Type, Optional
|
from typing import Dict, Any, Tuple, Optional
|
||||||
from datasets import load_from_disk, concatenate_datasets
|
from datasets import load_dataset
|
||||||
from torch.optim import AdamW
|
from torch.optim import AdamW
|
||||||
from torch.utils.data import DataLoader, Dataset
|
from torch.utils.data import DataLoader
|
||||||
from transformers import (
|
from transformers import (
|
||||||
PreTrainedModel,
|
PreTrainedModel,
|
||||||
get_linear_schedule_with_warmup,
|
get_constant_schedule_with_warmup,
|
||||||
DataCollatorForLanguageModeling,
|
DataCollatorForLanguageModeling,
|
||||||
PreTrainedTokenizer
|
PreTrainedTokenizer,
|
||||||
|
RobertaForMaskedLM
|
||||||
)
|
)
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from parse_dataset import ASTNodeInfo
|
|
||||||
|
|
||||||
# Configure logging
|
# Configure logging
|
||||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format='%(asctime)s - %(levelname)s - %(message)s',
|
||||||
|
datefmt='%Y-%m-%d %H:%M:%S'
|
||||||
|
)
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
def set_deterministic_mode() -> None:
|
def set_deterministic_mode(seed: int) -> None:
|
||||||
"""Enable deterministic mode for reproducibility."""
|
"""Enable deterministic mode for reproducibility."""
|
||||||
# Set Python random seed
|
# Set Python random seed
|
||||||
random.seed(42)
|
random.seed(seed)
|
||||||
|
|
||||||
# Set NumPy random seed
|
# Set NumPy random seed
|
||||||
np.random.seed(42)
|
np.random.seed(seed)
|
||||||
|
|
||||||
# Set PyTorch random seed
|
# Set PyTorch random seed
|
||||||
torch.manual_seed(42)
|
torch.manual_seed(seed)
|
||||||
|
|
||||||
# Set CUDA random seed
|
# Set CUDA random seed
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
torch.cuda.manual_seed_all(42)
|
torch.cuda.manual_seed_all(seed)
|
||||||
|
|
||||||
# Enable deterministic operations for PyTorch 2.x
|
# Enable deterministic operations for PyTorch 2.x
|
||||||
torch.use_deterministic_algorithms(True)
|
torch.use_deterministic_algorithms(True)
|
||||||
@ -47,7 +50,7 @@ def set_deterministic_mode() -> None:
|
|||||||
torch.backends.cudnn.benchmark = False
|
torch.backends.cudnn.benchmark = False
|
||||||
|
|
||||||
# Set environment variables for reproducibility
|
# Set environment variables for reproducibility
|
||||||
os.environ['PYTHONHASHSEED'] = '42'
|
os.environ['PYTHONHASHSEED'] = f'{seed}'
|
||||||
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8' # Required for CUDA >= 10.2
|
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8' # Required for CUDA >= 10.2
|
||||||
|
|
||||||
def get_device_info() -> Dict[str, Any]:
|
def get_device_info() -> Dict[str, Any]:
|
||||||
@ -70,53 +73,6 @@ def get_device_info() -> Dict[str, Any]:
|
|||||||
|
|
||||||
return info
|
return info
|
||||||
|
|
||||||
class TreeDataCollator(DataCollatorForLanguageModeling):
|
|
||||||
"""Custom data collator that handles both MLM and optional AST node information."""
|
|
||||||
|
|
||||||
def torch_call(self, examples):
|
|
||||||
# Check if we have AST nodes
|
|
||||||
has_ast = 'ast_nodes' in examples[0]
|
|
||||||
|
|
||||||
# Extract AST nodes before MLM processing if they exist
|
|
||||||
ast_nodes = None
|
|
||||||
if has_ast:
|
|
||||||
ast_nodes = [e.pop('ast_nodes') for e in examples]
|
|
||||||
|
|
||||||
# Process normal MLM features
|
|
||||||
batch = super().torch_call(examples)
|
|
||||||
|
|
||||||
# Add AST nodes back to batch if they existed
|
|
||||||
if has_ast:
|
|
||||||
batch['ast_nodes'] = ast_nodes
|
|
||||||
|
|
||||||
return batch
|
|
||||||
|
|
||||||
def __call__(self, examples):
|
|
||||||
return self.torch_call(examples)
|
|
||||||
|
|
||||||
class PreprocessedBaseDataset(Dataset):
|
|
||||||
"""Dataset that uses pre-processed data without AST information."""
|
|
||||||
def __init__(self, dataset: Any):
|
|
||||||
self.dataset = dataset
|
|
||||||
|
|
||||||
def __len__(self) -> int:
|
|
||||||
return len(self.dataset)
|
|
||||||
|
|
||||||
def __getitem__(self, idx: int) -> Dict[str, Any]:
|
|
||||||
item = self.dataset[idx]
|
|
||||||
return {
|
|
||||||
'input_ids': torch.tensor(item['input_ids']),
|
|
||||||
'attention_mask': torch.tensor(item['attention_mask']),
|
|
||||||
'labels': torch.tensor(item['input_ids']).clone()
|
|
||||||
}
|
|
||||||
|
|
||||||
class PreprocessedTreeDataset(PreprocessedBaseDataset):
|
|
||||||
"""Dataset that includes AST information."""
|
|
||||||
def __getitem__(self, idx: int) -> Dict[str, Any]:
|
|
||||||
item = super().__getitem__(idx)
|
|
||||||
item['ast_nodes'] = [ASTNodeInfo.from_dict(node) for node in self.dataset[idx]['ast_nodes']]
|
|
||||||
return item
|
|
||||||
|
|
||||||
def set_seed(seed: int) -> None:
|
def set_seed(seed: int) -> None:
|
||||||
"""Set random seeds for reproducibility."""
|
"""Set random seeds for reproducibility."""
|
||||||
random.seed(seed)
|
random.seed(seed)
|
||||||
@ -128,6 +84,7 @@ def set_seed(seed: int) -> None:
|
|||||||
def setup_wandb(
|
def setup_wandb(
|
||||||
config: Dict[str, Any],
|
config: Dict[str, Any],
|
||||||
model_name: str = 'base',
|
model_name: str = 'base',
|
||||||
|
script_path: Path = None,
|
||||||
device_info: Optional[Dict[str, Any]] = None
|
device_info: Optional[Dict[str, Any]] = None
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize W&B logging with reproducibility information."""
|
"""Initialize W&B logging with reproducibility information."""
|
||||||
@ -136,17 +93,21 @@ def setup_wandb(
|
|||||||
# Add reproducibility settings to config
|
# Add reproducibility settings to config
|
||||||
full_config = {
|
full_config = {
|
||||||
**config,
|
**config,
|
||||||
'random_seed': 42,
|
|
||||||
'deterministic_mode': True,
|
'deterministic_mode': True,
|
||||||
'device_info': device_info or get_device_info(),
|
'device_info': device_info or get_device_info(),
|
||||||
}
|
}
|
||||||
|
|
||||||
wandb.init(
|
wandb.init(
|
||||||
project='codebert-training-test',
|
project='new-codebert-tree-training',
|
||||||
name=f'{model_name}_{curr_time}',
|
name=f'{model_name}_{curr_time}',
|
||||||
config=full_config
|
config=full_config
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Upload the training script
|
||||||
|
wandb.save(script_path)
|
||||||
|
wandb.save(Path(script_path).parent / 'training_utils.py')
|
||||||
|
logger.info(f'Saving script {script_path} to W&B')
|
||||||
|
|
||||||
def setup_directories(current_dir: Path, model_name: str = 'base') -> Path:
|
def setup_directories(current_dir: Path, model_name: str = 'base') -> Path:
|
||||||
"""Create output directories for model artifacts."""
|
"""Create output directories for model artifacts."""
|
||||||
curr_time: str = datetime.datetime.now().strftime('%Y-%m-%d_%H:%M')
|
curr_time: str = datetime.datetime.now().strftime('%Y-%m-%d_%H:%M')
|
||||||
@ -172,7 +133,6 @@ def setup_device() -> torch.device:
|
|||||||
def create_optimizer_and_scheduler(
|
def create_optimizer_and_scheduler(
|
||||||
model: PreTrainedModel,
|
model: PreTrainedModel,
|
||||||
config: Dict[str, Any],
|
config: Dict[str, Any],
|
||||||
num_training_steps: int
|
|
||||||
) -> Tuple[AdamW, Any]:
|
) -> Tuple[AdamW, Any]:
|
||||||
"""Create optimizer and learning rate scheduler."""
|
"""Create optimizer and learning rate scheduler."""
|
||||||
optimizer = AdamW(
|
optimizer = AdamW(
|
||||||
@ -181,270 +141,215 @@ def create_optimizer_and_scheduler(
|
|||||||
weight_decay=config['weight_decay']
|
weight_decay=config['weight_decay']
|
||||||
)
|
)
|
||||||
|
|
||||||
scheduler = get_linear_schedule_with_warmup(
|
scheduler = get_constant_schedule_with_warmup(
|
||||||
optimizer,
|
optimizer,
|
||||||
num_warmup_steps=config['warmup_steps'],
|
num_warmup_steps=config['warmup_steps'],
|
||||||
num_training_steps=num_training_steps
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return optimizer, scheduler
|
return optimizer, scheduler
|
||||||
|
|
||||||
def evaluate(model: PreTrainedModel, dataloader: DataLoader) -> Tuple[float, float]:
|
def evaluate(
|
||||||
"""Evaluate model on validation set."""
|
model: RobertaForMaskedLM,
|
||||||
|
dataloader: DataLoader,
|
||||||
|
device: torch.device
|
||||||
|
) -> Tuple[float, float]:
|
||||||
|
"""Evaluate the model on the validation set.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (average loss, accuracy on masked tokens)
|
||||||
|
"""
|
||||||
model.eval()
|
model.eval()
|
||||||
total_loss: float = 0.0
|
total_loss = 0
|
||||||
total_acc: float = 0.0
|
total_correct = 0
|
||||||
|
total_predictions = 0
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for batch in tqdm(dataloader, desc='Validation'):
|
for batch in tqdm(dataloader, desc='Evaluating'):
|
||||||
|
batch = {
|
||||||
|
k: v.to(device) if isinstance(v, torch.Tensor) else v
|
||||||
|
for k, v in batch.items()
|
||||||
|
}
|
||||||
outputs = model(**batch)
|
outputs = model(**batch)
|
||||||
total_loss += outputs.loss.item() if hasattr(outputs, 'loss') else outputs['loss'].item()
|
|
||||||
logits = outputs.logits if hasattr(outputs, 'logits') else outputs['logits']
|
|
||||||
total_acc += logits.argmax(dim=-1).eq(batch['labels']).sum().item()
|
|
||||||
|
|
||||||
avg_loss: float = total_loss / len(dataloader)
|
# Get loss
|
||||||
avg_acc: float = total_acc / len(dataloader.dataset)
|
total_loss += outputs['loss'].item()
|
||||||
model.train()
|
|
||||||
return avg_loss, avg_acc
|
# Calculate accuracy only on masked tokens
|
||||||
|
predictions = outputs['logits'].argmax(dim=-1) # [batch_size, seq_length]
|
||||||
|
labels = batch['labels'] # [batch_size, seq_length]
|
||||||
|
|
||||||
|
# Create mask for tokens we actually want to predict (ignore padding and unmasked tokens)
|
||||||
|
predict_mask = labels != -100 # -100 is the ignore index
|
||||||
|
|
||||||
|
# Calculate accuracy
|
||||||
|
correct = predictions[predict_mask] == labels[predict_mask]
|
||||||
|
total_correct += correct.sum().item()
|
||||||
|
total_predictions += predict_mask.sum().item()
|
||||||
|
|
||||||
|
avg_loss = total_loss / len(dataloader)
|
||||||
|
accuracy = total_correct / total_predictions if total_predictions > 0 else 0
|
||||||
|
|
||||||
|
return avg_loss, accuracy
|
||||||
|
|
||||||
def train_and_evaluate(
|
def train_and_evaluate(
|
||||||
model: PreTrainedModel,
|
model: RobertaForMaskedLM,
|
||||||
train_dataloader: DataLoader,
|
train_dataloader: DataLoader,
|
||||||
valid_dataloader: DataLoader,
|
valid_dataloader: DataLoader,
|
||||||
optimizer: AdamW,
|
optimizer: AdamW,
|
||||||
scheduler: Any,
|
scheduler: Any,
|
||||||
config: Dict[str, Any],
|
config: Dict[str, Any],
|
||||||
|
device: torch.device,
|
||||||
output_dir: Path,
|
output_dir: Path,
|
||||||
log_weights: bool = False,
|
|
||||||
device: torch.device = torch.device('cpu')
|
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Train and evaluate model with deterministic behavior and comprehensive logging."""
|
"""Train and evaluate the model."""
|
||||||
# Enable deterministic algorithms for PyTorch 2.5
|
|
||||||
torch.use_deterministic_algorithms(True)
|
|
||||||
|
|
||||||
num_training_steps: int = config['epochs'] * len(train_dataloader)
|
model.train()
|
||||||
best_valid_loss: float = float('inf')
|
best_valid_loss = float('inf')
|
||||||
|
|
||||||
# Save initial model state for reproducibility
|
for epoch in range(config['epochs']):
|
||||||
torch.save({
|
total_loss = 0
|
||||||
'model_state': model.state_dict(),
|
|
||||||
'optimizer_state': optimizer.state_dict(),
|
|
||||||
'scheduler_state': scheduler.state_dict(),
|
|
||||||
'rng_state': torch.get_rng_state(),
|
|
||||||
'cuda_rng_state': torch.cuda.get_rng_state() if torch.cuda.is_available() else None
|
|
||||||
}, output_dir / 'initial_state.pt')
|
|
||||||
|
|
||||||
with tqdm(total=num_training_steps, desc='Training') as pbar:
|
# Training loop
|
||||||
for epoch_idx in range(config['epochs']):
|
with tqdm(train_dataloader, desc=f'Epoch {epoch+1}/{config["epochs"]}') as pbar:
|
||||||
model.train()
|
for step, batch in enumerate(pbar):
|
||||||
epoch_loss = 0.0
|
|
||||||
epoch_steps = 0
|
|
||||||
|
|
||||||
for train_idx, train_batch in enumerate(train_dataloader):
|
|
||||||
# Move batch to device
|
# Move batch to device
|
||||||
train_batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v
|
batch = {
|
||||||
for k, v in train_batch.items()}
|
k: v.to(device) if isinstance(v, torch.Tensor) else v
|
||||||
|
for k, v in batch.items()
|
||||||
|
}
|
||||||
|
|
||||||
# Forward pass
|
# Forward pass
|
||||||
outputs = model(**train_batch)
|
outputs = model(**batch)
|
||||||
train_loss = outputs.loss if hasattr(outputs, 'loss') else outputs['loss']
|
loss = outputs['loss']
|
||||||
|
|
||||||
# Backward pass with deterministic behavior
|
# Backward pass
|
||||||
train_loss.backward()
|
loss.backward()
|
||||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
|
||||||
model.parameters(),
|
# Clip gradients
|
||||||
config['max_grad_norm']
|
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config['max_grad_norm'])
|
||||||
)
|
|
||||||
|
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
scheduler.step()
|
scheduler.step()
|
||||||
optimizer.zero_grad(set_to_none=False) # Use False for determinism
|
optimizer.zero_grad()
|
||||||
|
|
||||||
# Update metrics
|
# Update metrics
|
||||||
current_loss = train_loss.item()
|
total_loss += loss.item()
|
||||||
epoch_loss += current_loss
|
|
||||||
epoch_steps += 1
|
|
||||||
|
|
||||||
# Calculate global step
|
|
||||||
step = train_idx + len(train_dataloader) * epoch_idx
|
|
||||||
|
|
||||||
# Prepare logging dictionary
|
|
||||||
log_dict = {
|
|
||||||
'step': step,
|
|
||||||
'epoch': epoch_idx,
|
|
||||||
'train_loss': current_loss,
|
|
||||||
'gradient_norm': grad_norm.item(),
|
|
||||||
'learning_rate': scheduler.get_last_lr()[0],
|
|
||||||
}
|
|
||||||
|
|
||||||
# Add embedding weights if using tree model
|
|
||||||
if log_weights and 'embedding_weights' in outputs:
|
|
||||||
weights = outputs['embedding_weights']
|
|
||||||
log_dict.update({
|
|
||||||
'token_weight': weights['token'],
|
|
||||||
'tree_weight': weights['tree'],
|
|
||||||
'sequential_weight': weights['sequential']
|
|
||||||
})
|
|
||||||
pbar_dict = {
|
|
||||||
'epoch': f"{epoch_idx + 1}/{config['epochs']}",
|
|
||||||
'train_loss': f"{current_loss:.3f}",
|
|
||||||
'α': f"{weights['token']:.2f}",
|
|
||||||
'β': f"{weights['tree']:.2f}",
|
|
||||||
'γ': f"{weights['sequential']:.2f}"
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
pbar_dict = {
|
|
||||||
'epoch': f"{epoch_idx + 1}/{config['epochs']}",
|
|
||||||
'train_loss': f"{current_loss:.3f}"
|
|
||||||
}
|
|
||||||
|
|
||||||
# Log to wandb
|
|
||||||
wandb.log(log_dict)
|
|
||||||
|
|
||||||
# Update progress bar
|
# Update progress bar
|
||||||
pbar.update(1)
|
pbar.set_postfix({'loss': f'{loss.item():.4f}'})
|
||||||
pbar.set_postfix(pbar_dict)
|
|
||||||
|
|
||||||
# Periodic evaluation
|
# Log to wandb if configured
|
||||||
if (train_idx != 0 and train_idx % config['eval_every'] == 0) or train_idx == len(train_dataloader) - 1:
|
log_values = {
|
||||||
model.eval()
|
'train_loss': loss.item(),
|
||||||
valid_loss, valid_acc = evaluate(model, valid_dataloader)
|
'grad_norm': grad_norm.item(),
|
||||||
model.train()
|
'learning_rate': scheduler.get_last_lr()[0]
|
||||||
|
}
|
||||||
|
if 'embedding_weights' in outputs and 'tree' in outputs['embedding_weights']:
|
||||||
|
log_values.update({
|
||||||
|
'token_weight': outputs['embedding_weights']['token'],
|
||||||
|
'tree_weight': outputs['embedding_weights']['tree'],
|
||||||
|
'seq_weight': outputs['embedding_weights']['sequential'],
|
||||||
|
})
|
||||||
|
elif 'embedding_weights' in outputs:
|
||||||
|
log_values.update({
|
||||||
|
'token_weight': outputs['embedding_weights']['token'],
|
||||||
|
'seq_weight': outputs['embedding_weights']['sequential'],
|
||||||
|
})
|
||||||
|
if 'regularization_metrics' in outputs:
|
||||||
|
log_values.update({f"regularization/{k}": v for k, v in outputs['regularization_metrics'].items()})
|
||||||
|
wandb.log(log_values)
|
||||||
|
|
||||||
|
# Evaluate periodically
|
||||||
|
if step % config['eval_every'] == 0 and step > 0:
|
||||||
|
valid_loss, valid_acc = evaluate(model, valid_dataloader, device)
|
||||||
|
|
||||||
# Log validation metrics
|
# Log validation metrics
|
||||||
wandb.log({
|
wandb.log({'valid_loss': valid_loss, 'valid_accuracy': valid_acc})
|
||||||
'valid_loss': valid_loss,
|
|
||||||
'valid_acc': valid_acc,
|
|
||||||
'step': step,
|
|
||||||
'epoch': epoch_idx,
|
|
||||||
})
|
|
||||||
|
|
||||||
# Save checkpoint if best model
|
# Print validation metrics
|
||||||
|
print(f"\nValidation - Loss: {valid_loss:.4f}, Accuracy: {valid_acc:.4f}")
|
||||||
|
|
||||||
|
# Save best model
|
||||||
if valid_loss < best_valid_loss:
|
if valid_loss < best_valid_loss:
|
||||||
best_valid_loss = valid_loss
|
best_valid_loss = valid_loss
|
||||||
|
model.save_pretrained(output_dir / f'checkpoint-{epoch}-{step}')
|
||||||
|
|
||||||
# Save complete state
|
model.train() # Resume training mode
|
||||||
torch.save({
|
|
||||||
'epoch': epoch_idx,
|
|
||||||
'step': step,
|
|
||||||
'model_state': model.state_dict(),
|
|
||||||
'optimizer_state': optimizer.state_dict(),
|
|
||||||
'scheduler_state': scheduler.state_dict(),
|
|
||||||
'rng_state': torch.get_rng_state(),
|
|
||||||
'cuda_rng_state': torch.cuda.get_rng_state() if torch.cuda.is_available() else None,
|
|
||||||
'best_valid_loss': best_valid_loss,
|
|
||||||
'config': config,
|
|
||||||
'device_info': get_device_info()
|
|
||||||
}, output_dir / 'best_model.pt')
|
|
||||||
|
|
||||||
# Log best metrics
|
|
||||||
wandb.run.summary['best_valid_loss'] = best_valid_loss
|
|
||||||
wandb.run.summary['best_valid_acc'] = valid_acc
|
|
||||||
wandb.run.summary['best_epoch'] = epoch_idx
|
|
||||||
wandb.run.summary['best_step'] = step
|
|
||||||
|
|
||||||
# End of epoch logging
|
|
||||||
avg_epoch_loss = epoch_loss / epoch_steps
|
|
||||||
wandb.log({
|
|
||||||
'epoch': epoch_idx,
|
|
||||||
'epoch_avg_loss': avg_epoch_loss,
|
|
||||||
})
|
|
||||||
|
|
||||||
# Save end of epoch checkpoint
|
|
||||||
torch.save({
|
|
||||||
'epoch': epoch_idx,
|
|
||||||
'model_state': model.state_dict(),
|
|
||||||
'optimizer_state': optimizer.state_dict(),
|
|
||||||
'scheduler_state': scheduler.state_dict(),
|
|
||||||
'rng_state': torch.get_rng_state(),
|
|
||||||
'cuda_rng_state': torch.cuda.get_rng_state() if torch.cuda.is_available() else None,
|
|
||||||
'train_loss': avg_epoch_loss,
|
|
||||||
'config': config,
|
|
||||||
'device_info': get_device_info()
|
|
||||||
}, output_dir / f'checkpoint_epoch_{epoch_idx}.pt')
|
|
||||||
|
|
||||||
# End of training logging
|
|
||||||
wandb.run.summary['final_epoch'] = config['epochs'] - 1
|
|
||||||
wandb.run.summary['total_steps'] = num_training_steps
|
|
||||||
|
|
||||||
logger.info(f'Training completed. Best validation loss: {best_valid_loss:.4f}')
|
|
||||||
logger.info(f'Model checkpoints saved in {output_dir}')
|
|
||||||
|
|
||||||
def load_all_chunks(chunks_dir: Path) -> Any:
|
|
||||||
"""Load and concatenate all dataset chunks."""
|
|
||||||
logger.info(f"Loading dataset chunks from {chunks_dir}")
|
|
||||||
chunks = []
|
|
||||||
for chunk_path in sorted(chunks_dir.glob('chunk_*'))[:5]:
|
|
||||||
chunks.append(load_from_disk(str(chunk_path)))
|
|
||||||
dataset = concatenate_datasets(chunks)
|
|
||||||
logger.info(f"Loaded {len(dataset)} examples from {len(chunks)} chunks")
|
|
||||||
return dataset
|
|
||||||
|
|
||||||
def create_base_dataloaders(
|
def create_base_dataloaders(
|
||||||
processed_data_dir: Path,
|
processed_data_dir: Path,
|
||||||
tokenizer: PreTrainedTokenizer,
|
tokenizer: PreTrainedTokenizer,
|
||||||
config: Dict[str, Any],
|
config: Dict[str, Any],
|
||||||
device: torch.device,
|
base_training=False,
|
||||||
dataset_class: Type[Dataset] = PreprocessedBaseDataset,
|
|
||||||
) -> Tuple[DataLoader, DataLoader]:
|
) -> Tuple[DataLoader, DataLoader]:
|
||||||
"""Create reproducible dataloaders from pre-processed data."""
|
"""Create dataloaders from pre-processed parquet data."""
|
||||||
# Load chunks
|
|
||||||
chunks_dir = processed_data_dir / 'chunks'
|
# Load chunks using datasets library
|
||||||
full_dataset = load_all_chunks(chunks_dir)
|
chunks_dir = processed_data_dir / 'data'
|
||||||
|
dataset = load_dataset(
|
||||||
|
"parquet",
|
||||||
|
data_files=str(chunks_dir / "*.parquet"),
|
||||||
|
split="train"
|
||||||
|
)
|
||||||
|
|
||||||
|
contents = dataset['content'][:config['batch_size']]
|
||||||
|
with open("contents.txt", "w") as f:
|
||||||
|
for content in contents:
|
||||||
|
f.write(content)
|
||||||
|
f.write("\n\n\n")
|
||||||
|
f.write("#" * 80)
|
||||||
|
f.write("\n\n\n")
|
||||||
|
|
||||||
|
# Remove columns that are not needed
|
||||||
|
columns_to_remove = dataset.column_names
|
||||||
|
columns_to_remove.remove('input_ids')
|
||||||
|
columns_to_remove.remove('attention_mask')
|
||||||
|
|
||||||
|
if not base_training:
|
||||||
|
columns_to_remove.remove('depths')
|
||||||
|
columns_to_remove.remove('sibling_idxs')
|
||||||
|
|
||||||
|
dataset = dataset.remove_columns(columns_to_remove)
|
||||||
|
|
||||||
|
logging.info(f"Loaded dataset:\n{dataset}")
|
||||||
|
|
||||||
# Calculate split sizes
|
# Calculate split sizes
|
||||||
dataset_size = len(full_dataset)
|
dataset_size = len(dataset)
|
||||||
train_size = int(config['train_size'] * dataset_size)
|
train_size = int(config['train_size'] * dataset_size)
|
||||||
val_size = dataset_size - train_size
|
val_size = dataset_size - train_size
|
||||||
|
|
||||||
# Create splits using indices to avoid device issues
|
# Create splits
|
||||||
indices = torch.arange(dataset_size, device=device)
|
splits = dataset.train_test_split(
|
||||||
|
train_size=train_size,
|
||||||
|
test_size=val_size,
|
||||||
|
seed=config['seed']
|
||||||
|
)
|
||||||
|
train_dataset = splits['train']
|
||||||
|
valid_dataset = splits['test']
|
||||||
|
|
||||||
# Create a deterministic generator on the correct device
|
# Create data collator for MLM
|
||||||
generator = torch.Generator(device=device)
|
data_collator = DataCollatorForLanguageModeling(
|
||||||
generator.manual_seed(42)
|
|
||||||
|
|
||||||
# Shuffle indices deterministically
|
|
||||||
shuffled_indices = indices[torch.randperm(len(indices), device=device, generator=generator)]
|
|
||||||
train_indices = shuffled_indices[:train_size].cpu()
|
|
||||||
valid_indices = shuffled_indices[train_size:].cpu()
|
|
||||||
|
|
||||||
# Create train/validation splits using the shuffled indices
|
|
||||||
train_dataset = torch.utils.data.Subset(full_dataset, train_indices.tolist())
|
|
||||||
valid_dataset = torch.utils.data.Subset(full_dataset, valid_indices.tolist())
|
|
||||||
|
|
||||||
# Create dataset wrappers
|
|
||||||
train_dataset = dataset_class(train_dataset)
|
|
||||||
valid_dataset = dataset_class(valid_dataset)
|
|
||||||
|
|
||||||
logger.info(f"Created train dataset with {len(train_dataset)} samples")
|
|
||||||
logger.info(f"Created validation dataset with {len(valid_dataset)} samples")
|
|
||||||
|
|
||||||
# Use TreeDataCollator for both base and tree models
|
|
||||||
data_collator = TreeDataCollator(
|
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
mlm=True,
|
mlm=True,
|
||||||
mlm_probability=config['mlm_probability']
|
mlm_probability=config['mlm_probability']
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create train dataloader without generator
|
|
||||||
train_dataloader = DataLoader(
|
train_dataloader = DataLoader(
|
||||||
train_dataset,
|
train_dataset,
|
||||||
batch_size=config['batch'],
|
batch_size=config['batch_size'],
|
||||||
shuffle=False, # We already shuffled the data
|
shuffle=False, # We'll let the dataset handle shuffling
|
||||||
collate_fn=data_collator,
|
collate_fn=data_collator,
|
||||||
num_workers=0, # Use single worker for reproducibility
|
num_workers=0,
|
||||||
drop_last=True # Ensure consistent batch sizes
|
drop_last=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create validation dataloader
|
|
||||||
valid_dataloader = DataLoader(
|
valid_dataloader = DataLoader(
|
||||||
valid_dataset,
|
valid_dataset,
|
||||||
batch_size=config['batch'],
|
batch_size=config['batch_size'],
|
||||||
shuffle=False,
|
shuffle=False,
|
||||||
collate_fn=data_collator,
|
collate_fn=data_collator,
|
||||||
num_workers=0,
|
num_workers=0,
|
||||||
drop_last=True
|
drop_last=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
return train_dataloader, valid_dataloader
|
return train_dataloader, valid_dataloader
|
||||||
|
Loading…
Reference in New Issue
Block a user