import numpy as np
import scipy

##### MODEL-SPECIFIC: generate Y

def generateY(n_obs, sigma, radius):
    '''Generate new observations'''    
    theta = np.random.uniform(low = 0, high = 1, size = n_obs) * 2*np.pi
    eps1 = sigma[0] * np.random.randn(n_obs)
    eps2 = sigma[1] * np.random.randn(n_obs)
    # Observations
    Y1 = radius * np.cos(theta) + eps1
    Y2 = radius * np.sin(theta) + eps2
    return (Y1, Y2)


##### ESTIMATION FUNCTIONS

def phiY(t1, t2, Y1, Y2):
    '''Estimator of phi_Y at each (s1,s2) with s1 in t1 and s2 in t2.
        - t1 and t2 must be 1D numpy arrays or floats
        Output: value of Phitilde on the grid [t1 x t2] (2D numpy array).'''
    t1 = np.atleast_1d(t1)
    t2 = np.atleast_1d(t2)
    len1 = np.array(t1).size
    len2 = np.array(t2).size
    t1_tot_vec = (np.array(t1)[:,None] * np.ones((1,len2))).reshape(-1)
    t2_tot_vec = (np.ones((len1,1)) * np.array(t2)[None,:]).reshape(-1)
    tY1 = t1_tot_vec[:,None] * Y1[None,:]
    tY2 = t2_tot_vec[:,None] * Y2[None,:]
    return ( np.mean(np.exp(1j*(tY1 + tY2)), axis=1).reshape(len1,len2) )

def phiY1D(t, Y1D):
    '''Estimator of phi_Y at each s in t
        - t must be a 1D numpy array or a float
        Output: value of Phitilde on the grid t (1D numpy array).'''
    t = np.atleast_1d(t)
    tY = np.array(t)[:,None] * Y1D[None,:]
    return ( np.mean(np.exp(1j*(tY)), axis=1) )


def softmax(eta):
    temp = np.exp(eta - np.mean(eta))
    return temp / np.sum(temp)

def TmPhi2D(eta, center, gridrange, gridpoints):
    '''Compute Tm phi over a square grid [grid_1D x grid_1D].
        - grid_1D is np.linspace(-gridrange,gridrange,gridpoints)
        - eta is 1D array of size m, center is of size 2*m
        Output: evaluation of Tm Phi on [grid x grid].'''
    t1D = np.linspace(-gridrange, gridrange, gridpoints)
    result2D = np.zeros((gridpoints,gridpoints), dtype=complex)
    m = eta.size
    weight = softmax(eta)
    for k in range(0,m):
        exp1 = np.exp(1j*center[k,0]*t1D)
        exp2 = np.exp(1j*center[k,1]*t1D)
        result2D += weight[k] * exp1[:,None] * exp2[None,:]
    return result2D

def TmPhi1D(eta, center1D, gridrange, gridpoints):
    '''Compute Tm phi over the grid
        np.linspace(-gridrange,gridrange,gridpoints).
        - eta is 1D array of size m, center is of size 2*m'''
    t1D = np.linspace(-gridrange, gridrange, gridpoints)
    result1D = np.zeros(gridpoints, dtype=complex)
    m = eta.size
    weight = softmax(eta)
    for k in range(0,m):
        result1D += weight[k] * np.exp(1j * center1D[k] * t1D)
    return result1D


def Mn(theta, Phitilde, Phitilde_dot0, Phitilde_0dot, gridrange, gridpoints):
    '''Compute the empirical risk Mn, that is the integral over
        [-gridrange,gridrange]^2 of
        (t1,t2) -> | Phitilde(t1,t2) TmPhi(t1,0) TmPhi(0,t2)
                    - TmPhi(t1,t2) Phitilde(t1,0) Phitilde(0,t2) |^2.
        - TmPhi is computed from the mixture with parameters eta, center
        - Phitilde(t1,0) is Phitilde_dot0[t1], same for Phitilde(0,t2).
        - The integral is computed by Riemann sum over a grid with
            gridpoints points along each coordinate.
        Output: a nonnegative float.'''
    m = theta.size // 3
    thetaTemp = theta.reshape((m,3))
    eta = thetaTemp[:,0]
    center = thetaTemp[:,1:]
    phi1D_1 = TmPhi1D(eta, center[:,0], gridrange, gridpoints)
    phi1D_2 = TmPhi1D(eta, center[:,1], gridrange, gridpoints)
    phi2D = TmPhi2D(eta, center, gridrange, gridpoints)
    prodPhitildedot = Phitilde_dot0[:,None] * Phitilde_0dot[None,:]
    prodphidot      = phi1D_1[:,None] * phi1D_2[None,:]
    return (4*gridrange**2) * np.mean(
        np.absolute(Phitilde * prodphidot - phi2D * prodPhitildedot)**2 )
