import numpy as np
import scipy.optimize as opt

from collections import namedtuple
import time

from auxiliary_functions import *

np.random.seed(5555)

##### PARAMETERS
# NUMBER OF OBSERVATIONS
n_obs = 10000
# Note: n_obs=100000 and n_grid_M_tot=101*101 requires 15.2GiB of memory

# VARIANCE OF NOISE
sigma = np.sqrt(np.array([3/4,3/4]))

# VALUES OF PARAMETERS TO CHECK
nu_values = [12]
m_values  = [12]

# NUMBER OF SIMULATIONS TO PERFORM
nsim = 1

# GRID SIZE FOR Mn (on [-nu,nu])
n_grid_M_tot = 101*101

# SQUARE SIZE (for random initialization on the square)
square_size = 3

# Should we use the initialization with points regularly spaced on the circle?
use_circle_init = True
# Standard deviation of the perturbation of the points around the regularly
# spaced positions
circle_init_perturbation = 0


##### ESTIMATION

# Find 1D size of the grid for Mn, and the index of 0 in it
n_grid_M_1D = round(np.sqrt(n_grid_M_tot)) # number of pts along each axis
n_grid_M_1D += 1 - n_grid_M_1D%2 # Make sure it's odd, so 0 is in the grid
index_zero = (n_grid_M_1D - 1) // 2 # index of value 0 in the 1D grid

# Define namedtuple list to gather the results.
simOutput = namedtuple('simOutput', ['sim_ind', 'sigma', 'n_obs', 'm', 'nu',
                                     'est_coefs', 'cvg_success', 'nit',
                                     'init_used', 'n_grid_M', 'score'
                                    ])
bloc = []

observations = np.zeros((n_obs, nsim, 2))
sim_values = np.arange(nsim)


# Start timer
start_1 = time.time()

# Simulation loop, minimization of Mn for each m, nu
for ind_sim, sim in enumerate(sim_values):
    # Generate and save the observations
    Y1, Y2 = generateY(n_obs, sigma, 2)
    observations[:,sim,0] = Y1
    observations[:,sim,1] = Y2
    
    # Estimate for several parameter values
    for ind_nu, nu in enumerate(nu_values):        
        # Grid points
        grid_Mn  = np.linspace(-nu, nu, n_grid_M_1D)
        
        # Pre-computation of Phitilde (estimator of Phi_Y)
        Phitilde = phiY(grid_Mn, grid_Mn, Y1, Y2)
        Phitilde_dot0 = Phitilde[:,index_zero]
        Phitilde_0dot = Phitilde[index_zero,:]
        
        for ind_m, m in enumerate(m_values):
            # Minimisation of Mn (starting at a uniform spread of points
            # on the circle)
            score_unif_circle = np.inf
            if use_circle_init:
                theta_unif_circle = np.zeros((m,3)) / m
                theta_unif_circle[:,1] = (np.cos(2 * np.pi * np.arange(m) / m)
                        + circle_init_perturbation * np.random.randn(m) )
                theta_unif_circle[:,2] = (np.sin(2 * np.pi * np.arange(m) / m)
                        + circle_init_perturbation * np.random.randn(m) )
                theta_unif_circle = theta_unif_circle.reshape(-1)
                res_unif_circle = opt.minimize(Mn, x0 = theta_unif_circle,
                                   method='BFGS',
                                   args=(Phitilde,Phitilde_dot0,Phitilde_0dot,
                                         nu,n_grid_M_1D))
                result_unif_circle = res_unif_circle.x
                score_unif_circle  = Mn(result_unif_circle,Phitilde,
                                        Phitilde_dot0,Phitilde_0dot,nu,
                                        n_grid_M_1D)
            print(".", end="")
            
            # Minimisation of Mn (starting at a uniform spread of points
            # in a square of size square_size)
            score_unif_square = np.inf
            theta_unif_square = np.zeros((m,3)) / m
            theta_unif_square[:,1] = np.random.uniform(low  = -square_size,
                                                       high =  square_size,
                                                       size=m)
            theta_unif_square[:,2] = np.random.uniform(low  = -square_size,
                                                       high =  square_size,
                                                       size=m)
            theta_unif_square = theta_unif_square.reshape(-1)
            res_unif_square = opt.minimize(Mn, x0 = theta_unif_square,
                                 method='BFGS',
                                 args=(Phitilde, Phitilde_dot0, Phitilde_0dot,
                                       nu, n_grid_M_1D))
            result_unif_square = res_unif_square.x
            score_unif_square  = Mn(result_unif_square,Phitilde,Phitilde_dot0,
                                    Phitilde_0dot,nu,n_grid_M_1D)
            print(".", end="")
            
            # Choose result with best score
            min_score = min(score_unif_circle, score_unif_square)
            if (score_unif_circle == min_score):
                res = res_unif_circle
                init_used = "unif_circle"
            else:
                res = res_unif_square
                init_used = "unif_square"
            result = res.x
            
            # Save results and display progress status
            bloc.append(simOutput(sim_ind=sim, sigma=sigma, n_obs=n_obs, m=m,
                            nu=nu, est_coefs=result,
                            cvg_success=res.success, nit=res.nit,
                            init_used=init_used,
                            n_grid_M=n_grid_M_tot, score=min_score
                            ))
            print(f"{ind_sim+1}/{nsim}, nu={round(nu,2):<3}, "
                  f"m={m:<2}, Mn={min_score:.2e}")

end_1 = time.time()
duration_1 = end_1 - start_1
print("It took", round(duration_1, 3), "seconds.\n")



#%% PLOT observations, initial points and estimated points
import matplotlib.pyplot as plt

res = bloc[0]
m = res.est_coefs.size // 3
thetaTemp = res.est_coefs.reshape((m,3))
center = thetaTemp[:,1:]

eta = thetaTemp[:,0]
weight = softmax(eta) # "softmax" defined in auxiliary_functions.py
barycenter = np.average(center, weights=weight, axis=0)
center = center - barycenter # Ensure the signal is centered

fig, ax = plt.subplots(figsize=(5, 5))
plt.scatter(center[:,0], center[:,1], s=weight**2 * 300*m, c='red',
         label="Support points", zorder=10)

Y1 = observations[:,res.sim_ind,0]
Y2 = observations[:,res.sim_ind,1]
plt.plot(Y1, Y2, '.', color="black",
         label="Observations")

plt.xlim(-4,4)
plt.ylim(-4,4)

plt.savefig("plot_obs+supportpoints.pdf", format="pdf")
plt.show()


#%% PLOT estimated points + real support
from matplotlib.patches import Circle

fig, ax = plt.subplots(figsize=(5, 5))
plt.scatter(center[:,0], center[:,1], s=weight**2 * 300*m, c='red')

circle = Circle((0, 0), 2, fill=False)
ax.add_patch(circle)

plt.xlim(-4,4)
plt.ylim(-4,4)

plt.title("True support and estimated support points")

plt.savefig("plot_truesupport+supportpoints.pdf", format="pdf")
plt.show()
