171 lines
4.9 KiB
Python
171 lines
4.9 KiB
Python
from flask import Flask, request, jsonify
|
|
import boto3
|
|
import re
|
|
import os
|
|
from flask_cors import CORS
|
|
import time
|
|
|
|
# from dotenv import load_dotenv
|
|
|
|
# load_dotenv()
|
|
|
|
# AWS_ACCESS_KEY_ID = os.getenv("AWS_ACCESS_KEY_ID")
|
|
# AWS_SECRET_ACCESS_KEY = os.getenv("AWS_SECRET_ACCESS_KEY")
|
|
# AWS_SESSION_TOKEN = os.getenv("AWS_SESSION_TOKEN")
|
|
# AWS_REGION = os.getenv("AWS_REGION")
|
|
# BUCKET_NAME = os.getenv("AWS_S3_BUCKET_NAME", "s464979-test")
|
|
|
|
AWS_REGION = "us-east-1"
|
|
BUCKET_NAME = "s464979-test-2"
|
|
|
|
s3_client = boto3.client(
|
|
's3',
|
|
# aws_access_key_id=AWS_ACCESS_KEY_ID,
|
|
# aws_secret_access_key=AWS_SECRET_ACCESS_KEY,
|
|
# aws_session_token=AWS_SESSION_TOKEN,
|
|
region_name=AWS_REGION
|
|
)
|
|
|
|
textract_client = boto3.client(
|
|
'textract',
|
|
# aws_access_key_id=AWS_ACCESS_KEY_ID,
|
|
# aws_secret_access_key=AWS_SECRET_ACCESS_KEY,
|
|
# aws_session_token=AWS_SESSION_TOKEN,
|
|
region_name=AWS_REGION
|
|
)
|
|
|
|
application = Flask(__name__)
|
|
CORS(application)
|
|
|
|
@application.route('/')
|
|
def hello_world():
|
|
return 'Hello, World!'
|
|
|
|
@application.route('/invoice', methods=['POST'])
|
|
def process_invoice():
|
|
file = request.files['file']
|
|
file_name = file.filename
|
|
|
|
s3_client.upload_fileobj(file, BUCKET_NAME, file_name)
|
|
|
|
response = textract_client.start_document_analysis(
|
|
DocumentLocation={'S3Object': {'Bucket': BUCKET_NAME, 'Name': file_name}},
|
|
FeatureTypes=['FORMS']
|
|
)
|
|
job_id = response['JobId']
|
|
|
|
wait_for_textract_job(job_id)
|
|
|
|
pages = get_textract_job_results(job_id)
|
|
|
|
kvs = get_kv_map(pages)
|
|
|
|
print("Extracted Key-Value Pairs:", kvs)
|
|
|
|
sprzedawca = kvs.get("Sprzedawca:", "")
|
|
nip = re.search(r"\b\d{10}\b", sprzedawca)
|
|
wartosc_brutto = kvs.get("Wartosc brutto", "")
|
|
|
|
return jsonify({
|
|
"seller": sprzedawca,
|
|
"vat_id": nip.group(0) if nip else "",
|
|
"total": wartosc_brutto
|
|
})
|
|
|
|
|
|
def get_textract_job_results(job_id):
|
|
pages = []
|
|
response = textract_client.get_document_analysis(JobId=job_id)
|
|
pages.append(response)
|
|
|
|
# If the response is paginated, continue to fetch all pages
|
|
next_token = response.get('NextToken', None)
|
|
while next_token:
|
|
response = textract_client.get_document_analysis(JobId=job_id, NextToken=next_token)
|
|
pages.append(response)
|
|
next_token = response.get('NextToken', None)
|
|
|
|
return pages
|
|
|
|
|
|
def get_kv_map(pages):
|
|
"""
|
|
Returns a dictionary of all extracted key-value pairs from a list of
|
|
Textract result pages.
|
|
"""
|
|
kvs = {}
|
|
|
|
# We first map BLOCK ids to the full block object for quick reference
|
|
block_map = {}
|
|
for page in pages:
|
|
for block in page['Blocks']:
|
|
block_map[block['Id']] = block
|
|
|
|
# Next, find the KEY blocks
|
|
for page in pages:
|
|
for block in page['Blocks']:
|
|
if block['BlockType'] == 'KEY_VALUE_SET' and 'KEY' in block.get('EntityTypes', []):
|
|
key_block = block
|
|
value_block = find_value_block(key_block, block_map) # We'll define this next
|
|
key = get_text(key_block, block_map)
|
|
val = get_text(value_block, block_map) if value_block else ""
|
|
kvs[key] = val
|
|
|
|
return kvs
|
|
|
|
|
|
def find_value_block(key_block, block_map):
|
|
"""
|
|
From a KEY block, find its associated VALUE block by looking at 'Relationships'.
|
|
"""
|
|
if 'Relationships' in key_block:
|
|
for rel in key_block['Relationships']:
|
|
if rel['Type'] == 'VALUE':
|
|
for value_id in rel['Ids']:
|
|
value_block = block_map[value_id]
|
|
return value_block
|
|
return None
|
|
|
|
|
|
def get_text(block, block_map):
|
|
"""
|
|
Recursively get text from a block by following relationships to 'CHILD' blocks.
|
|
"""
|
|
text = []
|
|
if 'Relationships' in block:
|
|
for rel in block['Relationships']:
|
|
if rel['Type'] == 'CHILD':
|
|
for child_id in rel['Ids']:
|
|
word = block_map[child_id]
|
|
if word['BlockType'] == 'WORD':
|
|
text.append(word['Text'])
|
|
elif word['BlockType'] == 'SELECTION_ELEMENT':
|
|
if word['SelectionStatus'] == 'SELECTED':
|
|
text.append('X')
|
|
return " ".join(text)
|
|
|
|
|
|
def extract_text_from_textract(result):
|
|
return " ".join([block['Text'] for block in result['Blocks'] if block['BlockType'] == 'LINE'])
|
|
|
|
def extract_vat_id(text):
|
|
match = re.search(r'\b\d{10}\b', text)
|
|
return match.group(0) if match else None
|
|
|
|
|
|
def wait_for_textract_job(job_id):
|
|
while True:
|
|
response = textract_client.get_document_analysis(JobId=job_id)
|
|
status = response['JobStatus']
|
|
print(f"Status Textract Job: {status}")
|
|
|
|
if status == 'SUCCEEDED':
|
|
return response
|
|
elif status == 'FAILED':
|
|
raise RuntimeError("Textract job failed")
|
|
|
|
time.sleep(2)
|
|
|
|
if __name__ == '__main__':
|
|
application.run(port=5000, debug=True)
|