import numpy as np
from numpy.linalg import inv
import matplotlib.pyplot as plt
from math import cos, sin, tan, asin, pi
import pandas as pd
from scipy.signal import butter, filtfilt

def get_data(i, data):
    x = data[0][i]  # (41500, 1)
    y = data[1][i]  # (41500, 1)
    z = data[2][i]  # (41500, 1)
    return x, y, z

def Euler_accel(ax, ay, az):
    g = 9.8 # 9.8
    theta = asin(ax / g)
    phi = asin(-ay / (g * cos(theta)))
    return phi, theta

def sec(theta):
    return 1/cos(theta)

def Ajacob(xhat, rates, dt):
    '''
    :param xhat: State Variables(phi, theta, psi)
    :param rates: angel speed(p,q,r)
    :param dt: variable to make discrete form
    '''
    A = np.zeros([3,3])
    phi = xhat[0]
    theta = xhat[1]

    p,q,r = rates[0], rates[1], rates[2]

    A[0][0] = q * cos(phi)*tan(theta) - r*sin(phi)*tan(theta)
    A[0][1] = q * sin(phi)*(sec(theta)**2) + r*cos(phi)*(sec(theta)**2)
    A[0][2] = 0

    A[1][0] = -q * sin(phi) - r * cos(phi)
    A[1][1] = 0
    A[1][2] = 0

    A[2][0] = q * cos(phi) * sec(theta) - r * sin(phi) * sec(theta)
    A[2][1] = q * sin(phi) * sec(theta)*tan(theta) + r*cos(phi)*sec(theta)*tan(theta)
    A[2][2] = 0

    A = np.eye(3) + A*dt

    return A

def fx(xhat, rates, dt):
    phi = xhat[0]
    theta = xhat[1]

    p,q,r = rates[0], rates[1], rates[2]

    xdot = np.zeros([3,1])
    xdot[0] = p + q * sin(phi) * tan(theta) + r * cos(phi)*tan(theta)
    xdot[1] = q * cos(phi) - r * sin(phi)
    xdot[2] = q * sin(phi)*sec(theta) + r * cos(phi) * sec(theta)

    xp = xhat.reshape(-1,1) + xdot*dt # xhat : (3,) --> (3,1)
    return xp

def Euler_EKF(z, rates, dt):
    global firstRun
    global Q, H, R
    global x, P
    if firstRun:
        H = np.array([[1,0,0],[0,1,0]])
        Q = np.array([[0.0001,0,0],[0,0.0001,0],[0,0,0.1]])
        R = 10 * np.eye(2)
        x = np.array([0, 0, 0]).transpose()
        P = 10 * np.eye(3)
        firstRun = False
    else:
        A = Ajacob(x, rates, dt)
        Xp = fx(x, rates, dt) # Xp : State Variable Prediction
        Pp = A @ P @ A.T + Q # Error Covariance Prediction

        K = (Pp @ H.T) @ inv(H@Pp@H.T + R) # K : Kalman Gain

        x = Xp + K@(z.reshape(-1,1) - H@Xp) # Update State Variable Estimation
        P = Pp - K@H@Pp # Update Error Covariance Estimation

    phi   = x[0]
    theta = x[1]
    psi   = x[2]
    return phi, theta, psi

def butter_lowpass_filter(data, cutoff, fs, order):
    nyq = (0.5 * fs) # Nyquist Frequency
    normal_cutoff = cutoff / nyq
    # Get the filter coefficients 
    b, a = butter(order, normal_cutoff, btype='low', analog=False)
    y = filtfilt(b, a, data)
    return y

def normalize_data_vector(data, division):
    return [i/division for i in list(data)]

H, Q, R = None, None, None
x, P = None, None
firstRun = True

division_param = 32768

# Filter requirements.
T = 5.0         # Sample Period
fs = 30.0       # sample rate, Hz
cutoff = 2      # desired cutoff frequency of the filter, Hz, slightly higher than actual 1.2 Hz
order = 2       # sin wave can be approx represented as quadratic

df = pd.read_csv('raw_data_6d.xls', sep='\t')

gyroX = df['%GyroX'].values
gyroY = df['GyroY'].values
gyroZ = df['GyroZ'].values

acceX = df['AcceX'].values
acceY = df['AcceY'].values
acceZ = df['AcceZ'].values

gyroX = normalize_data_vector(gyroX, division_param)
gyroY = normalize_data_vector(gyroY, division_param)
gyroZ = normalize_data_vector(gyroZ, division_param)

acceX = normalize_data_vector(acceX, division_param)
acceY = normalize_data_vector(acceY, division_param)
acceZ = normalize_data_vector(acceZ, division_param)

# gyroX = butter_lowpass_filter(gyroX, cutoff, fs, order)
# gyroY = butter_lowpass_filter(gyroY, cutoff, fs, order)
# gyroZ = butter_lowpass_filter(gyroZ, cutoff, fs, order)

# acceX = butter_lowpass_filter(acceX, cutoff, fs, order)
# acceY = butter_lowpass_filter(acceY, cutoff, fs, order)
# acceZ = butter_lowpass_filter(acceZ, cutoff, fs, order)

gyro_data = [gyroX, gyroY, gyroZ]
acce_data = [acceX, acceY, acceZ]

Nsamples = len(df)
EulerSaved = np.zeros([Nsamples,3])
dt = 0.01

for k in range(Nsamples):
    p, q, r = get_data(k, gyro_data)
    ax, ay, az = get_data(k, acce_data)
    phi_a, theta_a = Euler_accel(ax, ay, az)

    phi, theta, psi = Euler_EKF(np.array([phi_a, theta_a]).T, [p,q,r], dt)
    if type(phi) == type(np.array([])):
        EulerSaved[k] = [phi[0], theta[0], psi[0]]
    else:
        EulerSaved[k] = [phi, theta, psi]


t = np.arange(0, Nsamples * dt ,dt)
PhiSaved = EulerSaved[:,0] * 180/pi
ThetaSaved = EulerSaved[:,1] * 180/pi
PsiSaved = EulerSaved[:,2] * 180/pi

plt.figure()
plt.plot(t, PhiSaved)
plt.xlabel('Time [Sec]')
plt.ylabel('Roll angle [deg]')
plt.savefig('12_EulerEKF_roll.png')

plt.figure()
plt.plot(t, ThetaSaved)
plt.xlabel('Time [Sec]')
plt.ylabel('Pitch angle [deg]')
plt.savefig('12_EulerEKF_pitch.png')
plt.show()
'''
plt.subplot(133)
plt.plot(t, PsiSaved)
plt.xlabel('Time [Sec]')
plt.ylabel('Psi angle [deg]')
'''

n =int(T * fs) 
# sin wave
sig = np.sin(1.2*2*np.pi*n)# Lets add some noise
noise = 1.5*np.cos(9*2*np.pi*n) + 0.5*np.sin(12.0*2*np.pi*n)
data = sig + noise

print(data)