From 1b0bf749e970d4bcd5a536a62c3016e6536a6dcb Mon Sep 17 00:00:00 2001 From: Kacper Kalinowski Date: Mon, 27 May 2024 18:08:12 +0200 Subject: [PATCH] . --- rnn/rnn.py | 44 +++++++++++++++++++++++++++++++++++--------- 1 file changed, 35 insertions(+), 9 deletions(-) diff --git a/rnn/rnn.py b/rnn/rnn.py index 5166cfc..4f7c92a 100644 --- a/rnn/rnn.py +++ b/rnn/rnn.py @@ -34,13 +34,20 @@ dev_sentences, dev_labels = preprocess_data(dev_sentences, dev_labels) test_sentences = preprocess_data(test_sentences) # Create a word index and label index -word2idx = {w: i + 2 for i, w in enumerate(set(word for sentence in train_sentences for word in sentence))} -word2idx[''] = 0 -word2idx[''] = 1 +special_tokens = ['', '', '', ''] +word2idx = {w: i + len(special_tokens) for i, w in enumerate(set(word for sentence in train_sentences for word in sentence))} +for i, token in enumerate(special_tokens): + word2idx[token] = i + idx2word = {i: w for w, i in word2idx.items()} -label2idx = {l: i + 1 for i, l in enumerate(set(label for label_list in train_labels for label in label_list))} -label2idx[''] = 0 +label2idx = { + 'O': 0, + 'B-PER': 1, 'I-PER': 2, + 'B-ORG': 3, 'I-ORG': 4, + 'B-LOC': 5, 'I-LOC': 6, + 'B-MISC': 7, 'I-MISC': 8 +} idx2label = {i: l for l, i in label2idx.items()} # Convert words and labels to integers @@ -56,7 +63,7 @@ X_dev, y_dev = encode_data(dev_sentences, dev_labels) X_test = encode_data(test_sentences) # Limit sequence length to avoid excessive memory usage -max_len = 100 # You can adjust this value to a reasonable limit based on your data and memory +max_len = 1000 # You can adjust this value to a reasonable limit based on your data and memory X_train = pad_sequences(X_train, padding='post', maxlen=max_len) y_train = pad_sequences(y_train, padding='post', maxlen=max_len) @@ -78,7 +85,7 @@ model = tf.keras.models.Sequential([ model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy']) # Train the model with a smaller batch size -history = model.fit(X_train, np.array(y_train), validation_data=(X_dev, np.array(y_dev)), epochs=5, batch_size=16) +history = model.fit(X_train, np.array(y_train), validation_data=(X_dev, np.array(y_dev)), epochs=25, batch_size=16) # Evaluate the model y_pred = model.predict(X_dev) @@ -97,19 +104,38 @@ print(classification_report( target_names=[idx2label[i] for i in list(label2idx.values())[1:]] )) +# Correct IOB labels function +def correct_iob_labels(predictions): + corrected = [] + for pred in predictions: + corrected_sentence = [] + prev_label = 'O' + for label in pred: + if label.startswith('I-') and (prev_label == 'O' or prev_label[2:] != label[2:]): + corrected_sentence.append('B-' + label[2:]) + else: + corrected_sentence.append(label) + prev_label = corrected_sentence[-1] + corrected.append(corrected_sentence) + return corrected + # Predict on test data y_test_pred = model.predict(X_test) y_test_pred = np.argmax(y_test_pred, axis=-1) y_test_pred_tags = [[idx2label[i] for i in row] for row in y_test_pred] +# Correct the predicted tags +y_pred_tags_corrected = correct_iob_labels(y_pred_tags) +y_test_pred_tags_corrected = correct_iob_labels(y_test_pred_tags) + # Save dev predictions -dev_predictions = [' '.join(tags) for tags in y_pred_tags] +dev_predictions = [' '.join(tags) for tags in y_pred_tags_corrected] with open('./dev0/out.tsv', 'w', encoding='utf-8') as f: for prediction in dev_predictions: f.write("%s\n" % prediction) # Save test predictions -test_predictions = [' '.join(tags) for tags in y_test_pred_tags] +test_predictions = [' '.join(tags) for tags in y_test_pred_tags_corrected] with open('./testA/out.tsv', 'w', encoding='utf-8') as f: for prediction in test_predictions: f.write("%s\n" % prediction)