challenging-america-word-ga.../run.py

59 lines
1.7 KiB
Python
Raw Normal View History

2022-04-29 13:01:36 +02:00
# -*- coding: utf-8 -*-
2022-06-09 20:52:29 +02:00
"""run - gpt.ipynb
2022-04-29 13:01:36 +02:00
Automatically generated by Colaboratory.
Original file is located at
2022-06-09 20:52:29 +02:00
https://colab.research.google.com/drive/1YlyKQShvsB_4qBTfjdRm2ngeYAKpPtxt
2022-04-29 13:01:36 +02:00
"""
2022-06-09 20:52:29 +02:00
!pip install transformers
2022-04-29 13:01:36 +02:00
from google.colab import drive
drive.mount('/content/gdrive/')
2022-04-12 10:01:45 +02:00
import pandas as pd
2022-06-09 20:52:29 +02:00
import torch
import transformers
2022-04-12 10:01:45 +02:00
import csv
2022-06-09 20:52:29 +02:00
import tensorflow as tf
import re
import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("gpt2")
model = AutoModelForCausalLM.from_pretrained("gpt2")
def predict(text):
text = str.join(" ", text.split()[-50:])
input_ids = tokenizer(text, return_tensors='pt')
with torch.no_grad():
logits = model(**input_ids).logits[:, -1, :]
result = ""
top = torch.topk(logits, 2)
probs = tf.nn.softmax(top.values[0]).numpy().tolist()
print(probs)
for i in range(2):
predicted_word = tokenizer.decode(top.indices[0][i], skip_special_tokens=True).split()[-1]
sentence_score = probs[i]
result+=f"{predicted_word}:{sentence_score} "
result = result + " :0.2"
return result
def predict_doc(input_path, output_path):
data = pd.read_csv(input_path, sep='\t', on_bad_lines='skip', header=None, quoting=csv.QUOTE_NONE)[6]
data = data.replace('\\\\n', "", regex=True)
data = data.apply(lambda x: re.sub('[^a-zA-Z0-9] ', '', x))
cnt = len(data)
with open(output_path, 'w') as file:
for i, row in enumerate(data):
try:
result = predict(row)
except:
2022-06-10 22:58:49 +02:00
result = ":1.0"
2022-06-09 20:52:29 +02:00
print(f"{i}/{cnt} {result}")
file.write(result + '\n')
# predict_doc('gdrive/MyDrive/dev-0/in.tsv.xz', 'gdrive/MyDrive/dev-0/out.tsv')
predict_doc('gdrive/MyDrive/test-A/in.tsv.xz', 'gdrive/MyDrive/test-A/out.tsv')