# fmt: on
import numpy as np

from OGOAT.src.Tiler.utils import compute_inverted_placement, factors
from OGOAT.src.utils.context import Logger


class TilingError(Exception):
    pass


class MHATiler:

    def __init__(
        self,
        layer,
        device,
        overlay,
        kernel,
        logger: Logger,
    ):

        self.layer = layer
        self.overlay = overlay
        self.device = device
        self.kernel = kernel
        self.logger = logger
        self.mha_mode = "2p1" if "2p1" in layer.op_type else "3p0"
        self.has_bias = "bias" in layer.op_type

        # the first dimension is the batch size => reduce to the last 2 dimensions
        self.activations_2d_shapes = dict()
        for input_name, shape in self.layer.activations_shapes.items():
            self.activations_2d_shapes[input_name] = shape[-2:]

        self.Mgran = self.kernel.kernel_granularity[self.overlay.overlay_name]["Mgran"]
        self.Kgran = self.kernel.kernel_granularity[self.overlay.overlay_name]["Kgran"]
        self.Ngran = self.kernel.kernel_granularity[self.overlay.overlay_name]["Ngran"]
        self.kernel_granularity = {
            "ifm": np.array([self.Mgran, self.Kgran]),
            "wgt": np.array([self.Kgran, self.Ngran]),
            "ofm": (
                np.array([self.Mgran, self.Kgran])
                if self.mha_mode == "3p0"
                else np.array([self.Mgran, self.Ngran])
            ),
            "ifm_v": np.array([self.Ngran, self.Kgran]),
            "ifm_m": np.array([self.Mgran, self.Ngran]),
            "ifm_b": np.array([1, self.Ngran]),
        }

        self.compute_padded_shapes()

        schedule_list = [6]

        self.memtile_subvols = {k: {} for k in schedule_list}
        self.memtile_sublayers = {k: {} for k in schedule_list}
        self.memtile_iters = {k: {} for k in schedule_list}

        self.fits_in_memtile = {k: {} for k in schedule_list}
        self.valid_fits_in_memtile = {k: {} for k in schedule_list}

        self.valid_memtile_subvols = {k: {} for k in schedule_list}
        self.valid_memtile_sublayers = {k: {} for k in schedule_list}
        self.valid_memtile_iters = {k: {} for k in schedule_list}

        self.core_subvols = {k: {} for k in schedule_list}
        self.core_iters = {k: {} for k in schedule_list}

        self.core_validity_checks = {k: {} for k in schedule_list}
        self.valid_core_subvols = {k: {} for k in schedule_list}
        self.valid_core_iters = {k: {} for k in schedule_list}
        self.valid_core_subvids = {k: {} for k in schedule_list}

        self.vars_dict = {
            "ifm_bytes": layer.in_bytes,
            "wgt_bytes": layer.wgt_bytes,
            "ofm_bytes": layer.out_ofm_bytes,
            "bias_bytes": 8 if "qdq" in layer.op_type else layer.wgt1_bytes,
            "tdm_bytes": kernel.tdm_bytes,
            "coeff_bytes": kernel.coeff_bytes,
            ## check other constraints
            "Mgran": self.Mgran,
            "Kgran": self.Kgran,
            "Ngran": self.Ngran,
        }

        if "mha" in layer.op_type.lower():
            self.vars_dict["bias_bytes"] = 0

        ## calculate bank-wise space formulas
        self.inverted_placement = compute_inverted_placement(
            self.kernel.placement_constraints
        )

    def get_padded_shape(
        self, shape: list[int], padding_factors: list[int]
    ) -> np.array:
        assert not any(dim == 0 for dim in shape), "padding by zero is not possible"
        return np.ceil(np.array(shape) / np.array(padding_factors)) * np.array(
            padding_factors
        )

    def compute_host_layer_padding(self) -> dict[str, list[int]]:
        """
        Compute and return the host layer padding for each inputs/output shape.
        The inner dimension of all matmul needs to be padded to 8 or 64 if dim > 64.
            - inner dimensions are k_dim and n_dim since they are the inner dimension
            of each resp. matmul
        """
        # Compute the padding factor of each inputs/outputs shapes
        host_padding_factors = {
            "ifm_q": [1, 8],  # M x K
            # If input is transposed the original 2d shape is NxK so K == inner_dim
            # otherwise the shape is KxN so N == inner_dim
            "ifm_k": [1, 8] if not self.layer.permK_3d else [8, 1],  # K x N
            "ifm_v": [1, 8],  # N x K
            "ifm_m": [1, 8],  # M x N
            "ifm_b": [1, 8],  # 1 x N
            "ofm": [1, 8],
        }

        host_layer_padding = dict()
        for input_name, shape in self.activations_2d_shapes.items():
            host_layer_padding[input_name] = self.get_padded_shape(
                shape, host_padding_factors[input_name]
            )

        return host_layer_padding

    def compute_padded_shapes(self):
        """
        1. Compute the host layer padding:
          - The shape that needs to be allocated on the host for each inputs and output
            The inner dimension of all matmul needs to be padded to 8 or 64 if dim > 64.
        2. Compute the overall padded shape of each inputs:
          - Compute the final shape after padding based on the host layer padded shapes and
            a padding factor constituted of the kernel kernel granularity of that input and the
            core split that we are exploring.
        3. Compute the padding that needs to be allocated on the dma for each inputs and output:
            dma layer padding = overall_padded_shape

        This function should set 3 attributes containing each padding kind computed and each shapes is over 2 dimension:
         - self.host_layer_padding
         - self.dma_layer_padding
         - self._padded_shapes
        """

        self.host_layer_padding = self.compute_host_layer_padding()

        v_core_splits = [
            self.overlay.core_splits["wgt"][1],  # N
            self.overlay.core_splits["wgt"][0],  # K
        ]
        b_core_splits = [
            1,
            self.overlay.core_splits["wgt"][1],  # N
        ]
        m_core_splits = [
            self.overlay.core_splits["ifm"][0],  # M
            self.overlay.core_splits["wgt"][1],  # N
        ]
        padding_factors = {
            "ifm_q": self.kernel_granularity["ifm"] * self.overlay.core_splits["ifm"],
            "ifm_k": self.kernel_granularity["wgt"] * self.overlay.core_splits["wgt"],
            "ofm": self.kernel_granularity["ofm"] * self.overlay.core_splits["ofm"],
            "ifm_v": self.kernel_granularity["ifm_v"] * v_core_splits,
            "ifm_m": self.kernel_granularity["ifm_m"] * m_core_splits,
            "ifm_b": self.kernel_granularity["ifm_b"] * b_core_splits,
        }

        self.dma_layer_padding = dict()
        for entry in self.host_layer_padding:
            padding_factor = padding_factors[entry]
            host_layer_padding = self.host_layer_padding[entry]

            self.dma_layer_padding[entry] = self.get_padded_shape(
                host_layer_padding, padding_factor
            )

        # FIXME: the rest of the tiler code requires ifm, wgt and ofm as entry
        # since they needs to be the same as the key for core split, mem split, etc
        # Ideally we should not needs to "copy" the dict and change the keys just for
        # that.
        self._padded_shapes = {
            "ifm": self.dma_layer_padding["ifm_q"],
            "wgt": self.dma_layer_padding["ifm_k"],
            "ofm": self.dma_layer_padding["ofm"],
        }

    def __str__(self) -> str:

        # FIXME: add string conversion to the Kernel, Overlay and Layer class
        str_repr = "MHA Tiler: {\n"
        str_repr += f"\tLayer: {self.layer}\n"
        str_repr += f"\tKernel: {self.kernel}\n"
        str_repr += f"\tOverlay: {vars(self.overlay)}\n"
        str_repr += f"\tKernel granularity: {str(self.kernel_granularity)}\n"
        str_repr += f"\t2D shapes: {str(self.activations_2d_shapes)}\n"
        str_repr += f"\tHost layer padding: {str(self.host_layer_padding)}\n"
        str_repr += f"\tDma layer padding: {str(self.dma_layer_padding)}\n"
        str_repr += f"\tValid memtile subvolume: {self.valid_memtile_subvols}\n"
        str_repr += f"\tValid memtile iters: {self.valid_memtile_iters}\n"
        str_repr += f"\tValid core subvolume: {self.valid_core_subvols}\n"
        str_repr += f"\tValid core iters: {self.valid_core_iters}\n"
        str_repr += "}"
        return str_repr

    def calculate_memtile_tilings(self):
        # layer_shapes = {
        #     'ifm':np.array(self.in_act_shape),
        #     'wgt':np.array(self.in_wgt_shape),
        #     'ofm':np.array(self.out_act_shape)
        #     }

        memtile_max_shapes = {}
        memtile_min_shapes = {}

        # split_ratios = {}

        self.logger.debug("Calculate memtile tiling:")
        self.logger.debug("Overlay memory splits used: " + str(self.overlay.mem_splits))
        self.logger.debug("Start computing maximum and minimum memtile shape")
        invalid = False
        for operand in self.overlay.mem_splits.keys():
            memtile_max_shapes[operand] = (
                self._padded_shapes[operand] // self.overlay.mem_splits[operand]
            )
            # captures how many core subvolumes are stored in each memtile
            split_ratio = (
                self.overlay.core_splits[operand] / self.overlay.mem_splits[operand]
            )  # Can be non-integer
            memtile_min_shapes[operand] = self.kernel_granularity[operand] * split_ratio

            if (memtile_min_shapes[operand] > memtile_max_shapes[operand]).any():
                invalid = True
                break

        # self.context.logger.debug(split_ratios)
        self.logger.debug("memtile max shapes: " + str(memtile_max_shapes))
        self.logger.debug("memtile min shapes: " + str(memtile_min_shapes))

        if invalid:
            raise TilingError(
                "max memtiling shape cannot be bigger than the min memtiling shape"
            )

        # print("memtile max shapes: " + str(memtile_max_shapes))
        # print("memtile min shapes: " + str(memtile_min_shapes))

        mi_max = memtile_max_shapes["ifm"][0]
        mi_min = memtile_min_shapes["ifm"][0]

        ki_max = memtile_max_shapes["ifm"][1]
        ki_min = memtile_min_shapes["ifm"][1]

        kw_max = memtile_max_shapes["wgt"][0]
        kw_min = memtile_min_shapes["wgt"][0]

        nw_max = memtile_max_shapes["wgt"][1]
        nw_min = memtile_min_shapes["wgt"][1]

        mo_max = memtile_max_shapes["ofm"][0]
        mo_min = memtile_min_shapes["ofm"][0]

        no_max = memtile_max_shapes["ofm"][1]
        no_min = memtile_min_shapes["ofm"][1]

        schedule_tilings = {}
        # # ##############
        # # Schedule 1: ifm pin, wgt full, ofm stream
        # # ##############
        # m_ifm_sublist = factors(mi_max, mi_min) # covers ifm full wgt full case
        # k_ifm_sublist = [ki_max]

        # k_wgt_sublist = [kw_max]
        # n_wgt_sublist = [nw_max]

        # m_ofm_sublist = [mo_max]  # dummy
        # n_ofm_sublist = factors(no_max, no_min) ## nofm != nwgt because ofm is streaming

        # tmp = np.array(np.meshgrid(m_ifm_sublist, k_ifm_sublist, k_wgt_sublist, n_wgt_sublist, m_ofm_sublist, n_ofm_sublist)).T.reshape(-1,6)
        # tmp[:,4] = tmp[:,0] * (self.overlay.mem_splits['ifm'][0]/self.overlay.mem_splits['ofm'][0]) #split_ratios['ofm'][0] # M is same for ifm and ofm (needs correction for dummy copies)
        # schedule_tilings[1] = tmp

        # # ##############
        # # Schedule 2: ifm pin, wgt stream, ofm stream
        # # ##############
        # m_ifm_sublist = factors(mi_max, mi_min)
        # k_ifm_sublist = [ki_max]

        # k_wgt_sublist = factors(kw_max, kw_min)[:-1] # ksub = K is covered in schedule 5
        # n_wgt_sublist = factors(nw_max, nw_min) ## nsub = N is not relevant

        # m_ofm_sublist = [mo_max]  # dummy
        # n_ofm_sublist = [no_max]  # dummy

        # tmp = np.array(np.meshgrid(m_ifm_sublist, k_ifm_sublist, k_wgt_sublist, n_wgt_sublist, m_ofm_sublist, n_ofm_sublist)).T.reshape(-1,6)
        # tmp[:,4] = tmp[:,0] * (self.overlay.mem_splits['ifm'][0]/self.overlay.mem_splits['ofm'][0]) # split_ratios['ofm'][0] # split ratios need to be used when copying tensors
        # tmp[:,5] = tmp[:,3] * (self.overlay.mem_splits['wgt'][1]/self.overlay.mem_splits['ofm'][1]) # split_ratios['ofm'][1]
        # schedule_tilings[2] = tmp

        # # ##############
        # # Schedule 3: ifm full, wgt pin
        # # ##############
        # m_ifm_sublist = [mi_max]
        # k_ifm_sublist = [ki_max]
        # k_wgt_sublist = [kw_max]
        # n_wgt_sublist = factors(nw_max, nw_min)[:-1] ## nsub = N is covered in schedule 1

        # schedule_tilings[3] = np.array(np.meshgrid(m_ifm_sublist, k_ifm_sublist, k_wgt_sublist, n_wgt_sublist)).T.reshape(-1,4)

        # # ##############
        # # Schedule 4: ifm stream, wgt pin
        # # ##############
        # m_sublist =     factors(mi_max, mi_min)
        # k_ifm_sublist = factors(ki_max, ki_min)[:-1] ## ksub = K case is irrelevant
        # k_wgt_sublist = [kw_max]
        # n_sublist = factors(nw_max, nw_min) ## nsub = N is covered in schedule 1

        # schedule_tilings[4] = np.array(np.meshgrid(m_sublist, k_ifm_sublist, k_wgt_sublist, n_sublist)).T.reshape(-1,4)

        # # ##############
        # # Schedule 5: ifm stream, wgt stream
        # # ##############
        # m_ifm_sublist = factors(mi_max, mi_min)
        # k_ifm_sublist = factors(ki_max, ki_min)[:-1] ## ksub_ifm = K is covered earlier

        # k_wgt_sublist = [0] ## ksub_ifm = ksub_wgt is the only relevant case
        # n_wgt_sublist = factors(nw_max, nw_min)

        # m_ofm_sublist = [0] # dummy as m_ofm = m_ifm
        # n_ofm_sublist = [0] # dummy as n_ofm = n_ifm

        # tmp = np.array(np.meshgrid(m_ifm_sublist, k_ifm_sublist, k_wgt_sublist, n_wgt_sublist,m_ofm_sublist,n_ofm_sublist)).T.reshape(-1,6)
        # tmp[:,2] = tmp[:,1]
        # tmp[:,4] = tmp[:,0] * (self.overlay.mem_splits['ifm'][0]/self.overlay.mem_splits['ofm'][0]) #split_ratios['ofm'][0] # split ratios need to be used when copying tensors
        # tmp[:,5] = tmp[:,3] * (self.overlay.mem_splits['wgt'][1]/self.overlay.mem_splits['ofm'][1]) #split_ratios['ofm'][1]
        # schedule_tilings[5] = tmp

        # ##############
        # Schedule 6: ifm pin, wgt pin
        # ##############
        m_ifm_sublist = factors(mi_max, mi_min)
        if not len(m_ifm_sublist):
            raise TilingError(
                "No common factor found between max and min memtile size of the M dim"
            )

        k_ifm_sublist = [ki_max]  ## ksub_ifm = K is covered earlier

        k_wgt_sublist = [kw_max]  ## ksub_ifm = ksub_wgt is the only relevant case
        n_wgt_sublist = [nw_max]

        m_ofm_sublist = [0]  # dummy as m_ofm = m_ifm
        n_ofm_sublist = [0]  # dummy as n_ofm = n_ifm

        tmp = np.array(
            np.meshgrid(
                m_ifm_sublist,
                k_ifm_sublist,
                k_wgt_sublist,
                n_wgt_sublist,
                m_ofm_sublist,
                n_ofm_sublist,
            )
        ).T.reshape(-1, 6)
        tmp[:, 2] = tmp[:, 1]
        tmp[:, 4] = tmp[:, 0] * (
            self.overlay.mem_splits["ifm"][0] / self.overlay.mem_splits["ofm"][0]
        )  # split_ratios['ofm'][0] # split ratios need to be used when copying tensors
        if self.mha_mode == "2p1":
            tmp[:, 5] = tmp[:, 3] * (
                self.overlay.mem_splits["wgt"][1] / self.overlay.mem_splits["ofm"][1]
            )  # split_ratios['ofm'][1]
        elif self.mha_mode == "3p0":
            tmp[:, 5] = tmp[:, 3] * self.overlay.mem_splits["ifm"][1]
        schedule_tilings[6] = tmp
        self.logger.debug("schedule tiling: " + str(schedule_tilings))

        # # calculte temporal iterations for each tiling
        # schedule_iterations = {}
        # # calculate subvolumes for each schedule
        # schedule_subvolumes = {}
        # # calculate sublayers (for array tiling) for each schedule
        # schedule_sublayers = {}
        for schedule_id, tmp in schedule_tilings.items():
            self.memtile_iters[schedule_id] = (
                np.hstack(
                    [
                        memtile_max_shapes["ifm"],
                        memtile_max_shapes["wgt"],
                        memtile_max_shapes["ofm"],
                    ]
                )
                // tmp
            )

            self.memtile_subvols[schedule_id] = {
                "ifm": tmp[:, [0, 1]],  # * split_ratios['ifm'],
                "wgt": tmp[:, [2, 3]],  # * split_ratios['wgt'],
                "ofm": tmp[:, [4, 5]],  # * split_ratios['ofm']
            }

            self.memtile_sublayers[schedule_id] = {
                "ifm": tmp[:, [0, 1]]
                * self.overlay.mem_splits["ifm"],  # * split_ratios['ifm']
                "wgt": tmp[:, [2, 3]]
                * self.overlay.mem_splits["wgt"],  # * split_ratios['wgt']
                "ofm": tmp[:, [4, 5]]
                * self.overlay.mem_splits["ofm"],  # * split_ratios['ofm']
            }

        self.logger.debug("Memtile sublayers found: " + str(self.memtile_sublayers))
        self.logger.debug("Memtile subvolumes found: " + str(self.memtile_subvols))
        self.logger.debug("Memtile iters found: " + str(self.memtile_iters))

    def calculate_array_tilings(self):
        # """ NOTE
        # with our current understanding there is no benefit in creating subtilings for core from memtile subvolumes
        # with pinning or streaming scenarios there is no expected benefit from fetching multiple subvolumes of operand into memtile
        # and operating on them in smaller pieces, therefore the logic for subtiling is disabled for now
        # """
        # self.core_subvols = self.memtile_subvols
        # self.core_iters = self.memtile_iters

        # """ DISABLED FEATURE of array subtiling """
        # for each schedule, calculate the core subvolumes and core iterations
        # tilings which have the full k in the memtile

        self.logger.debug("Calculating array tilings:")
        self.logger.debug("Initial memtile sublayers: " + str(self.memtile_sublayers))
        for sched, memtile_sublayers in self.memtile_sublayers.items():

            core_min_shapes = self.kernel_granularity
            core_max_shapes = {}
            for operand in memtile_sublayers.keys():
                core_max_shapes[operand] = (
                    memtile_sublayers[operand] // self.overlay.core_splits[operand]
                )

            # FIXME: remove when not necessary anymore to force
            # subvolume of q (Sq) to be equal to 16 all the time.
            # No other values are currently supported
            # Also reintroduce the outer_loop_min constraint in the
            # kernel metadata file
            for core_max_shape in core_max_shapes["ifm"]:
                core_max_shape[0] = 16

            for core_max_shape in core_max_shapes["ofm"]:
                core_max_shape[0] = 16

            # self.context.logger.debug(core_max_shapes)
            # Independent variables are Msub Ksub and Nsub at the core level

            core_subvols = {}
            core_iters = {}

            for k in range(len(core_max_shapes["ofm"])):

                if sched == 1:
                    # ##############
                    # Schedule 1: ifm pin, wgt full (1 or 2 buffers of ifm, 1 for wgt, 1 for ofm)
                    # ##############
                    m_sublist = [core_max_shapes["ofm"][k][0]]
                    k_sublist = factors(
                        self._padded_shapes["ifm"][1], self.kernel_granularity["ifm"][1]
                    )
                    n_sublist = [core_max_shapes["ofm"][k][1]]

                elif sched == 2:
                    # ##############
                    # Schedule 2: ifm pin, wgt stream (1 or 2 buffers of ifm, 1 for wgt, 1 for ofm)
                    # ##############
                    m_sublist = [core_max_shapes["ofm"][k][0]]
                    k_sublist = [core_max_shapes["wgt"][k][0]]
                    n_sublist = [core_max_shapes["ofm"][k][1]]

                elif sched == 5:
                    # ##############
                    # Schedule 5: ifm stream, wgt stream (2 buffers of ifm, 2 for wgt, 1 for ofm)
                    # ##############
                    m_sublist = [core_max_shapes["ofm"][k][0]]
                    k_sublist = [core_max_shapes["wgt"][k][0]]
                    n_sublist = [core_max_shapes["ofm"][k][1]]

                if sched == 6:
                    # ##############
                    # Schedule 6: ifm pin, wgt pin (1 or 2 buffers of ifm, 1 or 2 for wgt, 1 or 2 for ofm)
                    # ##############
                    m_sublist = [core_max_shapes["ofm"][k][0]]
                    k_sublist = [core_max_shapes["wgt"][k][0]]
                    n_sublist = [core_max_shapes["ofm"][k][1]]
                else:
                    assert False, f"Unexpected schedule number {sched}"

                tmp = np.array(np.meshgrid(m_sublist, k_sublist, n_sublist)).T.reshape(
                    -1, 3
                )

                core_iters[k] = (
                    np.hstack((core_max_shapes["ifm"][k][0], core_max_shapes["wgt"][k]))
                    // tmp
                )
                core_subvols[k] = {
                    "ifm": tmp[:, :2],
                    "wgt": tmp[:, 1:],
                    "ofm": tmp[:, [0, 2]],
                }

            self.core_subvols[sched] = core_subvols
            self.core_iters[sched] = core_iters
        self.logger.debug("computed core subvolume: " + str(self.core_subvols))
        self.logger.debug("computed core iters: " + str(self.core_iters))

    def check_core_constraints(self):
        core_bank_capacity = 1024 * (
            self.device.core_data_memory // self.device.core_num_banks
        )  # bytes
        memtile_capacity = (
            self.device.memtile_capacity * self.device.memtile_rows * 1024
        )  # In bytes

        fits_within_memtiles = {}
        valid_subvolumes = {}
        valid_memtile_sublayers = {}
        valid_fits_within_memtiles = {}
        valid_memtile_iters = {}

        for sched, core_subvol_dict in self.core_subvols.items():
            # self.core_validity_checks.setdefault(sched,{})
            memtile_validity = []
            for mem_sub_id, subvols in core_subvol_dict.items():
                self.logger.debug(
                    f"Check core constraints schedule id '{sched}' and mem_sub_id '{mem_sub_id}' and sub volumes '{subvols}':"
                )

                self.vars_dict["Sq"] = subvols["ifm"][:, 0]
                assert self.vars_dict["Sq"] == 16, "Only Sq = 16 is supported currently"

                self.vars_dict["So"] = subvols["ofm"][:, 0]
                self.vars_dict["Skv"] = subvols["wgt"][:, 1]
                self.vars_dict["Dh"] = subvols["ifm"][:, 1]
                # check buffer placements
                max_space_available = self.device.core_data_memory * 1024
                total_space_required = 0
                for bank, formula in self.inverted_placement.items():
                    total_space_required += eval(formula, self.vars_dict)

                valid_placement = max_space_available >= total_space_required
                if not valid_placement:
                    self.logger.debug(
                        f"Memory allocation invalid for subvols {subvols}. Space required is '{total_space_required}' while core capacity is '{max_space_available}'"
                    )

                validity_checks = {"buffer_placement": valid_placement}

                # loop constraints eval
                for constraint, formula in self.kernel.other_constraints.items():
                    validity_checks[constraint] = eval(formula, self.vars_dict)

                # memtile constraints check
                # correction for wgt bias packing
                num_wgt_subvols = np.prod(
                    self.memtile_subvols[sched]["wgt"][mem_sub_id] / subvols["wgt"], 1
                )
                bias_bytes_per_subvol = (
                    subvols["wgt"][:, 1] * self.vars_dict["bias_bytes"]
                )
                bias_bytes_in_memtile = num_wgt_subvols * bias_bytes_per_subvol

                memtile_subvolume_sizes = {
                    "ifm": np.prod(self.memtile_subvols[sched]["ifm"][mem_sub_id])
                    * self.layer.in_bytes,
                    "wgt": np.prod(self.memtile_subvols[sched]["wgt"][mem_sub_id])
                    * self.layer.wgt_bytes,
                    "ofm": np.prod(self.memtile_subvols[sched]["ofm"][mem_sub_id])
                    * self.layer.out_ofm_bytes,
                }
                if self.overlay.mem_splits["ifm"][0] == 2:
                    memtile_subvolume_sizes["ifm"] = memtile_subvolume_sizes["ifm"] * 2

                if sched == 1:
                    # ##############
                    # Schedule 1: ifm pin, wgt full (1 or 2 buffers of ifm, 1 for wgt, 1 for ofm)
                    # ##############
                    space_required_1buff = (
                        memtile_subvolume_sizes["ifm"]
                        + (memtile_subvolume_sizes["wgt"] + bias_bytes_in_memtile)
                        + memtile_subvolume_sizes["ofm"]
                    )
                    space_required_2buff = (
                        memtile_subvolume_sizes["ifm"] * 2
                        + (memtile_subvolume_sizes["wgt"] + bias_bytes_in_memtile)
                        + memtile_subvolume_sizes["ofm"]
                    )
                    total_space_required = np.vstack(
                        (space_required_1buff, space_required_2buff)
                    ).T
                elif sched == 2:
                    # ##############
                    # Schedule 2: ifm pin, wgt stream (1 or 2 buffers of ifm, 1 for wgt, 1 for ofm)
                    # ##############
                    space_required_1buff = (
                        memtile_subvolume_sizes["ifm"]
                        + (memtile_subvolume_sizes["wgt"] + bias_bytes_in_memtile) * 2
                        + memtile_subvolume_sizes["ofm"]
                    )
                    space_required_2buff = (
                        memtile_subvolume_sizes["ifm"] * 2
                        + (memtile_subvolume_sizes["wgt"] + bias_bytes_in_memtile) * 2
                        + memtile_subvolume_sizes["ofm"]
                    )
                    total_space_required = np.vstack(
                        (space_required_1buff, space_required_2buff)
                    ).T
                # elif sched == 3:
                #     # ##############
                #     # Schedule 3: ifm full, wgt pin
                #     # ##############
                #     space_required_1buff = memtile_subvolume_sizes['ifm'] + (memtile_subvolume_sizes['wgt'] + bias_bytes_in_memtile) + memtile_subvolume_sizes['ofm'] * 2
                #     space_required_2buff = memtile_subvolume_sizes['ifm'] + (memtile_subvolume_sizes['wgt'] + bias_bytes_in_memtile)*2 + memtile_subvolume_sizes['ofm'] * 2
                #     total_space_required = np.vstack((space_required_1buff, space_required_2buff)).T
                # elif sched == 4:
                #     # ##############
                #     # Schedule 4: ifm stream, wgt pin
                #     # ##############
                #     space_required_1buff = memtile_subvolume_sizes['ifm']*2 + (memtile_subvolume_sizes['wgt'] + bias_bytes_in_memtile) + memtile_subvolume_sizes['ofm'] * 2
                #     space_required_2buff = memtile_subvolume_sizes['ifm']*2 + (memtile_subvolume_sizes['wgt'] + bias_bytes_in_memtile)*2 + memtile_subvolume_sizes['ofm'] * 2
                #     total_space_required = np.vstack((space_required_1buff, space_required_2buff)).T
                elif sched == 5:
                    # ##############
                    # Schedule 5: ifm stream, wgt stream
                    # ##############
                    space_required_2buff = (
                        memtile_subvolume_sizes["ifm"] * 2
                        + (memtile_subvolume_sizes["wgt"] + bias_bytes_in_memtile) * 2
                        + memtile_subvolume_sizes["ofm"] * 1
                    )
                    total_space_required = space_required_2buff.reshape(
                        (len(space_required_2buff), 1)
                    )
                elif sched == 6:
                    # ##############
                    # Schedule 6: ifm pin, wgt pin
                    # ##############
                    space_required_1buff = (
                        memtile_subvolume_sizes["ifm"]
                        + memtile_subvolume_sizes["wgt"]
                        + memtile_subvolume_sizes["ofm"]
                    )
                    space_required_2buff = (
                        memtile_subvolume_sizes["ifm"] * 2
                        + memtile_subvolume_sizes["wgt"] * 2
                        + memtile_subvolume_sizes["ofm"] * 2
                    )
                    total_space_required = np.vstack(
                        (space_required_1buff, space_required_2buff)
                    ).T
                else:
                    assert False, f"Unexpected schedule number {sched}"

                # self.context.logger.debug(self.overlay.core_splits, sched, mem_sub_id)
                fits_in_memtiles = total_space_required < (memtile_capacity - 16 * 1024)
                self.logger.debug("fit in memtile: " + str(fits_in_memtiles))
                self.fits_in_memtile[sched][mem_sub_id] = fits_in_memtiles
                memtile_validity.append(~np.all(~fits_in_memtiles))

                # subvols that satisfy all constraints
                valid_subvols = np.all(
                    np.hstack(
                        (np.vstack(list(validity_checks.values())).T, fits_in_memtiles)
                    ),
                    1,
                )
                # self.context.logger.debug(validity_checks)

                self.valid_core_subvids[sched][mem_sub_id] = valid_subvols

                self.core_validity_checks[sched][mem_sub_id] = validity_checks
                # self.core_validity_checks

                # self.valid_core_subvols.setdefault(sched,{})
                if ~np.all(~fits_in_memtiles):
                    self.valid_core_subvols[sched][mem_sub_id] = {
                        "ifm": subvols["ifm"][valid_subvols, :],
                        "wgt": subvols["wgt"][valid_subvols, :],
                        "ofm": subvols["ofm"][valid_subvols, :],
                    }

                    self.valid_memtile_iters[sched][mem_sub_id] = self.memtile_iters[
                        sched
                    ][mem_sub_id]
                    self.valid_fits_in_memtile[sched][mem_sub_id] = (
                        self.fits_in_memtile[sched][mem_sub_id]
                    )
                    self.valid_memtile_subvols[sched][mem_sub_id] = {
                        operand: self.memtile_subvols[sched][operand][mem_sub_id]
                        for operand in self.memtile_subvols[sched].keys()
                    }

                    self.logger.debug(
                        f"valid memtile subvolumes found for schedule number '{sched}' and subv id '{mem_sub_id}': "
                        + str(self.valid_memtile_subvols[sched][mem_sub_id])
                    )
                    self.logger.debug(
                        f"valid core tiling subvolumes found for schedule number '{sched}' and subv id '{mem_sub_id}' "
                        + str(self.valid_core_subvols[sched][mem_sub_id])
                    )
                self.valid_core_iters[sched][mem_sub_id] = self.core_iters[sched][
                    mem_sub_id
                ][valid_subvols]

            # fits_within_memtiles[sched] = np.array(memtile_validity)
            # valid_fits_within_memtiles[sched] = {k:v for k,v in self.fits_in_memtile[sched].items() if memtile_validity[k]}

            valid_subvolumes[sched] = {
                operand: self.memtile_subvols[sched][operand][
                    np.array(memtile_validity)
                ]
                for operand in self.memtile_subvols[sched].keys()
            }

            valid_memtile_sublayers[sched] = {
                operand: tiling * self.overlay.mem_splits[operand]
                for operand, tiling in valid_subvolumes[sched].items()
            }
            # valid_memtile_iters[sched] = self.memtile_iters[sched][np.array(memtile_validity)]

        # self.fits_in_memtile = fits_within_memtiles
        # self.valid_fits_in_memtile = valid_fits_within_memtiles

        # self.valid_memtile_subvols = valid_subvolumes
        # self.valid_memtile_sublayers = valid_memtile_sublayers
        # self.valid_memtile_iters = valid_memtile_iters


if __name__ == "__main__":
    from overlay import Overlay

    ov = Overlay("4x4", "MHA", "M1K1N16")

    import json
    from layer import Layer

    with open("OGOAT/src/Tiler/tst_layer.json") as f:
        mdict = json.load(f)
    ld = mdict["tst"]
    ld["in_act_residency"] = "L3"
    ld["out_act_residency"] = "L3"

    l = Layer(ld)

    from kernel import Kernel
    from device import Device

    d = Device("strix")
    k = Kernel(l)

    t = MHATiler(l, d, ov, k, Logger.get_null_logger())
    t.calculate_memtile_tilings()
    # t.check_valid_memtile_tilings()
    t.calculate_array_tilings()
    t.check_core_constraints()
