aitech-sd-lab/NLU_lab_7-8/create_datasets.py
2022-05-02 21:00:28 +02:00

135 lines
5.3 KiB
Python

from pathlib import Path
import sys
import pandas as pd
from nltk.tokenize import word_tokenize
import re
import random
import nltk
#nltk.download('punkt')
class LineContent:
def __init__(self, _text, _intents, _slots, _tokens):
self.text = _text
self.intent = _intents
self.slots = _slots
self.tokens = _tokens
def process_file(file):
def get_intents(intent_row):
intents = re.sub('\(+.*?\)+', '', intent_row)
intents = intents.split('&')
intents = [intent.strip() for intent in intents]
intents = ' '.join(intents)
return intents
def get_slots(intent_row):
intent_row = intent_row.split('&')
intent_row = [intent.strip() for intent in intent_row]
slots = []
for intent in intent_row:
intent = intent.replace("'", '')
intent = intent.replace("\"", '')
try:
intent_content = re.search('\(+.*?\)+', intent).group()
except:
intent_content = ''
if '=' in intent_content:
slots_count = intent_content.count('=')
intent_content = intent_content.replace('(', '')
intent_content = intent_content.replace(')', '')
if slots_count > 1:
intent_content = intent_content.split(',')
intent_content = [slot.strip() for slot in intent_content]
else:
intent_content = [intent_content]
slots = [slot for slot in intent_content if slot[-1] != '=']
for slot in slots:
if slot.count('=') > 1:
slots.remove(slot)
slot = re.sub('^.*?=', '', slot)
slots.append(slot)
slots = [slot.split('=') for slot in slots if len(slot.split('=')) == 2]
slots = [[slot[0].split()[-1], slot[1]] if ',' in slot[0] else slot for slot in slots] # ?
return slots
def get_tokens(text, intents, slots):
def tokenize(text):
email = re.search("[^ ]+@[^ ]+", text)
if email:
email_address = email.group()
text = text.replace(email_address, '@')
text = text.replace("'", "")
tokens = word_tokenize(text)
tokens = [token.replace('@', email_address) for token in tokens]
else:
text = text.replace("'", "")
tokens = word_tokenize(text)
return tokens
text_tokens = tokenize(text)
for slot in slots:
slot[-1] = tokenize(slot[-1])
formatted_tokens = []
for index, text_token in enumerate(text_tokens):
slot = 'NoLabel'
formatted_token = [str(index + 1), text_token, intents, slot]
formatted_tokens.append(formatted_token)
for slot in slots:
value_len = len(slot[-1])
for i in range(len(formatted_tokens) - value_len + 1):
if slot[-1][0].lower() == formatted_tokens[i][1].lower():
found = True
for j in range(1, value_len):
if slot[-1][0 + j].lower() != formatted_tokens[i + j][1].lower():
found = False
if found:
formatted_tokens[i][3] = f'B-{slot[0]}'
for k in range(1, value_len):
formatted_tokens[i + k][3] = f'I-{slot[0]}'
return formatted_tokens
print('Processed: ', file.name)
df = pd.read_csv(file, sep='\t', header=None)
lines_contents = []
for _, row in df.iterrows():
if row[0] == 'user' and row[1]:
#if row[1]:
text = row[1]
intents = get_intents(row[2])
slots = get_slots(row[2])
tokens = get_tokens(text, intents, slots)
line_content = LineContent(text, intents, slots, tokens)
lines_contents.append(line_content)
return lines_contents
def write_to_files(lines_contents):
format_slots = lambda slots: ','.join([':'.join((lambda x: [x[0], ''.join(x[-1])])(slot)[::-1]) if len(slot) > 0 else '' for slot in slots])
format_tokens = lambda tokens: '\n'.join([f'{token[0]}\t{token[1]}\t{token[2]}\t{token[3]}' for token in tokens])
format_content = lambda content: f"# text: {content.text}\n# intent: {content.intent}\n# slots: {format_slots(content.slots)}\n{format_tokens(content.tokens)}\n\n"
random.shuffle(lines_contents)
l = (len(lines_contents) / 10) * 8
contents_train = lines_contents[:int(l)]
contents_test = lines_contents[int(l):]
with open('train-pl.conllu', 'a', encoding='utf-8') as train_f, open('test-pl.conllu', 'a+', encoding='utf-8') as \
test_f:
for content in contents_train:
train_f.write(format_content(content))
for content in contents_test:
test_f.write(format_content(content))
def main():
path = sys.argv[1]
dir = Path(rf'{path}')
for file in dir.glob('*'):
processed_contents = process_file(file)
write_to_files(processed_contents)
if __name__ == '__main__':
main()