This commit is contained in:
parent
96ce6341cb
commit
16db0c341e
@ -18,13 +18,19 @@ import sys
|
||||
from tqdm import tqdm
|
||||
from Levenshtein import distance as levenshtein_distance
|
||||
from sacred import Experiment
|
||||
import traceback
|
||||
|
||||
ex = Experiment("CNN")
|
||||
ex.observers.append(FileStorageObserver('sacred_file_observer'))
|
||||
try:
|
||||
ex.observers.append(MongoObserver(url='mongodb://mongo_user:mongo_password_IUM_2021@localhost:27017',
|
||||
db_name='sacred'))
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
|
||||
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
|
||||
|
||||
|
||||
class CNN(nn.Module):
|
||||
def __init__(self, kernel_size, hidden_layers, channels, embedding_size, in_alphabet, max_len):
|
||||
super(CNN, self).__init__()
|
||||
@ -206,4 +212,3 @@ def run(kernel_size, hidden_layers, data_file, epochs, teacher_forcing_probabili
|
||||
if mode == 'eval':
|
||||
cnn.eval()
|
||||
evaluate_monte_carlo(cnn, 1, data, batch_size, in_alphabet, max_len)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user