from cc3d.core.PySteppables import *

import os
import sys
from math import *
import random
import math

tVol = 1500
# Calcuate Cell Death due to Threshold and BM touching
NoOfDeathBM = 0
NoOfDeathThre = 0
# threshold for simulation
kill_Thr = 0.85
Thr_mcs = 100  ## Save the condition at 500 and use it as IC
Thr_N  = 4
GR=0.5
restart =1000000  ## Play with restart
stepSize = 0.3
WNT_Thr = 100 ## Default 100 
D_Thr = 18 ## Default 18

from cc3d.core.PySteppables import *
import numpy as np

class InitialConditionSteppable(SteppableBasePy):
    def __init__(self,frequency=1):
        SteppableBasePy.__init__(self,frequency)
        
    def start(self):
        radi = self.dim.x/2
        # Assign Property for Cell ID = 1
        cells_to_die=[]
        for cell in self.cell_list_by_type(self.STEMCELL):
            cell.targetVolume = tVol
            cell.lambdaVolume = 5
            
        radi = int(self.dim.x/2)
        # central tube generation
        Tube = self.new_cell(self.CT)
        for x, y, z in self.every_pixel():       
            # Generate Tube structure
            if y>= radi:
                if ((x-radi)**2+(z-radi)**2)<=((radi-5)/2)**2:
                    # central lumen
                    self.cell_field[x,y,z] = Tube
            # Generate Semi-Sphere Sturcture ((Bottom of the crypt)
            elif y < radi:
                if ((x-radi)**2+(y-radi)**2+(z-radi)**2)<=((radi-5)/2)**2:
                    self.cell_field[x,y,z] = Tube
        # Assign property for Cell type = 3
        for cell in self.cell_list_by_type(self.CT): 
            cell.targetVolume = cell.volume
            cell.lambdaVolume = 15000
            
        #Outer wall generation
        Wall = self.new_cell(self.BM)
        for x, y, z in self.every_pixel():         
            # Generate Tube structure
            if y>= radi:
                if ((x-radi)**2+(z-radi)**2)>=(radi-5)**2:
                    self.cell_field[x,y,z] = Wall
            # Generate Semi-Sphere Sturcture ((Bottom of the crypt)
            elif y < radi:
                if ((x-radi)**2+(y-radi)**2+(z-radi)**2)>=(radi-5)**2:
                    self.cell_field[x,y,z] = Wall
        
        self.cellField[0,0,0] = CompuCell.getMediumCell()
        # Assign property for Cell type = 2
        for cell in self.cell_list_by_type(self.BM): 
            cell.targetVolume = cell.volume
            cell.lambdaVolume = 10000000
        
        cells_to_die=[]
        for cell in self.cell_list:
            if cell.type == self.STEMCELL:                    
                # Program Cell Death
                # Set up threshold to kill cells when cells go above the threshold
                if cell.yCOM > self.dim.y*kill_Thr:
                    cells_to_die.append(cell)
#                 cellNeighborList=self.getCellNeighbors(cell) # generates list of neighbors of cell 'cell'
                wallflag=0
                # Kill cells when the cells not touching BM
                for (neighbor, common_surface_area) in self.get_cell_neighbor_data_list(cell):
                    if neighbor and neighbor.type == self.BM:
                        wallflag=1

                if wallflag==0:
                    # Delete the cells without contacting BM
                    cells_to_die.append(cell)
                   
        # Cell Killing program
        for cell in cells_to_die:    
            cell.targetVolume = 0
            cell.type = 0            
