import CompuCell
from PySteppables import *
from PySteppablesExamples import *
import os,inspect
from math import *
import time
import random

class InitialConditions(SteppableBasePy):
    def __init__(self,simulator,frequency,cd,targV,LamV,percApical,percBasal,percLateral,pASick,pBSick,pLSick,Trelax):
        SteppableBasePy.__init__(self,simulator,frequency)
        self.cd=cd;                   self.targV=targV;           self.LamV=LamV;                 
        self.percApical=percApical;   self.percBasal=percBasal;   self.percLateral=percLateral;   
        self.pASick=pASick;           self.pBSick=pBSick;         self.pLSick=pLSick;
        self.Trelax=Trelax;

        self.pC=1. - percApical - percLateral - percBasal;        
        self.growth=0.2*cd

    def start(self):
        #Drawing Tube
        Lx3=int(floor(self.dim.x/3));    Ly3=int(floor(self.dim.y/3));
        dAng=2.*pi/10.;         #there are 10 cell in the perimeter of the tubule in each xy-plane      
        AngPlus=0.;
        Rm=10.*self.cd/(2.*pi)  #Mean lumen diameter ~1.6 cell diameter         
        Rmin=Rm-self.cd/2.;                 Rmax=Rm+self.cd/2;
        lumen=self.newCell(self.LUMEN);
        
        for z in range(self.dim.z):
            if (z % self.cd == 0):
                #Creating new cells
                c=[]; 
                for i in range(10): c.append(self.newCell(self.CYTO))
                AngPlus+=pi/10.
            #Drawing
            for x in range(self.dim.x):
                dx=x-Lx3
                for y in range(self.dim.y):
                    dy=y-Ly3;
                    R=sqrt(dx*dx+dy*dy)
                    if(R<=Rmin):
                        self.cellField[x,y,z]=lumen
                    elif(R<=Rmax):
                        ang=acos(dx/R)
                        if(dy<0): ang=2.*pi-ang;
                        ang=(ang+AngPlus)%(2.*pi)
                        a=int(floor(ang/dAng))
                        if (a==10): cell=c[0];
                        else: cell=c[a];
                        self.cellField[x,y,z]=cell
        #Setting volume parameters
        for cell in self.cellListByType(self.CYTO):
            cell.targetVolume=cell.volume
            cell.lambdaVolume=self.LamV
            
            
    def step(self,mcs):
        t0=time.time()
        if(mcs==0):
            for cell in self.cellListByType(self.CYTO):  #Cyto
                x=cell.xCOM;   y=cell.yCOM;  z=cell.zCOM;
                if( x>(self.dim.x/3.+2.) and y>(self.dim.y/3.+10.) ):
                    if( z>(self.dim.z/2.-self.cd/2.) and z<(self.dim.z/2.+self.cd/2.) ):
                        self.sickCellId=cell.clusterId
                self.PixelatingB(cell)
        elif(mcs==self.Trelax):
            V=0.; compList=self.inventory.getClusterCells(self.sickCellId)
            for cell2 in compList:
                V+=cell2.volume
                if (cell2.type==self.LATERAL):  #Lateral
                    cell2.type=self.LATERAL2  #Lateral2
            S=6.*(pow(V,(2./3.)))
            for cell2 in compList:
                if (cell2.type==self.APICAL):
                    cell2.targetVolume=int(S*self.pASick*1.2+.5)
                elif (cell2.type==self.BASAL):
                    cell2.targetVolume=int(S*self.pASick*1.2+.5)
                elif (cell2.type==self.LATERAL):
                    cell2.targetVolume=int(S*self.pLSick*1.2+.5)
        
        
    def PixelatingB(self,cell):
        S=6.*pow(cell.volume,(2./3.))
        aCell=0;  bCell=0;  lCell=0;
        compList=self.inventory.getClusterCells(cell.clusterId)
        for cell2 in compList:
            if (cell2.type==self.APICAL):
                aCell=cell2
            elif (cell2.type==self.BASAL):
                bCell=cell2
            elif (cell2.type==self.LATERAL):
                lCell=cell2
        if (not aCell):
            aCell=self.newCell(self.APICAL)  #Create Apical domain
        if (not bCell):
            bCell=self.newCell(self.BASAL)   #Create Basal domain
        if (not lCell):
            lCell=self.newCell(self.LATERAL) #Create Lateral domain
        pList=[]; pixelList=CellPixelList(self.pixelTrackerPlugin,cell)
        for pixel in pixelList:
            pList.append([pixel.pixel.x,pixel.pixel.y,pixel.pixel.z])
        random.shuffle(pList)
        for i in range(int(floor(1.2*S))):
            x=pList[i][0];  y=pList[i][1];  z=pList[i][2];
            r=random.random()
            if (r<self.percApical):
                self.cellField[x,y,z]=aCell
            elif (r<(self.percApical+self.percBasal)):
                self.cellField[x,y,z]=bCell
            elif (r<(self.percApical+self.percBasal+self.percLateral)):
                self.cellField[x,y,z]=lCell

        #CYTO subcell
        cell.targetVolume=int(floor(cell.targetVolume-1.2*S+.5))
        cell.lambdaVolume=4*self.LamV
        #APICAL subcell
        aCell.targetVolume=int(floor(1.2*S*self.percApical+.5))
        aCell.lambdaVolume=4*self.LamV
        reassignIdFlag=self.inventory.reassignClusterId(aCell,cell.clusterId)
        #BASAL subcell
        bCell.targetVolume=int(floor(1.2*S*self.percBasal+.5))
        bCell.lambdaVolume=4*self.LamV
        reassignIdFlag=self.inventory.reassignClusterId(bCell,cell.clusterId)
        #LATERAL subcell
        lCell.targetVolume=int(floor(1.2*S*self.percLateral+.5))
        lCell.lambdaVolume=4*self.LamV
        reassignIdFlag=self.inventory.reassignClusterId(lCell,cell.clusterId)
        
