# -*- coding: utf-8 -*-
"""
Created on Sat Apr  5 19:47:25 2014

@author: denis
"""

import numpy as np

"""Les fonctions alternatives sont optionnelles : 
elles sont juste là pour montrer une alternative possible"""


def echange_ligne(A,i,j):  # L_i <-> L_j
    nbr_col=np.size(A[0])
    for k in range(nbr_col):
        A[i,k],A[j,k]=A[j,k],A[i,k]

def echange_ligne_alternatif(A,i,j):
    A[i,:],A[j,:]=np.copy(A[j,:]),np.copy(A[i,:])

def transvection(A,i,j,mu): #  L_i <- L_i + mu L_j
    nbr_col=np.size(A[0])
    for k in range(nbr_col):
        A[i,k]=np.copy(A[i,k])+ mu*np.copy(A[j,k])
        
def transvection_alternatif(A,i,j,mu): #  L_i <- L_i + mu L_j
    A[i]=np.copy(A[i])+ mu*np.copy(A[j])

def chercher_pivot(A,i):
    nbr_lignes=np.shape(A)[0]
    k=i                             # On commence ligne i
    for p in range(i+1,nbr_lignes):
        if abs(A[p,i])>abs(A[k,i]):
            k=p
    return k

def resolution(A0,Y0):
    """Résolution de A0 X = Y0, avec A0 carrée inversible"""
    A=np.copy(A0) #On copie, pour ne pas modifier la matrice A0 de départ
    Y=np.copy(Y0) #idem Y0
    n=np.shape(A)[0] #nombre de lignes
    assert np.shape(A)[1]==n # On vérifie que A est carrée
    # Mise sous forme triangulaire :
    for i in range(n-1):
        k=chercher_pivot(A,i)
        if k>i: # on échange si A[i,i] n'est pas le plus grand pivot
            echange_ligne(A,i,k)
            echange_ligne(Y,i,k)
        for k in range(i+1,n):
            mu=-A[k,i]/float(A[i,i]) #Important !!! Sinon -A[k,i]/float(A[i,i])=0 lorsqu'on l'applique à Y.
            transvection(A,k,i,mu)
            transvection(Y,k,i,mu)
    # Remontée de pivot
    for i in range(n-1,-1,-1):
        Y[i,0]=(Y[i,0]-np.sum(A[i,i+1:n]*Y[i+1:n,0]))/A[i,i]
    return A,Y

"""Complexité de resolution : O(n^3)"""

#Tests :
A1=np.array([[1.,2,3],[0,4,5],[0,0,6]])
A2=np.dot(A1,np.transpose(A1))
A=A2
print(A)
Y=np.array([[1.],[5],[2]])
print(Y)
B,X=resolution(A,Y)
print('-----A triangulaire :-----')
print(B)
print('-----X maison :-----')
print(X)
print('AX-Y')
print(np.dot(A,X)-Y)
print('-----X via inv :-----')
X=np.dot(np.linalg.inv(A),Y)
print(X)
print('AX-Y')
print(np.dot(A,X)-Y)
#print('-----no gag : A et Y n\'ont pas changé en route-----')
#print(A)
#print(Y)
