This commit is contained in:
s464966 2024-05-20 00:10:24 +02:00
commit bf2d451045
4 changed files with 108 additions and 0 deletions

25
word2vec/README.md Normal file
View File

@ -0,0 +1,25 @@
Sport Texts Classification Challenge - Ball
======================
Guess whether the sport is connected to the ball for a Polish article. Evaluation metrics: Accuracy, Likelihood.
Classes
-------
* `1` — ball
* `0` — no-ball
Directory structure
-------------------
* `README.md` — this file
* `config.txt` — configuration file
* `train/` — directory with training data
* `train/train.tsv` — sample train set
* `dev-0/` — directory with dev (test) data
* `dev-0/in.tsv` — input data for the dev set
* `dev-0/expected.tsv` — expected (reference) data for the dev set
* `test-A` — directory with test data
* `test-A/in.tsv` — input data for the test set
* `test-A/expected.tsv` — expected (reference) data for the test set

1
word2vec/config.txt Normal file
View File

@ -0,0 +1 @@
--metric Likelihood --metric Accuracy --precision 5

BIN
word2vec/geval Normal file

Binary file not shown.

82
word2vec/word2vec_3.py Normal file
View File

@ -0,0 +1,82 @@
import os
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.utils import class_weight
from gensim.models import KeyedVectors
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout
from tensorflow.keras.optimizers import Adam
import sys
# Ustawienie kodowania na utf-8
sys.stdout.reconfigure(encoding='utf-8')
# Funkcja do ładowania danych
def load_and_prepare_data(file_path, text_column='text', label_column='label', is_test=False):
if is_test:
df = pd.read_csv(file_path, sep='\t', header=None, names=[text_column], on_bad_lines='skip')
df['label'] = 0 # Dodanie kolumny label z wartością domyślną
else:
df = pd.read_csv(file_path, sep='\t', header=None, names=[label_column, text_column], on_bad_lines='skip')
df[text_column].fillna('', inplace=True)
return df
# 1. Przygotowanie danych
train_df = load_and_prepare_data('./train/train.tsv')
dev_df = load_and_prepare_data('./dev-0/in.tsv', is_test=True)
test_df = load_and_prepare_data('./test-A/in.tsv', is_test=True)
dev_labels = pd.read_csv('./dev-0/expected.tsv', sep='\t', header=None, names=['label'])
# 2. Przygotowanie modelu word2vec
word2vec_path = './word2vec_100_3_polish.bin'
w2v_model = KeyedVectors.load(word2vec_path)
# 3. Przekształcenie tekstów na wektory
def text_to_vector(text, model_w2v, vector_size=100):
if not isinstance(text, str):
return np.zeros((vector_size,))
words = text.split()
words = [word for word in words if word in model_w2v]
if words:
return np.mean(model_w2v[words], axis=0)
else:
return np.zeros((vector_size,))
# 4. Budowa i trenowanie modelu
def build_and_train_model(X_train, y_train):
model = Sequential([
Dense(256, activation='relu', input_dim=100),
Dropout(0.3),
Dense(128, activation='relu'),
Dropout(0.3),
Dense(64, activation='relu'),
Dropout(0.3),
Dense(32, activation='relu'),
Dropout(0.3),
Dense(1, activation='sigmoid')
])
model.compile(optimizer=Adam(learning_rate=0.0005), loss='binary_crossentropy', metrics=['accuracy'])
weights = class_weight.compute_class_weight('balanced', classes=np.unique(y_train), y=y_train)
weights_dict = dict(enumerate(weights))
model.fit(X_train, y_train, epochs=50, batch_size=32, class_weight=weights_dict, validation_split=0.2)
return model
# 5. Generowanie wyników
def generate_results(df, model, filepath, text_column='text'):
X_test = np.array([text_to_vector(text, w2v_model) for text in df[text_column]])
predictions = (model.predict(X_test) > 0.5).astype(int)
pd.DataFrame(predictions).to_csv(filepath, sep='\t', index=False, header=False)
# Przekształcenie tekstów na wektory
X_train = np.array([text_to_vector(text, w2v_model) for text in train_df['text']])
y_train = train_df['label'].values
# Trenowanie modelu
model = build_and_train_model(X_train, y_train)
# Generowanie wyników dla zbioru walidacyjnego i testowego
generate_results(dev_df, model, "./dev-0/out.tsv")
generate_results(test_df, model, "./test-A/out.tsv")