challam-roberta-with-date weekday
This commit is contained in:
parent
4057595363
commit
d97bdda8b3
64
challam-roberta-with-date/00_preprocess.py
Normal file
64
challam-roberta-with-date/00_preprocess.py
Normal file
@ -0,0 +1,64 @@
|
|||||||
|
import calendar
|
||||||
|
import datetime
|
||||||
|
|
||||||
|
def to_fractional_year(d: datetime.datetime) -> float:
|
||||||
|
"""
|
||||||
|
Converts a date stamp to a fractional year (i.e. number like `1939.781`)
|
||||||
|
"""
|
||||||
|
is_leap = calendar.isleap(d.year)
|
||||||
|
t = d.timetuple()
|
||||||
|
day_of_year = t.tm_yday
|
||||||
|
day_time = (60 * 60 * t.tm_hour + 60 * t.tm_min + t.tm_sec) / (24 * 60 * 60)
|
||||||
|
|
||||||
|
days_in_year = 366 if is_leap else 365
|
||||||
|
|
||||||
|
return d.year + ((day_of_year - 1 + day_time) / days_in_year)
|
||||||
|
|
||||||
|
def fractional_to_date(fractional):
|
||||||
|
eps = 0.0001
|
||||||
|
year = int(fractional)
|
||||||
|
is_leap = calendar.isleap(year)
|
||||||
|
|
||||||
|
modulus = fractional % 1
|
||||||
|
|
||||||
|
days_in_year = 366 if is_leap else 365
|
||||||
|
|
||||||
|
day_of_year = int( days_in_year * modulus + eps )
|
||||||
|
|
||||||
|
d = datetime.datetime(year, 1,1) + datetime.timedelta(days = day_of_year )
|
||||||
|
|
||||||
|
return d
|
||||||
|
|
||||||
|
dates = (datetime.datetime(1825,10,30),
|
||||||
|
datetime.datetime(1825,10,31),
|
||||||
|
datetime.datetime(1900,1,1),
|
||||||
|
datetime.datetime(1900,12,1),
|
||||||
|
datetime.datetime(1900,12,31),
|
||||||
|
datetime.datetime(1930,2,28),
|
||||||
|
datetime.datetime(1932,2,29),
|
||||||
|
)
|
||||||
|
|
||||||
|
for d in dates:
|
||||||
|
inverted = fractional_to_date(to_fractional_year(d))
|
||||||
|
assert d == inverted
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def convert_to_year_text_format(line):
|
||||||
|
_,_,_,fractional_year, _, _, text_l, text_r = line.split('\t')
|
||||||
|
date = fractional_to_date(float(fractional_year))
|
||||||
|
year = date.year
|
||||||
|
month = date.month
|
||||||
|
day = date.day
|
||||||
|
weekday = date.weekday()
|
||||||
|
return f'year: {year} month: {month} day: {day} weekday: {weekday} text: \t' + text_l + '\t' + text_r
|
||||||
|
|
||||||
|
|
||||||
|
def convert_dataset(f_in_path, f_out_path):
|
||||||
|
with open(f_in_path,'r') as f_in, open(f_out_path, 'w') as f_out:
|
||||||
|
for line in f_in:
|
||||||
|
out = convert_to_year_text_format(line)
|
||||||
|
f_out.write(out)
|
||||||
|
|
||||||
|
convert_dataset('../dev-0/in.tsv', './dev-0-date.tsv')
|
||||||
|
convert_dataset('../test-A/in.tsv', './test-A-date.tsv')
|
47
challam-roberta-with-date/04_predict.py
Normal file
47
challam-roberta-with-date/04_predict.py
Normal file
@ -0,0 +1,47 @@
|
|||||||
|
from tqdm import tqdm
|
||||||
|
from transformers import pipeline
|
||||||
|
|
||||||
|
def get_formatted(text):
|
||||||
|
answers = unmasker(text, top_k=15)
|
||||||
|
answers = {x['token_str']:x['score'] for x in answers}
|
||||||
|
empty = 1 - sum(answers.values())
|
||||||
|
answers[''] = empty
|
||||||
|
answers_str =''
|
||||||
|
for k,v in answers.items():
|
||||||
|
answers_str += k.strip()+':'+str(v) + ' '
|
||||||
|
return answers_str.rstrip(' ').lstrip(' ')
|
||||||
|
|
||||||
|
def write(f_path_in, f_path_out):
|
||||||
|
with open(f_path_in) as f_in, open(f_path_out,'w') as f_out:
|
||||||
|
i = 0
|
||||||
|
for line in tqdm(f_in,total=10_600):
|
||||||
|
char_context = 400
|
||||||
|
i+=1
|
||||||
|
#print(i)
|
||||||
|
is_ok = False
|
||||||
|
while not is_ok:
|
||||||
|
try:
|
||||||
|
date, left_text, right_text = line.rstrip().split('\t')
|
||||||
|
l_in = date + left_text[-char_context:] + ' <mask> ' + right_text[:char_context]
|
||||||
|
a = get_formatted(l_in)
|
||||||
|
is_ok = True
|
||||||
|
except:
|
||||||
|
print('lowering context')
|
||||||
|
char_context -= 50
|
||||||
|
if char_context < 60:
|
||||||
|
a = ':1'
|
||||||
|
print('lower threshold context exceeded')
|
||||||
|
is_ok = True
|
||||||
|
|
||||||
|
f_out.write(a + '\n')
|
||||||
|
#left_text = line.rstrip().split('\t')[-2]
|
||||||
|
#right_text = line.rstrip().split('\t')[-1]
|
||||||
|
#l_in = left_text[-char_context:] + ' <mask> ' + right_text[:char_context]
|
||||||
|
#a = get_formatted(l_in)
|
||||||
|
|
||||||
|
#f_out.write(a + '\n')
|
||||||
|
|
||||||
|
model = 'with_date/checkpoint-396000'
|
||||||
|
unmasker = pipeline('fill-mask', model=model, device=0)
|
||||||
|
write('./dev-0-date.tsv', '../dev-0/out.tsv')
|
||||||
|
write('./test-A-date.tsv', '../test-A/out.tsv')
|
21038
dev-0/out.tsv
21038
dev-0/out.tsv
File diff suppressed because it is too large
Load Diff
14828
test-A/out.tsv
14828
test-A/out.tsv
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user