########################################################

class ContactInhibition(SteppableBasePy):
    def __init__(self,simulator,frequency,gFactor,hillCoef,critAlpha,gFactorSick,hillCoefSick,
                 critAlphaSick,percApical,percBasal,percLateral,pASick,pBSick,pLSick):
        SteppableBasePy.__init__(self,simulator,frequency)
        self.gFactor=gFactor;          self.hillCoef=hillCoef;          
        self.gFactorSick=gFactorSick;  self.hillCoefSick=hillCoefSick;  
        self.percApical=percApical;    self.percBasal=percBasal;        self.percLateral=percLateral;           
        self.pASick=pASick;            self.pBSick=pBSick;              self.pLSick=pLSick;
        self.critAlphaHill=pow(critAlpha,hillCoef);
        self.critAlphaSickHill=pow(critAlphaSick,hillCoefSick);
        
    def start(self):
        for cell in self.cellListByType(self.CYTO):
            cellDict=CompuCell.getPyAttrib(cell)
            compartmentList=self.inventory.getClusterCells(cell.clusterId)
            SCell=0;  ST=0.;
            for cell2 in compartmentList:
                cellNeighborList=self.getCellNeighbors(cell2)
                Scell=0;  St=0.0;
                #'common surface area between cells'/'common surface between cell + lumen + medium'
                for neighbor in cellNeighborList:
                    if (not neighbor.neighborAddress):  #Medium
                        ST+=neighbor.commonSurfaceArea
                    elif (neighbor.neighborAddress.clusterId != cell.clusterId):
                        ST+=neighbor.commonSurfaceArea
                        if (neighbor.neighborAddress.type!=self.LUMEN):
                            Scell+=neighbor.commonSurfaceArea         
                       
            cellDict["Scell"]=Scell/ST
    
    def step(self,mcs):
        for cell in self.cellListByType(self.CYTO):
            #Contact inhibition
            compList=self.inventory.getClusterCells(cell.clusterId)
            Vol=0; Scell=0;  ST=0.;  type="normal";
            for cell2 in compList:
                Vol+=cell2.targetVolume
                cellNeighborList=self.getCellNeighbors(cell2)
                #'common surface area between cells'/'common surface between cell + lumen + medium'
                for neigh in cellNeighborList:
                    if ( not neigh.neighborAddress ):  #Cell-Medium
                        ST+=neigh.commonSurfaceArea
                    elif (neigh.neighborAddress.clusterId != cell.clusterId):
                        ST+=neigh.commonSurfaceArea
                        if(neigh.neighborAddress.type != self.LUMEN):
                            Scell+=neigh.commonSurfaceArea
                if ( cell2.type == self.LATERAL2 ):
                    type="sick"
            
            S=6.*1.2*pow(Vol,(2./3.))
            cellDict=CompuCell.getPyAttrib(cell)
            cellDict["Scell"]=cellDict["Scell"]*49./50. + Scell/ST/50.
            if ( type == "normal" ):
                aN=pow(cellDict["Scell"],self.hillCoef)
                G=self.critAlphaHill*(1.-aN)/(self.critAlphaHill + aN)
                cell.targetVolume+=G*self.gFactor
                for cell2 in compList:
                    if ( cell2.type == self.APICAL ):   #Apical
                        cell2.targetVolume=S*self.percApical
                    elif (cell2.type == self.BASAL ):   #Basal
                        cell2.targetVolume=S*self.percBasal
                    elif (cell2.type == self.LATERAL ): #Lateral
                        cell2.targetVolume=S*self.percLateral
            else: #type="sick"
                aN=pow(cellDict["Scell"],self.hillCoefSick)
                G=self.critAlphaSickHill*(1.-aN)/(self.critAlphaSickHill + aN)
                cell.targetVolume+=G*self.gFactorSick
                for cell2 in compList:
                    if ( cell2.type == self.APICAL ):   #Apical
                        cell2.targetVolume=S*self.pASick
                    elif (cell2.type == self.BASAL ):   #Basal
                        cell2.targetVolume=S*self.pBSick
                    elif (cell2.type == self.LATERAL2 ): #Lateral2
                        cell2.targetVolume=S*self.pLSick
                  

