from cc3d.cpp.PlayerPython import * 
from cc3d import CompuCellSetup
from cc3d.core.PySteppables import *
import numpy as np

class ChannelSteppable(SteppableBasePy):
    def __init__(self, frequency=1):
        SteppableBasePy.__init__(self,frequency)

        self.rdC=5            # cell radius
        self.szW=2            # wall thickness
        self.tgV=100.         # cell target volume 
        self.lbdV=10.         # cell lambda volume
        self.Fx=-200.         # force x-component
        
        # vector and scalar fields created from CC3D Python/Extra Fields
        self.vectorField = self.create_vector_field_cell_level_py("VELOCITY")
        self.scalarField = self.create_scalar_field_cell_level_py("PRESSURE")
        self.scalarField2 = self.create_scalar_field_py("AVGPRESSURE")

    def start(self):
        self.xSink=self.dim.x-3*self.rdC  # sink x position
        for cell in self.cell_list_by_type(self.CELL):
            cell.lambdaVolume = self.lbdV
            cell.targetVolume = self.tgV
            cell.lambdaVecX = self.Fx           # force the cell along the x-axis
            cell.dict["oldXcm"]=cell.xCOM       # dict entry for old cell x CM
            cell.dict["oldYcm"]=cell.yCOM       # dict entry for old cell y CM
        
        # numpy array to track pressure at each pixel, and the mcs at which the pixel is frist part of a cell
        self.pAvg = np.zeros((self.dim.x,self.dim.y))
        self.pAvgStartMCS = np.zeros((self.dim.x,self.dim.y))
        
        # put all the Wall cells into one big cell
        firstWall = None
        for cell in self.cell_list_by_type(self.WALL):
            if firstWall:
                self.merge_cells(cell, firstWall)
                print(cell.id,firstWall.id)
            else:
                firstWall = cell
                print(cell.id,firstWall.id)


    def step(self, mcs):
        timeinterval = 10      # time interval between cell events (source/sink)
        if mcs%timeinterval == 0:
            x = self.rdC                    # cell source x position
            for y in range(self.szW+self.rdC,self.dim.y-self.rdC,self.rdC):
                # cell source y position
                currentCell = self.cell_field[x, y, 0]  # attributes the lattice point to a temp cell
                if (not currentCell):    # ift that point belongs to MEDUM...
                    self.cell_field[x,y, 0] = self.new_cell(self.CELL)  # creates a NEW cell at it,...
                    newCell = self.cell_field[x,y,0]   # and say that the point belongs to the NEW cell
                    newCell.targetVolume=self.tgV      # give the the attributes
                    newCell.lambdaVolume=self.lbdV
                    newCell.lambdaVecX=self.Fx
                    newCell.dict["oldXcm"]=x
                    newCell.dict["oldYcm"]=y
            
            # if the cell reaches the sink, delete it
            for cell in self.cell_list_by_type(self.CELL):
                if cell.xCOM > self.xSink:      # if the cell X cm crosses xSink position...
                    self.delete_cell(cell)           # delete it     

    # updating Velocity and Pressure fields    
            fieldV = self.vectorField                        # placeholder for the vector field
            fieldS = self.scalarField                        # placeholder for the scalar field
            for cell in self.cell_list_by_type(self.CELL):
                delX=cell.xCOM-cell.dict["oldXcm"]                   # cell x displacement
                if   delX<-self.dim.x/2.: delX+=self.dim.x           # x periodic bounday correction
                elif delX> self.dim.x/2.: delX-=self.dim.x 
                CVelX=delX/timeinterval                                  # cell x compon Vel
                #
                delY=cell.yCOM-cell.dict["oldYcm"]                    # cell y displacement
                if   delY<-self.dim.y/2.: delY+=self.dim.y          # y periodic bounday correction
                elif delY> self.dim.y/2.: delY-=self.dim.y 
                CVelY=delY/timeinterval                                 # cell y compon Vel

                fieldV[cell] = [CVelX, CVelY, 0.]                        # filling the field with values
                #fieldS[cell] = cell.pressure                        # filling the field with values
                fieldS[cell] = cell.targetVolume - cell.volume       # filling the field with values
               
                cell.dict["oldXcm"]=cell.xCOM                           # storing actual cell x CM
                cell.dict["oldYcm"]=cell.yCOM                           # storing actual cell x CM

    # updating the Average Pressure field
            if mcs > 0:
                fieldAvgPress = self.scalarField2      # placeholder for the pixel-based scalar field
                for x, y, z in self.every_pixel():
                    if x > 10 and x < 480:  # don't calculate average pressure to close to the cell source
                        cell = self.cell_field[x,y,z]
                        if cell:  # skip this pixel if it is Medium
                            if self.pAvgStartMCS[x,y] == 0:  # don't start averaging for a pixel until it is part of a cell
                                self.pAvgStartMCS[x,y] = mcs - timeinterval
                            cellPress = cell.targetVolume - cell.volume
                            #oldSum = self.pAvg[x,y]*(mcs-timeinterval)/float(timeinterval)
                            oldSum = self.pAvg[x,y]*(mcs-self.pAvgStartMCS[x,y]-timeinterval)/float(timeinterval)
                            newSum = oldSum + cellPress
                            #newAvg = newSum/float(mcs)/timeinterval)
                            newAvg = newSum/float((mcs-self.pAvgStartMCS[x,y])/timeinterval)
                            self.pAvg[x,y] = newAvg
                            #print(mcs,timeinterval,mcs/timeinterval,"   ",self.pAvg[x,y],oldSum,cellPress,newSum,newAvg)
                            fieldAvgPress[x,y,z]=newAvg
                    #print('\t\t\t max and min average pixel pressure:',np.amax(self.pAvg),np.amin(self.pAvg))
     

    def finish(self):
        """
        Called after the last MCS to wrap up the simulation
        """

    def on_stop(self):
        """
        Called if the simulation is stopped before the last MCS
        """
