import CompuCell
from PySteppables import *
from PySteppablesExamples import *
from math import *
import random

class InitialConditions(SteppableBasePy):
    def __init__(self,simulator,frequency,cd,LamV,percCyto,percApical,percBasal,percLateral):
        SteppableBasePy.__init__(self,simulator,frequency)
        self.cd=cd;                  self.LamV=LamV;                 
        self.percCyto=percCyto;      self.percApical=percApical;
        self.percBasal=percBasal;    self.percLateral=percLateral;   
        self.targV=cd*cd*cd;
        
    def start(self):
        cyto=self.newCell(self.CYTO);  basal=self.newCell(self.BASAL);
        cad8=self.newCell(self.CADHERIN8);  apical=self.newCell(self.APICAL);
           
        self.inventory.reassignClusterId(basal,cyto.clusterId)
        self.inventory.reassignClusterId(cad8,cyto.clusterId)
        self.inventory.reassignClusterId(apical,cyto.clusterId)

        Lx2=int(self.dim.x/2);   Ly2=int(self.dim.y/2);  Lz2=int(self.dim.z/2);  radius=int(self.cd/2);
        for k in range(self.cd):
            x=Lx2-radius+k
            for l in range(self.cd):
                y=Ly2-radius+l
                for m in range(self.cd):
                    z=Lz2-radius+m
                    ran=random.random()
                    if ( ran < self.percCyto ):
                        self.cellField[x,y,z]=cyto
                    elif ( ran < (self.percCyto+self.percApical) ):
                        self.cellField[x,y,z]=apical
                    elif ( ran < (self.percCyto+self.percApical+self.percBasal) ):
                        self.cellField[x,y,z]=basal
                    else:
                        self.cellField[x,y,z]=cad8

        cyto.targetVolume=self.percCyto*self.targV;     cyto.lambdaVolume=self.LamV;
        basal.targetVolume=self.percBasal*self.targV;   basal.lambdaVolume=self.LamV;
        cad8.targetVolume=self.percLateral*self.targV;  cad8.lambdaVolume=self.LamV;
        apical.targetVolume=self.percApical*self.targV; apical.lambdaVolume=self.LamV;


########################################################

class ContactInhibition(SteppableBasePy):
    def __init__(self,simulator,frequency,gFactor,hillCoef,critAlpha,percApical,percBasal,percLateral):
        SteppableBasePy.__init__(self,simulator,frequency)
        self.gFactor=gFactor;          self.hillCoef=hillCoef;          
        self.percApical=percApical;    self.percBasal=percBasal;        self.percLateral=percLateral;           
        self.critAlphaHill=pow(critAlpha,hillCoef);

    def start(self):
        for cell in self.cellListByType(self.CYTO):
            cellDict=CompuCell.getPyAttrib(cell)
            compartmentList=self.inventory.getClusterCells(cell.clusterId)
            SCell=0;  ST=0.;
            for cell2 in compartmentList:
                cellNeighborList=self.getCellNeighbors(cell2)
                Scell=0;  St=0.0;
                #Calculating surface area ratio: 'common surface area between cells'/'common surface between cell + lumen + medium'
                for neighbor in cellNeighborList:
                    if (not neighbor.neighborAddress):  #Medium
                        ST+=neighbor.commonSurfaceArea
                    elif (neighbor.neighborAddress.clusterId != cell.clusterId):
                        ST+=neighbor.commonSurfaceArea
                        if(neighbor.neighborAddress.type!=self.LUMEN):
                            Scell+=neighbor.commonSurfaceArea         

            cellDict["Scell"]=Scell/ST
    
    def step(self,mcs):
        for cell in self.cellListByType(self.CYTO):
            #Contact inhibition
            compList=self.inventory.getClusterCells(cell.clusterId)
            Vol=0; Scell=0;  ST=0.;
            for cell2 in compList:
                Vol+=cell2.targetVolume;
                cellNeighborList=self.getCellNeighbors(cell2)
                #Calculating surface area ratio: 'common surface area between cells'/'common surface between cell + lumen + medium'
                for neigh in cellNeighborList:
                    if ( not neigh.neighborAddress ):  #Cell-Medium
                        ST+=neigh.commonSurfaceArea
                    elif (neigh.neighborAddress.clusterId != cell.clusterId):
                        ST+=neigh.commonSurfaceArea
                        if (neigh.neighborAddress.type != self.LUMEN):
                            Scell+=neigh.commonSurfaceArea

            S=6.*1.2*Vol**(2./3.)*1.67
            cellDict=CompuCell.getPyAttrib(cell)
            cellDict["Scell"]=.98*cellDict["Scell"]+.02*Scell/ST
            aN=pow(cellDict["Scell"],self.hillCoef)
            G=self.critAlphaHill*(1.-aN)/(self.critAlphaHill + aN)
            cell.targetVolume+=G*self.gFactor;
            # Compartments grow proportionally
            for cell2 in compList:
                if ( cell2.type == self.APICAL ): cell2.targetVolume=S*self.percApical
                elif ( cell2.type == self.BASAL ): cell2.targetVolume=S*self.percBasal
                elif ( cell2.type == self.CADHERIN8 ): cell2.targetVolume=S*self.percLateral
                else: pass
                