#
#
class GrowthSteppable(SteppableBasePy):
    def __init__(self,frequency=1):
        SteppableBasePy.__init__(self,frequency)
        
    def start(self):
        # Assign Property for Cell ID = 1
        for cell in self.cell_list_by_type(self.STEMCELL):
            # Assign stochastic initial conditions for stem cell/TA cell volume
            cell.targetVolume = tVol*random.uniform(0.95,1.5)# Make the initial target Volume of diff cells constant  
    
    def step(self, mcs):
        cells_to_die=[]
        global NoOfDeathThre,NoOfDeathBM
        NoOfDeathThre=0
        NoOfDeathBM=0
        for cell in self.cell_list_by_type(self.STEMCELL,self.PANETH,self.GOBLET,self.ENTEROCYTE):
            if cell.targetVolume:
                if mcs%restart >Thr_mcs and cell.sbml.DN['R']>Thr_N:
                    # Program Stem Cell Growth
                    cell.targetVolume+= GR*random.uniform(0.95,1.15)#random growth rate
            if cell.yCOM > self.dim.y*kill_Thr:
                cells_to_die.append(cell)
                NoOfDeathThre+=1
            # Kill cells when the cells not touching BM
            wallflag=0
            for (neighbor, common_surface_area) in self.get_cell_neighbor_data_list(cell):
                #IMPORTANT: cell may have Medium (NULL pointer) as a neighbor. therefore before accessing neighbor we first check if it is no Medium
                if neighbor and neighbor.type == self.BM:
                    wallflag=1
            if wallflag==0:
                # Delete the cells without contacting BM
                cells_to_die.append(cell)
                if mcs > 0:
                    NoOfDeathBM+=1
        # Cell Killing program
        for cell in cells_to_die:    
            cell.targetVolume = 0
            cell.type = 0
     
#        
class DeltaNotchClassSteppable(SteppableBasePy):
    def __init__(self, frequency=1):
        SteppableBasePy.__init__(self, frequency)
        
    def start(self):
        # adding options that setup SBML solver integrator - these are optional but useful when encounteting integration instabilities              
        modelFile='Simulation/PFLI.sbml' 
        # SBML convergence error was here!!!
        options = {'relative': 1e-6, 'absolute': 1e-9, 'steps': 1000,'stiff': True}       
        self.set_sbml_global_options(options)
        
        self.add_sbml_to_cell_types(model_file=modelFile, model_name="DN", cell_types=[self.STEMCELL], step_size=stepSize)  
        #Initial conditions

    def setInitialConditions(self):
        state={} #dictionary to store state veriables of the SBML model
        for cell in self.cell_list_by_type(self.STEMCELL,self.PANETH,self.GOBLET, self.ENTEROCYTE):
            state['D'] = random.uniform(48.8,50.0)
            state['N'] = random.uniform(0.1,0.2)
            state['B'] = 0
            state['R'] = random.uniform(0.1,0.2)         
            self.set_sbml_state(modelName=self.Key,cell=cell,state=state)
            
    def step(self, mcs):
        if (mcs%restart==0 and mcs>0):
            self.setInitialConditions()

        for cell in self.cell_list_by_type(self.STEMCELL,self.PANETH,self.GOBLET, self.ENTEROCYTE):
            Davg=0.0; nn=0; Navg=0.0
            neighborList = []
            compList = self.get_cluster_cells(cell.clusterId)
            for cell2 in compList:
                for (neighbor , commonSurfaceArea) in self.get_cell_neighbor_data_list(cell2):                
                    if neighbor:
                        neighborList.append(neighbor)
            for (neighbor , commonSurfaceArea) in self.get_cell_neighbor_data_list(cell):
                if (neighbor and neighbor.type in [self.STEMCELL,self.PANETH]):
                    nn+=1
                    Nstate = neighbor.sbml.DN
                    Davg += Nstate['D']
                    Navg += Nstate['N']
            if (nn>0):
                Davg=Davg/nn
                Navg=Navg/nn
            state={}
            #GammaB = 1.0
            state['Davg']=Davg
            state['Navg']=Navg 
            #state['GammaB']=GammaB 
            self.set_sbml_state(model_name='DN',cell=cell,state=state)
           
            field = self.field.WNT            
            WNT = field[int(cell.xCOM+.5), int(cell.yCOM+.5), int(cell.zCOM+.5)]           
            if mcs > 200:
                if (WNT>WNT_Thr):
                    if cell.sbml.DN['D'] > D_Thr:
                        cell.type = self.PANETH
                else:
                    if cell.sbml.DN['D'] > D_Thr:
                        cell.type = self.GOBLET
                    else:
                        cell.type = self.ENTEROCYTE
        self.timestep_sbml()
        
