removing comments
This commit is contained in:
parent
7ca1b79497
commit
c80c591e7c
File diff suppressed because it is too large
Load Diff
158
code/src/notebooks/test_removing_comments.ipynb
Normal file
158
code/src/notebooks/test_removing_comments.ipynb
Normal file
@ -0,0 +1,158 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from tree_sitter import Language, Parser\n",
|
||||
"import tree_sitter_python as tspython\n",
|
||||
"from datasets import load_dataset"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# dataset = load_dataset('json', data_files={'train': '/work/s452638/datasets/CodeSearchNet/python/train.jsonl'}, split='train').select(range(10))\n",
|
||||
"# print(dataset)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"PY_LANGUAGE = Language(tspython.language())\n",
|
||||
"parser = Parser(PY_LANGUAGE)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def remove_docstrings_and_comments_from_code(code):\n",
|
||||
" # Parse the code\n",
|
||||
" tree = parser.parse(bytes(code, \"utf8\"))\n",
|
||||
" cursor = tree.walk()\n",
|
||||
"\n",
|
||||
" # Traverse the tree and collect all docstrings\n",
|
||||
" to_remove = []\n",
|
||||
"\n",
|
||||
" def traverse_tree(cursor, prev_node_type=None):\n",
|
||||
" node_type = cursor.node.type\n",
|
||||
" node_text = cursor.node.text.decode(\"utf-8\")\n",
|
||||
" # Check if the current node is a function or class definition\n",
|
||||
" if node_type == \"string\" and node_text.startswith('\"\"\"') and node_text.endswith('\"\"\"') and prev_node_type == \"expression_statement\":\n",
|
||||
" to_remove.append((cursor.node.start_byte, cursor.node.end_byte))\n",
|
||||
" if cursor.node.type == \"comment\":\n",
|
||||
" to_remove.append((cursor.node.start_byte, cursor.node.end_byte))\n",
|
||||
"\n",
|
||||
" # Traverse children\n",
|
||||
" if cursor.goto_first_child():\n",
|
||||
" while True:\n",
|
||||
" traverse_tree(cursor, node_type)\n",
|
||||
" if not cursor.goto_next_sibling():\n",
|
||||
" break\n",
|
||||
" cursor.goto_parent()\n",
|
||||
"\n",
|
||||
" return node_type\n",
|
||||
"\n",
|
||||
" # Start traversing from the root\n",
|
||||
" traverse_tree(cursor)\n",
|
||||
"\n",
|
||||
" # Remove docstrings from code\n",
|
||||
" code_without_docstrings = code\n",
|
||||
" for start, end in sorted(to_remove, reverse=True):\n",
|
||||
" code_without_docstrings = code_without_docstrings[:start] + code_without_docstrings[end:]\n",
|
||||
"\n",
|
||||
" # Remove empty lines\n",
|
||||
" code_without_docstrings = \"\\n\".join(\n",
|
||||
" line for line in code_without_docstrings.splitlines() if line.strip()\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" return code_without_docstrings"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"def test():\n",
|
||||
" str_variable = \"\"\"\n",
|
||||
" This is a string\n",
|
||||
" \"\"\"\n",
|
||||
" print(\"Hello World\")\n",
|
||||
"class Test:\n",
|
||||
" def __init__(self):\n",
|
||||
" pass\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"test_code = '''\n",
|
||||
"###########################\n",
|
||||
"# This is a test code\n",
|
||||
"###########################\n",
|
||||
"\n",
|
||||
"def test():\n",
|
||||
" \"\"\"\n",
|
||||
" This is a test function\n",
|
||||
" \"\"\"\n",
|
||||
" str_variable = \"\"\"\n",
|
||||
" This is a string\n",
|
||||
" \"\"\"\n",
|
||||
" print(\"Hello World\")\n",
|
||||
"\n",
|
||||
"class Test:\n",
|
||||
" \"\"\"\n",
|
||||
" This is a test class\n",
|
||||
" \"\"\"\n",
|
||||
" def __init__(self):\n",
|
||||
" pass\n",
|
||||
"'''\n",
|
||||
"\n",
|
||||
"print(remove_docstrings_and_comments_from_code(test_code))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "base",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.10"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
@ -17,10 +17,55 @@ logging.basicConfig(
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def remove_docstrings_and_comments_from_code(code, parser):
|
||||
# Parse the code
|
||||
tree = parser.parse(bytes(code, "utf8"))
|
||||
cursor = tree.walk()
|
||||
|
||||
# Traverse the tree and collect all docstrings
|
||||
to_remove = []
|
||||
|
||||
def traverse_tree(cursor, prev_node_type=None):
|
||||
node_type = cursor.node.type
|
||||
node_text = cursor.node.text.decode("utf-8")
|
||||
# Check if the current node is a function or class definition
|
||||
if node_type == "string" and node_text.startswith('"""') and node_text.endswith('"""') and prev_node_type == "expression_statement":
|
||||
to_remove.append((cursor.node.start_byte, cursor.node.end_byte))
|
||||
if cursor.node.type == "comment":
|
||||
to_remove.append((cursor.node.start_byte, cursor.node.end_byte))
|
||||
|
||||
# Traverse children
|
||||
if cursor.goto_first_child():
|
||||
while True:
|
||||
traverse_tree(cursor, node_type)
|
||||
if not cursor.goto_next_sibling():
|
||||
break
|
||||
cursor.goto_parent()
|
||||
|
||||
return node_type
|
||||
|
||||
# Start traversing from the root
|
||||
traverse_tree(cursor)
|
||||
|
||||
# Remove docstrings from code
|
||||
code_without_docstrings = code
|
||||
for start, end in sorted(to_remove, reverse=True):
|
||||
code_without_docstrings = code_without_docstrings[:start] + code_without_docstrings[end:]
|
||||
|
||||
# Remove empty lines
|
||||
code_without_docstrings = "\n".join(
|
||||
line for line in code_without_docstrings.splitlines() if line.strip()
|
||||
)
|
||||
|
||||
return code_without_docstrings
|
||||
|
||||
def process_example(code, tokenizer):
|
||||
# Instantiate parser locally to avoid global references
|
||||
parser = Parser(Language(tspython.language()))
|
||||
|
||||
# Remove docstrings and comments
|
||||
code = remove_docstrings_and_comments_from_code(code, parser)
|
||||
|
||||
# Parse the code into an AST
|
||||
tree = parser.parse(bytes(code, "utf8"))
|
||||
root_node = tree.root_node
|
||||
@ -121,7 +166,7 @@ def process_batch(batch, tokenizer):
|
||||
def main():
|
||||
current_dir = Path(__file__).parent
|
||||
input_dir = current_dir.parent / 'data' / 'codeparrot-clean'
|
||||
output_dir = current_dir.parent / 'data' / 'codeparrot-clean-parsed-starencoder-classes'
|
||||
output_dir = current_dir.parent / 'data' / 'codeparrot-clean-parsed-starencoder-no-comments'
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Initialize tokenizer
|
||||
|
@ -13,6 +13,9 @@ from transformers import (
|
||||
DataCollatorForLanguageModeling,
|
||||
AutoModelForMaskedLM
|
||||
)
|
||||
import random
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from tree_codebert import TreeCodeBERTForPreTraining
|
||||
from tree_starencoder import TreeStarEncoderForPreTraining
|
||||
@ -24,11 +27,22 @@ logging.basicConfig(
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def set_seed(seed: int):
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
|
||||
def load_config(config_path: Path) -> dict:
|
||||
with open(config_path, 'r') as f:
|
||||
return json.load(f)
|
||||
|
||||
def main():
|
||||
def main():
|
||||
set_seed(config['seed'])
|
||||
|
||||
# Setup paths
|
||||
current_dir = Path(__file__).parent
|
||||
config = load_config(current_dir / 'config.json')
|
||||
|
Loading…
Reference in New Issue
Block a user