challenging-america-word-ga.../run.py
2022-06-19 13:28:09 +02:00

58 lines
1.8 KiB
Python

# -*- coding: utf-8 -*-
"""run - gpt.ipynb
Automatically generated by Colaboratory.
Original file is located at
https://colab.research.google.com/drive/1YlyKQShvsB_4qBTfjdRm2ngeYAKpPtxt
"""
!pip install transformers
from google.colab import drive
drive.mount('/content/gdrive/')
import pandas as pd
import torch
import transformers
import csv
import tensorflow as tf
import re
import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("gpt2")
model = AutoModelForCausalLM.from_pretrained("gpt2")
def predict(text):
text = str.join(" ", text.split()[-50:])
input_ids = tokenizer(text, return_tensors='pt')
with torch.no_grad():
logits = model(**input_ids).logits[:, -1, :]
result = ""
top = torch.topk(logits, 2)
probs = tf.nn.softmax(top.values[0]).numpy().tolist()
for i in range(2):
predicted_word = tokenizer.decode(top.indices[0][i], skip_special_tokens=True).split()[-1]
sentence_score = probs[i]
result+=f"{predicted_word}:{sentence_score} "
result = result + ":0.2"
return result
def predict_doc(input_path, output_path):
data = pd.read_csv(input_path, sep='\t', on_bad_lines='skip', header=None, quoting=csv.QUOTE_NONE)[6]
data = data.replace('\\\\n', "", regex=True)
data = data.apply(lambda x: re.sub('[^a-zA-Z0-9] ', '', x))
cnt = len(data)
with open(output_path, 'w') as file:
for i, row in enumerate(data):
try:
result = predict(row)
except:
result = "the:0.2 be:0.2 to:0.2 of:0.1 and:0.1 a:0.1 :0.1"
print(f"{i}/{cnt} {result}")
file.write(result + '\n')
predict_doc('gdrive/MyDrive/dev-0/in.tsv.xz', 'gdrive/MyDrive/dev-0/out.tsv')
predict_doc('gdrive/MyDrive/test-A/in.tsv.xz', 'gdrive/MyDrive/test-A/out.tsv')