added training code for treecodebert
This commit is contained in:
parent
9299c56bb1
commit
96fc1041cf
@ -36,3 +36,4 @@ distribution = true
|
||||
|
||||
[tool.pdm.scripts]
|
||||
run_training = {cmd = "src/train_codebert_mlm.py"}
|
||||
run_tree_training = {cmd = "src/train_tree_codebert_mlm.py"}
|
||||
|
72
code/src/run_inference.py
Normal file
72
code/src/run_inference.py
Normal file
@ -0,0 +1,72 @@
|
||||
import torch
|
||||
from transformers import RobertaTokenizer, RobertaForMaskedLM, DataCollatorForLanguageModeling
|
||||
from pathlib import Path
|
||||
|
||||
def load_model_and_tokenizer(model_path: Path, tokenizer_name: str = 'microsoft/codebert-base'):
|
||||
# Load the pre-trained tokenizer
|
||||
tokenizer = RobertaTokenizer.from_pretrained(tokenizer_name)
|
||||
|
||||
# Load the trained model
|
||||
state_dict = torch.load(model_path, weights_only=True)
|
||||
corrected_state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}
|
||||
model = RobertaForMaskedLM.from_pretrained('roberta-base')
|
||||
model.load_state_dict(corrected_state_dict)
|
||||
|
||||
return model, tokenizer
|
||||
|
||||
def run_inference(model, tokenizer, input_text: str, max_length: int = 512):
|
||||
# Tokenize the input text
|
||||
inputs = tokenizer(
|
||||
input_text,
|
||||
return_tensors='pt',
|
||||
padding='max_length',
|
||||
truncation=True,
|
||||
max_length=max_length
|
||||
)
|
||||
|
||||
# Use DataCollatorForLanguageModeling for MLM-style dynamic masking
|
||||
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=True, mlm_probability=0.15)
|
||||
inputs = data_collator([inputs]) # Collate the input and add random masks
|
||||
|
||||
# Squeeze the batch dimension
|
||||
for k, v in inputs.items():
|
||||
inputs[k] = v.squeeze(0)
|
||||
|
||||
print(tokenizer.decode(inputs['input_ids'][0][inputs['attention_mask'][0] == 1], skip_special_tokens=False))
|
||||
|
||||
with torch.no_grad():
|
||||
# Run inference
|
||||
outputs = model(**inputs).logits.squeeze(0)
|
||||
|
||||
# Get predicted token ids (argmax over logits)
|
||||
predicted_token_ids = outputs.argmax(dim=-1)
|
||||
|
||||
# Ignore padding tokens
|
||||
predicted_token_ids = predicted_token_ids[inputs['attention_mask'][0] == 1]
|
||||
|
||||
# Decode predicted token ids to text
|
||||
predicted_text = tokenizer.decode(predicted_token_ids, skip_special_tokens=False)
|
||||
|
||||
return predicted_text
|
||||
|
||||
def main_inference():
|
||||
# Load the trained model and tokenizer
|
||||
model_path = Path('/sql/msc-patryk-bartkowiak/outputs/2024-10-21_20:15/best_model.pt') # Update this to your trained model path
|
||||
model, tokenizer = load_model_and_tokenizer(model_path)
|
||||
|
||||
# Define input string
|
||||
input_string = """def compute_area(radius):
|
||||
# This function calculates the area of a circle
|
||||
pi = 3.14159
|
||||
return pi * radius ** 2"""
|
||||
|
||||
# Run inference
|
||||
print(f"Input:\n{input_string}")
|
||||
print("\n\n")
|
||||
print("Masked input:")
|
||||
output = run_inference(model, tokenizer, input_string)
|
||||
print("\n\n")
|
||||
print(f"Output:\n{output}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main_inference()
|
1471
code/src/test_dataset.ipynb
Normal file
1471
code/src/test_dataset.ipynb
Normal file
File diff suppressed because it is too large
Load Diff
File diff suppressed because one or more lines are too long
@ -382,7 +382,7 @@ def train_and_evaluate(
|
||||
|
||||
if train_loss is not None:
|
||||
train_loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), config['max_grad_norm'])
|
||||
norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config['max_grad_norm'])
|
||||
optimizer.step()
|
||||
scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
@ -394,7 +394,7 @@ def train_and_evaluate(
|
||||
# Update progress bar with all three weights
|
||||
pbar.update(1)
|
||||
pbar.set_postfix({
|
||||
'loss': f"{current_loss:.3f}",
|
||||
'train_loss': f"{current_loss:.3f}",
|
||||
'α': f"{weights['token']:.2f}",
|
||||
'β': f"{weights['tree']:.2f}",
|
||||
'γ': f"{weights['sequential']:.2f}"
|
||||
@ -403,10 +403,12 @@ def train_and_evaluate(
|
||||
# Log all three weights separately
|
||||
step = train_idx + len(train_dataloader) * epoch_idx
|
||||
wandb.log({
|
||||
'loss': current_loss,
|
||||
'train_loss': current_loss,
|
||||
'token_weight': weights['token'],
|
||||
'tree_weight': weights['tree'],
|
||||
'sequential_weight': weights['sequential'],
|
||||
'gradient_norm': norm.item(),
|
||||
'learning_rate': scheduler.get_last_lr()[0],
|
||||
'step': step,
|
||||
})
|
||||
|
||||
@ -425,7 +427,7 @@ def train_and_evaluate(
|
||||
torch.save(model.state_dict(), output_dir / 'best_model.pt')
|
||||
else:
|
||||
pbar.update(1)
|
||||
pbar.set_postfix({'loss': 'N/A'})
|
||||
pbar.set_postfix({'train_loss': 'N/A'})
|
||||
|
||||
logger.info(f'Best validation loss: {best_valid_loss}')
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user