###############################################################

class Mitosis(MitosisSteppableBase):
    def __init__(self,simulator,frequency,targV,LamV,percApical,percBasal,percLateral,pASick,pBSick,pLSick):
        MitosisSteppableBase.__init__(self,simulator,frequency)
        self.targV=targV;             self.LamV=LamV;
        self.percApical=percApical;   self.percBasal=percBasal;   self.percLateral=percLateral;
        self.percCell=1. - percApical - percBasal - percLateral;
        self.pASick=pASick;           self.pBSick=pBSick;         self.pLSick=pLSick;

    
    def step(self,mcs):
        div=[];  self.pixelate=[]; Vdouble=2.*self.targV;
        #Checking mitosis condition
        for cluster in self.clusterList:
            clusterVolume=0;  latType=self.LATERAL;
            for cell in CompartmentList(cluster):
                if (cell.type == self.LUMEN): break;
                clusterVolume+=cell.volume
                if (cell.type == self.CYTO): cyto=cell;
                if (cell.type == self.LATERAL2): latType=self.LATERAL2;
            if (clusterVolume > Vdouble):
                div.append([cyto,latType])
                
        #Dividing
        for cell,latType in div:
            self.latType=latType
            v=self.gettingVector(cell)
            self.dePixelating(cell)
            if (v == [0,0,0]):
                self.divideCellAlongMinorAxis(cell)
            else:
                self.divideCellOrientationVectorBased(cell,v[0],v[1],v[2])
        #rePixalating
        for cell,latType in self.pixelate:
            self.PixelatingB(cell,latType)
        
    def updateAttributes(self):
        parentCell=self.mitosisSteppable.parentCell
        childCell=self.mitosisSteppable.childCell
        childCell.type=parentCell.type
        parentCell.targetVolume=self.targV
        childCell.targetVolume=self.targV
        #Dictionary
        pCellDict=CompuCell.getPyAttrib(parentCell)
        cCellDict=CompuCell.getPyAttrib(childCell)
        cCellDict["Scell"]=pCellDict["Scell"]
        #Add to pixelating list
        self.pixelate.append([childCell,self.latType])
        self.pixelate.append([parentCell,self.latType])

    def gettingVector(self,cell):
        v=[0,0,0]; vA=[0,0,0];  vB=[0,0,0]; nLL=0;
        compList=self.inventory.getClusterCells(cell.clusterId)
        for cell2 in compList:
            if (cell2.volume):
                if (cell2.type == self.latType):
                    cellNeighborList=self.getCellNeighbors(cell2)
                    for neighbor in cellNeighborList:
                        if (neighbor.neighborAddress):
                            if ((neighbor.neighborAddress.type == self.LATERAL) or (neighbor.neighborAddress.type==self.LATERAL2)):
                                nLL+=1
                elif (cell2.type == self.APICAL):  #Apical
                    vA[0]=cell.xCOM-cell2.xCOM
                    vA[1]=cell.yCOM-cell2.yCOM
                    vA[2]=cell.zCOM-cell2.zCOM
                    vv=sqrt(pow(vA[0],2.)+pow(vA[1],2.)+pow(vA[2],2.))
                    vA[0]=vA[0]/vv;  vA[1]=vA[1]/vv;  vA[2]=vA[2]/vv;
                elif (cell2.type == self.BASAL):  #Basal
                    vB[0]=cell.xCOM-cell2.xCOM
                    vB[1]=cell.yCOM-cell2.yCOM
                    vB[2]=cell.zCOM-cell2.zCOM
                    vv=sqrt(pow(vB[0],2.)+pow(vB[1],2.)+pow(vB[2],2.))
                    vB[0]=vB[0]/vv;  vB[1]=vB[1]/vv;  vB[2]=vB[2]/vv;
        
        if(nLL<2):
            v=[0,0,0]
        else:  # V = Va x Vb
            v[0]=vA[1]*vB[2]-vA[2]*vB[1]
            v[1]=vA[2]*vB[0]-vA[0]*vB[2]
            v[2]=vA[0]*vB[1]-vA[1]*vB[0]
        return v


    def dePixelating(self,cell):
        L=[];  compList=self.inventory.getClusterCells(cell.clusterId)
        for cell2 in compList:
            if (cell2.id != cell.id):
                pixelList=self.getCellPixelList(cell)
                for pixel in pixelList:
                    L.append(pixel.pixel)
        for pixel in L:
            self.cellField.set(pixel,cell)
        cell.targetVolume+=len(L)
            
    
    def PixelatingB(self,cell,latType):
        S=6.*pow(cell.volume,(2./3.))
        aCell=0;  bCell=0;  lCell=0;
        compList=self.inventory.getClusterCells(cell.clusterId)
        for cell2 in compList:
            if (cell2.type == self.APICAL):
                aCell=cell2
            elif (cell2.type == self.BASAL):
                bCell=cell2
            elif (cell2.type == latType):
                lCell=cell2
        if (not aCell):
            aCell=self.newCell(self.APICAL)   # Create apical domain
        if (not bCell):
            bCell=self.newCell(self.BASAL)    # Create basal domain
        if (not lCell):
            lCell=self.newCell(latType)       # Create lateral domain
        pList=[]; pixelList=CellPixelList(self.pixelTrackerPlugin,cell)
        for pixel in pixelList:
            x=pixel.pixel.x;  y=pixel.pixel.y;  z=pixel.pixel.z;
            pList.append([x,y,z])
        random.shuffle(pList)
        for i in range (int(floor(1.2*S))):
            x=pList[i][0];  y=pList[i][1];   z=pList[i][2];
            r=random.random()
            if (r < self.percApical):
                self.cellField[x,y,z]=aCell
            elif (r < (self.percApical+self.percBasal)):
                self.cellField[x,y,z]=bCell
            elif (r < (self.percApical+self.percBasal+self.percLateral)):
                self.cellField[x,y,z]=lCell

        #CYTO subcell
        cell.targetVolume=int(floor(cell.targetVolume-1.2*S+0.5))
        cell.lambdaVolume=4*self.LamV
        if (latType == self.LATERAL):
            #APICAL subcell
            aCell.targetVolume=int(floor(1.2*S*self.percApical+0.5))
            aCell.lambdaVolume=4*self.LamV
            reassignIdFlag=self.inventory.reassignClusterId(aCell,cell.clusterId)
            #BASAL subcell
            bCell.targetVolume=int(floor(1.2*S*self.percBasal+0.5))
            bCell.lambdaVolume=4*self.LamV
            reassignIdFlag=self.inventory.reassignClusterId(bCell,cell.clusterId)
            #LATERAL subcell
            lCell.targetVolume=int(floor(1.2*S*self.percLateral+0.5))
            lCell.lambdaVolume=4*self.LamV
            reassignIdFlag=self.inventory.reassignClusterId(lCell,cell.clusterId)
        else: #latType == self.LATERAL2
            #APICAL subcell                                                              
            aCell.targetVolume=int(floor(1.2*S*self.pASick+0.5))
            aCell.lambdaVolume=4*self.LamV
            reassignIdFlag=self.inventory.reassignClusterId(aCell,cell.clusterId)
            #BASAL subcell                                                                                     
            bCell.targetVolume=int(floor(1.2*S*self.pBSick+0.5))
            bCell.lambdaVolume=4*self.LamV
            reassignIdFlag=self.inventory.reassignClusterId(bCell,cell.clusterId)
            #LATERAL subcell                                                                                     
            lCell.targetVolume=int(floor(1.2*S*self.pLSick+0.5))
            lCell.lambdaVolume=4*self.LamV
            reassignIdFlag=self.inventory.reassignClusterId(lCell,cell.clusterId)
            
            
