aitech-sd-lab/data/create_slot_datasets.py

156 lines
6.3 KiB
Python
Raw Permalink Normal View History

2022-05-02 16:00:15 +02:00
from pathlib import Path
import sys
import pandas as pd
from nltk.tokenize import word_tokenize
import re
import random
class LineContent:
def __init__(self, _text, _intents, _slots, _tokens):
self.text = _text
self.intent = _intents
self.slots = _slots
self.tokens = _tokens
def process_file(file):
def get_intents(intent_row):
intents = re.sub('\(+.*?\)+', '', intent_row)
intents = intents.split('&')
intents = [intent.strip() for intent in intents]
intents = ' '.join(intents)
return intents
"""
2022-05-02 16:00:15 +02:00
def get_slots(intent_row):
intent_row = intent_row.split('&')
intent_row = [intent.strip() for intent in intent_row]
slots = []
for intent in intent_row:
intent = intent.replace("'", '')
intent = intent.replace("\"", '')
try:
intent_content = re.search('\(+.*?\)+', intent).group()
except:
intent_content = ''
if '=' in intent_content:
slots_count = intent_content.count('=')
intent_content = intent_content.replace('(', '')
intent_content = intent_content.replace(')', '')
if slots_count > 1:
intent_content = intent_content.split(',')
intent_content = [slot.strip() for slot in intent_content]
else:
intent_content = [intent_content]
slots = [slot for slot in intent_content if slot[-1] != '=']
for slot in slots:
if slot.count('=') > 1:
slots.remove(slot)
slot = re.sub('^.*?=', '', slot)
slots.append(slot)
slots = [slot.split('=') for slot in slots if len(slot.split('=')) == 2]
slots = [[slot[0].split()[-1], slot[1]] if ',' in slot[0] else slot for slot in slots] # ?
return slots
"""
def get_slots(intent_row):
intent_content = re.search('\(+.*?\)+', intent_row).group()
intent_content = intent_content.strip()
intent_content = intent_content.replace('(', '')
intent_content = intent_content.replace(')', '')
intent_content = intent_content.replace("'", '')
intent_content = intent_content.replace("\"", '')
intent_content_split = intent_content.split('&')
slots = []
for intent in intent_content_split:
if '=' in intent and intent[-1] != '=':
intent_value_pair = intent.split('=')
slots.append(intent_value_pair)
return slots
2022-05-02 16:00:15 +02:00
def get_tokens(text, intents, slots):
def tokenize(text):
email = re.search("[^ ]+@[^ ,]+", text)
2022-05-02 16:00:15 +02:00
if email:
email_address = email.group()
text = text.replace(email_address, '@')
2022-05-28 14:38:59 +02:00
text = text.replace("'", "")
2022-05-02 16:00:15 +02:00
tokens = word_tokenize(text)
tokens = [token.replace('@', email_address) for token in tokens]
else:
2022-05-28 14:38:59 +02:00
text = text.replace("'", "")
2022-05-02 16:00:15 +02:00
tokens = word_tokenize(text)
return tokens
2022-05-28 14:25:51 +02:00
2022-05-02 16:00:15 +02:00
text_tokens = tokenize(text)
for slot in slots:
slot[-1] = tokenize(slot[-1])
formatted_tokens = []
for index, text_token in enumerate(text_tokens):
slot = 'NoLabel'
formatted_token = [str(index + 1), text_token, intents, slot]
formatted_tokens.append(formatted_token)
for slot in slots:
value_len = len(slot[-1])
for i in range(len(formatted_tokens) - value_len + 1):
if slot[-1][0].lower() == formatted_tokens[i][1].lower():
found = True
for j in range(1, value_len):
if slot[-1][0 + j].lower() != formatted_tokens[i + j][1].lower():
found = False
if found:
formatted_tokens[i][3] = f'B-{slot[0]}'
for k in range(1, value_len):
formatted_tokens[i + k][3] = f'I-{slot[0]}'
return formatted_tokens
print('Processed: ', file.name)
df = pd.read_csv(file, sep='\t', header=None)
lines_contents = []
for _, row in df.iterrows():
if row[0] == 'user' and row[1]:
2022-05-28 14:25:51 +02:00
# if row[1]:
2022-05-02 16:00:15 +02:00
text = row[1]
intents = get_intents(row[2])
slots = get_slots(row[2])
tokens = get_tokens(text, intents, slots)
line_content = LineContent(text, intents, slots, tokens)
lines_contents.append(line_content)
return lines_contents
def write_to_files(lines_contents):
2022-05-28 14:25:51 +02:00
format_slots = lambda slots: ','.join(
[':'.join((lambda x: [x[0], ''.join(x[-1])])(slot)[::-1]) if len(slot) > 0 else '' for slot in slots])
format_tokens = lambda tokens: '\n'.join([f"{token[0]}\t{token[1]}\t{token[2].replace(' ', '_')}\t{token[3]}" for
token in tokens])
# f"{(token[3] + ' ' + token[2]).replace(' ', '_')}" for token in tokens])
format_content = lambda \
content: f"# text: {content.text}\n# intent: {content.intent}\n# slots: {format_slots(content.slots)}\n{format_tokens(content.tokens)}\n\n"
2022-05-02 16:00:15 +02:00
random.shuffle(lines_contents)
l = (len(lines_contents) / 10) * 8
contents_train = lines_contents[:int(l)]
contents_test = lines_contents[int(l):]
2022-05-02 17:13:49 +02:00
with open('train-pl.conllu', 'a', encoding='utf-8') as train_f, open('test-pl.conllu', 'a+', encoding='utf-8') as \
test_f:
2022-05-02 16:00:15 +02:00
for content in contents_train:
2022-05-28 14:25:51 +02:00
formatted = format_content(content)
formatted = re.sub('NoLabel.+', 'NoLabel', formatted)
train_f.write(formatted)
2022-05-02 16:00:15 +02:00
for content in contents_test:
2022-05-28 14:25:51 +02:00
formatted = format_content(content)
formatted = re.sub('NoLabel.+', 'NoLabel\n\n', formatted)
test_f.write(formatted)
2022-05-02 16:00:15 +02:00
def main():
2022-05-02 17:13:49 +02:00
path = sys.argv[1]
2022-05-02 16:00:15 +02:00
dir = Path(rf'{path}')
for file in dir.glob('*'):
processed_contents = process_file(file)
write_to_files(processed_contents)
if __name__ == '__main__':
main()