13 KiB
13 KiB
from tree_sitter import Language, Parser
from datasets import load_dataset
/home/s452638/magisterka/magisterka_env/lib/python3.8/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html from .autonotebook import tqdm as notebook_tqdm
# Load the dataset
dataset = load_dataset('json', data_files={'train': '/work/s452638/datasets/CodeSearchNet/python/train.jsonl'}, split='train')
print(dataset)
Dataset({ features: ['repo', 'path', 'func_name', 'original_string', 'language', 'code', 'code_tokens', 'docstring', 'docstring_tokens', 'sha', 'url', 'partition'], num_rows: 251820 })
# Build the language library if not already built
# This should be done only once, and the resulting .so file can be reused
Language.build_library(
'/home/s452638/magisterka/build/my-languages.so', # Output location of compiled language library
[
'/home/s452638/magisterka/vendor/tree-sitter-python' # Replace with the path to the tree-sitter-python grammar
]
)
/home/s452638/magisterka/magisterka_env/lib/python3.8/site-packages/tree_sitter/__init__.py:36: FutureWarning: Language.build_library is deprecated. Use the new bindings instead. warn("{} is deprecated. Use {} instead.".format(old, new), FutureWarning)
False
# Load the language
PYTHON_LANGUAGE = Language('/home/s452638/magisterka/build/my-languages.so', 'python')
# Initialize the parser
parser = Parser()
parser.set_language(PYTHON_LANGUAGE)
/home/s452638/magisterka/magisterka_env/lib/python3.8/site-packages/tree_sitter/__init__.py:36: FutureWarning: Language(path, name) is deprecated. Use Language(ptr, name) instead. warn("{} is deprecated. Use {} instead.".format(old, new), FutureWarning)
def remove_docstrings_and_comments_from_code(code):
# Parse the code
tree = parser.parse(bytes(code, "utf8"))
cursor = tree.walk()
# Traverse the tree and collect all docstrings
to_remove = []
def traverse_tree(cursor, prev_node_type=None):
node_type = cursor.node.type
node_text = cursor.node.text.decode("utf-8")
# Check if the current node is a function or class definition
if node_type == "string" and node_text.startswith('"""') and node_text.endswith('"""') and prev_node_type == "expression_statement":
to_remove.append((cursor.node.start_byte, cursor.node.end_byte))
if cursor.node.type == "comment":
to_remove.append((cursor.node.start_byte, cursor.node.end_byte))
# Traverse children
if cursor.goto_first_child():
while True:
traverse_tree(cursor, node_type)
if not cursor.goto_next_sibling():
break
cursor.goto_parent()
return node_type
# Start traversing from the root
traverse_tree(cursor)
# Remove docstrings from code
code_without_docstrings = code
for start, end in sorted(to_remove, reverse=True):
code_without_docstrings = code_without_docstrings[:start] + code_without_docstrings[end:]
return code_without_docstrings
idx = 3
print(dataset[idx]['code'])
def gather_categories(imap, header, categories=None): """ Find the user specified categories in the map and create a dictionary to contain the relevant data for each type within the categories. Multiple categories will have their types combined such that each possible combination will have its own entry in the dictionary. :type imap: dict :param imap: The input mapping file data keyed by SampleID :type header: list :param header: The header line from the input mapping file. This will be searched for the user-specified categories :type categories: list :param categories: The list of user-specified category column name from mapping file :rtype: dict :return: A sorted dictionary keyed on the combinations of all the types found within the user-specified categories. Each entry will contain an empty DataCategory namedtuple. If no categories are specified, a single entry with the key 'default' will be returned """ # If no categories provided, return all SampleIDs if categories is None: return {"default": DataCategory(set(imap.keys()), {})} cat_ids = [header.index(cat) for cat in categories if cat in header and "=" not in cat] table = OrderedDict() conditions = defaultdict(set) for i, cat in enumerate(categories): if "=" in cat and cat.split("=")[0] in header: cat_name = header[header.index(cat.split("=")[0])] conditions[cat_name].add(cat.split("=")[1]) # If invalid categories or conditions identified, return all SampleIDs if not cat_ids and not conditions: return {"default": DataCategory(set(imap.keys()), {})} #If only category column given, return column-wise SampleIDs if cat_ids and not conditions: for sid, row in imap.items(): cat_name = "_".join([row[cid] for cid in cat_ids]) if cat_name not in table: table[cat_name] = DataCategory(set(), {}) table[cat_name].sids.add(sid) return table # Collect all condition names cond_ids = set() for k in conditions: try: cond_ids.add(header.index(k)) except ValueError: continue idx_to_test = set(cat_ids).union(cond_ids) # If column name and condition given, return overlapping SampleIDs of column and # condition combinations for sid, row in imap.items(): if all([row[header.index(c)] in conditions[c] for c in conditions]): key = "_".join([row[idx] for idx in idx_to_test]) try: assert key in table.keys() except AssertionError: table[key] = DataCategory(set(), {}) table[key].sids.add(sid) try: assert len(table) > 0 except AssertionError: return {"default": DataCategory(set(imap.keys()), {})} else: return table
print(remove_docstrings_and_comments_from_code(dataset[idx]['code']))
def gather_categories(imap, header, categories=None): if categories is None: return {"default": DataCategory(set(imap.keys()), {})} cat_ids = [header.index(cat) for cat in categories if cat in header and "=" not in cat] table = OrderedDict() conditions = defaultdict(set) for i, cat in enumerate(categories): if "=" in cat and cat.split("=")[0] in header: cat_name = header[header.index(cat.split("=")[0])] conditions[cat_name].add(cat.split("=")[1]) if not cat_ids and not conditions: return {"default": DataCategory(set(imap.keys()), {})} if cat_ids and not conditions: for sid, row in imap.items(): cat_name = "_".join([row[cid] for cid in cat_ids]) if cat_name not in table: table[cat_name] = DataCategory(set(), {}) table[cat_name].sids.add(sid) return table cond_ids = set() for k in conditions: try: cond_ids.add(header.index(k)) except ValueError: continue idx_to_test = set(cat_ids).union(cond_ids) for sid, row in imap.items(): if all([row[header.index(c)] in conditions[c] for c in conditions]): key = "_".join([row[idx] for idx in idx_to_test]) try: assert key in table.keys() except AssertionError: table[key] = DataCategory(set(), {}) table[key].sids.add(sid) try: assert len(table) > 0 except AssertionError: return {"default": DataCategory(set(imap.keys()), {})} else: return table
test_code = '''
def test():
"""
This is a test function
"""
str_variable = """
This is a string
"""
print("Hello World")
class Test:
"""
This is a test class
"""
def __init__(self):
pass
'''
print(remove_docstrings_and_comments_from_code(test_code))
def test(): str_variable = """ This is a string """ print("Hello World") class Test: def __init__(self): pass