aaaa
This commit is contained in:
parent
a459bfbb6f
commit
8a8e1a8307
20726
dev-0/out.tsv
20726
dev-0/out.tsv
File diff suppressed because it is too large
Load Diff
47
run.py
47
run.py
@ -36,28 +36,6 @@ def train_model(data, model):
|
|||||||
model[w2][w1] /= total_count
|
model[w2][w1] /= total_count
|
||||||
|
|
||||||
|
|
||||||
def predict(word, model):
|
|
||||||
predictions = dict(model[word])
|
|
||||||
most_common = dict(Counter(predictions).most_common(5))
|
|
||||||
|
|
||||||
total_prob = 0.0
|
|
||||||
str_prediction = ""
|
|
||||||
|
|
||||||
for word, prob in most_common.items():
|
|
||||||
total_prob += prob
|
|
||||||
str_prediction += f"{word}:{prob} "
|
|
||||||
|
|
||||||
if not total_prob:
|
|
||||||
return "the:0.2 be:0.2 to:0.2 of:0.1 and:0.1 a:0.1 :0.1"
|
|
||||||
|
|
||||||
if 1 - total_prob >= 0.01:
|
|
||||||
str_prediction += f":{1-total_prob}"
|
|
||||||
else:
|
|
||||||
str_prediction += f":0.01"
|
|
||||||
|
|
||||||
return str_prediction
|
|
||||||
|
|
||||||
|
|
||||||
def predict_data(read_path, save_path, model):
|
def predict_data(read_path, save_path, model):
|
||||||
data = get_csv(read_path)
|
data = get_csv(read_path)
|
||||||
|
|
||||||
@ -67,9 +45,32 @@ def predict_data(read_path, save_path, model):
|
|||||||
if len(words) < 3:
|
if len(words) < 3:
|
||||||
prediction = "the:0.2 be:0.2 to:0.2 of:0.1 and:0.1 a:0.1 :0.1"
|
prediction = "the:0.2 be:0.2 to:0.2 of:0.1 and:0.1 a:0.1 :0.1"
|
||||||
else:
|
else:
|
||||||
prediction = predict(words[-1], model)
|
prediction = predict(words[0], model)
|
||||||
f.write(prediction + "\n")
|
f.write(prediction + "\n")
|
||||||
|
|
||||||
|
|
||||||
|
def predict(word, model):
|
||||||
|
predictions = dict(model[word])
|
||||||
|
most_common = dict(Counter(predictions).most_common(6))
|
||||||
|
|
||||||
|
total_prob = 0.0
|
||||||
|
str_prediction = ""
|
||||||
|
|
||||||
|
for word, prob in most_common.items():
|
||||||
|
total_prob += prob
|
||||||
|
str_prediction += f"{word}:{prob} "
|
||||||
|
|
||||||
|
if total_prob == 0.0:
|
||||||
|
return "the:0.2 be:0.2 to:0.2 of:0.1 and:0.1 a:0.1 :0.1"
|
||||||
|
|
||||||
|
rem_prob = 1 - total_prob
|
||||||
|
if rem_prob < 0.01:
|
||||||
|
rem_prob = 0.01
|
||||||
|
|
||||||
|
str_prediction += f":{rem_prob}"
|
||||||
|
|
||||||
|
return str_prediction
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
14696
test-A/out.tsv
14696
test-A/out.tsv
File diff suppressed because it is too large
Load Diff
12
utils.py
12
utils.py
@ -5,9 +5,19 @@ from csv import QUOTE_NONE
|
|||||||
|
|
||||||
ENCODING = "utf-8"
|
ENCODING = "utf-8"
|
||||||
|
|
||||||
|
REP = re.compile(r"[{}\[\]\&%^$*#\(\)@\t\n0123456789]+")
|
||||||
|
REM = re.compile(r"'s|[\-]\\n|\-\\n|\p{P}")
|
||||||
|
|
||||||
|
|
||||||
def clean_text(text):
|
def clean_text(text):
|
||||||
return re.sub(r"\p{P}", "", str(text).lower().replace("-\\n", "").replace("\\n", " "))
|
res = str(text).lower().strip()
|
||||||
|
res = res.replace("’", "'")
|
||||||
|
res = REM.sub("", res)
|
||||||
|
res = REP.sub(" ", res)
|
||||||
|
res = res.replace("'s", " is")
|
||||||
|
res = res.replace("'ll", " will")
|
||||||
|
res = res.replace("won't", "will not")
|
||||||
|
return res.replace("'m", " am")
|
||||||
|
|
||||||
|
|
||||||
def get_csv(fname):
|
def get_csv(fname):
|
||||||
|
Loading…
Reference in New Issue
Block a user