###############################################################

class Mitosis(MitosisSteppableBase):
    def __init__(self,simulator,frequency,cd,LamV,percApical,percBasal,percLateral):
        MitosisSteppableBase.__init__(self,simulator,frequency)
        self.targV=cd*cd*cd;          self.LamV=LamV;
        self.percApical=percApical;   self.percBasal=percBasal;   self.percLateral=percLateral;
        self.percCyto=1.- percApical-percBasal- percLateral;

    def step(self,mcs):
        div=[]; self.pixelate=[]; Vdouble=2.*self.targV;
        #chekcing mitosis condition------------
        for cluster in self.clusterList:
            clusterVolume=0;
            for cell in CompartmentList(cluster):
                if (cell.type==self.LUMEN): break
                clusterVolume+=cell.volume
                if (cell.type==self.CYTO): cyto=cell
                if (clusterVolume>Vdouble):
                    div.append(cyto)

        #dividing------------------------------
        for cell in div:
            v=self.gettingVector(cell)
            self.dePixelating(cell)
            if (v==[0,0,0]):
                self.divideCellAlongMinorAxis(cell)
            else:
                self.divideCellOrientationVectorBased(cell,v[0],v[1],v[2])

        #rePixelating-------------------------
        for cell in self.pixelate:
            self.PixelatingB(cell)
        

    def updateAttributes(self):
        parentCell=self.mitosisSteppable.parentCell
        childCell=self.mitosisSteppable.childCell
        childCell.type=parentCell.type
        #Dictionary
        pCellDict=CompuCell.getPyAttrib(parentCell)
        cCellDict=CompuCell.getPyAttrib(childCell)
        cCellDict["Scell"]=pCellDict["Scell"]
        self.pixelate.append(childCell)
        self.pixelate.append(parentCell)        

    def gettingVector(self,cell):
        v=[0,0,0]; vA=[0,0,0];  vB=[0,0,0]; nLL=0;
        compList=self.inventory.getClusterCells(cell.clusterId)
        for cell2 in compList:
            if (cell2.volume):
                if (cell2.type == self.CADHERIN8):
                    cellNeighborList=self.getCellNeighbors(cell2)
                    for neighbor in cellNeighborList:
                        if (neighbor.neighborAddress):
                            if (neighbor.neighborAddress.type == self.CADHERIN8):
                                nLL+=1
                elif (cell2.type == self.APICAL):  #Apical
                    vA[0]=cell.xCOM-cell2.xCOM
                    vA[1]=cell.yCOM-cell2.yCOM
                    vA[2]=cell.zCOM-cell2.zCOM
                    vv=sqrt(pow(vA[0],2.)+pow(vA[1],2.)+pow(vA[2],2.))
                    vA[0]=vA[0]/vv;  vA[1]=vA[1]/vv;  vA[2]=vA[2]/vv;
                elif (cell2.type == self.BASAL):  #Basal
                    vB[0]=cell.xCOM-cell2.xCOM
                    vB[1]=cell.yCOM-cell2.yCOM
                    vB[2]=cell.zCOM-cell2.zCOM
                    vv=sqrt(pow(vB[0],2.)+pow(vB[1],2.)+pow(vB[2],2.))
                    vB[0]=vB[0]/vv;  vB[1]=vB[1]/vv;  vB[2]=vB[2]/vv;
        
        if(nLL<2):
            v=[0,0,0]
        else:  # V = Va x Vb
            v[0]=vA[1]*vB[2]-vA[2]*vB[1]
            v[1]=vA[2]*vB[0]-vA[0]*vB[2]
            v[2]=vA[0]*vB[1]-vA[1]*vB[0]
        return v


    def dePixelating(self,cell):
        L=[];  compList=self.inventory.getClusterCells(cell.clusterId)
        for cell2 in compList:
            if (cell2.id != cell.id):
                pixelList=self.getCellPixelList(cell)
                for pixel in pixelList:
                    x=pixel.pixel.x;  y=pixel.pixel.y;  z=pixel.pixel.z;
                    self.cellField.set([x,y,z],cell)
                    
        cell.targetVolume=cell.volume
            
    
    def PixelatingB(self,cell):
        comps=0;
        for cell2 in self.inventory.getClusterCells(cell.clusterId):
            if( cell2.type == self.APICAL ): apical=cell2; comps+=1;
            elif( cell2.type == self.BASAL ): basal=cell2;  comps+=1;
            elif( cell2.type == self.CADHERIN8): cad8=cell2; comps+=1;
            else: pass;

        if(comps == 0):
            apical=self.newCell(self.APICAL)     # Create apical domain
            basal=self.newCell(self.BASAL)       # Create basal domain
            cad8=self.newCell(self.CADHERIN8)    # Create lateral domain
        
        pixelList=CellPixelList(self.pixelTrackerPlugin,cell)
        for pixel in pixelList:
            x=pixel.pixel.x;  y=pixel.pixel.y;  z=pixel.pixel.z;
            ran=random.random()
            if ( ran < self.percCyto ):
                pass
            elif ( ran < (self.percCyto+self.percApical) ):
                self.cellField[x,y,z]=apical
            elif ( ran < (self.percCyto+self.percApical+self.percBasal) ):
                self.cellField[x,y,z]=basal
            else:
                self.cellField[x,y,z]=cad8

        cell.targetVolume=self.percCyto*self.targV;     cell.lambdaVolume=self.LamV;
        basal.targetVolume=self.percBasal*self.targV;   basal.lambdaVolume=self.LamV;
        cad8.targetVolume=self.percLateral*self.targV;  cad8.lambdaVolume=self.LamV;
        apical.targetVolume=self.percApical*self.targV; apical.lambdaVolume=self.LamV;
            
        self.inventory.reassignClusterId(basal,cell.clusterId)
        self.inventory.reassignClusterId(cad8,cell.clusterId)
        self.inventory.reassignClusterId(apical,cell.clusterId)
        
        
