This commit is contained in:
parent
96ce6341cb
commit
16db0c341e
@ -18,13 +18,19 @@ import sys
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from Levenshtein import distance as levenshtein_distance
|
from Levenshtein import distance as levenshtein_distance
|
||||||
from sacred import Experiment
|
from sacred import Experiment
|
||||||
|
import traceback
|
||||||
|
|
||||||
ex = Experiment("CNN")
|
ex = Experiment("CNN")
|
||||||
ex.observers.append(FileStorageObserver('sacred_file_observer'))
|
ex.observers.append(FileStorageObserver('sacred_file_observer'))
|
||||||
ex.observers.append(MongoObserver(url='mongodb://mongo_user:mongo_password_IUM_2021@localhost:27017',
|
try:
|
||||||
|
ex.observers.append(MongoObserver(url='mongodb://mongo_user:mongo_password_IUM_2021@localhost:27017',
|
||||||
db_name='sacred'))
|
db_name='sacred'))
|
||||||
|
except Exception as e:
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
|
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
|
||||||
|
|
||||||
|
|
||||||
class CNN(nn.Module):
|
class CNN(nn.Module):
|
||||||
def __init__(self, kernel_size, hidden_layers, channels, embedding_size, in_alphabet, max_len):
|
def __init__(self, kernel_size, hidden_layers, channels, embedding_size, in_alphabet, max_len):
|
||||||
super(CNN, self).__init__()
|
super(CNN, self).__init__()
|
||||||
@ -206,4 +212,3 @@ def run(kernel_size, hidden_layers, data_file, epochs, teacher_forcing_probabili
|
|||||||
if mode == 'eval':
|
if mode == 'eval':
|
||||||
cnn.eval()
|
cnn.eval()
|
||||||
evaluate_monte_carlo(cnn, 1, data, batch_size, in_alphabet, max_len)
|
evaluate_monte_carlo(cnn, 1, data, batch_size, in_alphabet, max_len)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user