all done
This commit is contained in:
parent
c0894d950a
commit
eb10e5db4a
10685
dev-0/out.tsv
10685
dev-0/out.tsv
File diff suppressed because it is too large
Load Diff
8
hf.py
8
hf.py
@ -6,12 +6,12 @@ import regex as re
|
||||
import sys
|
||||
|
||||
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
||||
model = AutoModelForCausalLM.from_pretrained('gpt2')
|
||||
model = AutoModelForCausalLM.from_pretrained('gpt2').to('cuda')
|
||||
|
||||
|
||||
for line in sys.stdin:
|
||||
input_text = line.split('\t')[-2].rstrip()
|
||||
input_ids = tokenizer.encode(input_text, return_tensors='pt')
|
||||
input_ids = tokenizer.encode(input_text, return_tensors='pt').to('cuda')
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(input_ids)
|
||||
@ -39,5 +39,5 @@ for line in sys.stdin:
|
||||
continue
|
||||
unknow_prob = 1 - sum_probs
|
||||
string_to_print += f":{unknow_prob}"
|
||||
|
||||
print(string_to_print)
|
||||
string_to_print = re.sub(' +', ' ', string_to_print)
|
||||
print(string_to_print.rstrip().strip())
|
||||
|
14828
test-A/out.tsv
14828
test-A/out.tsv
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user