####################################################################################

class LumenFlux(SteppableBasePy): # NEW flux based lumen growth
    def __init__(self,simulator,frequency,LamV,Kl,ApArea):
        SteppableBasePy.__init__(self,simulator,frequency)
        self.boundaryStrategy=CompuCell.BoundaryStrategy.getInstance()
        self.maxNeighborIndex=self.boundaryStrategy.getMaxNeighborIndexFromNeighborOrder(1)
        self.LamV=LamV;   self.Kl=Kl;  self.ApArea=ApArea;

    def Merging(self,cell,cellD):
        L=[]; pixelList=CellPixelList(self.pixelTrackerPlugin,cellD);
        for pixelData in pixelList:
            L.append(CompuCell.Point3D(pixelData.pixel))
        cell.targetVolume+=cellD.targetVolume
        for pt in L:
            self.cellField.set(pt,cell)
        self.cleanDeadCells()
        
    def step(self,mcs):
        #LUMEN MAINTENANCE
        List=[]; dCell=[];
        for cell in self.cellListByType(self.LUMEN):  #Lumen cells
            sApical=0; nCell=0; cellDict=CompuCell.getPyAttrib(cell);
            cellNeighborList=self.getCellNeighbors(cell)
            for neigh in cellNeighborList:
                if ( neigh.neighborAddress and (neigh.neighborAddress.type == self.LUMEN)): #Lumen-Lumen
                    if (cell.id < neigh.neighborAddress.id):
                        if ((neigh.neighborAddress.id not in dCell) and (cell.id not in dCell)):
                            List.append([cell,neigh.neighborAddress])
                            dCell.append(neigh.neighborAddress.id)
                
                elif (not neigh.neighborAddress):  #Lumen-Medium
                    cell.targetVolume-=neigh.commonSurfaceArea
                elif ( neigh.neighborAddress and (neigh.neighborAddress.type == self.APICAL)):
                    sApical+=neigh.commonSurfaceArea;  nCell+=1;
            
            if (cell.targetVolume > 0):
                cell.targetVolume+=self.Kl*(nCell-sApical/(self.ApArea))
        for i in range(len(List)):
            self.Merging(List[i][0],List[i][1])

        #CREATING LUMEN
        NN=2
        for cell in self.cellListByType(self.APICAL):
            #Checking if apical cell has at least NN apical neighbors
            n=0;  cellNeighborList=self.getCellNeighbors(cell);
            for neigh in cellNeighborList:
                if (neigh.neighborAddress and (neigh.neighborAddress.type == self.APICAL)):  #Apical x Apical
                    n+=1
                elif (neigh.neighborAddress and (neigh.neighborAddress.type == self.LUMEN)): # Apical x Lumen
                    n-=NN; break;

            #Nucleating new lumen
            if (n >= NN):
                #Looping over boundary pixels
                ok=0;  boundaryPixelList=self.getCellBoundaryPixelList(cell)
                for bPixel in boundaryPixelList:
                    for i in range(self.maxNeighborIndex+1):
                        pixelNeighbor=self.boundaryStrategy.getNeighborDirect(bPixel.pixel,i)
                        cell2=self.cellField.get(pixelNeighbor.pt)
                        if ( cell2 and (cell2.id != cell.id) and (cell2.type == self.APICAL)):  #cell2 = Apical
                            ok+=1
                    
                    #If pixel has 2 neighbors pixels of different beta cells, it becomes lumen
                    if(ok>1):
                        lumenCell=self.potts.createCellG(bPixel.pixel)
                        lumenCell.type=self.LUMEN
                        lumenCell.targetVolume=5.+.5
                        lumenCell.lambdaVolume=1.5*self.LamV
                        break

##########################################