#
class MitosisSteppable(MitosisSteppableBase):
    def __init__(self,frequency=1):
        MitosisSteppableBase.__init__(self,frequency)

    def step(self, mcs):
        global NoOfDivCells
        NoOfDivCells = 0
        cells_to_divide=[]
        for cell in self.cell_list_by_type(self.STEMCELL):
            if mcs%restart>Thr_mcs and cell.volume > tVol*1.75 and cell.dict['R']>Thr_N:
                if random.randint(1, 30)%30 <10: 
                    # programmed stem cell division
                    cells_to_divide.append(cell)
                    NoOfDivCells+=1
                    
        for cell in cells_to_divide:
            self.dividecell_along_major_axis(cell)  ###Try in minor axis division       
            # Other valid options
            # self.dividecell_orientation_vector_based(cell,1,1,0)
            # self.dividecell_along_major_axis(cell)
            # self.dividecell_along_minor_axis(cell)

    def update_attributes(self):        
        self.childcell.type = self.parentcell.type    
        self.parentcell.targetVolume = tVol
        self.childcell.targetVolume = tVol
        self.childcell.lambdaVolume = self.parentcell.lambdaVolume;
        # inherite properties from parent cells
        self.copySBMLs(_fromCell=self.parentcell,_toCell=self.childcell)
        self.childcell.dict=CompuCell.getPyAttrib(self.childcell)
        self.parentcell.dict=CompuCell.getPyAttrib(self.parentcell)
        self.childcell.dict["D"]=self.parentcell.dict["D"]
        self.childcell.dict["N"]=self.parentcell.dict["N"]
        self.childcell.dict["B"]=self.parentcell.dict["B"]
        self.childcell.dict["R"]=self.parentcell.dict["R"]

#        
class ExtraFieldsSteppable(SteppableBasePy):
    def __init__(self, frequency=1):
        SteppableBasePy.__init__(self, frequency)
        
#         self.create_scalar_fieldcell_level_py("FIELD_NAME_SCL")
        
        self.create_scalar_field_cell_level_py("Delta")
        self.create_scalar_field_cell_level_py("Notch")
        self.create_scalar_field_cell_level_py("B_cat")
        self.create_scalar_field_cell_level_py("NICD")
        self.create_scalar_field_cell_level_py("TargetVolume")
        
        
        
        
#         self.trackcell_level_scalar_attribute(field_name='Delta', attribute_name='D')
#         self.trackcell_level_scalar_attribute(field_name='Notch', attribute_name='N')
#         self.trackcell_level_scalar_attribute(field_name='B_cat', attribute_name='B')
#         self.trackcell_level_scalar_attribute(field_name='NICD', attribute_name='R')
#         self.trackcell_level_scalar_attribute(field_name='TargetVolume', attribute_name='ATTR_NAME')
        
        
#         self.scalarFieldD=CompuCellSetup.createScalarFieldCellLevelPy("Delta")
#         self.scalarFieldN=CompuCellSetup.createScalarFieldCellLevelPy("Notch")
#         self.scalarFieldB=CompuCellSetup.createScalarFieldCellLevelPy("B-cat")
#         self.scalarFieldR=CompuCellSetup.createScalarFieldCellLevelPy("NICD") 
#         self.scalarFieldtVol=CompuCellSetup.createScalarFieldCellLevelPy("TargetVolume") 
   
        
    def start(self):
        pass
        
        field_Delta = self.field.Delta.clear()
        field_Notch = self.field.Notch.clear()
        field_B_cat = self.field.B_cat.clear()
        field_NICD = self.field.NICD.clear()
        field_TargetVolume = self.field.TargetVolume.clear()
         
        for cell in self.cell_list_by_type(self.STEMCELL,self.PANETH,self.GOBLET, self.ENTEROCYTE):
#             self.trackcell_level_scalar_attribute(field_name='Delta', attribute_name='D')
#             self.trackcell_level_scalar_attribute(field_name='Notch', attribute_name='N')
#             self.trackcell_level_scalar_attribute(field_name='B_cat', attribute_name='B')
#             self.trackcell_level_scalar_attribute(field_name='NICD', attribute_name='R')
            
            field = self.field.TargetVolume
            field_TargetVolume[cell] = cell.targetVolume
             
#             self.scalarFieldD[cell]=cellDict['D']
#             self.scalarFieldN[cell]=cellDict['N']
#             self.scalarFieldB[cell]=cellDict['B']
#             self.scalarFieldR[cell]=cellDict['R']
#             self.scalarFieldtVol[cell]=cell.targetVolume
            

    def step(self, mcs):
        print("ExtraFieldsSteppable: This function is called every 1 MCS")

        for cell in self.cell_list:
            print("CELL ID=",cell.id, " CELL TYPE=",cell.type," volume=",cell.volume)

