Compare commits

..

1 Commits

Author SHA1 Message Date
e7951d0867 Fit GPT2 finetuning 2023-09-25 01:29:56 +02:00
7 changed files with 17934 additions and 53800 deletions

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -201,7 +201,7 @@ def predict_words(dataset):
src = tokenizer.encode(text, return_tensors="pt", truncation=True).to(device) src = tokenizer.encode(text, return_tensors="pt", truncation=True).to(device)
output = model.generate(src, max_length=len(src[0]) + 1, do_sample=True, top_k=0, temperature=0.8, output = model.generate(src, max_length=len(src[0]) + 1, do_sample=True, top_k=0, temperature=0.8,
num_return_sequences=1, no_repeat_ngram_size=2, output_scores=True) num_return_sequences=1, no_repeat_ngram_size=2, output_scores=True)
probs, idxs = torch.softmax(output.scores[0][-1], dim=0).topk(50) probs, idxs = torch.softmax(output.scores[0][-1], dim=0).topk(30)
current_output = '' current_output = ''
accumulated_probability = 0 accumulated_probability = 0
for prob, token_id in zip(probs, idxs): for prob, token_id in zip(probs, idxs):

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff