new changes
This commit is contained in:
parent
c80c591e7c
commit
3b7bc5d6d2
@ -1,13 +1,13 @@
|
|||||||
{
|
{
|
||||||
"extra_embeddings": true,
|
"extra_embeddings": true,
|
||||||
"run_name": "tree-seq-non-sinusoidal",
|
"run_name": "no-sinusoidal",
|
||||||
"data_dir": "./data/codeparrot-clean-parsed-starencoder-classes-encoded/",
|
"data_dir": "./data/codeparrot-clean-parsed-starencoder-no-comments/",
|
||||||
"output_dir": "./outputs/tree-seq-non-sinusoidal",
|
"output_dir": "./outputs/long-no-comments-starencoder-no-sinusoidal",
|
||||||
"seed": 420,
|
"seed": 420,
|
||||||
"mlm_probability": 0.15,
|
"mlm_probability": 0.15,
|
||||||
"batch_size": 32,
|
"batch_size": 32,
|
||||||
"epochs": 1,
|
"epochs": 3,
|
||||||
"eval_every": 5000,
|
"eval_every": 10000,
|
||||||
"learning_rate": 5e-4,
|
"learning_rate": 5e-4,
|
||||||
"weight_decay": 0.1,
|
"weight_decay": 0.1,
|
||||||
"max_grad_norm": 1.0,
|
"max_grad_norm": 1.0,
|
||||||
|
@ -41,16 +41,17 @@ def load_config(config_path: Path) -> dict:
|
|||||||
return json.load(f)
|
return json.load(f)
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
set_seed(config['seed'])
|
|
||||||
|
|
||||||
# Setup paths
|
# Setup paths
|
||||||
current_dir = Path(__file__).parent
|
current_dir = Path(__file__).parent
|
||||||
config = load_config(current_dir / 'config.json')
|
config = load_config(current_dir / 'config.json')
|
||||||
data_dir = Path(config['data_dir'])
|
data_dir = Path(config['data_dir'])
|
||||||
output_dir = Path(config['output_dir'])
|
output_dir = Path(config['output_dir'])
|
||||||
|
|
||||||
|
# Set seed
|
||||||
|
set_seed(config['seed'])
|
||||||
|
|
||||||
# Initialize W&B
|
# Initialize W&B
|
||||||
wandb.init(project='codeparrot-starencoder', config=config, name=config['run_name'])
|
wandb.init(project='codeparrot-starencoder-no-comments', config=config, name=config['run_name'])
|
||||||
|
|
||||||
# Upload the training files to W&B
|
# Upload the training files to W&B
|
||||||
wandb.save(__file__)
|
wandb.save(__file__)
|
||||||
@ -87,7 +88,6 @@ def main():
|
|||||||
if config['extra_embeddings']:
|
if config['extra_embeddings']:
|
||||||
columns_to_remove.remove('depths')
|
columns_to_remove.remove('depths')
|
||||||
columns_to_remove.remove('sibling_idxs')
|
columns_to_remove.remove('sibling_idxs')
|
||||||
columns_to_remove.remove('node_types_encoded')
|
|
||||||
dataset = dataset.remove_columns(columns_to_remove)
|
dataset = dataset.remove_columns(columns_to_remove)
|
||||||
logger.info(f'Loaded dataset:\n{dataset}')
|
logger.info(f'Loaded dataset:\n{dataset}')
|
||||||
|
|
||||||
|
@ -9,17 +9,16 @@ from transformers import AutoConfig, BertForMaskedLM
|
|||||||
from tree_codebert import TreePositionalEmbedding
|
from tree_codebert import TreePositionalEmbedding
|
||||||
|
|
||||||
class TreeStarEncoderForPreTraining(BertForMaskedLM):
|
class TreeStarEncoderForPreTraining(BertForMaskedLM):
|
||||||
def __init__(self, config: AutoConfig, max_depth: int = 32, max_seq_length: int = 512, log: bool = True):
|
def __init__(self, config: AutoConfig, max_depth: int = 32, max_seq_length: int = 512):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.config = config
|
self.config = config
|
||||||
self.log = log
|
|
||||||
|
|
||||||
self.fusion_layer = nn.Sequential(
|
# self.fusion_layer = nn.Sequential(
|
||||||
nn.Linear(config.hidden_size * 4, config.hidden_size),
|
# nn.Linear(config.hidden_size * 4, config.hidden_size),
|
||||||
nn.GELU(),
|
# nn.GELU(),
|
||||||
nn.Dropout(config.hidden_dropout_prob),
|
# nn.Dropout(config.hidden_dropout_prob),
|
||||||
nn.LayerNorm(config.hidden_size)
|
# nn.LayerNorm(config.hidden_size)
|
||||||
)
|
# )
|
||||||
|
|
||||||
# Override config to set max_seq_length
|
# Override config to set max_seq_length
|
||||||
config.max_position_embeddings = max_seq_length
|
config.max_position_embeddings = max_seq_length
|
||||||
@ -32,13 +31,13 @@ class TreeStarEncoderForPreTraining(BertForMaskedLM):
|
|||||||
|
|
||||||
self.seq_pos_embeddings = nn.Embedding(max_seq_length, config.hidden_size)
|
self.seq_pos_embeddings = nn.Embedding(max_seq_length, config.hidden_size)
|
||||||
|
|
||||||
# Initialize sequential position embeddings with sinusoidal pattern
|
# # Initialize sequential position embeddings with sinusoidal pattern
|
||||||
position = torch.arange(max_seq_length).unsqueeze(1)
|
# position = torch.arange(max_seq_length).unsqueeze(1)
|
||||||
div_term = torch.exp(torch.arange(0, config.hidden_size, 2) * (-math.log(10000.0) / config.hidden_size))
|
# div_term = torch.exp(torch.arange(0, config.hidden_size, 2) * (-math.log(10000.0) / config.hidden_size))
|
||||||
pe = torch.zeros(max_seq_length, config.hidden_size)
|
# pe = torch.zeros(max_seq_length, config.hidden_size)
|
||||||
pe[:, 0::2] = torch.sin(position * div_term)
|
# pe[:, 0::2] = torch.sin(position * div_term)
|
||||||
pe[:, 1::2] = torch.cos(position * div_term)
|
# pe[:, 1::2] = torch.cos(position * div_term)
|
||||||
self.seq_pos_embeddings.weight.data.copy_(pe)
|
# self.seq_pos_embeddings.weight.data.copy_(pe)
|
||||||
|
|
||||||
# New node type embeddings
|
# New node type embeddings
|
||||||
self.node_type_embeddings = nn.Embedding(217, config.hidden_size)
|
self.node_type_embeddings = nn.Embedding(217, config.hidden_size)
|
||||||
@ -51,7 +50,6 @@ class TreeStarEncoderForPreTraining(BertForMaskedLM):
|
|||||||
labels: Optional[torch.Tensor] = None,
|
labels: Optional[torch.Tensor] = None,
|
||||||
depths: Optional[torch.Tensor] = None,
|
depths: Optional[torch.Tensor] = None,
|
||||||
sibling_idxs: Optional[torch.Tensor] = None,
|
sibling_idxs: Optional[torch.Tensor] = None,
|
||||||
node_types: Optional[torch.Tensor] = None,
|
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
**kwargs
|
**kwargs
|
||||||
) -> Dict[str, torch.Tensor]:
|
) -> Dict[str, torch.Tensor]:
|
||||||
@ -70,11 +68,14 @@ class TreeStarEncoderForPreTraining(BertForMaskedLM):
|
|||||||
else:
|
else:
|
||||||
tree_embeddings = torch.zeros_like(token_embeddings)
|
tree_embeddings = torch.zeros_like(token_embeddings)
|
||||||
|
|
||||||
# Get node type embeddings
|
# # Get node type embeddings
|
||||||
node_type_embeddings = self.node_type_embeddings(node_types)
|
# node_type_embeddings = self.node_type_embeddings(node_types)
|
||||||
|
|
||||||
combined = torch.cat([token_embeddings, tree_embeddings, seq_embeddings, node_type_embeddings], dim=-1)
|
# combined = torch.cat([token_embeddings, tree_embeddings, seq_embeddings, node_type_embeddings], dim=-1)
|
||||||
combined_embeddings = self.fusion_layer(combined)
|
# combined_embeddings = self.fusion_layer(combined)
|
||||||
|
|
||||||
|
# Add the embeddings instead of concatenating
|
||||||
|
combined_embeddings = token_embeddings + tree_embeddings + seq_embeddings
|
||||||
|
|
||||||
combined_embeddings = self.norm(combined_embeddings)
|
combined_embeddings = self.norm(combined_embeddings)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user