.
This commit is contained in:
parent
a137ba4043
commit
1b0bf749e9
44
rnn/rnn.py
44
rnn/rnn.py
@ -34,13 +34,20 @@ dev_sentences, dev_labels = preprocess_data(dev_sentences, dev_labels)
|
|||||||
test_sentences = preprocess_data(test_sentences)
|
test_sentences = preprocess_data(test_sentences)
|
||||||
|
|
||||||
# Create a word index and label index
|
# Create a word index and label index
|
||||||
word2idx = {w: i + 2 for i, w in enumerate(set(word for sentence in train_sentences for word in sentence))}
|
special_tokens = ['<PAD>', '<UNK>', '<BOS>', '<EOS>']
|
||||||
word2idx['<PAD>'] = 0
|
word2idx = {w: i + len(special_tokens) for i, w in enumerate(set(word for sentence in train_sentences for word in sentence))}
|
||||||
word2idx['<UNK>'] = 1
|
for i, token in enumerate(special_tokens):
|
||||||
|
word2idx[token] = i
|
||||||
|
|
||||||
idx2word = {i: w for w, i in word2idx.items()}
|
idx2word = {i: w for w, i in word2idx.items()}
|
||||||
|
|
||||||
label2idx = {l: i + 1 for i, l in enumerate(set(label for label_list in train_labels for label in label_list))}
|
label2idx = {
|
||||||
label2idx['<PAD>'] = 0
|
'O': 0,
|
||||||
|
'B-PER': 1, 'I-PER': 2,
|
||||||
|
'B-ORG': 3, 'I-ORG': 4,
|
||||||
|
'B-LOC': 5, 'I-LOC': 6,
|
||||||
|
'B-MISC': 7, 'I-MISC': 8
|
||||||
|
}
|
||||||
idx2label = {i: l for l, i in label2idx.items()}
|
idx2label = {i: l for l, i in label2idx.items()}
|
||||||
|
|
||||||
# Convert words and labels to integers
|
# Convert words and labels to integers
|
||||||
@ -56,7 +63,7 @@ X_dev, y_dev = encode_data(dev_sentences, dev_labels)
|
|||||||
X_test = encode_data(test_sentences)
|
X_test = encode_data(test_sentences)
|
||||||
|
|
||||||
# Limit sequence length to avoid excessive memory usage
|
# Limit sequence length to avoid excessive memory usage
|
||||||
max_len = 100 # You can adjust this value to a reasonable limit based on your data and memory
|
max_len = 1000 # You can adjust this value to a reasonable limit based on your data and memory
|
||||||
X_train = pad_sequences(X_train, padding='post', maxlen=max_len)
|
X_train = pad_sequences(X_train, padding='post', maxlen=max_len)
|
||||||
y_train = pad_sequences(y_train, padding='post', maxlen=max_len)
|
y_train = pad_sequences(y_train, padding='post', maxlen=max_len)
|
||||||
|
|
||||||
@ -78,7 +85,7 @@ model = tf.keras.models.Sequential([
|
|||||||
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
|
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
|
||||||
|
|
||||||
# Train the model with a smaller batch size
|
# Train the model with a smaller batch size
|
||||||
history = model.fit(X_train, np.array(y_train), validation_data=(X_dev, np.array(y_dev)), epochs=5, batch_size=16)
|
history = model.fit(X_train, np.array(y_train), validation_data=(X_dev, np.array(y_dev)), epochs=25, batch_size=16)
|
||||||
|
|
||||||
# Evaluate the model
|
# Evaluate the model
|
||||||
y_pred = model.predict(X_dev)
|
y_pred = model.predict(X_dev)
|
||||||
@ -97,19 +104,38 @@ print(classification_report(
|
|||||||
target_names=[idx2label[i] for i in list(label2idx.values())[1:]]
|
target_names=[idx2label[i] for i in list(label2idx.values())[1:]]
|
||||||
))
|
))
|
||||||
|
|
||||||
|
# Correct IOB labels function
|
||||||
|
def correct_iob_labels(predictions):
|
||||||
|
corrected = []
|
||||||
|
for pred in predictions:
|
||||||
|
corrected_sentence = []
|
||||||
|
prev_label = 'O'
|
||||||
|
for label in pred:
|
||||||
|
if label.startswith('I-') and (prev_label == 'O' or prev_label[2:] != label[2:]):
|
||||||
|
corrected_sentence.append('B-' + label[2:])
|
||||||
|
else:
|
||||||
|
corrected_sentence.append(label)
|
||||||
|
prev_label = corrected_sentence[-1]
|
||||||
|
corrected.append(corrected_sentence)
|
||||||
|
return corrected
|
||||||
|
|
||||||
# Predict on test data
|
# Predict on test data
|
||||||
y_test_pred = model.predict(X_test)
|
y_test_pred = model.predict(X_test)
|
||||||
y_test_pred = np.argmax(y_test_pred, axis=-1)
|
y_test_pred = np.argmax(y_test_pred, axis=-1)
|
||||||
y_test_pred_tags = [[idx2label[i] for i in row] for row in y_test_pred]
|
y_test_pred_tags = [[idx2label[i] for i in row] for row in y_test_pred]
|
||||||
|
|
||||||
|
# Correct the predicted tags
|
||||||
|
y_pred_tags_corrected = correct_iob_labels(y_pred_tags)
|
||||||
|
y_test_pred_tags_corrected = correct_iob_labels(y_test_pred_tags)
|
||||||
|
|
||||||
# Save dev predictions
|
# Save dev predictions
|
||||||
dev_predictions = [' '.join(tags) for tags in y_pred_tags]
|
dev_predictions = [' '.join(tags) for tags in y_pred_tags_corrected]
|
||||||
with open('./dev0/out.tsv', 'w', encoding='utf-8') as f:
|
with open('./dev0/out.tsv', 'w', encoding='utf-8') as f:
|
||||||
for prediction in dev_predictions:
|
for prediction in dev_predictions:
|
||||||
f.write("%s\n" % prediction)
|
f.write("%s\n" % prediction)
|
||||||
|
|
||||||
# Save test predictions
|
# Save test predictions
|
||||||
test_predictions = [' '.join(tags) for tags in y_test_pred_tags]
|
test_predictions = [' '.join(tags) for tags in y_test_pred_tags_corrected]
|
||||||
with open('./testA/out.tsv', 'w', encoding='utf-8') as f:
|
with open('./testA/out.tsv', 'w', encoding='utf-8') as f:
|
||||||
for prediction in test_predictions:
|
for prediction in test_predictions:
|
||||||
f.write("%s\n" % prediction)
|
f.write("%s\n" % prediction)
|
||||||
|
Loading…
Reference in New Issue
Block a user