added training code for treecodebert
This commit is contained in:
parent
9299c56bb1
commit
96fc1041cf
@ -36,3 +36,4 @@ distribution = true
|
|||||||
|
|
||||||
[tool.pdm.scripts]
|
[tool.pdm.scripts]
|
||||||
run_training = {cmd = "src/train_codebert_mlm.py"}
|
run_training = {cmd = "src/train_codebert_mlm.py"}
|
||||||
|
run_tree_training = {cmd = "src/train_tree_codebert_mlm.py"}
|
||||||
|
72
code/src/run_inference.py
Normal file
72
code/src/run_inference.py
Normal file
@ -0,0 +1,72 @@
|
|||||||
|
import torch
|
||||||
|
from transformers import RobertaTokenizer, RobertaForMaskedLM, DataCollatorForLanguageModeling
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
def load_model_and_tokenizer(model_path: Path, tokenizer_name: str = 'microsoft/codebert-base'):
|
||||||
|
# Load the pre-trained tokenizer
|
||||||
|
tokenizer = RobertaTokenizer.from_pretrained(tokenizer_name)
|
||||||
|
|
||||||
|
# Load the trained model
|
||||||
|
state_dict = torch.load(model_path, weights_only=True)
|
||||||
|
corrected_state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}
|
||||||
|
model = RobertaForMaskedLM.from_pretrained('roberta-base')
|
||||||
|
model.load_state_dict(corrected_state_dict)
|
||||||
|
|
||||||
|
return model, tokenizer
|
||||||
|
|
||||||
|
def run_inference(model, tokenizer, input_text: str, max_length: int = 512):
|
||||||
|
# Tokenize the input text
|
||||||
|
inputs = tokenizer(
|
||||||
|
input_text,
|
||||||
|
return_tensors='pt',
|
||||||
|
padding='max_length',
|
||||||
|
truncation=True,
|
||||||
|
max_length=max_length
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use DataCollatorForLanguageModeling for MLM-style dynamic masking
|
||||||
|
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=True, mlm_probability=0.15)
|
||||||
|
inputs = data_collator([inputs]) # Collate the input and add random masks
|
||||||
|
|
||||||
|
# Squeeze the batch dimension
|
||||||
|
for k, v in inputs.items():
|
||||||
|
inputs[k] = v.squeeze(0)
|
||||||
|
|
||||||
|
print(tokenizer.decode(inputs['input_ids'][0][inputs['attention_mask'][0] == 1], skip_special_tokens=False))
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
# Run inference
|
||||||
|
outputs = model(**inputs).logits.squeeze(0)
|
||||||
|
|
||||||
|
# Get predicted token ids (argmax over logits)
|
||||||
|
predicted_token_ids = outputs.argmax(dim=-1)
|
||||||
|
|
||||||
|
# Ignore padding tokens
|
||||||
|
predicted_token_ids = predicted_token_ids[inputs['attention_mask'][0] == 1]
|
||||||
|
|
||||||
|
# Decode predicted token ids to text
|
||||||
|
predicted_text = tokenizer.decode(predicted_token_ids, skip_special_tokens=False)
|
||||||
|
|
||||||
|
return predicted_text
|
||||||
|
|
||||||
|
def main_inference():
|
||||||
|
# Load the trained model and tokenizer
|
||||||
|
model_path = Path('/sql/msc-patryk-bartkowiak/outputs/2024-10-21_20:15/best_model.pt') # Update this to your trained model path
|
||||||
|
model, tokenizer = load_model_and_tokenizer(model_path)
|
||||||
|
|
||||||
|
# Define input string
|
||||||
|
input_string = """def compute_area(radius):
|
||||||
|
# This function calculates the area of a circle
|
||||||
|
pi = 3.14159
|
||||||
|
return pi * radius ** 2"""
|
||||||
|
|
||||||
|
# Run inference
|
||||||
|
print(f"Input:\n{input_string}")
|
||||||
|
print("\n\n")
|
||||||
|
print("Masked input:")
|
||||||
|
output = run_inference(model, tokenizer, input_string)
|
||||||
|
print("\n\n")
|
||||||
|
print(f"Output:\n{output}")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main_inference()
|
1471
code/src/test_dataset.ipynb
Normal file
1471
code/src/test_dataset.ipynb
Normal file
File diff suppressed because it is too large
Load Diff
File diff suppressed because one or more lines are too long
@ -382,7 +382,7 @@ def train_and_evaluate(
|
|||||||
|
|
||||||
if train_loss is not None:
|
if train_loss is not None:
|
||||||
train_loss.backward()
|
train_loss.backward()
|
||||||
torch.nn.utils.clip_grad_norm_(model.parameters(), config['max_grad_norm'])
|
norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config['max_grad_norm'])
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
scheduler.step()
|
scheduler.step()
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
@ -394,7 +394,7 @@ def train_and_evaluate(
|
|||||||
# Update progress bar with all three weights
|
# Update progress bar with all three weights
|
||||||
pbar.update(1)
|
pbar.update(1)
|
||||||
pbar.set_postfix({
|
pbar.set_postfix({
|
||||||
'loss': f"{current_loss:.3f}",
|
'train_loss': f"{current_loss:.3f}",
|
||||||
'α': f"{weights['token']:.2f}",
|
'α': f"{weights['token']:.2f}",
|
||||||
'β': f"{weights['tree']:.2f}",
|
'β': f"{weights['tree']:.2f}",
|
||||||
'γ': f"{weights['sequential']:.2f}"
|
'γ': f"{weights['sequential']:.2f}"
|
||||||
@ -403,10 +403,12 @@ def train_and_evaluate(
|
|||||||
# Log all three weights separately
|
# Log all three weights separately
|
||||||
step = train_idx + len(train_dataloader) * epoch_idx
|
step = train_idx + len(train_dataloader) * epoch_idx
|
||||||
wandb.log({
|
wandb.log({
|
||||||
'loss': current_loss,
|
'train_loss': current_loss,
|
||||||
'token_weight': weights['token'],
|
'token_weight': weights['token'],
|
||||||
'tree_weight': weights['tree'],
|
'tree_weight': weights['tree'],
|
||||||
'sequential_weight': weights['sequential'],
|
'sequential_weight': weights['sequential'],
|
||||||
|
'gradient_norm': norm.item(),
|
||||||
|
'learning_rate': scheduler.get_last_lr()[0],
|
||||||
'step': step,
|
'step': step,
|
||||||
})
|
})
|
||||||
|
|
||||||
@ -425,7 +427,7 @@ def train_and_evaluate(
|
|||||||
torch.save(model.state_dict(), output_dir / 'best_model.pt')
|
torch.save(model.state_dict(), output_dir / 'best_model.pt')
|
||||||
else:
|
else:
|
||||||
pbar.update(1)
|
pbar.update(1)
|
||||||
pbar.set_postfix({'loss': 'N/A'})
|
pbar.set_postfix({'train_loss': 'N/A'})
|
||||||
|
|
||||||
logger.info(f'Best validation loss: {best_valid_loss}')
|
logger.info(f'Best validation loss: {best_valid_loss}')
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user