print("-------------------------------------------------------------")
print("     YOUNG'S INTERFERENCE IN THE SINGLE-PARTICLE REGIME")
print("written by David Papoular           david.papoular@u-cergy.fr")
print("Part of Problem 1 for the ICFP-M2 Quantum Mechanics tutorials")
print("                         2024-09-18")
print("-------------------------------------------------------------")
print("\n\n")

import numpy as np                   #for exp function
import matplotlib.pyplot as plt      #static plots
from matplotlib import animation     #animations
import scipy.integrate as integrate  #integral used for normalisation
from scipy import stats              #probability distributions

######################################################################
# SIMULATION PARAMETERS
# Plausible values for neon atoms may be extracted from:
# [Shimizu et al, PRA 46, R17(R) (1992)]

a=.004             #Distance between the two holes   [mm]
d=4                #Distance between mask and screen [mm]
lambdadB=.001      #de Broglie wavelength            [mm]

delta=lambdadB*d/a #Expected fringe spacing          [mm]
print("Expected fringe spacing near x=0: delta=",delta, " mm")

######################################################################
# WAVEFUNCTION AND PROBABILITY DISTRIBUTION

l1= lambda x: ((x-.5*a)**2+d**2)**.5   #Distance to hole 1 [mm]
l2= lambda x: ((x+.5*a)**2+d**2)**.5   #Distance to hole 2 [mm]

#Non-normalised wavefunction
psi  = lambda x: (
    np.exp(1j*2.*np.pi/lambdadB*l1(x))/l1(x)
    +
    np.exp(1J*2.*np.pi/lambdadB*l2(x))/l2(x)
    )
psi2 = lambda x: abs(psi(x))**2        #squared non-normalised wavefunction

xM=50.*delta                           #cutoff in the x direction
#Squared normalisation factor: exploit symmetry with respect to x=0
Zpsi2=2.*integrate.quad(psi2,0,xM)[0]  #[1] contains error estimate
prob = lambda x: psi2(x)/Zpsi2         #Probability distribution

#FIGURE 1: plot squared wavefunction

xtab=np.linspace(-xM,xM,num=1000)      #Values of x for static plot
probtab=prob(xtab)                     #Values of psi2(x) for static plot

plt.figure(1)                          #first figure window
plt.clf()                              #erase figure contents
plt.xlim(-50,50)                       #set lower, upper x values for plot
plt.xticks([-50,-25,0,25,50])          #ticks on x axis
plt.xlabel("Position x $[\mathrm{mm}]$",
           size=12)                    #x axis title
plt.ylim(0,0.16)                       #set lower, upper y values for plot
plt.yticks([0,.05,.1,.15])             #ticks on y axis
plt.ylabel("Squared wavefunction $|\psi(x)|^2$"
           ,size=12)                   #y axis title
plt.plot(xtab,probtab)                 #generate plot
plt.savefig("young_psi2.pdf")          #save figure to PDF file

print("Press 'q' or close Figure 1 window to launch animation")

plt.show()                             #display figure


######################################################################
#PROBABILITY DISTRIBUTION CORRESPONDING TO SQUARED WAVEFUNCTION

#Define probability density from squared wavefunction
class young_distribution(stats.rv_continuous):
    def _pdf(self, x): #pdf means Probability Density Function
        return prob(x)

#Create corresponding density distribution    
#the parameters a and b set the distribution to 0 for x<-xM and w>xM
distribution = young_distribution(a=-xM,b=xM)

nmeasurements=10000  #Number of measurements to simulate
xvallist=[]          #Empty list to be filled with measurement results                

######################################################################
#ANIMATED RECONSTRUCTION OF THE PROBABILITY DENSITY

#Function updating animation: i is the measurement index
def animate(i):
    i=i+1
    #Draw new impact position and add it to the list
    xvallist.append(distribution.rvs())

    #Create the histogram
    #Set density=True to choose normalisation which reproduces psi2
    hist,bin_edges=np.histogram(xvallist,       #List of measurement results
                                bins=500,       #Number of bins
                                range=(-xM,xM), #histogram range
                                density=False)  #normalisation

    #For the bin whose edges are a and b, plot bar at .5*(a+b)
    bin_coords=bin_edges[:-1]+.5*np.diff(bin_edges)[0]
    plt.clf()
    plt.title("#measurements: "+str(i)) #Display current measurement number
    plt.xlim(-50,50)
    plt.xticks([-50,-25,0,25,50])
    plt.xlabel("Position x $[\mathrm{mm}]$",size=12)
    #Uncomment the next two lines if density=True is used in histogram
    #plt.yticks([0,.05,.1,.15]) 
    #plt.ylim(0,0.16)
    plt.xlabel("Position x $[\mathrm{mm}]$",
               size=12)
    plt.ylabel("Number of counts per bin",
               size=12)
    plt.plot(bin_coords,hist,color='royalblue')

#FIGURE 2: animation    
fig2=plt.figure(2)   #Second figure window
#Create animation
anim = animation.FuncAnimation(fig2,          #appears as figure 2
                               animate,       #function updating animation
                               nmeasurements, #number of frames
                               interval=1,    #time between frames
                               repeat=False)  #Stop after nmeasurements
plt.show()           #Display animation


    



