import os

from OGOAT.src.Tiler.dataflow_tiling_opt import DataflowTilingOpt
from OGOAT.src.Tiler.tiling_result import TilingResult
from dataflow.concat.concat_run_tiler import ConcatDims
from enum import Enum

parent_dir = os.path.dirname(
    os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
)


class Tile_Attr(Enum):
    SUBVOLS = 0
    ITERS = 1
    SIZES = 2


class ConcatTilingOpt(DataflowTilingOpt):
    def find_optimal_tiling(self) -> TilingResult:
        dims = self.tiler.get_op_dims()
        ifm_mem_tile_size, ifm_n = max(
            [
                (
                    (dims.Yis[n] * dims.Xi[n] * dims.Ci[n] * dims.ifm_bits)
                    // 8,
                    n,
                )
                for n in range(dims.num_inputs)
            ],
            key=lambda x: x,
        )
        ofm_memtile_size, ofm_n = max(
            [
                ((dims.Yos * dims.Xi[n] * dims.Ci[n] * dims.ofm_bits) // 8, n)
                for n in range(dims.num_inputs)
            ],
            key=lambda x: x,
        )

        subV_ifm, iters_ifm, sizes_ifm = (
            lambda enum: tuple(
                self.getIfm(Tile_Attr_Item, ifm_n, ifm_mem_tile_size, dims)
                for Tile_Attr_Item in enum
            )
        )(Tile_Attr)

        subV_ofm, iters_ofm, sizes_ofm = (
            lambda enum: tuple(
                self.getOfm(Tile_Attr_Item, ofm_n, ofm_memtile_size, dims)
                for Tile_Attr_Item in enum
            )
        )(Tile_Attr)

        tiling_result = TilingResult(op_dims=dims)
        tiling_result.add_mem_tile_params(
            subV_ifm, subV_ofm, iters_ifm, iters_ofm, sizes_ifm, sizes_ofm
        )
        tiling_result.add_overlay_info(self.tiler.overlay, self.mode)

        tiling_result.add_kernel_info(kernel_includes=["super.hh"])

        tiling_result.add_layer_info(self.tiler.layer)

        return tiling_result

    def getIfm(
        self,
        tile_attr: Tile_Attr,
        ifm_n: int,
        ifm_memt_tile_size: int,
        dims: ConcatDims,
    ):
        if tile_attr == Tile_Attr.SUBVOLS:
            return [dims.Yis[ifm_n], dims.Xis[ifm_n], dims.Cis[ifm_n]]
        elif tile_attr == Tile_Attr.ITERS:
            return [
                dims.Yi[ifm_n] // dims.Yis[ifm_n],
                dims.Xi[ifm_n] // dims.Xis[ifm_n],
                dims.Ci_loop[ifm_n],
            ]
        elif tile_attr == Tile_Attr.SIZES:
            return [ifm_memt_tile_size]
        else:
            raise ValueError("Invalid Tile Attribute")

    def getOfm(
        self,
        tpye: Tile_Attr,
        ofm_n: int,
        ofm_memtile_size: int,
        dims: ConcatDims,
    ):
        if tpye == Tile_Attr.SUBVOLS:
            return [dims.Yos, dims.Xos[ofm_n], dims.Cos[ofm_n]]
        elif tpye == Tile_Attr.ITERS:
            return [
                dims.Yo // dims.Yos,
                dims.Xo // dims.Xos[ofm_n],
                dims.Co_loop[ofm_n],
            ]
        elif tpye == Tile_Attr.SIZES:
            return [ofm_memtile_size]
        else:
            raise ValueError("Invalid Tile Attribute")
