diff --git a/train_model.py b/train_model.py index 1bb0c99..d92104f 100644 --- a/train_model.py +++ b/train_model.py @@ -18,13 +18,19 @@ import sys from tqdm import tqdm from Levenshtein import distance as levenshtein_distance from sacred import Experiment +import traceback ex = Experiment("CNN") ex.observers.append(FileStorageObserver('sacred_file_observer')) -ex.observers.append(MongoObserver(url='mongodb://mongo_user:mongo_password_IUM_2021@localhost:27017', - db_name='sacred')) +try: + ex.observers.append(MongoObserver(url='mongodb://mongo_user:mongo_password_IUM_2021@localhost:27017', + db_name='sacred')) +except Exception as e: + traceback.print_exc() + device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu') + class CNN(nn.Module): def __init__(self, kernel_size, hidden_layers, channels, embedding_size, in_alphabet, max_len): super(CNN, self).__init__() @@ -206,4 +212,3 @@ def run(kernel_size, hidden_layers, data_file, epochs, teacher_forcing_probabili if mode == 'eval': cnn.eval() evaluate_monte_carlo(cnn, 1, data, batch_size, in_alphabet, max_len) -