aaaa
This commit is contained in:
parent
a459bfbb6f
commit
8a8e1a8307
20726
dev-0/out.tsv
20726
dev-0/out.tsv
File diff suppressed because it is too large
Load Diff
47
run.py
47
run.py
@ -36,28 +36,6 @@ def train_model(data, model):
|
||||
model[w2][w1] /= total_count
|
||||
|
||||
|
||||
def predict(word, model):
|
||||
predictions = dict(model[word])
|
||||
most_common = dict(Counter(predictions).most_common(5))
|
||||
|
||||
total_prob = 0.0
|
||||
str_prediction = ""
|
||||
|
||||
for word, prob in most_common.items():
|
||||
total_prob += prob
|
||||
str_prediction += f"{word}:{prob} "
|
||||
|
||||
if not total_prob:
|
||||
return "the:0.2 be:0.2 to:0.2 of:0.1 and:0.1 a:0.1 :0.1"
|
||||
|
||||
if 1 - total_prob >= 0.01:
|
||||
str_prediction += f":{1-total_prob}"
|
||||
else:
|
||||
str_prediction += f":0.01"
|
||||
|
||||
return str_prediction
|
||||
|
||||
|
||||
def predict_data(read_path, save_path, model):
|
||||
data = get_csv(read_path)
|
||||
|
||||
@ -67,9 +45,32 @@ def predict_data(read_path, save_path, model):
|
||||
if len(words) < 3:
|
||||
prediction = "the:0.2 be:0.2 to:0.2 of:0.1 and:0.1 a:0.1 :0.1"
|
||||
else:
|
||||
prediction = predict(words[-1], model)
|
||||
prediction = predict(words[0], model)
|
||||
f.write(prediction + "\n")
|
||||
|
||||
|
||||
def predict(word, model):
|
||||
predictions = dict(model[word])
|
||||
most_common = dict(Counter(predictions).most_common(6))
|
||||
|
||||
total_prob = 0.0
|
||||
str_prediction = ""
|
||||
|
||||
for word, prob in most_common.items():
|
||||
total_prob += prob
|
||||
str_prediction += f"{word}:{prob} "
|
||||
|
||||
if total_prob == 0.0:
|
||||
return "the:0.2 be:0.2 to:0.2 of:0.1 and:0.1 a:0.1 :0.1"
|
||||
|
||||
rem_prob = 1 - total_prob
|
||||
if rem_prob < 0.01:
|
||||
rem_prob = 0.01
|
||||
|
||||
str_prediction += f":{rem_prob}"
|
||||
|
||||
return str_prediction
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
14696
test-A/out.tsv
14696
test-A/out.tsv
File diff suppressed because it is too large
Load Diff
12
utils.py
12
utils.py
@ -5,9 +5,19 @@ from csv import QUOTE_NONE
|
||||
|
||||
ENCODING = "utf-8"
|
||||
|
||||
REP = re.compile(r"[{}\[\]\&%^$*#\(\)@\t\n0123456789]+")
|
||||
REM = re.compile(r"'s|[\-]\\n|\-\\n|\p{P}")
|
||||
|
||||
|
||||
def clean_text(text):
|
||||
return re.sub(r"\p{P}", "", str(text).lower().replace("-\\n", "").replace("\\n", " "))
|
||||
res = str(text).lower().strip()
|
||||
res = res.replace("’", "'")
|
||||
res = REM.sub("", res)
|
||||
res = REP.sub(" ", res)
|
||||
res = res.replace("'s", " is")
|
||||
res = res.replace("'ll", " will")
|
||||
res = res.replace("won't", "will not")
|
||||
return res.replace("'m", " am")
|
||||
|
||||
|
||||
def get_csv(fname):
|
||||
|
Loading…
Reference in New Issue
Block a user