This commit is contained in:
s440054 2022-04-05 19:08:22 +02:00
parent a459bfbb6f
commit 8a8e1a8307
4 changed files with 17746 additions and 17735 deletions

File diff suppressed because it is too large Load Diff

47
run.py
View File

@ -36,28 +36,6 @@ def train_model(data, model):
model[w2][w1] /= total_count model[w2][w1] /= total_count
def predict(word, model):
predictions = dict(model[word])
most_common = dict(Counter(predictions).most_common(5))
total_prob = 0.0
str_prediction = ""
for word, prob in most_common.items():
total_prob += prob
str_prediction += f"{word}:{prob} "
if not total_prob:
return "the:0.2 be:0.2 to:0.2 of:0.1 and:0.1 a:0.1 :0.1"
if 1 - total_prob >= 0.01:
str_prediction += f":{1-total_prob}"
else:
str_prediction += f":0.01"
return str_prediction
def predict_data(read_path, save_path, model): def predict_data(read_path, save_path, model):
data = get_csv(read_path) data = get_csv(read_path)
@ -67,9 +45,32 @@ def predict_data(read_path, save_path, model):
if len(words) < 3: if len(words) < 3:
prediction = "the:0.2 be:0.2 to:0.2 of:0.1 and:0.1 a:0.1 :0.1" prediction = "the:0.2 be:0.2 to:0.2 of:0.1 and:0.1 a:0.1 :0.1"
else: else:
prediction = predict(words[-1], model) prediction = predict(words[0], model)
f.write(prediction + "\n") f.write(prediction + "\n")
def predict(word, model):
predictions = dict(model[word])
most_common = dict(Counter(predictions).most_common(6))
total_prob = 0.0
str_prediction = ""
for word, prob in most_common.items():
total_prob += prob
str_prediction += f"{word}:{prob} "
if total_prob == 0.0:
return "the:0.2 be:0.2 to:0.2 of:0.1 and:0.1 a:0.1 :0.1"
rem_prob = 1 - total_prob
if rem_prob < 0.01:
rem_prob = 0.01
str_prediction += f":{rem_prob}"
return str_prediction
if __name__ == "__main__": if __name__ == "__main__":
main() main()

File diff suppressed because it is too large Load Diff

View File

@ -5,9 +5,19 @@ from csv import QUOTE_NONE
ENCODING = "utf-8" ENCODING = "utf-8"
REP = re.compile(r"[{}\[\]\&%^$*#\(\)@\t\n0123456789]+")
REM = re.compile(r"'s|[\-]\\n|\-\\n|\p{P}")
def clean_text(text): def clean_text(text):
return re.sub(r"\p{P}", "", str(text).lower().replace("-\\n", "").replace("\\n", " ")) res = str(text).lower().strip()
res = res.replace("", "'")
res = REM.sub("", res)
res = REP.sub(" ", res)
res = res.replace("'s", " is")
res = res.replace("'ll", " will")
res = res.replace("won't", "will not")
return res.replace("'m", " am")
def get_csv(fname): def get_csv(fname):