diff --git a/fine_tuning.py b/fine_tuning.py index e269f87..c12b7b9 100644 --- a/fine_tuning.py +++ b/fine_tuning.py @@ -1,4 +1,4 @@ -from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer +from transformers import BertTokenizer, BertForSequenceClassification, TrainingArguments, Trainer import random import torch @@ -15,42 +15,32 @@ with open('test-A/in.tsv') as f: data_test_X = f.readlines() class CustomDataset(torch.utils.data.Dataset): - def __init__(self, encodings, labels): + def __init__(self, encodings, labels=None): self.encodings = encodings self.labels = labels def __getitem__(self, idx): item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()} - item['labels'] = torch.tensor(self.labels[idx]) + if self.labels: + item["labels"] = torch.tensor(self.labels[idx]) return item def __len__(self): - return len(self.labels) + return len(self.encodings["input_ids"]) data_train = list(zip(data_train_X, data_train_Y)) -data_train = random.sample(data_train, 150000) +data_train = random.sample(data_train, 5000) -tokenizer = AutoTokenizer.from_pretrained("bert-base-cased") -train_encodings = tokenizer([text[0] for text in data_train], truncation=True, padding=True) +tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") +train_X = tokenizer([text[0] for text in data_train], truncation=True, padding=True) +train_Y = [int(text[1]) for text in data_train] -train_dataset = CustomDataset(train_encodings, [int(text[1]) for text in data_train]) +train_dataset = CustomDataset(train_X, train_Y) -model = AutoModelForSequenceClassification.from_pretrained("bert-base-cased", num_labels=2) +model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2) -training_args = TrainingArguments("test_trainer") +training_args = TrainingArguments("model") trainer = Trainer( model=model, args=training_args, train_dataset=train_dataset) trainer.train() -with open('train/out.tsv', 'w') as writer: - for result in trainer.predict(data_train_X): - writer.write(str(result) + '\n') - -with open('dev-0/out.tsv', 'w') as writer: - for result in trainer.predict(data_dev_X): - writer.write(str(result) + '\n') - -with open('test-A/out.tsv', 'w') as writer: - for result in trainer.predict(data_test_X): - writer.write(str(result) + '\n') - diff --git a/predict.py b/predict.py new file mode 100644 index 0000000..43c373c --- /dev/null +++ b/predict.py @@ -0,0 +1,59 @@ +from transformers import BertTokenizer, BertForSequenceClassification, TrainingArguments, Trainer +import random +import torch +import numpy as np + +with open('train/in.tsv') as f: + data_train_X = f.readlines() + +with open('train/expected.tsv') as f: + data_train_Y = f.readlines() + +with open('dev-0/in.tsv') as f: + data_dev_X = f.readlines() + +with open('test-A/in.tsv') as f: + data_test_X = f.readlines() + +class CustomDataset(torch.utils.data.Dataset): + def __init__(self, encodings, labels=None): + self.encodings = encodings + self.labels = labels + + def __getitem__(self, idx): + item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()} + if self.labels: + item["labels"] = torch.tensor(self.labels[idx]) + return item + + def __len__(self): + return len(self.encodings["input_ids"]) + +tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") + +model_path = "model/checkpoint-1500" +model = BertForSequenceClassification.from_pretrained(model_path, num_labels=2) + +trainer = Trainer(model) + +with open('train/out.tsv', 'w') as writer: + train_encodings = tokenizer(data_train_X, truncation=True, padding=True) + train_dataset = CustomDataset(train_encodings) + raw_pred, _, _ = trainer.predict(train_dataset) + for result in np.argmax(raw_pred, axis=1): + writer.write(str(result) + '\n') + +with open('dev-0/out.tsv', 'w') as writer: + dev_encodings = tokenizer(data_dev_X, truncation=True, padding=True) + dev_dataset = CustomDataset(dev_encodings) + raw_pred, _, _ = trainer.predict(dev_dataset) + for result in np.argmax(raw_pred, axis=1): + writer.write(str(result) + '\n') + +with open('test-A/out.tsv', 'w') as writer: + test_encodings = tokenizer(data_test_X, truncation=True, padding=True) + test_dataset = CustomDataset(test_encodings) + raw_pred, _, _ = trainer.predict(test_dataset) + for result in np.argmax(raw_pred, axis=1): + writer.write(str(result) + '\n') + diff --git a/requirements.txt b/requirements.txt index 862a0c6..4ee2926 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,10 +1,166 @@ +certifi==2021.5.30 +chardet==4.0.0 +click==8.0.1 +datasets==1.8.0 +dill==0.3.4 +filelock==3.0.12 +fsspec==2021.6.0 +huggingface-hub==0.0.8 +idna==2.10 joblib==1.0.1 +keyboard==0.13.5 +multiprocess==0.70.12.2 numpy==1.20.2 +packaging==20.9 pandas==1.2.4 +pyarrow==3.0.0 +pynput==1.7.3 +pyobjc==7.2 +pyobjc-core==7.2 +pyobjc-framework-Accessibility==7.2 +pyobjc-framework-Accounts==7.2 +pyobjc-framework-AddressBook==7.2 +pyobjc-framework-AdServices==7.2 +pyobjc-framework-AdSupport==7.2 +pyobjc-framework-AppleScriptKit==7.2 +pyobjc-framework-AppleScriptObjC==7.2 +pyobjc-framework-ApplicationServices==7.2 +pyobjc-framework-AppTrackingTransparency==7.2 +pyobjc-framework-AuthenticationServices==7.2 +pyobjc-framework-AutomaticAssessmentConfiguration==7.2 +pyobjc-framework-Automator==7.2 +pyobjc-framework-AVFoundation==7.2 +pyobjc-framework-AVKit==7.2 +pyobjc-framework-BusinessChat==7.2 +pyobjc-framework-CalendarStore==7.2 +pyobjc-framework-CallKit==7.2 +pyobjc-framework-CFNetwork==7.2 +pyobjc-framework-ClassKit==7.2 +pyobjc-framework-CloudKit==7.2 +pyobjc-framework-Cocoa==7.2 +pyobjc-framework-Collaboration==7.2 +pyobjc-framework-ColorSync==7.2 +pyobjc-framework-Contacts==7.2 +pyobjc-framework-ContactsUI==7.2 +pyobjc-framework-CoreAudio==7.2 +pyobjc-framework-CoreAudioKit==7.2 +pyobjc-framework-CoreBluetooth==7.2 +pyobjc-framework-CoreData==7.2 +pyobjc-framework-CoreHaptics==7.2 +pyobjc-framework-CoreLocation==7.2 +pyobjc-framework-CoreMedia==7.2 +pyobjc-framework-CoreMediaIO==7.2 +pyobjc-framework-CoreMIDI==7.2 +pyobjc-framework-CoreML==7.2 +pyobjc-framework-CoreMotion==7.2 +pyobjc-framework-CoreServices==7.2 +pyobjc-framework-CoreSpotlight==7.2 +pyobjc-framework-CoreText==7.2 +pyobjc-framework-CoreWLAN==7.2 +pyobjc-framework-CryptoTokenKit==7.2 +pyobjc-framework-DeviceCheck==7.2 +pyobjc-framework-DictionaryServices==7.2 +pyobjc-framework-DiscRecording==7.2 +pyobjc-framework-DiscRecordingUI==7.2 +pyobjc-framework-DiskArbitration==7.2 +pyobjc-framework-DVDPlayback==7.2 +pyobjc-framework-EventKit==7.2 +pyobjc-framework-ExceptionHandling==7.2 +pyobjc-framework-ExecutionPolicy==7.2 +pyobjc-framework-ExternalAccessory==7.2 +pyobjc-framework-FileProvider==7.2 +pyobjc-framework-FileProviderUI==7.2 +pyobjc-framework-FinderSync==7.2 +pyobjc-framework-FSEvents==7.2 +pyobjc-framework-GameCenter==7.2 +pyobjc-framework-GameController==7.2 +pyobjc-framework-GameKit==7.2 +pyobjc-framework-GameplayKit==7.2 +pyobjc-framework-ImageCaptureCore==7.2 +pyobjc-framework-IMServicePlugIn==7.2 +pyobjc-framework-InputMethodKit==7.2 +pyobjc-framework-InstallerPlugins==7.2 +pyobjc-framework-InstantMessage==7.2 +pyobjc-framework-Intents==7.2 +pyobjc-framework-IOSurface==7.2 +pyobjc-framework-iTunesLibrary==7.2 +pyobjc-framework-KernelManagement==7.2 +pyobjc-framework-LatentSemanticMapping==7.2 +pyobjc-framework-LaunchServices==7.2 +pyobjc-framework-libdispatch==7.2 +pyobjc-framework-LinkPresentation==7.2 +pyobjc-framework-LocalAuthentication==7.2 +pyobjc-framework-MapKit==7.2 +pyobjc-framework-MediaAccessibility==7.2 +pyobjc-framework-MediaLibrary==7.2 +pyobjc-framework-MediaPlayer==7.2 +pyobjc-framework-MediaToolbox==7.2 +pyobjc-framework-Metal==7.2 +pyobjc-framework-MetalKit==7.2 +pyobjc-framework-MetalPerformanceShaders==7.2 +pyobjc-framework-MetalPerformanceShadersGraph==7.2 +pyobjc-framework-MLCompute==7.2 +pyobjc-framework-ModelIO==7.2 +pyobjc-framework-MultipeerConnectivity==7.2 +pyobjc-framework-NaturalLanguage==7.2 +pyobjc-framework-NetFS==7.2 +pyobjc-framework-Network==7.2 +pyobjc-framework-NetworkExtension==7.2 +pyobjc-framework-NotificationCenter==7.2 +pyobjc-framework-OpenDirectory==7.2 +pyobjc-framework-OSAKit==7.2 +pyobjc-framework-OSLog==7.2 +pyobjc-framework-PassKit==7.2 +pyobjc-framework-PencilKit==7.2 +pyobjc-framework-Photos==7.2 +pyobjc-framework-PhotosUI==7.2 +pyobjc-framework-PreferencePanes==7.2 +pyobjc-framework-PushKit==7.2 +pyobjc-framework-Quartz==7.2 +pyobjc-framework-QuickLookThumbnailing==7.2 +pyobjc-framework-ReplayKit==7.2 +pyobjc-framework-SafariServices==7.2 +pyobjc-framework-SceneKit==7.2 +pyobjc-framework-ScreenSaver==7.2 +pyobjc-framework-ScreenTime==7.2 +pyobjc-framework-ScriptingBridge==7.2 +pyobjc-framework-SearchKit==7.2 +pyobjc-framework-Security==7.2 +pyobjc-framework-SecurityFoundation==7.2 +pyobjc-framework-SecurityInterface==7.2 +pyobjc-framework-ServiceManagement==7.2 +pyobjc-framework-Social==7.2 +pyobjc-framework-SoundAnalysis==7.2 +pyobjc-framework-Speech==7.2 +pyobjc-framework-SpriteKit==7.2 +pyobjc-framework-StoreKit==7.2 +pyobjc-framework-SyncServices==7.2 +pyobjc-framework-SystemConfiguration==7.2 +pyobjc-framework-SystemExtensions==7.2 +pyobjc-framework-UniformTypeIdentifiers==7.2 +pyobjc-framework-UserNotifications==7.2 +pyobjc-framework-UserNotificationsUI==7.2 +pyobjc-framework-VideoSubscriberAccount==7.2 +pyobjc-framework-VideoToolbox==7.2 +pyobjc-framework-Virtualization==7.2 +pyobjc-framework-Vision==7.2 +pyobjc-framework-WebKit==7.2 +pyparsing==2.4.7 python-dateutil==2.8.1 pytz==2021.1 +PyYAML==5.4.1 +regex==2021.4.4 +requests==2.25.1 +sacremoses==0.0.45 scikit-learn==0.24.2 scipy==1.6.3 six==1.16.0 sklearn==0.0 threadpoolctl==2.1.0 +tokenizers==0.10.3 +torch==1.9.0 +tqdm==4.49.0 +transformers==4.7.0 +typing-extensions==3.10.0.0 +urllib3==1.26.5 +xxhash==2.0.2