From b9815844a470397d31f6dfa723bb59c9de9d8af4 Mon Sep 17 00:00:00 2001 From: Iwona Christop Date: Tue, 3 May 2022 21:54:24 +0200 Subject: [PATCH] Ready to go --- heSaidEdgar.ipynb | 232 ++++++++++++++++++++++++++++++++++++++++++---- main.ipynb | 2 +- main.py | 126 ++++++++++++++++++++++++- 3 files changed, 337 insertions(+), 23 deletions(-) diff --git a/heSaidEdgar.ipynb b/heSaidEdgar.ipynb index 9a6e4d4..39644b8 100644 --- a/heSaidEdgar.ipynb +++ b/heSaidEdgar.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 23, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -16,32 +16,41 @@ }, { "cell_type": "code", - "execution_count": 40, + "execution_count": 2, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/spacy/util.py:833: UserWarning: [W095] Model 'en_pipeline' (0.0.0) was trained with spaCy v3.3 and may not be 100% compatible with the current version (3.2.4). If you see errors or degraded performance, download a newer compatible model or retrain your custom model with the current spaCy version. For more details and available updates, run: python -m spacy validate\n", + " warnings.warn(warn_msg)\n" + ] + } + ], "source": [ "import spacy\n", "from spacy import displacy\n", "\n", - "nlp = spacy.load('NER')\n", + "ner = spacy.load('NER')\n", "\n", - "text = NDAs[9]\n", - "doc = nlp(text)\n", + "# text = NDAs[9]\n", + "# doc = nlp(text)\n", "\n", - "effective_date = []\n", - "jurisdiction = []\n", - "party = []\n", - "term = []\n", + "# effective_date = []\n", + "# jurisdiction = []\n", + "# party = []\n", + "# term = []\n", "\n", - "for word in doc.ents:\n", - " if word.label_ == 'effective_date':\n", - " effective_date.append(word.text)\n", - " elif word.label_ == 'jurisdiction':\n", - " jurisdiction.append(word.text)\n", - " elif word.label_ == 'party':\n", - " party.append(word.text)\n", - " else:\n", - " term.append(word.text)" + "# for word in doc.ents:\n", + "# if word.label_ == 'effective_date':\n", + "# effective_date.append(word.text)\n", + "# elif word.label_ == 'jurisdiction':\n", + "# jurisdiction.append(word.text)\n", + "# elif word.label_ == 'party':\n", + "# party.append(word.text)\n", + "# else:\n", + "# term.append(word.text)" ] }, { @@ -132,6 +141,189 @@ " print(word.text, '-->', word.label_)" ] }, + { + "cell_type": "code", + "execution_count": 108, + "metadata": {}, + "outputs": [], + "source": [ + "months = {'01': 'January', '02': 'February', '03': 'March', \n", + " '04': 'April', '05': 'May', '06': 'June',\n", + " '07': 'July', '08': 'August', '09': 'September',\n", + " '10': 'October', '11': 'November', '12': 'December'}\n", + "\n", + "punctuation = '!\"#$%&\\'()*+,-./:;<=>?@[\\\\\\\\]^_`{|}~'\n", + "\n", + "document = ner(NDAs[4])\n", + "\n", + "effectiveDate = []\n", + "\n", + "for word in document.ents:\n", + " if word.label_ == 'effective_date':\n", + " effectiveDate.append(word.text)\n", + "\n", + "try:\n", + " effectiveDate = { date : effectiveDate.count(date) for date in effectiveDate }\n", + " effectiveDate = max(effectiveDate, key=effectiveDate.get)\n", + " for char in punctuation: effectiveDate = effectiveDate.replace(char, '')\n", + " # Get month\n", + " for d in effectiveDate.split():\n", + " if d in list(months.values()):\n", + " month = list(months.keys())[list(months.values()).index(d)]\n", + " elif int(d) < 32:\n", + " day = d\n", + " elif int(d) > 1900 and int(d) < 2030:\n", + " year = d\n", + " effectiveDate = year + '-' + month + '-' + day\n", + "except:\n", + " pass\n", + "\n", + "# effectiveDate = '2011-07-13'" + ] + }, + { + "cell_type": "code", + "execution_count": 69, + "metadata": {}, + "outputs": [], + "source": [ + "states = ['Alabama', 'New York']\n", + "\n", + "document = ner(NDAs[6])\n", + "\n", + "jurisdiction = []\n", + "\n", + "for word in document.ents:\n", + " if word.label_ == 'jurisdiction':\n", + " if word.text not in states:\n", + " for state in states:\n", + " if word.text in state:\n", + " jurisdiction.append(state)\n", + " else:\n", + " jurisdiction.append(text)\n", + "\n", + "try:\n", + " jurisdiction = { state : jurisdiction.count(state) for state in jurisdiction }\n", + " jurisdiction = max(jurisdiction, key=jurisdiction.get).replace(' ', '_')\n", + "except:\n", + " pass\n", + "\n", + "# jurisdiction = 'New_York'" + ] + }, + { + "cell_type": "code", + "execution_count": 68, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'New_York'" + ] + }, + "execution_count": 68, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "jurisdiction" + ] + }, + { + "cell_type": "code", + "execution_count": 45, + "metadata": {}, + "outputs": [], + "source": [ + "document = ner(NDAs[9])\n", + "\n", + "party = []\n", + "\n", + "for word in document.ents:\n", + " if word.label_ == 'party':\n", + " party.append(word.text)\n", + "\n", + "party = list(dict.fromkeys(party))\n", + "party = [ p.replace(' ', '_') for p in party]\n", + "# party = ['CompuDyne_Corporation']" + ] + }, + { + "cell_type": "code", + "execution_count": 61, + "metadata": {}, + "outputs": [], + "source": [ + "wordToNumber = {1 : 'one', 2 : 'two', 3 : 'three', 4 : 'four', 5 : 'five',\n", + " 6 : 'six', 7 : 'seven', 8 : 'eight', 9 : 'nine', 10 : 'ten',\n", + " 11 : 'eleven', 12 : 'twelve', 13 : 'thirteen', 14 : 'fourteen',\n", + " 15 : 'fifteen', 16 : 'sixteen', 17 : 'seventeen', 18 : 'eighteen',\n", + " 19 : 'nineteen', 20 : 'twenty',\n", + " 30 : 'thirty', 40 : 'forty', 50 : 'fifty', 60 : 'sixty',\n", + " 70 : 'seventy', 80 : 'eighty', 90 : 'ninety' }\n", + "\n", + "document = ner(NDAs[7])\n", + "\n", + "term = []\n", + "\n", + "for word in document.ents:\n", + " if word.label_ == 'term':\n", + " term.append(word.text)\n", + "\n", + "try:\n", + " term = { time : term.count(time) for time in term }\n", + " term = max(term, key=term.get)\n", + " term = term.split()\n", + " term[0] = str(list(wordToNumber.keys())[list(wordToNumber.values()).index(term[0])])\n", + " term = '_'.join(term)\n", + "except:\n", + " pass\n", + "\n", + "# term = '3_years'" + ] + }, + { + "cell_type": "code", + "execution_count": 62, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'3_years'" + ] + }, + "execution_count": 62, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "term" + ] + }, + { + "cell_type": "code", + "execution_count": 56, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "3" + ] + }, + "execution_count": 56, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "list(wordToNumber.keys())[list(wordToNumber.values()).index(term[0])]" + ] + }, { "cell_type": "code", "execution_count": null, @@ -159,7 +351,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.10" + "version": "3.9.2" }, "orig_nbformat": 4 }, diff --git a/main.ipynb b/main.ipynb index ad6e880..3b1a161 100644 --- a/main.ipynb +++ b/main.ipynb @@ -1852,7 +1852,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.10" + "version": "3.9.2" }, "orig_nbformat": 4 }, diff --git a/main.py b/main.py index 6b23e6e..b6770fb 100644 --- a/main.py +++ b/main.py @@ -1,5 +1,34 @@ import lzma +from matplotlib.pyplot import getp import spacy +import csv + + +months = {'01': 'January', '02': 'February', '03': 'March', + '04': 'April', '05': 'May', '06': 'June', + '07': 'July', '08': 'August', '09': 'September', + '10': 'October', '11': 'November', '12': 'December'} + +punctuation = '!"#$%&\'()*+,-./:;<=>?@[\\\\]^_`{|}~' + +states = ['Alabama', 'Alaska', 'Arizona', 'Arkansas', 'California', + 'Colorado', 'Connecticut', 'Delaware', 'Florida', 'Georgia', + 'Hawaii', 'Idaho', 'Illinois', 'Indiana', 'Iowa', + 'Kansas', 'Kentucky', 'Louisiana', 'Maine', 'Maryland', + 'Massachusetts', 'Michigan', 'Minnesota', 'Mississippi', 'Missouri', + 'Montana', 'Nebraska', 'Nevada', 'New Hampshire', 'New Jersey', + 'New Mexico', 'New York', 'North Carolina', 'North Dakota', 'Ohio', + 'Oklahoma', 'Oregon', 'Pennsylvania', 'Rhode Island', 'South Carolina', + 'South Dakota', 'Tennessee', 'Texas', 'Utah', 'Vermont', + 'Virginia', 'Washington', 'West Virginia', 'Wisconsin', 'Wyoming'] + +wordToNumber = {1 : 'one', 2 : 'two', 3 : 'three', 4 : 'four', 5 : 'five', + 6 : 'six', 7 : 'seven', 8 : 'eight', 9 : 'nine', 10 : 'ten', + 11 : 'eleven', 12 : 'twelve', 13 : 'thirteen', 14 : 'fourteen', + 15 : 'fifteen', 16 : 'sixteen', 17 : 'seventeen', 18 : 'eighteen', + 19 : 'nineteen', 20 : 'twenty', + 30 : 'thirty', 40 : 'forty', 50 : 'fifty', 60 : 'sixty', + 70 : 'seventy', 80 : 'eighty', 90 : 'ninety' } def readInput(dir): @@ -9,11 +38,104 @@ def readInput(dir): NDAs.append(line.decode('utf-8')) return NDAs +def getEffectiveDate(document): + effectiveDate = [] + + for word in document.ents: + if word.label_ == 'effective_date': + effectiveDate.append(word.text) + + #if len(effectiveDate) > 0: + try: + effectiveDate = { date : effectiveDate.count(date) for date in effectiveDate } + effectiveDate = max(effectiveDate, key=effectiveDate.get) + for char in punctuation: effectiveDate = effectiveDate.replace(char, '') + for d in effectiveDate.split(): + if d in list(months.values()): + month = list(months.keys())[list(months.values()).index(d)] + elif int(d) < 32: + day = d + elif int(d) > 1900 and int(d) < 2030: + year = d + effectiveDate = year + '-' + month + '-' + day + except: + effectiveDate = '' + + return effectiveDate # effectiveDate = '2011-07-13' + +def getJurisdiction(document): + jurisdiction = [] + + for word in document.ents: + if word.label_ == 'jurisdiction': + if word.text not in states: + for state in states: + if word.text in state: + jurisdiction.append(state) + else: + jurisdiction.append(word.text) + + if len(jurisdiction) > 0: + jurisdiction = { state : jurisdiction.count(state) for state in jurisdiction } + jurisdiction = max(jurisdiction, key=jurisdiction.get).replace(' ', '_') + else: + jurisdiction = '' + + return jurisdiction # jurisdiction = 'New_York' + +def getParties(document): + party = [] + + for word in document.ents: + if word.label_ == 'party': + party.append(word.text) + + party = list(dict.fromkeys(party)) + party = [ p.replace(' ', '_') for p in party] + + return party # party = ['CompuDyne_Corporation'] + +def getTerm(document): + term = [] + + for word in document.ents: + if word.label_ == 'term': + term.append(word.text) + + if len(term) > 0: + term = { time : term.count(time) for time in term } + term = max(term, key=term.get) + term = term.split() + term[0] = str(list(wordToNumber.keys())[list(wordToNumber.values()).index(term[0])]) + term = '_'.join(term) + else: term = '' + + return term # term = '3_years' + if __name__ == '__main__': NDAs = readInput('train/in.tsv.xz') ner = spacy.load('NER') - for nda in NDAs: - print('pass') \ No newline at end of file + predicted = [''] * len(NDAs) + + document = ner(NDAs[9]) + + for i in range(len(NDAs)): + document = ner(NDAs[i]) + + ed = getEffectiveDate(document) + j = getJurisdiction(document) + p = getParties(document) + t = getTerm(document) + + if len(ed) > 0: predicted[i] += 'effective_date=' + ed + ' ' + if len(j) > 0: predicted[i] += 'jurisdiction=' + j + ' ' + if len(p) > 0: + for party in p: predicted[i] += 'party=' + party + ' ' + if len(t) > 0: predicted[i] += 'term=' + t + + with open('train/out.tsv', 'w', newline='') as f: + writer = csv.writer(f) + writer.writerows(predicted) \ No newline at end of file