####################################################################################3

class LumenFlux(SteppableBasePy): # NEW flux based lumen growth
    def __init__(self,simulator,frequency,LamV,Kl,ApArea):
        SteppableBasePy.__init__(self,simulator,frequency)
        self.boundaryStrategy=CompuCell.BoundaryStrategy.getInstance()
        self.maxNeighborIndex=self.boundaryStrategy.getMaxNeighborIndexFromNeighborOrder(1)
        self.LamV=LamV;   self.Kl=Kl;  self.ApArea=ApArea;


    def start(self):
        for cell in self.cellListByType(self.LUMEN):
            self.LumenID=cell.id
            self.LumenVol=cell.volume
            cell.lambdaVolume=5*self.LamV
            cell.targetVolume=self.LumenVol

    def Merging(self,cell,cellD):
        L=[]; pixelList=CellPixelList(self.pixelTrackerPlugin,cellD);
        for pixelData in pixelList:
            L.append(CompuCell.Point3D(pixelData.pixel))
        cell.targetVolume+=cellD.targetVolume
        for pt in L:
            self.cellField.set(pt,cell)
        self.cleanDeadCells()
        
    def step(self,mcs):
        #LUMEN MAINTENANCE
        List=[]; dCell=[];
        for cell in self.cellListByType(self.LUMEN):  #Lumen cells
            sApical=0; nCell=0; cellDict=CompuCell.getPyAttrib(cell);
            cellNeighborList=self.getCellNeighbors(cell)
            for neigh in cellNeighborList:
                if ( neigh.neighborAddress and (neigh.neighborAddress.type == self.LUMEN)): #Lumen-Lumen
                    if (cell.id < neigh.neighborAddress.id):
                        if ((neigh.neighborAddress.id not in dCell) and (cell.id not in dCell)):
                            List.append([cell,neigh.neighborAddress])
                            dCell.append(neigh.neighborAddress.id)
                
                elif (not neigh.neighborAddress):  #Lumen-Medium
                    cell.targetVolume-=neigh.commonSurfaceArea
                elif ( neigh.neighborAddress and (neigh.neighborAddress.type == self.APICAL)):
                    sApical+=neigh.commonSurfaceArea;  nCell+=1;
            
            if (cell.targetVolume > 0):
                cell.targetVolume+=self.Kl*(nCell-sApical/(self.ApArea))
        for i in range(len(List)):
            self.Merging(List[i][0],List[i][1])
        #Big lumen maintenance
        cell=self.attemptFetchingCellById(self.LumenID)
        cell.targetVolume=self.LumenVol

        #CREATING LUMEN
        NN=2
        for cell in self.cellListByType(self.APICAL):
            #Checking if apical cell has at least NN apical neighbors
            n=0;  cellNeighborList=self.getCellNeighbors(cell);
            for neigh in cellNeighborList:
                if (neigh.neighborAddress and (neigh.neighborAddress.type == self.APICAL)):  #Apical x Apical
                    n+=1
                elif (neigh.neighborAddress and (neigh.neighborAddress.type == self.LUMEN)): # Apical x Lumen
                    n-=NN; break;

            #Nucleating new lumen
            if (n >= NN):
                #Looping over boundary pixels
                ok=0;  boundaryPixelList=self.getCellBoundaryPixelList(cell)
                for bPixel in boundaryPixelList:
                    for i in range(self.maxNeighborIndex+1):
                        pixelNeighbor=self.boundaryStrategy.getNeighborDirect(bPixel.pixel,i)
                        cell2=self.cellField.get(pixelNeighbor.pt)
                        if ( cell2 and (cell2.id != cell.id) and (cell2.type == self.APICAL)):  #cell2 = Apical
                            ok+=1
                    
                    #If pixel has 2 neighbos pixels of different beta cells, it becomes lumen
                    if(ok>1):
                        lumenCell=self.potts.createCellG(bPixel.pixel)
                        lumenCell.type=self.LUMEN
                        lumenCell.targetVolume=5.+.5
                        lumenCell.lambdaVolume=2*self.LamV
                        break


##########################################

