2022-04-29 13:01:36 +02:00
|
|
|
# -*- coding: utf-8 -*-
|
2022-06-09 20:52:29 +02:00
|
|
|
"""run - gpt.ipynb
|
2022-04-29 13:01:36 +02:00
|
|
|
|
|
|
|
Automatically generated by Colaboratory.
|
|
|
|
|
|
|
|
Original file is located at
|
2022-06-09 20:52:29 +02:00
|
|
|
https://colab.research.google.com/drive/1YlyKQShvsB_4qBTfjdRm2ngeYAKpPtxt
|
2022-04-29 13:01:36 +02:00
|
|
|
"""
|
|
|
|
|
2022-06-09 20:52:29 +02:00
|
|
|
!pip install transformers
|
|
|
|
|
2022-04-29 13:01:36 +02:00
|
|
|
from google.colab import drive
|
|
|
|
drive.mount('/content/gdrive/')
|
|
|
|
|
2022-04-12 10:01:45 +02:00
|
|
|
import pandas as pd
|
2022-06-09 20:52:29 +02:00
|
|
|
import torch
|
|
|
|
import transformers
|
2022-04-12 10:01:45 +02:00
|
|
|
import csv
|
2022-06-09 20:52:29 +02:00
|
|
|
import tensorflow as tf
|
|
|
|
import re
|
|
|
|
import numpy as np
|
|
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
|
|
|
model = AutoModelForCausalLM.from_pretrained("gpt2")
|
|
|
|
|
|
|
|
def predict(text):
|
|
|
|
text = str.join(" ", text.split()[-50:])
|
|
|
|
input_ids = tokenizer(text, return_tensors='pt')
|
|
|
|
with torch.no_grad():
|
|
|
|
logits = model(**input_ids).logits[:, -1, :]
|
|
|
|
result = ""
|
|
|
|
top = torch.topk(logits, 2)
|
|
|
|
probs = tf.nn.softmax(top.values[0]).numpy().tolist()
|
|
|
|
print(probs)
|
|
|
|
for i in range(2):
|
|
|
|
predicted_word = tokenizer.decode(top.indices[0][i], skip_special_tokens=True).split()[-1]
|
|
|
|
sentence_score = probs[i]
|
|
|
|
result+=f"{predicted_word}:{sentence_score} "
|
|
|
|
result = result + " :0.2"
|
|
|
|
return result
|
|
|
|
|
|
|
|
def predict_doc(input_path, output_path):
|
|
|
|
data = pd.read_csv(input_path, sep='\t', on_bad_lines='skip', header=None, quoting=csv.QUOTE_NONE)[6]
|
|
|
|
data = data.replace('\\\\n', "", regex=True)
|
|
|
|
data = data.apply(lambda x: re.sub('[^a-zA-Z0-9] ', '', x))
|
|
|
|
cnt = len(data)
|
|
|
|
with open(output_path, 'w') as file:
|
|
|
|
for i, row in enumerate(data):
|
|
|
|
try:
|
|
|
|
result = predict(row)
|
|
|
|
except:
|
2022-06-10 22:58:49 +02:00
|
|
|
result = ":1.0"
|
2022-06-09 20:52:29 +02:00
|
|
|
print(f"{i}/{cnt} {result}")
|
|
|
|
file.write(result + '\n')
|
|
|
|
|
|
|
|
# predict_doc('gdrive/MyDrive/dev-0/in.tsv.xz', 'gdrive/MyDrive/dev-0/out.tsv')
|
|
|
|
|
|
|
|
predict_doc('gdrive/MyDrive/test-A/in.tsv.xz', 'gdrive/MyDrive/test-A/out.tsv')
|