Zadanie 2
This commit is contained in:
parent
9dc85acc07
commit
ec2728dc64
82
main.py
Normal file
82
main.py
Normal file
@ -0,0 +1,82 @@
|
|||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
from tensorflow.keras.preprocessing.text import Tokenizer
|
||||||
|
from tensorflow.keras.preprocessing.sequence import pad_sequences
|
||||||
|
from tensorflow.keras.models import Sequential
|
||||||
|
from tensorflow.keras.layers import Dense, Embedding, Flatten
|
||||||
|
from tensorflow.keras.optimizers import Adam
|
||||||
|
import gensim.downloader as api
|
||||||
|
|
||||||
|
# Define the file paths
|
||||||
|
train_file_path = 'mnt/train/train.tsv'
|
||||||
|
dev_file_path = 'mnt/dev-0/in.tsv'
|
||||||
|
test_file_path = 'mnt/test-A/in.tsv'
|
||||||
|
|
||||||
|
# Load data with error handling for problematic lines
|
||||||
|
def load_tsv(file_path):
|
||||||
|
try:
|
||||||
|
df = pd.read_csv(file_path, sep='\t', header=None, names=['text', 'label'], on_bad_lines='skip')
|
||||||
|
return df
|
||||||
|
except pd.errors.ParserError as e:
|
||||||
|
print(f"Error parsing {file_path}: {e}")
|
||||||
|
# Attempt to read the file with a different approach
|
||||||
|
with open(file_path, 'r', encoding='utf-8') as file:
|
||||||
|
lines = file.readlines()
|
||||||
|
data = [line.strip().split('\t') for line in lines if len(line.strip().split('\t')) == 2]
|
||||||
|
return pd.DataFrame(data, columns=['text', 'label'])
|
||||||
|
|
||||||
|
# Load the data
|
||||||
|
train_df = load_tsv(train_file_path)
|
||||||
|
dev_df = pd.read_csv(dev_file_path, sep='\t', header=None, names=['text'])
|
||||||
|
test_df = pd.read_csv(test_file_path, sep='\t', header=None, names=['text'])
|
||||||
|
|
||||||
|
# Load pre-trained word2vec model from Google News
|
||||||
|
word2vec_model = api.load('word2vec-google-news-300')
|
||||||
|
|
||||||
|
# Tokenize and pad sequences
|
||||||
|
tokenizer = Tokenizer()
|
||||||
|
tokenizer.fit_on_texts(train_df['text'])
|
||||||
|
vocab_size = len(tokenizer.word_index) + 1
|
||||||
|
|
||||||
|
max_length = max(train_df['text'].apply(lambda x: len(x.split())))
|
||||||
|
|
||||||
|
X_train = tokenizer.texts_to_sequences(train_df['text'])
|
||||||
|
X_train = pad_sequences(X_train, maxlen=max_length, padding='post')
|
||||||
|
|
||||||
|
X_dev = tokenizer.texts_to_sequences(dev_df['text'])
|
||||||
|
X_dev = pad_sequences(X_dev, maxlen=max_length, padding='post')
|
||||||
|
|
||||||
|
X_test = tokenizer.texts_to_sequences(test_df['text'])
|
||||||
|
X_test = pad_sequences(X_test, maxlen=max_length, padding='post')
|
||||||
|
|
||||||
|
y_train = train_df['label'].values
|
||||||
|
|
||||||
|
# Create embedding matrix
|
||||||
|
embedding_matrix = np.zeros((vocab_size, word2vec_model.vector_size))
|
||||||
|
for word, index in tokenizer.word_index.items():
|
||||||
|
if word in word2vec_model:
|
||||||
|
embedding_matrix[index] = word2vec_model[word]
|
||||||
|
|
||||||
|
# Define the model
|
||||||
|
model = Sequential([
|
||||||
|
Embedding(vocab_size, word2vec_model.vector_size, weights=[embedding_matrix], input_length=max_length, trainable=False),
|
||||||
|
Flatten(),
|
||||||
|
Dense(10, activation='relu'),
|
||||||
|
Dense(1, activation='sigmoid')
|
||||||
|
])
|
||||||
|
|
||||||
|
model.compile(optimizer=Adam(learning_rate=0.001), loss='binary_crossentropy', metrics=['accuracy'])
|
||||||
|
|
||||||
|
# Train the model
|
||||||
|
model.fit(X_train, y_train, epochs=10, verbose=2, validation_split=0.2)
|
||||||
|
|
||||||
|
# Predictions
|
||||||
|
dev_predictions = (model.predict(X_dev) > 0.5).astype(int)
|
||||||
|
test_predictions = (model.predict(X_test) > 0.5).astype(int)
|
||||||
|
|
||||||
|
# Save predictions
|
||||||
|
dev_df['prediction'] = dev_predictions
|
||||||
|
test_df['prediction'] = test_predictions
|
||||||
|
|
||||||
|
dev_df[['prediction']].to_csv('/mnt/data/dev-0/out.tsv', sep='\t', index=False, header=False)
|
||||||
|
test_df[['prediction']].to_csv('/mnt/data/test-A/out.tsv', sep='\t', index=False, header=False)
|
1
mnt
Submodule
1
mnt
Submodule
@ -0,0 +1 @@
|
|||||||
|
Subproject commit 9cb2fb26126561611a5539564fac6b5dbcbb0ca2
|
Loading…
Reference in New Issue
Block a user