removing comments

This commit is contained in:
Patryk Bartkowiak 2024-12-30 18:23:11 +00:00
parent 7ca1b79497
commit c80c591e7c
4 changed files with 6338 additions and 6660 deletions

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,158 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"from tree_sitter import Language, Parser\n",
"import tree_sitter_python as tspython\n",
"from datasets import load_dataset"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"# dataset = load_dataset('json', data_files={'train': '/work/s452638/datasets/CodeSearchNet/python/train.jsonl'}, split='train').select(range(10))\n",
"# print(dataset)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"PY_LANGUAGE = Language(tspython.language())\n",
"parser = Parser(PY_LANGUAGE)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"def remove_docstrings_and_comments_from_code(code):\n",
" # Parse the code\n",
" tree = parser.parse(bytes(code, \"utf8\"))\n",
" cursor = tree.walk()\n",
"\n",
" # Traverse the tree and collect all docstrings\n",
" to_remove = []\n",
"\n",
" def traverse_tree(cursor, prev_node_type=None):\n",
" node_type = cursor.node.type\n",
" node_text = cursor.node.text.decode(\"utf-8\")\n",
" # Check if the current node is a function or class definition\n",
" if node_type == \"string\" and node_text.startswith('\"\"\"') and node_text.endswith('\"\"\"') and prev_node_type == \"expression_statement\":\n",
" to_remove.append((cursor.node.start_byte, cursor.node.end_byte))\n",
" if cursor.node.type == \"comment\":\n",
" to_remove.append((cursor.node.start_byte, cursor.node.end_byte))\n",
"\n",
" # Traverse children\n",
" if cursor.goto_first_child():\n",
" while True:\n",
" traverse_tree(cursor, node_type)\n",
" if not cursor.goto_next_sibling():\n",
" break\n",
" cursor.goto_parent()\n",
"\n",
" return node_type\n",
"\n",
" # Start traversing from the root\n",
" traverse_tree(cursor)\n",
"\n",
" # Remove docstrings from code\n",
" code_without_docstrings = code\n",
" for start, end in sorted(to_remove, reverse=True):\n",
" code_without_docstrings = code_without_docstrings[:start] + code_without_docstrings[end:]\n",
"\n",
" # Remove empty lines\n",
" code_without_docstrings = \"\\n\".join(\n",
" line for line in code_without_docstrings.splitlines() if line.strip()\n",
" )\n",
"\n",
" return code_without_docstrings"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"def test():\n",
" str_variable = \"\"\"\n",
" This is a string\n",
" \"\"\"\n",
" print(\"Hello World\")\n",
"class Test:\n",
" def __init__(self):\n",
" pass\n"
]
}
],
"source": [
"test_code = '''\n",
"###########################\n",
"# This is a test code\n",
"###########################\n",
"\n",
"def test():\n",
" \"\"\"\n",
" This is a test function\n",
" \"\"\"\n",
" str_variable = \"\"\"\n",
" This is a string\n",
" \"\"\"\n",
" print(\"Hello World\")\n",
"\n",
"class Test:\n",
" \"\"\"\n",
" This is a test class\n",
" \"\"\"\n",
" def __init__(self):\n",
" pass\n",
"'''\n",
"\n",
"print(remove_docstrings_and_comments_from_code(test_code))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "base",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.10"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@ -17,10 +17,55 @@ logging.basicConfig(
)
logger = logging.getLogger(__name__)
def remove_docstrings_and_comments_from_code(code, parser):
# Parse the code
tree = parser.parse(bytes(code, "utf8"))
cursor = tree.walk()
# Traverse the tree and collect all docstrings
to_remove = []
def traverse_tree(cursor, prev_node_type=None):
node_type = cursor.node.type
node_text = cursor.node.text.decode("utf-8")
# Check if the current node is a function or class definition
if node_type == "string" and node_text.startswith('"""') and node_text.endswith('"""') and prev_node_type == "expression_statement":
to_remove.append((cursor.node.start_byte, cursor.node.end_byte))
if cursor.node.type == "comment":
to_remove.append((cursor.node.start_byte, cursor.node.end_byte))
# Traverse children
if cursor.goto_first_child():
while True:
traverse_tree(cursor, node_type)
if not cursor.goto_next_sibling():
break
cursor.goto_parent()
return node_type
# Start traversing from the root
traverse_tree(cursor)
# Remove docstrings from code
code_without_docstrings = code
for start, end in sorted(to_remove, reverse=True):
code_without_docstrings = code_without_docstrings[:start] + code_without_docstrings[end:]
# Remove empty lines
code_without_docstrings = "\n".join(
line for line in code_without_docstrings.splitlines() if line.strip()
)
return code_without_docstrings
def process_example(code, tokenizer):
# Instantiate parser locally to avoid global references
parser = Parser(Language(tspython.language()))
# Remove docstrings and comments
code = remove_docstrings_and_comments_from_code(code, parser)
# Parse the code into an AST
tree = parser.parse(bytes(code, "utf8"))
root_node = tree.root_node
@ -121,7 +166,7 @@ def process_batch(batch, tokenizer):
def main():
current_dir = Path(__file__).parent
input_dir = current_dir.parent / 'data' / 'codeparrot-clean'
output_dir = current_dir.parent / 'data' / 'codeparrot-clean-parsed-starencoder-classes'
output_dir = current_dir.parent / 'data' / 'codeparrot-clean-parsed-starencoder-no-comments'
output_dir.mkdir(parents=True, exist_ok=True)
# Initialize tokenizer

View File

@ -13,6 +13,9 @@ from transformers import (
DataCollatorForLanguageModeling,
AutoModelForMaskedLM
)
import random
import numpy as np
import torch
from tree_codebert import TreeCodeBERTForPreTraining
from tree_starencoder import TreeStarEncoderForPreTraining
@ -24,11 +27,22 @@ logging.basicConfig(
)
logger = logging.getLogger(__name__)
def set_seed(seed: int):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def load_config(config_path: Path) -> dict:
with open(config_path, 'r') as f:
return json.load(f)
def main():
def main():
set_seed(config['seed'])
# Setup paths
current_dir = Path(__file__).parent
config = load_config(current_dir / 'config.json')