from cc3d.core.PySteppables import *
from pathlib import Path
from collections import defaultdict


class CellSortingSteppable(SteppableBasePy):

    def __init__(self, frequency=1):
        SteppableBasePy.__init__(self, frequency)

    def start(self):
        # we are constructing output file name
        # As an example let's assume that self.output_dir is:
        # self.output_dir  - C:\Users\m\CC3DWorkspace\CellSortingParameterScanWorkshop2020_output\scan_iteration_3\CellSortingParameterScanWorkshop2020
        out_dir = Path(self.output_dir)

        # param_scan_output_basename = scan_iteration_3
        param_scan_output_basename = out_dir.parts[-2]

        # param_scan_main_output_dir C:\Users\m\CC3DWorkspace\CellSortingParameterScanWorkshop2020_output
        param_scan_main_output_dir = str(Path(*out_dir.parts[:-2]))

        # self.csv_output_path = C:\Users\m\CC3DWorkspace\CellSortingParameterScanWorkshop2020_output\scan_iteration_3.csv
        self.csv_output_path = str(Path(param_scan_main_output_dir, param_scan_output_basename)) + '.csv'

        c_ml = float(self.get_xml_element('C_ML').cdata)
        c_md = float(self.get_xml_element('C_MD').cdata)
        c_ll = float(self.get_xml_element('C_LL').cdata)
        c_ld = float(self.get_xml_element('C_LD').cdata)
        c_dd = float(self.get_xml_element('C_DD').cdata)



    def heterotypic_boundary_length(self):
        heterotypic_lengths = defaultdict(int)
        offsets = [[0, 1], [0, -1], [1, 0], [-1, 0]]

        for x in range(1, self.dim.x - 1, 1):
            for y in range(1, self.dim.x - 1, 1):
                cell = self.cell_field[x, y, 0]
                if not cell:
                    cell_type = 0
                else:
                    cell_type = cell.type
                for offset in offsets:
                    n_cell = self.cell_field[x + offset[0], y + offset[1], 0]
                    if self.are_cells_different(n_cell, cell):
                        if not n_cell:
                            n_cell_type = 0
                        else:
                            n_cell_type = n_cell.type

                        heterotypic_lengths[(cell_type, n_cell_type)] += 1
        return heterotypic_lengths

    def finish(self):
        energy = self.simulator.getPotts().getEnergy()
        c_ml = float(self.get_xml_element('C_ML').cdata)
        c_md = float(self.get_xml_element('C_MD').cdata)
        c_ll = float(self.get_xml_element('C_LL').cdata)
        c_ld = float(self.get_xml_element('C_LD').cdata)
        c_dd = float(self.get_xml_element('C_DD').cdata)

        hl_len_dict = self.heterotypic_boundary_length()
        tot_hl = hl_len_dict[(1, 2)] + hl_len_dict[(0, 1)] + hl_len_dict[(0, 2)]

        num_cells = len(self.cell_list)

        with open(self.csv_output_path, 'w') as csv_out:
            csv_out.write(','.join([str(c_ml), str(c_md), str(c_ll), str(c_ld), str(c_dd), str(energy), str(tot_hl),
                                    str(num_cells)]) + '\n')
