on this commit there is a weird loss spike (unresolved)
This commit is contained in:
parent
0cd04e6131
commit
b2d65059da
@ -7,15 +7,18 @@ authors = [
|
||||
]
|
||||
dependencies = [
|
||||
"wandb==0.18.5",
|
||||
"torch==2.5.0",
|
||||
"torch==2.5.1",
|
||||
"tqdm==4.66.5",
|
||||
"tree-sitter==0.23.1",
|
||||
"transformers==4.45.2",
|
||||
"transformers[torch]>=4.46.3",
|
||||
"datasets==3.0.1",
|
||||
"huggingface-hub==0.26.0",
|
||||
"matplotlib==3.9.2",
|
||||
"scikit-learn==1.5.2",
|
||||
"seaborn==0.13.2",
|
||||
"tree-sitter-python==0.23.4",
|
||||
"ipykernel>=6.29.5",
|
||||
"ipywidgets>=8.1.5",
|
||||
]
|
||||
requires-python = "==3.11.*"
|
||||
readme = "README.md"
|
||||
@ -35,6 +38,5 @@ build-backend = "pdm.backend"
|
||||
distribution = true
|
||||
|
||||
[tool.pdm.scripts]
|
||||
run_training = {cmd = "src/train_codebert_mlm.py"}
|
||||
run_tree_training = {cmd = "src/train_tree_codebert_mlm.py"}
|
||||
parse_dataset = {cmd = "src/parse_dataset.py"}
|
||||
run_training = {cmd = "src/training.py"}
|
||||
|
@ -1,12 +1,13 @@
|
||||
{
|
||||
"seed": 42,
|
||||
"seed": 420,
|
||||
"mlm_probability": 0.15,
|
||||
"batch": 16,
|
||||
"epochs": 1,
|
||||
"eval_every": 20000,
|
||||
"batch_size": 32,
|
||||
"epochs": 5,
|
||||
"eval_every": 1000,
|
||||
"learning_rate": 5e-4,
|
||||
"weight_decay": 0.01,
|
||||
"weight_decay": 0.1,
|
||||
"max_grad_norm": 1.0,
|
||||
"warmup_steps": 20000,
|
||||
"train_size": 0.95
|
||||
"warmup_steps": 1000,
|
||||
"train_size": 0.95,
|
||||
"fp16": true
|
||||
}
|
4507
code/src/dataset_parsing_test.ipynb
Normal file
4507
code/src/dataset_parsing_test.ipynb
Normal file
File diff suppressed because it is too large
Load Diff
22
code/src/download_dataset.py
Normal file
22
code/src/download_dataset.py
Normal file
@ -0,0 +1,22 @@
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
from huggingface_hub import list_repo_files, hf_hub_download
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def download_dataset(dataset_dir: Path) -> None:
|
||||
if not dataset_dir.exists():
|
||||
logger.info("Downloading the dataset...")
|
||||
dataset_dir.mkdir(parents=True, exist_ok=True)
|
||||
files_list: List[str] = list_repo_files(repo_id='bigcode/the-stack-dedup', repo_type='dataset')
|
||||
files_to_download: List[str] = [file for file in files_list if file.startswith('data/python/')]
|
||||
for file_name in files_to_download:
|
||||
hf_hub_download(repo_id='bigcode/the-stack-dedup', repo_type='dataset', filename=file_name, local_dir=dataset_dir)
|
||||
logger.info("Dataset downloaded successfully.")
|
||||
|
||||
if __name__ == '__main__':
|
||||
dataset_dir = Path('data/the-stack-python')
|
||||
download_dataset(dataset_dir)
|
@ -1,128 +1,123 @@
|
||||
import ast
|
||||
from pathlib import Path
|
||||
import logging
|
||||
import multiprocessing
|
||||
import tree_sitter_python as tspython
|
||||
from tqdm import tqdm
|
||||
from pathlib import Path
|
||||
from pprint import pprint
|
||||
from typing import Dict, Any, List
|
||||
from datasets import load_dataset, Dataset, concatenate_datasets
|
||||
from transformers import RobertaTokenizer
|
||||
from dataclasses import dataclass
|
||||
import numpy as np
|
||||
import multiprocessing
|
||||
from datasets import load_dataset, Dataset
|
||||
from tree_sitter import Language, Parser
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
import warnings
|
||||
warnings.filterwarnings("ignore", category=SyntaxWarning)
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(levelname)s - %(message)s',
|
||||
datefmt='%Y-%m-%d %H:%M:%S'
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@dataclass
|
||||
class ASTNodeInfo:
|
||||
"""Stores structural information about an AST node."""
|
||||
node_type: str
|
||||
start_token_idx: int
|
||||
end_token_idx: int
|
||||
depth: int
|
||||
sibling_pos: int
|
||||
parent_idx: int
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
'node_type': self.node_type,
|
||||
'start_token_idx': self.start_token_idx,
|
||||
'end_token_idx': self.end_token_idx,
|
||||
'depth': self.depth,
|
||||
'sibling_pos': self.sibling_pos,
|
||||
'parent_idx': self.parent_idx
|
||||
}
|
||||
def analyze_code_with_codebert_and_treesitter(code_snippet: str, tokenizer: AutoTokenizer):
|
||||
"""
|
||||
Map tokens to their original text and Tree-sitter AST position
|
||||
"""
|
||||
# Initialize Tree-sitter
|
||||
PY_LANGUAGE = Language(tspython.language())
|
||||
parser = Parser(PY_LANGUAGE)
|
||||
|
||||
def from_dict(data: Dict[str, Any]) -> 'ASTNodeInfo':
|
||||
return ASTNodeInfo(
|
||||
node_type=data['node_type'],
|
||||
start_token_idx=data['start_token_idx'],
|
||||
end_token_idx=data['end_token_idx'],
|
||||
depth=data['depth'],
|
||||
sibling_pos=data['sibling_pos'],
|
||||
parent_idx=data['parent_idx']
|
||||
)
|
||||
|
||||
def safe_ast_parse(content: str, max_length: int = 50000) -> bool:
|
||||
"""Safely attempt to parse Python code."""
|
||||
if not content or not content.strip() or '\0' in content or len(content) > max_length:
|
||||
return False
|
||||
|
||||
try:
|
||||
tree = ast.parse(content)
|
||||
return bool(list(ast.walk(tree)))
|
||||
except:
|
||||
return False
|
||||
|
||||
def process_and_extract_ast_info(content: str, max_ast_size: int = 1000) -> List[Dict[str, Any]]:
|
||||
"""Process AST and extract node information."""
|
||||
try:
|
||||
tree = ast.parse(content)
|
||||
nodes_info = []
|
||||
|
||||
def visit_node(node: ast.AST, depth: int, parent_idx: int, sibling_pos: int):
|
||||
if len(nodes_info) >= max_ast_size:
|
||||
return
|
||||
# Parse with Tree-sitter
|
||||
tree = parser.parse(bytes(code_snippet, "utf8"))
|
||||
|
||||
encoded = tokenizer(
|
||||
code_snippet,
|
||||
add_special_tokens=True,
|
||||
return_offsets_mapping=True,
|
||||
return_tensors='pt',
|
||||
padding='max_length',
|
||||
truncation=True,
|
||||
max_length=512,
|
||||
)
|
||||
|
||||
tokens = tokenizer.convert_ids_to_tokens(encoded['input_ids'][0])
|
||||
offset_mapping = encoded['offset_mapping'][0].tolist()
|
||||
|
||||
def get_node_position(node, depth=0, idx=0):
|
||||
"""Get depth and sibling index for a node"""
|
||||
return (depth, idx)
|
||||
|
||||
def find_node_at_position(node, start_byte, depth=0, sibling_idx=0):
|
||||
"""Find the most specific node containing the given position"""
|
||||
if start_byte >= node.start_byte and start_byte < node.end_byte:
|
||||
# Check children first for more specific nodes
|
||||
for idx, child in enumerate(node.children):
|
||||
result = find_node_at_position(child, start_byte, depth + 1, idx)
|
||||
if result:
|
||||
return result
|
||||
# Return current node's position if no child contains the position
|
||||
return get_node_position(node, depth, sibling_idx)
|
||||
return None
|
||||
|
||||
depths = []
|
||||
sibling_idxs = []
|
||||
node_texts = []
|
||||
for token, (start, end) in zip(tokens, offset_mapping):
|
||||
if token not in ['<s>', '</s>', '<pad>', '<unk>']:
|
||||
text = code_snippet[start:end] if start < len(code_snippet) else ""
|
||||
# Tree-sitter works with bytes, so convert position
|
||||
start_byte = len(code_snippet[:start].encode('utf8')) if start < len(code_snippet) else 0
|
||||
position = find_node_at_position(tree.root_node, start_byte) or (-1, -1)
|
||||
|
||||
current_idx = len(nodes_info)
|
||||
if hasattr(node, 'lineno'):
|
||||
node_info = ASTNodeInfo(
|
||||
node_type=type(node).__name__,
|
||||
start_token_idx=node.lineno,
|
||||
end_token_idx=getattr(node, 'end_lineno', node.lineno),
|
||||
depth=min(depth, 31),
|
||||
sibling_pos=min(sibling_pos, 31),
|
||||
parent_idx=parent_idx
|
||||
)
|
||||
nodes_info.append(node_info.to_dict())
|
||||
|
||||
if len(nodes_info) < max_ast_size:
|
||||
for i, child in enumerate(ast.iter_child_nodes(node)):
|
||||
visit_node(child, depth + 1, current_idx, i)
|
||||
|
||||
visit_node(tree, 0, -1, 0)
|
||||
return nodes_info
|
||||
except:
|
||||
return []
|
||||
depths.append(position[0])
|
||||
sibling_idxs.append(position[1])
|
||||
node_texts.append(text)
|
||||
else:
|
||||
depths.append(-1)
|
||||
sibling_idxs.append(-1)
|
||||
node_texts.append(None)
|
||||
|
||||
return encoded['input_ids'], encoded['attention_mask'], tokens, depths, sibling_idxs, node_texts
|
||||
|
||||
def process_batch(examples: Dict[str, List[Any]], tokenizer: RobertaTokenizer) -> Dict[str, List[Any]]:
|
||||
def process_batch(examples, tokenizer) -> Dict[str, List[Any]]:
|
||||
"""Process a batch of examples."""
|
||||
contents = examples['content']
|
||||
|
||||
processed_contents = []
|
||||
processed_ast_nodes = []
|
||||
|
||||
processed_input_ids = []
|
||||
processed_attention_masks = []
|
||||
|
||||
processed_attention_mask = []
|
||||
processed_tokens = []
|
||||
processed_depths = []
|
||||
processed_sibling_idxs = []
|
||||
processed_node_texts = []
|
||||
|
||||
for content in contents:
|
||||
if safe_ast_parse(content):
|
||||
ast_nodes = process_and_extract_ast_info(content)
|
||||
if ast_nodes:
|
||||
try:
|
||||
encoding = tokenizer(
|
||||
content,
|
||||
max_length=512,
|
||||
truncation=True,
|
||||
padding='max_length',
|
||||
return_tensors='pt'
|
||||
)
|
||||
|
||||
processed_contents.append(content)
|
||||
processed_ast_nodes.append(ast_nodes)
|
||||
processed_input_ids.append(encoding['input_ids'].squeeze(0).tolist())
|
||||
processed_attention_masks.append(encoding['attention_mask'].squeeze(0).tolist())
|
||||
except:
|
||||
continue
|
||||
|
||||
try:
|
||||
input_ids, attention_mask, tokens, depths, sibling_idxs, node_texts = analyze_code_with_codebert_and_treesitter(content, tokenizer)
|
||||
processed_input_ids.append(input_ids[0])
|
||||
processed_attention_mask.append(attention_mask[0])
|
||||
processed_tokens.append(tokens)
|
||||
processed_depths.append(depths)
|
||||
processed_sibling_idxs.append(sibling_idxs)
|
||||
processed_node_texts.append(node_texts)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing example: {e}")
|
||||
# Return empty lists that will be filtered out later
|
||||
processed_input_ids.append([])
|
||||
processed_attention_mask.append([])
|
||||
processed_tokens.append([])
|
||||
processed_depths.append([])
|
||||
processed_sibling_idxs.append([])
|
||||
processed_node_texts.append([])
|
||||
|
||||
return {
|
||||
'content': processed_contents,
|
||||
'ast_nodes': processed_ast_nodes,
|
||||
'input_ids': processed_input_ids,
|
||||
'attention_mask': processed_attention_masks,
|
||||
'attention_mask': processed_attention_mask,
|
||||
'tokens': processed_tokens,
|
||||
'depths': processed_depths,
|
||||
'sibling_idxs': processed_sibling_idxs,
|
||||
'node_texts': processed_node_texts,
|
||||
}
|
||||
|
||||
def save_dataset_in_chunks(dataset: Dataset, output_path: str, chunk_size: int = 10000):
|
||||
@ -131,14 +126,9 @@ def save_dataset_in_chunks(dataset: Dataset, output_path: str, chunk_size: int =
|
||||
|
||||
# Create directory for chunks
|
||||
output_dir = Path(output_path).parent
|
||||
chunks_dir = output_dir / 'chunks'
|
||||
chunks_dir = output_dir / 'data'
|
||||
chunks_dir.mkdir(exist_ok=True)
|
||||
|
||||
stats_accumulator = {
|
||||
'total_ast_nodes': 0,
|
||||
'total_samples': 0
|
||||
}
|
||||
|
||||
|
||||
pbar = tqdm(range(num_chunks), desc="Saving chunks", unit="chunk")
|
||||
for i in pbar:
|
||||
start_idx = i * chunk_size
|
||||
@ -147,28 +137,10 @@ def save_dataset_in_chunks(dataset: Dataset, output_path: str, chunk_size: int =
|
||||
# Select chunk from dataset
|
||||
chunk_dataset = dataset.select(range(start_idx, end_idx))
|
||||
|
||||
# Update statistics
|
||||
stats_accumulator['total_ast_nodes'] += sum(len(nodes) for nodes in chunk_dataset['ast_nodes'])
|
||||
stats_accumulator['total_samples'] += len(chunk_dataset)
|
||||
|
||||
# Save chunk using datasets native method
|
||||
chunk_path = chunks_dir / f'chunk_{i:04d}'
|
||||
chunk_dataset.save_to_disk(str(chunk_path))
|
||||
|
||||
# Update progress bar postfix with current chunk info
|
||||
pbar.set_postfix({
|
||||
'samples': len(chunk_dataset),
|
||||
'path': str(chunk_path.name)
|
||||
})
|
||||
|
||||
return stats_accumulator
|
||||
|
||||
def load_all_chunks(chunks_dir: Path) -> Dataset:
|
||||
"""Load and concatenate all dataset chunks."""
|
||||
chunks = []
|
||||
for chunk_path in sorted(chunks_dir.glob('chunk_*')):
|
||||
chunks.append(load_from_disk(str(chunk_path)))
|
||||
return concatenate_datasets(chunks)
|
||||
if len(chunk_dataset) > 0: # Only process if chunk has data
|
||||
# Save chunk using datasets native method
|
||||
chunk_path = chunks_dir / f'data-{i:05d}-of-{num_chunks:05d}.parquet'
|
||||
chunk_dataset.to_parquet(str(chunk_path))
|
||||
|
||||
def main():
|
||||
current_dir = Path(__file__).parent
|
||||
@ -176,14 +148,23 @@ def main():
|
||||
output_dir = current_dir.parent / 'data' / 'processed-python'
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
tokenizer = RobertaTokenizer.from_pretrained('microsoft/codebert-base')
|
||||
tokenizer = AutoTokenizer.from_pretrained('microsoft/codebert-base', use_fast=True)
|
||||
|
||||
logger.info("Loading dataset...")
|
||||
dataset = load_dataset(str(input_dir))['train']
|
||||
# dataset = dataset.select(range(200_000)) # Limit dataset size for testing
|
||||
|
||||
original_dataset_size = len(dataset)
|
||||
logger.info(f"Original dataset size: {original_dataset_size}")
|
||||
|
||||
logger.info("Dataset:")
|
||||
pprint(dataset)
|
||||
columns_to_remove = dataset.column_names
|
||||
columns_to_remove.remove('content')
|
||||
columns_to_remove.remove('size')
|
||||
columns_to_remove.remove('avg_line_length')
|
||||
columns_to_remove.remove('max_line_length')
|
||||
columns_to_remove.remove('alphanum_fraction')
|
||||
columns_to_remove.remove('max_stars_count')
|
||||
logging.info(f"Columns to remove: {columns_to_remove}")
|
||||
|
||||
num_proc = min(multiprocessing.cpu_count() - 1, 32)
|
||||
logger.info(f"Using {num_proc} processes for dataset processing")
|
||||
@ -193,35 +174,38 @@ def main():
|
||||
process_batch,
|
||||
fn_kwargs={'tokenizer': tokenizer},
|
||||
batched=True,
|
||||
remove_columns=dataset.column_names,
|
||||
remove_columns=columns_to_remove,
|
||||
desc="Processing dataset",
|
||||
num_proc=num_proc,
|
||||
load_from_cache_file=False
|
||||
)
|
||||
|
||||
processed_dataset = processed_dataset.filter(
|
||||
lambda batch: [len(nodes) > 0 for nodes in batch['ast_nodes']],
|
||||
lambda batch: [len(tokens) > 0 for tokens in batch['tokens']],
|
||||
batched=True,
|
||||
num_proc=num_proc,
|
||||
desc="Filtering invalid examples"
|
||||
)
|
||||
reduced_dataset_size = len(processed_dataset)
|
||||
|
||||
logger.info("Processed dataset:")
|
||||
pprint(processed_dataset)
|
||||
|
||||
logger.info(f"Saving {len(processed_dataset)} processed examples in chunks...")
|
||||
stats_accumulator = save_dataset_in_chunks(
|
||||
reduced_dataset_size = len(processed_dataset)
|
||||
logger.info(f"Processed dataset size: {reduced_dataset_size}")
|
||||
|
||||
if reduced_dataset_size == 0:
|
||||
logger.error("No valid examples found in dataset!")
|
||||
return
|
||||
|
||||
logger.info(f"Saving {reduced_dataset_size} processed examples in chunks...")
|
||||
save_dataset_in_chunks(
|
||||
processed_dataset,
|
||||
str(output_dir / 'processed_dataset'),
|
||||
chunk_size=100_000
|
||||
)
|
||||
|
||||
# Add derived statistics
|
||||
stats = {
|
||||
'original_dataset_size': original_dataset_size,
|
||||
'reduced_dataset_size': reduced_dataset_size,
|
||||
'% samples removed': (1 - reduced_dataset_size / original_dataset_size) * 100,
|
||||
'avg_ast_nodes': float(stats_accumulator['total_ast_nodes'] / stats_accumulator['total_samples'])
|
||||
'samples_removed_pct': (1 - reduced_dataset_size / original_dataset_size) * 100,
|
||||
}
|
||||
|
||||
# Log stats with pprint
|
||||
|
50
code/src/reduce_dataset.py
Normal file
50
code/src/reduce_dataset.py
Normal file
@ -0,0 +1,50 @@
|
||||
import logging
|
||||
import multiprocessing
|
||||
from pathlib import Path
|
||||
from datasets import load_dataset
|
||||
|
||||
from parse_dataset import save_dataset_in_chunks
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(levelname)s - %(message)s',
|
||||
datefmt='%Y-%m-%d %H:%M:%S'
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def main():
|
||||
current_dir = Path(__file__).parent
|
||||
input_dir = current_dir.parent / 'data' / 'processed-python'
|
||||
output_dir = current_dir.parent / 'data' / 'filtered-python'
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
num_proc = min(multiprocessing.cpu_count() - 1, 32)
|
||||
logger.info(f"Using {num_proc} processes for dataset processing")
|
||||
|
||||
logger.info("Loading dataset...")
|
||||
dataset = load_dataset(str(input_dir))['data']
|
||||
logging.info(f"Dataset:\n{dataset}")
|
||||
|
||||
logger.info("Filtering dataset by max_stars_count > 3...")
|
||||
filtered_dataset = dataset.filter(
|
||||
lambda batch: [example and example > 3 for example in batch['max_stars_count']],
|
||||
num_proc=num_proc,
|
||||
batched=True,
|
||||
desc="Filtering dataset"
|
||||
)
|
||||
|
||||
filtered_dataset_size = len(filtered_dataset)
|
||||
logger.info(f"Filtered dataset size: {filtered_dataset_size}")
|
||||
|
||||
if filtered_dataset_size == 0:
|
||||
logger.error("No examples found with max_stars_count > 3!")
|
||||
return
|
||||
|
||||
logger.info(f"Saving {filtered_dataset_size} filtered examples...")
|
||||
save_dataset_in_chunks(filtered_dataset, output_dir, chunk_size=100_000)
|
||||
|
||||
logger.info("Filtering and saving completed!")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -1,12 +0,0 @@
|
||||
{
|
||||
"seed": 42,
|
||||
"mlm_probability": 0.15,
|
||||
"batch": 16,
|
||||
"epochs": 1,
|
||||
"eval_every": 5000,
|
||||
"learning_rate": 5e-4,
|
||||
"weight_decay": 0.01,
|
||||
"max_grad_norm": 1.0,
|
||||
"warmup_steps": 5000,
|
||||
"train_size": 0.95
|
||||
}
|
@ -1,50 +0,0 @@
|
||||
from pathlib import Path
|
||||
import torch
|
||||
from transformers import RobertaForMaskedLM, RobertaConfig, RobertaTokenizer
|
||||
|
||||
from training_utils import (
|
||||
set_seed, setup_wandb, setup_directories, load_config,
|
||||
setup_device, create_optimizer_and_scheduler,
|
||||
train_and_evaluate, create_base_dataloaders, set_deterministic_mode
|
||||
)
|
||||
|
||||
def main() -> None:
|
||||
set_deterministic_mode()
|
||||
|
||||
current_dir = Path(__file__).parent
|
||||
output_dir = setup_directories(current_dir, model_name='base')
|
||||
config = load_config(current_dir / 'tmp_config.json')
|
||||
|
||||
setup_wandb(config, model_name='base')
|
||||
set_seed(config['seed'])
|
||||
device = setup_device()
|
||||
|
||||
# Load tokenizer
|
||||
tokenizer = RobertaTokenizer.from_pretrained('microsoft/codebert-base')
|
||||
|
||||
# Create dataloaders
|
||||
processed_data_dir = current_dir.parent / 'data' / 'processed-python'
|
||||
train_dataloader, valid_dataloader = create_base_dataloaders(
|
||||
processed_data_dir,
|
||||
tokenizer,
|
||||
config,
|
||||
device
|
||||
)
|
||||
|
||||
# Model setup
|
||||
model_config = RobertaConfig.from_pretrained('microsoft/codebert-base')
|
||||
model = RobertaForMaskedLM(model_config)
|
||||
model = model.to(device)
|
||||
|
||||
# Training setup and run
|
||||
num_training_steps = config['epochs'] * len(train_dataloader)
|
||||
optimizer, scheduler = create_optimizer_and_scheduler(model, config, num_training_steps)
|
||||
|
||||
train_and_evaluate(
|
||||
model, train_dataloader, valid_dataloader,
|
||||
optimizer, scheduler, config, output_dir,
|
||||
log_weights=False, device=device
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -1,227 +0,0 @@
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional
|
||||
from transformers import RobertaConfig, RobertaTokenizer, RobertaForMaskedLM
|
||||
|
||||
from training_utils import (
|
||||
set_seed, setup_wandb, setup_directories, load_config,
|
||||
setup_device, create_optimizer_and_scheduler,
|
||||
train_and_evaluate, create_base_dataloaders, PreprocessedTreeDataset,
|
||||
set_deterministic_mode
|
||||
)
|
||||
from parse_dataset import ASTNodeInfo
|
||||
|
||||
class TreePositionalEmbedding(nn.Module):
|
||||
"""Generates tree-aware positional embeddings for code tokens."""
|
||||
|
||||
def __init__(self, d_model: int = 768, max_depth: int = 32):
|
||||
super().__init__()
|
||||
self.d_model = d_model
|
||||
self.max_depth = max_depth
|
||||
|
||||
self.depth_embedding = nn.Embedding(max_depth, d_model)
|
||||
self.sibling_embedding = nn.Embedding(max_depth, d_model)
|
||||
self.combine = nn.Linear(d_model * 2, d_model)
|
||||
|
||||
self._initialize_embeddings()
|
||||
|
||||
def _initialize_embeddings(self):
|
||||
position = torch.arange(self.max_depth).unsqueeze(1).float()
|
||||
div_term = torch.exp(torch.arange(0, self.d_model, 2).float() *
|
||||
(-math.log(10000.0) / self.d_model))
|
||||
|
||||
pe = torch.zeros(self.max_depth, self.d_model)
|
||||
pe[:, 0::2] = torch.sin(position * div_term)
|
||||
pe[:, 1::2] = torch.cos(position * div_term)
|
||||
|
||||
with torch.no_grad():
|
||||
self.depth_embedding.weight.copy_(pe)
|
||||
self.sibling_embedding.weight.copy_(pe)
|
||||
|
||||
def forward(self, input_ids: torch.Tensor, ast_nodes_batch: List[List[ASTNodeInfo]]) -> torch.Tensor:
|
||||
"""Process batched input with corresponding AST nodes."""
|
||||
batch_size, seq_len = input_ids.shape
|
||||
device = input_ids.device
|
||||
embeddings = torch.zeros((batch_size, seq_len, self.d_model), device=device)
|
||||
|
||||
# Process each item in the batch
|
||||
for batch_idx in range(batch_size):
|
||||
ast_nodes = ast_nodes_batch[batch_idx]
|
||||
# Process each position in the sequence
|
||||
for i in range(seq_len):
|
||||
containing_nodes = [
|
||||
node for node in ast_nodes
|
||||
if node.start_token_idx <= i < node.end_token_idx
|
||||
]
|
||||
|
||||
if containing_nodes:
|
||||
node = max(containing_nodes, key=lambda n: n.depth)
|
||||
depth = min(node.depth, self.max_depth - 1)
|
||||
sibling_pos = min(node.sibling_pos, self.max_depth - 1)
|
||||
|
||||
depth_emb = self.depth_embedding(torch.tensor(depth, device=device))
|
||||
sibling_emb = self.sibling_embedding(torch.tensor(sibling_pos, device=device))
|
||||
embeddings[batch_idx, i] = self.combine(torch.cat([depth_emb, sibling_emb]))
|
||||
|
||||
return embeddings
|
||||
|
||||
class TreeCodeBERTForPreTraining(RobertaForMaskedLM):
|
||||
"""CodeBERT model enhanced with normalized embedding weights for stable training."""
|
||||
|
||||
def __init__(self, config: RobertaConfig, max_depth: int = 32, max_seq_length: int = 512):
|
||||
super().__init__(config)
|
||||
|
||||
self.tree_pos_embeddings = TreePositionalEmbedding(
|
||||
d_model=config.hidden_size,
|
||||
max_depth=max_depth
|
||||
)
|
||||
|
||||
self.seq_pos_embeddings = nn.Embedding(max_seq_length, config.hidden_size)
|
||||
|
||||
# Initialize sequential position embeddings with sinusoidal pattern
|
||||
position = torch.arange(max_seq_length).unsqueeze(1)
|
||||
div_term = torch.exp(torch.arange(0, config.hidden_size, 2) * (-math.log(10000.0) / config.hidden_size))
|
||||
pe = torch.zeros(max_seq_length, config.hidden_size)
|
||||
pe[:, 0::2] = torch.sin(position * div_term)
|
||||
pe[:, 1::2] = torch.cos(position * div_term)
|
||||
self.seq_pos_embeddings.weight.data.copy_(pe)
|
||||
|
||||
# Initialize weights with small random values around 0
|
||||
self.alpha = nn.Parameter(torch.randn(1) * 0.02)
|
||||
self.beta = nn.Parameter(torch.randn(1) * 0.02)
|
||||
self.gamma = nn.Parameter(torch.randn(1) * 0.02)
|
||||
|
||||
self.embedding_combination_layer_norm = nn.LayerNorm(config.hidden_size)
|
||||
self.final_layer_norm = nn.LayerNorm(config.hidden_size)
|
||||
|
||||
# Add dropout for regularization
|
||||
self.embedding_dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
|
||||
def get_normalized_weights(self) -> torch.Tensor:
|
||||
"""
|
||||
Compute softmax-normalized weights for embedding combination.
|
||||
Returns tensor of shape (3,) containing normalized [alpha, beta, gamma].
|
||||
"""
|
||||
weights = torch.stack([self.alpha, self.beta, self.gamma])
|
||||
return torch.softmax(weights, dim=0)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
labels: Optional[torch.Tensor] = None,
|
||||
ast_nodes: Optional[List[List[ASTNodeInfo]]] = None,
|
||||
output_attentions: bool = False,
|
||||
**kwargs
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
# Move tensors to device
|
||||
device = input_ids.device
|
||||
|
||||
# Get embeddings
|
||||
token_embeddings = self.roberta.embeddings.word_embeddings(input_ids)
|
||||
|
||||
seq_positions = torch.arange(input_ids.size(1), device=device)
|
||||
seq_embeddings = self.seq_pos_embeddings(seq_positions).unsqueeze(0).expand(input_ids.size(0), -1, -1)
|
||||
|
||||
# Get normalized weights
|
||||
norm_weights = self.get_normalized_weights()
|
||||
|
||||
# Combine embeddings based on presence of AST nodes
|
||||
if ast_nodes is not None:
|
||||
tree_embeddings = self.tree_pos_embeddings(input_ids, ast_nodes)
|
||||
combined_embeddings = (
|
||||
norm_weights[0] * token_embeddings +
|
||||
norm_weights[1] * tree_embeddings +
|
||||
norm_weights[2] * seq_embeddings
|
||||
)
|
||||
else:
|
||||
# Redistribute tree weight to other components when no AST available
|
||||
token_seq_weights = torch.softmax(torch.stack([self.alpha, self.gamma]), dim=0)
|
||||
combined_embeddings = (
|
||||
token_seq_weights[0] * token_embeddings +
|
||||
token_seq_weights[1] * seq_embeddings
|
||||
)
|
||||
|
||||
# Apply layer normalization and dropout
|
||||
combined_embeddings = self.embedding_combination_layer_norm(combined_embeddings)
|
||||
combined_embeddings = self.embedding_dropout(combined_embeddings)
|
||||
combined_embeddings = self.final_layer_norm(combined_embeddings)
|
||||
|
||||
# Forward pass through transformer
|
||||
outputs = self.roberta(
|
||||
inputs_embeds=combined_embeddings,
|
||||
attention_mask=attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
prediction_scores = self.lm_head(sequence_output)
|
||||
|
||||
# Calculate loss if labels provided
|
||||
masked_lm_loss = None
|
||||
if labels is not None:
|
||||
loss_fct = nn.CrossEntropyLoss()
|
||||
masked_lm_loss = loss_fct(
|
||||
prediction_scores.view(-1, self.config.vocab_size),
|
||||
labels.view(-1)
|
||||
)
|
||||
|
||||
# Get normalized weights for logging
|
||||
norm_weights_cpu = norm_weights.detach().cpu()
|
||||
|
||||
return {
|
||||
"loss": masked_lm_loss,
|
||||
"logits": prediction_scores,
|
||||
"hidden_states": outputs.hidden_states if hasattr(outputs, "hidden_states") else None,
|
||||
"attentions": outputs.attentions,
|
||||
"embedding_weights": {
|
||||
"token": norm_weights_cpu[0].item(),
|
||||
"tree": norm_weights_cpu[1].item(),
|
||||
"sequential": norm_weights_cpu[2].item()
|
||||
}
|
||||
}
|
||||
|
||||
def main() -> None:
|
||||
set_deterministic_mode()
|
||||
|
||||
current_dir = Path(__file__).parent
|
||||
output_dir = setup_directories(current_dir, model_name='tree')
|
||||
config = load_config(current_dir / 'tmp_config.json')
|
||||
|
||||
setup_wandb(config, model_name='tree')
|
||||
set_seed(config['seed'])
|
||||
device = setup_device()
|
||||
|
||||
# Load tokenizer
|
||||
tokenizer = RobertaTokenizer.from_pretrained('microsoft/codebert-base')
|
||||
|
||||
# Create dataloaders with tree dataset
|
||||
processed_data_dir = current_dir.parent / 'data' / 'processed-python'
|
||||
train_dataloader, valid_dataloader = create_base_dataloaders(
|
||||
processed_data_dir,
|
||||
tokenizer,
|
||||
config,
|
||||
device,
|
||||
dataset_class=PreprocessedTreeDataset
|
||||
)
|
||||
|
||||
# Model setup
|
||||
model_config = RobertaConfig.from_pretrained('microsoft/codebert-base')
|
||||
model = TreeCodeBERTForPreTraining(model_config)
|
||||
model = model.to(device)
|
||||
|
||||
# Training setup and run
|
||||
num_training_steps = config['epochs'] * len(train_dataloader)
|
||||
optimizer, scheduler = create_optimizer_and_scheduler(model, config, num_training_steps)
|
||||
|
||||
train_and_evaluate(
|
||||
model, train_dataloader, valid_dataloader,
|
||||
optimizer, scheduler, config, output_dir,
|
||||
log_weights=True, device=device
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
306
code/src/training.py
Normal file
306
code/src/training.py
Normal file
@ -0,0 +1,306 @@
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional
|
||||
from transformers import RobertaConfig, AutoTokenizer, RobertaForMaskedLM
|
||||
|
||||
from training_utils import *
|
||||
|
||||
MODEL = 'base-custom' # 'base' or 'tree' or 'base-custom'
|
||||
DATASET = 'filtered-python' # 'processed-python-small' or 'processed-python'
|
||||
|
||||
class TreePositionalEmbedding(nn.Module):
|
||||
"""Improved tree-aware positional embeddings that work directly with depth and sibling tensors."""
|
||||
|
||||
def __init__(self, d_model: int = 768, max_depth: int = 32, dropout: float = 0.1):
|
||||
super().__init__()
|
||||
self.d_model = d_model
|
||||
self.max_depth = max_depth
|
||||
|
||||
# Separate embeddings for different features
|
||||
self.depth_embedding = nn.Embedding(max_depth, d_model)
|
||||
self.sibling_embedding = nn.Embedding(max_depth, d_model)
|
||||
|
||||
# Improved projection layers
|
||||
self.node_projection = nn.Sequential(
|
||||
nn.Linear(d_model * 2, d_model * 2),
|
||||
nn.GELU(),
|
||||
nn.Linear(d_model * 2, d_model),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
# Layer norm for stability
|
||||
self.layer_norm = nn.LayerNorm(d_model)
|
||||
|
||||
self._initialize_embeddings()
|
||||
|
||||
def _initialize_embeddings(self):
|
||||
std = 0.02
|
||||
for embedding in [self.depth_embedding, self.sibling_embedding]:
|
||||
nn.init.normal_(embedding.weight, mean=0.0, std=std)
|
||||
|
||||
# Initialize projection layers
|
||||
for layer in self.node_projection:
|
||||
if isinstance(layer, nn.Linear):
|
||||
nn.init.normal_(layer.weight, mean=0.0, std=std)
|
||||
nn.init.zeros_(layer.bias)
|
||||
|
||||
def forward(self, depths: torch.Tensor, sibling_idxs: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
depths: Tensor of shape [batch_size, seq_len] containing depth values
|
||||
sibling_idxs: Tensor of shape [batch_size, seq_len] containing sibling positions
|
||||
Returns:
|
||||
Tensor of shape [batch_size, seq_len, d_model] containing tree-aware embeddings
|
||||
"""
|
||||
# Clamp values to max_depth
|
||||
depths = torch.clamp(depths, 0, self.max_depth - 1)
|
||||
sibling_idxs = torch.clamp(sibling_idxs, 0, self.max_depth - 1)
|
||||
|
||||
# Get embeddings for each feature
|
||||
depth_embeddings = self.depth_embedding(depths) # [batch, seq_len, d_model]
|
||||
sibling_embeddings = self.sibling_embedding(sibling_idxs) # [batch, seq_len, d_model]
|
||||
|
||||
# Combine features
|
||||
combined = torch.cat([depth_embeddings, sibling_embeddings], dim=-1)
|
||||
embeddings = self.node_projection(combined)
|
||||
|
||||
# Apply layer norm
|
||||
normalized_embeddings = self.layer_norm(embeddings)
|
||||
|
||||
return normalized_embeddings
|
||||
|
||||
class TreeCodeBERTForPreTraining(RobertaForMaskedLM):
|
||||
"""CodeBERT model enhanced with tree-structural information."""
|
||||
|
||||
def __init__(self, config: RobertaConfig, max_depth: int = 32, max_seq_length: int = 512, base_custom=False):
|
||||
super().__init__(config)
|
||||
self.base_custom = base_custom
|
||||
|
||||
self.tree_pos_embeddings = TreePositionalEmbedding(
|
||||
d_model=config.hidden_size,
|
||||
max_depth=max_depth,
|
||||
dropout=config.hidden_dropout_prob
|
||||
)
|
||||
|
||||
self.seq_pos_embeddings = nn.Embedding(max_seq_length, config.hidden_size)
|
||||
|
||||
# Initialize sequential position embeddings with sinusoidal pattern
|
||||
position = torch.arange(max_seq_length).unsqueeze(1)
|
||||
div_term = torch.exp(torch.arange(0, config.hidden_size, 2) * (-math.log(10000.0) / config.hidden_size))
|
||||
pe = torch.zeros(max_seq_length, config.hidden_size)
|
||||
pe[:, 0::2] = torch.sin(position * div_term)
|
||||
pe[:, 1::2] = torch.cos(position * div_term)
|
||||
self.seq_pos_embeddings.weight.data.copy_(pe)
|
||||
|
||||
# Initialize embedding weights equally
|
||||
if base_custom:
|
||||
initial_weight = math.log(1/2)
|
||||
self.embedding_weights = nn.Parameter(torch.full((2,), initial_weight))
|
||||
else:
|
||||
initial_weight = math.log(1/3) # log(1/3) because we use softmax later
|
||||
self.embedding_weights = nn.Parameter(torch.full((3,), initial_weight))
|
||||
|
||||
# Layer norms for embedding combination
|
||||
self.pre_combination_norm = nn.LayerNorm(config.hidden_size)
|
||||
self.post_combination_norm = nn.LayerNorm(config.hidden_size)
|
||||
|
||||
def get_normalized_weights(self):
|
||||
"""Get softmaxed weights for embedding combination."""
|
||||
return torch.softmax(self.embedding_weights, dim=0)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
labels: Optional[torch.Tensor] = None,
|
||||
depths: Optional[torch.Tensor] = None,
|
||||
sibling_idxs: Optional[torch.Tensor] = None,
|
||||
output_attentions: bool = False,
|
||||
**kwargs
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
device = input_ids.device
|
||||
|
||||
# Get normalized weights for embedding combination and calculate regularization
|
||||
weights = self.get_normalized_weights()
|
||||
|
||||
# Calculate weight variance regularization
|
||||
# We want weights to remain somewhat balanced, so penalize high variance
|
||||
weight_variance = torch.var(weights)
|
||||
weight_reg_loss = 0.1 * weight_variance # Adjustable coefficient
|
||||
|
||||
# Add L2 regularization to prevent any weight from getting too close to 1
|
||||
# This helps maintain a more balanced contribution from each embedding type
|
||||
max_weight_penalty = torch.sum(torch.relu(weights - 0.8) ** 2) # Penalize weights > 0.8
|
||||
l2_reg_loss = 0.05 * max_weight_penalty # Adjustable coefficient
|
||||
|
||||
# Get token embeddings
|
||||
token_embeddings = self.roberta.embeddings.word_embeddings(input_ids)
|
||||
token_embeddings = self.pre_combination_norm(token_embeddings)
|
||||
|
||||
# Get sequential position embeddings
|
||||
seq_positions = torch.arange(input_ids.size(1), device=device)
|
||||
seq_embeddings = self.seq_pos_embeddings(seq_positions).unsqueeze(0).expand(input_ids.size(0), -1, -1)
|
||||
|
||||
# Get tree positional embeddings if tree information is provided
|
||||
if depths is not None and sibling_idxs is not None and not self.base_custom:
|
||||
tree_embeddings = self.tree_pos_embeddings(depths, sibling_idxs)
|
||||
else:
|
||||
tree_embeddings = torch.zeros_like(token_embeddings)
|
||||
|
||||
# Combine all embeddings using learned weights
|
||||
if self.base_custom:
|
||||
combined_embeddings = (
|
||||
weights[0] * token_embeddings +
|
||||
weights[1] * seq_embeddings
|
||||
)
|
||||
else:
|
||||
combined_embeddings = (
|
||||
weights[0] * token_embeddings +
|
||||
weights[1] * tree_embeddings +
|
||||
weights[2] * seq_embeddings
|
||||
)
|
||||
|
||||
combined_embeddings = self.post_combination_norm(combined_embeddings)
|
||||
|
||||
# Forward pass through base model
|
||||
outputs = self.roberta(
|
||||
inputs_embeds=combined_embeddings,
|
||||
attention_mask=attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
prediction_scores = self.lm_head(sequence_output)
|
||||
|
||||
# Calculate MLM loss if labels are provided
|
||||
masked_lm_loss = None
|
||||
if labels is not None:
|
||||
loss_fct = nn.CrossEntropyLoss()
|
||||
masked_lm_loss = loss_fct(
|
||||
prediction_scores.view(-1, self.config.vocab_size),
|
||||
labels.view(-1)
|
||||
)
|
||||
|
||||
# Add regularization losses to final loss
|
||||
if masked_lm_loss is not None:
|
||||
final_loss = masked_lm_loss + weight_reg_loss + l2_reg_loss
|
||||
else:
|
||||
final_loss = weight_reg_loss + l2_reg_loss
|
||||
else:
|
||||
final_loss = None
|
||||
|
||||
# Prepare embedding weights for logging
|
||||
weights_cpu = weights.detach().cpu()
|
||||
if self.base_custom:
|
||||
embedding_weights = {
|
||||
"token": weights_cpu[0].item(),
|
||||
"sequential": weights_cpu[1].item()
|
||||
}
|
||||
reg_metrics = {
|
||||
"weight_variance": weight_variance.item(),
|
||||
"max_weight_penalty": max_weight_penalty.item(),
|
||||
"weight_reg_loss": weight_reg_loss.item(),
|
||||
"l2_reg_loss": l2_reg_loss.item()
|
||||
}
|
||||
else:
|
||||
embedding_weights = {
|
||||
"token": weights_cpu[0].item(),
|
||||
"tree": weights_cpu[1].item(),
|
||||
"sequential": weights_cpu[2].item()
|
||||
}
|
||||
reg_metrics = {
|
||||
"weight_variance": weight_variance.item(),
|
||||
"max_weight_penalty": max_weight_penalty.item(),
|
||||
"weight_reg_loss": weight_reg_loss.item(),
|
||||
"l2_reg_loss": l2_reg_loss.item()
|
||||
}
|
||||
|
||||
return {
|
||||
"loss": final_loss,
|
||||
"logits": prediction_scores,
|
||||
"hidden_states": outputs.hidden_states if hasattr(outputs, "hidden_states") else None,
|
||||
"attentions": outputs.attentions if output_attentions else None,
|
||||
"embedding_weights": embedding_weights,
|
||||
"regularization_metrics": reg_metrics
|
||||
}
|
||||
|
||||
def main() -> None:
|
||||
current_dir = Path(__file__).parent
|
||||
output_dir = setup_directories(current_dir, model_name=MODEL)
|
||||
config = load_config(current_dir / 'config.json')
|
||||
|
||||
set_deterministic_mode(config['seed'])
|
||||
|
||||
setup_wandb(config, model_name=MODEL, script_path=__file__)
|
||||
set_seed(config['seed'])
|
||||
device = setup_device()
|
||||
|
||||
# Load tokenizer
|
||||
logger.info("Loading tokenizer...")
|
||||
tokenizer = AutoTokenizer.from_pretrained('microsoft/codebert-base', use_fast=True)
|
||||
|
||||
# Create dataloaders with tree dataset
|
||||
logger.info("Creating dataloaders...")
|
||||
processed_data_dir = current_dir.parent / 'data' / DATASET
|
||||
train_dataloader, valid_dataloader = create_base_dataloaders(
|
||||
processed_data_dir,
|
||||
tokenizer,
|
||||
config,
|
||||
base_training=True if MODEL == 'base' else False
|
||||
)
|
||||
|
||||
# Initialize model config
|
||||
logger.info("Initializing model...")
|
||||
model_config = RobertaConfig.from_pretrained('microsoft/codebert-base')
|
||||
|
||||
# Update W&B
|
||||
wandb.config.update({'model_config': model_config.__dict__})
|
||||
|
||||
# Initialize model
|
||||
if MODEL == 'tree':
|
||||
model = TreeCodeBERTForPreTraining(model_config, max_depth=32, max_seq_length=512)
|
||||
elif MODEL == 'base':
|
||||
model = RobertaForMaskedLM(model_config)
|
||||
elif MODEL == 'base-custom':
|
||||
model = TreeCodeBERTForPreTraining(model_config, max_depth=32, max_seq_length=512, base_custom=True)
|
||||
else:
|
||||
raise ValueError(f"Invalid model type: {MODEL}")
|
||||
model = model.to(device)
|
||||
logging.info(f"Model initialized: {MODEL}")
|
||||
|
||||
# Model precision
|
||||
if device.type == 'cuda' and config['fp16']:
|
||||
model = model.half()
|
||||
logging.info("Model set to half precision.")
|
||||
for param in model.parameters():
|
||||
logging.info(f"Parameter dtype: {param.dtype}")
|
||||
break
|
||||
|
||||
# Training setup
|
||||
logger.info("Setting up optimizer and scheduler...")
|
||||
optimizer, scheduler = create_optimizer_and_scheduler(model, config)
|
||||
|
||||
# Train and evaluate model
|
||||
logger.info("Starting training...")
|
||||
train_and_evaluate(
|
||||
model=model,
|
||||
train_dataloader=train_dataloader,
|
||||
valid_dataloader=valid_dataloader,
|
||||
optimizer=optimizer,
|
||||
scheduler=scheduler,
|
||||
config=config,
|
||||
device=device,
|
||||
output_dir=output_dir
|
||||
)
|
||||
|
||||
logger.info("Training completed!")
|
||||
final_output_dir = output_dir / 'final-model'
|
||||
model.save_pretrained(final_output_dir)
|
||||
tokenizer.save_pretrained(final_output_dir)
|
||||
logger.info(f"Model saved to {final_output_dir}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -8,38 +8,41 @@ import wandb
|
||||
import numpy as np
|
||||
import torch
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, Tuple, Type, Optional
|
||||
from datasets import load_from_disk, concatenate_datasets
|
||||
from typing import Dict, Any, Tuple, Optional
|
||||
from datasets import load_dataset
|
||||
from torch.optim import AdamW
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
from torch.utils.data import DataLoader
|
||||
from transformers import (
|
||||
PreTrainedModel,
|
||||
get_linear_schedule_with_warmup,
|
||||
get_constant_schedule_with_warmup,
|
||||
DataCollatorForLanguageModeling,
|
||||
PreTrainedTokenizer
|
||||
PreTrainedTokenizer,
|
||||
RobertaForMaskedLM
|
||||
)
|
||||
from tqdm import tqdm
|
||||
|
||||
from parse_dataset import ASTNodeInfo
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(levelname)s - %(message)s',
|
||||
datefmt='%Y-%m-%d %H:%M:%S'
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def set_deterministic_mode() -> None:
|
||||
def set_deterministic_mode(seed: int) -> None:
|
||||
"""Enable deterministic mode for reproducibility."""
|
||||
# Set Python random seed
|
||||
random.seed(42)
|
||||
random.seed(seed)
|
||||
|
||||
# Set NumPy random seed
|
||||
np.random.seed(42)
|
||||
np.random.seed(seed)
|
||||
|
||||
# Set PyTorch random seed
|
||||
torch.manual_seed(42)
|
||||
torch.manual_seed(seed)
|
||||
|
||||
# Set CUDA random seed
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(42)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
# Enable deterministic operations for PyTorch 2.x
|
||||
torch.use_deterministic_algorithms(True)
|
||||
@ -47,7 +50,7 @@ def set_deterministic_mode() -> None:
|
||||
torch.backends.cudnn.benchmark = False
|
||||
|
||||
# Set environment variables for reproducibility
|
||||
os.environ['PYTHONHASHSEED'] = '42'
|
||||
os.environ['PYTHONHASHSEED'] = f'{seed}'
|
||||
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8' # Required for CUDA >= 10.2
|
||||
|
||||
def get_device_info() -> Dict[str, Any]:
|
||||
@ -70,53 +73,6 @@ def get_device_info() -> Dict[str, Any]:
|
||||
|
||||
return info
|
||||
|
||||
class TreeDataCollator(DataCollatorForLanguageModeling):
|
||||
"""Custom data collator that handles both MLM and optional AST node information."""
|
||||
|
||||
def torch_call(self, examples):
|
||||
# Check if we have AST nodes
|
||||
has_ast = 'ast_nodes' in examples[0]
|
||||
|
||||
# Extract AST nodes before MLM processing if they exist
|
||||
ast_nodes = None
|
||||
if has_ast:
|
||||
ast_nodes = [e.pop('ast_nodes') for e in examples]
|
||||
|
||||
# Process normal MLM features
|
||||
batch = super().torch_call(examples)
|
||||
|
||||
# Add AST nodes back to batch if they existed
|
||||
if has_ast:
|
||||
batch['ast_nodes'] = ast_nodes
|
||||
|
||||
return batch
|
||||
|
||||
def __call__(self, examples):
|
||||
return self.torch_call(examples)
|
||||
|
||||
class PreprocessedBaseDataset(Dataset):
|
||||
"""Dataset that uses pre-processed data without AST information."""
|
||||
def __init__(self, dataset: Any):
|
||||
self.dataset = dataset
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.dataset)
|
||||
|
||||
def __getitem__(self, idx: int) -> Dict[str, Any]:
|
||||
item = self.dataset[idx]
|
||||
return {
|
||||
'input_ids': torch.tensor(item['input_ids']),
|
||||
'attention_mask': torch.tensor(item['attention_mask']),
|
||||
'labels': torch.tensor(item['input_ids']).clone()
|
||||
}
|
||||
|
||||
class PreprocessedTreeDataset(PreprocessedBaseDataset):
|
||||
"""Dataset that includes AST information."""
|
||||
def __getitem__(self, idx: int) -> Dict[str, Any]:
|
||||
item = super().__getitem__(idx)
|
||||
item['ast_nodes'] = [ASTNodeInfo.from_dict(node) for node in self.dataset[idx]['ast_nodes']]
|
||||
return item
|
||||
|
||||
def set_seed(seed: int) -> None:
|
||||
"""Set random seeds for reproducibility."""
|
||||
random.seed(seed)
|
||||
@ -128,6 +84,7 @@ def set_seed(seed: int) -> None:
|
||||
def setup_wandb(
|
||||
config: Dict[str, Any],
|
||||
model_name: str = 'base',
|
||||
script_path: Path = None,
|
||||
device_info: Optional[Dict[str, Any]] = None
|
||||
) -> None:
|
||||
"""Initialize W&B logging with reproducibility information."""
|
||||
@ -136,17 +93,21 @@ def setup_wandb(
|
||||
# Add reproducibility settings to config
|
||||
full_config = {
|
||||
**config,
|
||||
'random_seed': 42,
|
||||
'deterministic_mode': True,
|
||||
'device_info': device_info or get_device_info(),
|
||||
}
|
||||
|
||||
wandb.init(
|
||||
project='codebert-training-test',
|
||||
project='new-codebert-tree-training',
|
||||
name=f'{model_name}_{curr_time}',
|
||||
config=full_config
|
||||
)
|
||||
|
||||
# Upload the training script
|
||||
wandb.save(script_path)
|
||||
wandb.save(Path(script_path).parent / 'training_utils.py')
|
||||
logger.info(f'Saving script {script_path} to W&B')
|
||||
|
||||
def setup_directories(current_dir: Path, model_name: str = 'base') -> Path:
|
||||
"""Create output directories for model artifacts."""
|
||||
curr_time: str = datetime.datetime.now().strftime('%Y-%m-%d_%H:%M')
|
||||
@ -172,7 +133,6 @@ def setup_device() -> torch.device:
|
||||
def create_optimizer_and_scheduler(
|
||||
model: PreTrainedModel,
|
||||
config: Dict[str, Any],
|
||||
num_training_steps: int
|
||||
) -> Tuple[AdamW, Any]:
|
||||
"""Create optimizer and learning rate scheduler."""
|
||||
optimizer = AdamW(
|
||||
@ -181,270 +141,215 @@ def create_optimizer_and_scheduler(
|
||||
weight_decay=config['weight_decay']
|
||||
)
|
||||
|
||||
scheduler = get_linear_schedule_with_warmup(
|
||||
scheduler = get_constant_schedule_with_warmup(
|
||||
optimizer,
|
||||
num_warmup_steps=config['warmup_steps'],
|
||||
num_training_steps=num_training_steps
|
||||
)
|
||||
|
||||
return optimizer, scheduler
|
||||
|
||||
def evaluate(model: PreTrainedModel, dataloader: DataLoader) -> Tuple[float, float]:
|
||||
"""Evaluate model on validation set."""
|
||||
def evaluate(
|
||||
model: RobertaForMaskedLM,
|
||||
dataloader: DataLoader,
|
||||
device: torch.device
|
||||
) -> Tuple[float, float]:
|
||||
"""Evaluate the model on the validation set.
|
||||
|
||||
Returns:
|
||||
Tuple of (average loss, accuracy on masked tokens)
|
||||
"""
|
||||
model.eval()
|
||||
total_loss: float = 0.0
|
||||
total_acc: float = 0.0
|
||||
total_loss = 0
|
||||
total_correct = 0
|
||||
total_predictions = 0
|
||||
|
||||
with torch.no_grad():
|
||||
for batch in tqdm(dataloader, desc='Validation'):
|
||||
for batch in tqdm(dataloader, desc='Evaluating'):
|
||||
batch = {
|
||||
k: v.to(device) if isinstance(v, torch.Tensor) else v
|
||||
for k, v in batch.items()
|
||||
}
|
||||
outputs = model(**batch)
|
||||
total_loss += outputs.loss.item() if hasattr(outputs, 'loss') else outputs['loss'].item()
|
||||
logits = outputs.logits if hasattr(outputs, 'logits') else outputs['logits']
|
||||
total_acc += logits.argmax(dim=-1).eq(batch['labels']).sum().item()
|
||||
|
||||
# Get loss
|
||||
total_loss += outputs['loss'].item()
|
||||
|
||||
# Calculate accuracy only on masked tokens
|
||||
predictions = outputs['logits'].argmax(dim=-1) # [batch_size, seq_length]
|
||||
labels = batch['labels'] # [batch_size, seq_length]
|
||||
|
||||
# Create mask for tokens we actually want to predict (ignore padding and unmasked tokens)
|
||||
predict_mask = labels != -100 # -100 is the ignore index
|
||||
|
||||
# Calculate accuracy
|
||||
correct = predictions[predict_mask] == labels[predict_mask]
|
||||
total_correct += correct.sum().item()
|
||||
total_predictions += predict_mask.sum().item()
|
||||
|
||||
avg_loss: float = total_loss / len(dataloader)
|
||||
avg_acc: float = total_acc / len(dataloader.dataset)
|
||||
model.train()
|
||||
return avg_loss, avg_acc
|
||||
avg_loss = total_loss / len(dataloader)
|
||||
accuracy = total_correct / total_predictions if total_predictions > 0 else 0
|
||||
|
||||
return avg_loss, accuracy
|
||||
|
||||
def train_and_evaluate(
|
||||
model: PreTrainedModel,
|
||||
train_dataloader: DataLoader,
|
||||
valid_dataloader: DataLoader,
|
||||
optimizer: AdamW,
|
||||
scheduler: Any,
|
||||
config: Dict[str, Any],
|
||||
model: RobertaForMaskedLM,
|
||||
train_dataloader: DataLoader,
|
||||
valid_dataloader: DataLoader,
|
||||
optimizer: AdamW,
|
||||
scheduler: Any,
|
||||
config: Dict[str, Any],
|
||||
device: torch.device,
|
||||
output_dir: Path,
|
||||
log_weights: bool = False,
|
||||
device: torch.device = torch.device('cpu')
|
||||
) -> None:
|
||||
"""Train and evaluate model with deterministic behavior and comprehensive logging."""
|
||||
# Enable deterministic algorithms for PyTorch 2.5
|
||||
torch.use_deterministic_algorithms(True)
|
||||
"""Train and evaluate the model."""
|
||||
|
||||
num_training_steps: int = config['epochs'] * len(train_dataloader)
|
||||
best_valid_loss: float = float('inf')
|
||||
model.train()
|
||||
best_valid_loss = float('inf')
|
||||
|
||||
# Save initial model state for reproducibility
|
||||
torch.save({
|
||||
'model_state': model.state_dict(),
|
||||
'optimizer_state': optimizer.state_dict(),
|
||||
'scheduler_state': scheduler.state_dict(),
|
||||
'rng_state': torch.get_rng_state(),
|
||||
'cuda_rng_state': torch.cuda.get_rng_state() if torch.cuda.is_available() else None
|
||||
}, output_dir / 'initial_state.pt')
|
||||
|
||||
with tqdm(total=num_training_steps, desc='Training') as pbar:
|
||||
for epoch_idx in range(config['epochs']):
|
||||
model.train()
|
||||
epoch_loss = 0.0
|
||||
epoch_steps = 0
|
||||
|
||||
for train_idx, train_batch in enumerate(train_dataloader):
|
||||
for epoch in range(config['epochs']):
|
||||
total_loss = 0
|
||||
|
||||
# Training loop
|
||||
with tqdm(train_dataloader, desc=f'Epoch {epoch+1}/{config["epochs"]}') as pbar:
|
||||
for step, batch in enumerate(pbar):
|
||||
# Move batch to device
|
||||
train_batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v
|
||||
for k, v in train_batch.items()}
|
||||
|
||||
# Forward pass
|
||||
outputs = model(**train_batch)
|
||||
train_loss = outputs.loss if hasattr(outputs, 'loss') else outputs['loss']
|
||||
|
||||
# Backward pass with deterministic behavior
|
||||
train_loss.backward()
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
model.parameters(),
|
||||
config['max_grad_norm']
|
||||
)
|
||||
|
||||
optimizer.step()
|
||||
scheduler.step()
|
||||
optimizer.zero_grad(set_to_none=False) # Use False for determinism
|
||||
|
||||
# Update metrics
|
||||
current_loss = train_loss.item()
|
||||
epoch_loss += current_loss
|
||||
epoch_steps += 1
|
||||
|
||||
# Calculate global step
|
||||
step = train_idx + len(train_dataloader) * epoch_idx
|
||||
|
||||
# Prepare logging dictionary
|
||||
log_dict = {
|
||||
'step': step,
|
||||
'epoch': epoch_idx,
|
||||
'train_loss': current_loss,
|
||||
'gradient_norm': grad_norm.item(),
|
||||
'learning_rate': scheduler.get_last_lr()[0],
|
||||
batch = {
|
||||
k: v.to(device) if isinstance(v, torch.Tensor) else v
|
||||
for k, v in batch.items()
|
||||
}
|
||||
|
||||
# Add embedding weights if using tree model
|
||||
if log_weights and 'embedding_weights' in outputs:
|
||||
weights = outputs['embedding_weights']
|
||||
log_dict.update({
|
||||
'token_weight': weights['token'],
|
||||
'tree_weight': weights['tree'],
|
||||
'sequential_weight': weights['sequential']
|
||||
})
|
||||
pbar_dict = {
|
||||
'epoch': f"{epoch_idx + 1}/{config['epochs']}",
|
||||
'train_loss': f"{current_loss:.3f}",
|
||||
'α': f"{weights['token']:.2f}",
|
||||
'β': f"{weights['tree']:.2f}",
|
||||
'γ': f"{weights['sequential']:.2f}"
|
||||
}
|
||||
else:
|
||||
pbar_dict = {
|
||||
'epoch': f"{epoch_idx + 1}/{config['epochs']}",
|
||||
'train_loss': f"{current_loss:.3f}"
|
||||
}
|
||||
# Forward pass
|
||||
outputs = model(**batch)
|
||||
loss = outputs['loss']
|
||||
|
||||
# Log to wandb
|
||||
wandb.log(log_dict)
|
||||
# Backward pass
|
||||
loss.backward()
|
||||
|
||||
# Clip gradients
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config['max_grad_norm'])
|
||||
|
||||
optimizer.step()
|
||||
scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Update metrics
|
||||
total_loss += loss.item()
|
||||
|
||||
# Update progress bar
|
||||
pbar.update(1)
|
||||
pbar.set_postfix(pbar_dict)
|
||||
pbar.set_postfix({'loss': f'{loss.item():.4f}'})
|
||||
|
||||
# Periodic evaluation
|
||||
if (train_idx != 0 and train_idx % config['eval_every'] == 0) or train_idx == len(train_dataloader) - 1:
|
||||
model.eval()
|
||||
valid_loss, valid_acc = evaluate(model, valid_dataloader)
|
||||
model.train()
|
||||
# Log to wandb if configured
|
||||
log_values = {
|
||||
'train_loss': loss.item(),
|
||||
'grad_norm': grad_norm.item(),
|
||||
'learning_rate': scheduler.get_last_lr()[0]
|
||||
}
|
||||
if 'embedding_weights' in outputs and 'tree' in outputs['embedding_weights']:
|
||||
log_values.update({
|
||||
'token_weight': outputs['embedding_weights']['token'],
|
||||
'tree_weight': outputs['embedding_weights']['tree'],
|
||||
'seq_weight': outputs['embedding_weights']['sequential'],
|
||||
})
|
||||
elif 'embedding_weights' in outputs:
|
||||
log_values.update({
|
||||
'token_weight': outputs['embedding_weights']['token'],
|
||||
'seq_weight': outputs['embedding_weights']['sequential'],
|
||||
})
|
||||
if 'regularization_metrics' in outputs:
|
||||
log_values.update({f"regularization/{k}": v for k, v in outputs['regularization_metrics'].items()})
|
||||
wandb.log(log_values)
|
||||
|
||||
# Evaluate periodically
|
||||
if step % config['eval_every'] == 0 and step > 0:
|
||||
valid_loss, valid_acc = evaluate(model, valid_dataloader, device)
|
||||
|
||||
# Log validation metrics
|
||||
wandb.log({
|
||||
'valid_loss': valid_loss,
|
||||
'valid_acc': valid_acc,
|
||||
'step': step,
|
||||
'epoch': epoch_idx,
|
||||
})
|
||||
wandb.log({'valid_loss': valid_loss, 'valid_accuracy': valid_acc})
|
||||
|
||||
# Save checkpoint if best model
|
||||
# Print validation metrics
|
||||
print(f"\nValidation - Loss: {valid_loss:.4f}, Accuracy: {valid_acc:.4f}")
|
||||
|
||||
# Save best model
|
||||
if valid_loss < best_valid_loss:
|
||||
best_valid_loss = valid_loss
|
||||
|
||||
# Save complete state
|
||||
torch.save({
|
||||
'epoch': epoch_idx,
|
||||
'step': step,
|
||||
'model_state': model.state_dict(),
|
||||
'optimizer_state': optimizer.state_dict(),
|
||||
'scheduler_state': scheduler.state_dict(),
|
||||
'rng_state': torch.get_rng_state(),
|
||||
'cuda_rng_state': torch.cuda.get_rng_state() if torch.cuda.is_available() else None,
|
||||
'best_valid_loss': best_valid_loss,
|
||||
'config': config,
|
||||
'device_info': get_device_info()
|
||||
}, output_dir / 'best_model.pt')
|
||||
|
||||
# Log best metrics
|
||||
wandb.run.summary['best_valid_loss'] = best_valid_loss
|
||||
wandb.run.summary['best_valid_acc'] = valid_acc
|
||||
wandb.run.summary['best_epoch'] = epoch_idx
|
||||
wandb.run.summary['best_step'] = step
|
||||
|
||||
# End of epoch logging
|
||||
avg_epoch_loss = epoch_loss / epoch_steps
|
||||
wandb.log({
|
||||
'epoch': epoch_idx,
|
||||
'epoch_avg_loss': avg_epoch_loss,
|
||||
})
|
||||
|
||||
# Save end of epoch checkpoint
|
||||
torch.save({
|
||||
'epoch': epoch_idx,
|
||||
'model_state': model.state_dict(),
|
||||
'optimizer_state': optimizer.state_dict(),
|
||||
'scheduler_state': scheduler.state_dict(),
|
||||
'rng_state': torch.get_rng_state(),
|
||||
'cuda_rng_state': torch.cuda.get_rng_state() if torch.cuda.is_available() else None,
|
||||
'train_loss': avg_epoch_loss,
|
||||
'config': config,
|
||||
'device_info': get_device_info()
|
||||
}, output_dir / f'checkpoint_epoch_{epoch_idx}.pt')
|
||||
|
||||
# End of training logging
|
||||
wandb.run.summary['final_epoch'] = config['epochs'] - 1
|
||||
wandb.run.summary['total_steps'] = num_training_steps
|
||||
|
||||
logger.info(f'Training completed. Best validation loss: {best_valid_loss:.4f}')
|
||||
logger.info(f'Model checkpoints saved in {output_dir}')
|
||||
|
||||
def load_all_chunks(chunks_dir: Path) -> Any:
|
||||
"""Load and concatenate all dataset chunks."""
|
||||
logger.info(f"Loading dataset chunks from {chunks_dir}")
|
||||
chunks = []
|
||||
for chunk_path in sorted(chunks_dir.glob('chunk_*'))[:5]:
|
||||
chunks.append(load_from_disk(str(chunk_path)))
|
||||
dataset = concatenate_datasets(chunks)
|
||||
logger.info(f"Loaded {len(dataset)} examples from {len(chunks)} chunks")
|
||||
return dataset
|
||||
model.save_pretrained(output_dir / f'checkpoint-{epoch}-{step}')
|
||||
|
||||
model.train() # Resume training mode
|
||||
|
||||
def create_base_dataloaders(
|
||||
processed_data_dir: Path,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
config: Dict[str, Any],
|
||||
device: torch.device,
|
||||
dataset_class: Type[Dataset] = PreprocessedBaseDataset,
|
||||
base_training=False,
|
||||
) -> Tuple[DataLoader, DataLoader]:
|
||||
"""Create reproducible dataloaders from pre-processed data."""
|
||||
# Load chunks
|
||||
chunks_dir = processed_data_dir / 'chunks'
|
||||
full_dataset = load_all_chunks(chunks_dir)
|
||||
"""Create dataloaders from pre-processed parquet data."""
|
||||
|
||||
# Load chunks using datasets library
|
||||
chunks_dir = processed_data_dir / 'data'
|
||||
dataset = load_dataset(
|
||||
"parquet",
|
||||
data_files=str(chunks_dir / "*.parquet"),
|
||||
split="train"
|
||||
)
|
||||
|
||||
contents = dataset['content'][:config['batch_size']]
|
||||
with open("contents.txt", "w") as f:
|
||||
for content in contents:
|
||||
f.write(content)
|
||||
f.write("\n\n\n")
|
||||
f.write("#" * 80)
|
||||
f.write("\n\n\n")
|
||||
|
||||
# Remove columns that are not needed
|
||||
columns_to_remove = dataset.column_names
|
||||
columns_to_remove.remove('input_ids')
|
||||
columns_to_remove.remove('attention_mask')
|
||||
|
||||
if not base_training:
|
||||
columns_to_remove.remove('depths')
|
||||
columns_to_remove.remove('sibling_idxs')
|
||||
|
||||
dataset = dataset.remove_columns(columns_to_remove)
|
||||
|
||||
logging.info(f"Loaded dataset:\n{dataset}")
|
||||
|
||||
# Calculate split sizes
|
||||
dataset_size = len(full_dataset)
|
||||
dataset_size = len(dataset)
|
||||
train_size = int(config['train_size'] * dataset_size)
|
||||
val_size = dataset_size - train_size
|
||||
|
||||
# Create splits using indices to avoid device issues
|
||||
indices = torch.arange(dataset_size, device=device)
|
||||
# Create splits
|
||||
splits = dataset.train_test_split(
|
||||
train_size=train_size,
|
||||
test_size=val_size,
|
||||
seed=config['seed']
|
||||
)
|
||||
train_dataset = splits['train']
|
||||
valid_dataset = splits['test']
|
||||
|
||||
# Create a deterministic generator on the correct device
|
||||
generator = torch.Generator(device=device)
|
||||
generator.manual_seed(42)
|
||||
|
||||
# Shuffle indices deterministically
|
||||
shuffled_indices = indices[torch.randperm(len(indices), device=device, generator=generator)]
|
||||
train_indices = shuffled_indices[:train_size].cpu()
|
||||
valid_indices = shuffled_indices[train_size:].cpu()
|
||||
|
||||
# Create train/validation splits using the shuffled indices
|
||||
train_dataset = torch.utils.data.Subset(full_dataset, train_indices.tolist())
|
||||
valid_dataset = torch.utils.data.Subset(full_dataset, valid_indices.tolist())
|
||||
|
||||
# Create dataset wrappers
|
||||
train_dataset = dataset_class(train_dataset)
|
||||
valid_dataset = dataset_class(valid_dataset)
|
||||
|
||||
logger.info(f"Created train dataset with {len(train_dataset)} samples")
|
||||
logger.info(f"Created validation dataset with {len(valid_dataset)} samples")
|
||||
|
||||
# Use TreeDataCollator for both base and tree models
|
||||
data_collator = TreeDataCollator(
|
||||
tokenizer=tokenizer,
|
||||
mlm=True,
|
||||
# Create data collator for MLM
|
||||
data_collator = DataCollatorForLanguageModeling(
|
||||
tokenizer=tokenizer,
|
||||
mlm=True,
|
||||
mlm_probability=config['mlm_probability']
|
||||
)
|
||||
|
||||
# Create train dataloader without generator
|
||||
|
||||
train_dataloader = DataLoader(
|
||||
train_dataset,
|
||||
batch_size=config['batch'],
|
||||
shuffle=False, # We already shuffled the data
|
||||
batch_size=config['batch_size'],
|
||||
shuffle=False, # We'll let the dataset handle shuffling
|
||||
collate_fn=data_collator,
|
||||
num_workers=0, # Use single worker for reproducibility
|
||||
drop_last=True # Ensure consistent batch sizes
|
||||
num_workers=0,
|
||||
drop_last=True,
|
||||
)
|
||||
|
||||
# Create validation dataloader
|
||||
valid_dataloader = DataLoader(
|
||||
valid_dataset,
|
||||
batch_size=config['batch'],
|
||||
batch_size=config['batch_size'],
|
||||
shuffle=False,
|
||||
collate_fn=data_collator,
|
||||
num_workers=0,
|
||||
drop_last=True
|
||||
drop_last=True,
|
||||
)
|
||||
|
||||
return train_dataloader, valid_dataloader
|
||||
|
Loading…
Reference in New Issue
Block a user