added training code for treecodebert

This commit is contained in:
Patryk Bartkowiak 2024-11-05 16:50:19 +00:00
parent 9299c56bb1
commit 96fc1041cf
5 changed files with 1685 additions and 114 deletions

View File

@ -36,3 +36,4 @@ distribution = true
[tool.pdm.scripts]
run_training = {cmd = "src/train_codebert_mlm.py"}
run_tree_training = {cmd = "src/train_tree_codebert_mlm.py"}

72
code/src/run_inference.py Normal file
View File

@ -0,0 +1,72 @@
import torch
from transformers import RobertaTokenizer, RobertaForMaskedLM, DataCollatorForLanguageModeling
from pathlib import Path
def load_model_and_tokenizer(model_path: Path, tokenizer_name: str = 'microsoft/codebert-base'):
# Load the pre-trained tokenizer
tokenizer = RobertaTokenizer.from_pretrained(tokenizer_name)
# Load the trained model
state_dict = torch.load(model_path, weights_only=True)
corrected_state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}
model = RobertaForMaskedLM.from_pretrained('roberta-base')
model.load_state_dict(corrected_state_dict)
return model, tokenizer
def run_inference(model, tokenizer, input_text: str, max_length: int = 512):
# Tokenize the input text
inputs = tokenizer(
input_text,
return_tensors='pt',
padding='max_length',
truncation=True,
max_length=max_length
)
# Use DataCollatorForLanguageModeling for MLM-style dynamic masking
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=True, mlm_probability=0.15)
inputs = data_collator([inputs]) # Collate the input and add random masks
# Squeeze the batch dimension
for k, v in inputs.items():
inputs[k] = v.squeeze(0)
print(tokenizer.decode(inputs['input_ids'][0][inputs['attention_mask'][0] == 1], skip_special_tokens=False))
with torch.no_grad():
# Run inference
outputs = model(**inputs).logits.squeeze(0)
# Get predicted token ids (argmax over logits)
predicted_token_ids = outputs.argmax(dim=-1)
# Ignore padding tokens
predicted_token_ids = predicted_token_ids[inputs['attention_mask'][0] == 1]
# Decode predicted token ids to text
predicted_text = tokenizer.decode(predicted_token_ids, skip_special_tokens=False)
return predicted_text
def main_inference():
# Load the trained model and tokenizer
model_path = Path('/sql/msc-patryk-bartkowiak/outputs/2024-10-21_20:15/best_model.pt') # Update this to your trained model path
model, tokenizer = load_model_and_tokenizer(model_path)
# Define input string
input_string = """def compute_area(radius):
# This function calculates the area of a circle
pi = 3.14159
return pi * radius ** 2"""
# Run inference
print(f"Input:\n{input_string}")
print("\n\n")
print("Masked input:")
output = run_inference(model, tokenizer, input_string)
print("\n\n")
print(f"Output:\n{output}")
if __name__ == "__main__":
main_inference()

1471
code/src/test_dataset.ipynb Normal file

File diff suppressed because it is too large Load Diff

File diff suppressed because one or more lines are too long

View File

@ -382,7 +382,7 @@ def train_and_evaluate(
if train_loss is not None:
train_loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), config['max_grad_norm'])
norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config['max_grad_norm'])
optimizer.step()
scheduler.step()
optimizer.zero_grad()
@ -394,7 +394,7 @@ def train_and_evaluate(
# Update progress bar with all three weights
pbar.update(1)
pbar.set_postfix({
'loss': f"{current_loss:.3f}",
'train_loss': f"{current_loss:.3f}",
'α': f"{weights['token']:.2f}",
'β': f"{weights['tree']:.2f}",
'γ': f"{weights['sequential']:.2f}"
@ -403,10 +403,12 @@ def train_and_evaluate(
# Log all three weights separately
step = train_idx + len(train_dataloader) * epoch_idx
wandb.log({
'loss': current_loss,
'train_loss': current_loss,
'token_weight': weights['token'],
'tree_weight': weights['tree'],
'sequential_weight': weights['sequential'],
'gradient_norm': norm.item(),
'learning_rate': scheduler.get_last_lr()[0],
'step': step,
})
@ -425,7 +427,7 @@ def train_and_evaluate(
torch.save(model.state_dict(), output_dir / 'best_model.pt')
else:
pbar.update(1)
pbar.set_postfix({'loss': 'N/A'})
pbar.set_postfix({'train_loss': 'N/A'})
logger.info(f'Best validation loss: {best_valid_loss}')