'''
Map unary op shapes to the AIE-4 dataflow architecture.
External facing functions are documented below.

    generate_mappings - enumerate all possible ways to map a conv shape
    onto the compute array and sort them in order of descending projected
    latency (fastest mappings first)
'''
from math import lcm
from utils.utils_common import overlay_3x4_core_stack_addr, iceil, log
from scheduler.uniop.uniop_common import UnaryShape, UnaryMapping


class uniop_mappings:
    """
    Class for handling unary operation mappings.
    """

    def __init__(self, shape: UnaryShape = None):
        """Initialize the uniop_mappings class."""
        self.function = shape.function
        self.TensorDim = shape.TensorDim
        self.SpatialSplitMode = shape.SpatialSplitMode
        self.IfmBytes = shape.ifmbytes
        self.OfmBytes = shape.ofmbytes
        log("splitmode", shape.SpatialSplitMode)

    def generateMapping(self) -> UnaryMapping:
        '''tiler function for unary op, returns tuple of subvolume size and split mode'''
        log("splitmode", self.SpatialSplitMode)

        function = self.function
        TensorDim = self.TensorDim

        if self.SpatialSplitMode != "N1X12C1":
            return UnaryMapping(TensorDim, TensorDim, TensorDim, TensorDim, TensorDim, 'N1X12C1')  # This is dummy Mapping

        assert len(TensorDim) == 3
        (N, X, C) = TensorDim
        log("TensorDim:", TensorDim)

        # If C is not multiple of 64 Pad it to multiple of 64
        if C % 64 != 0:
            C = iceil(C, 64)

        # Available L1 space ######
        margin_factor = 10 if function in {"swish", "tanh", "sigmoid", "elu", "silu", "gelu"} else 4
        Cap = overlay_3x4_core_stack_addr() - 1024*margin_factor

        if function == "l2norm":
            assert TensorDim[2] > 64

        # Kernel Granularity :
        granX = 4
        granC = 64
        min_Csubv = 128
        IfmBytes = self.IfmBytes
        OfmBytes = self.OfmBytes

        # solve for factorX * factorC * granX * granC * IfmBytes * numIOtensors < Cap
        # Check case of factorX == 1, factorC * granC == TensorDim[2]

        # For element-wise unary op :
        if function in ["dequant", "quant", "silu", "gelu", "copy"]:
            if C * granX * (IfmBytes + OfmBytes) >= Cap:
                X = X * 16
                C = C // 16
                TensorDim = (N, X, C)   # if op is "element-wise" Unary op, TensorDim is reshaped

        log("C: ", C)
        if C * granX * (IfmBytes + OfmBytes) >= Cap:
            assert False

        minfactorC = min_Csubv // granC  # hold for softmax/layernorm/l2norm
        factorC = max(iceil(C, granC) // granC, minfactorC)
        MaxMul_granX = iceil(N*X, granX*12) // (granX*12)
        log("MaxMul_granX: ", MaxMul_granX)
        log("factorC: ", factorC)
        log("factorC * granX * granC * IfmBytes * numIOtensors: ", factorC * granX * granC * (IfmBytes + OfmBytes))

        MinMul_granX = 1
        while (MinMul_granX + 1) * granX * factorC * granC * (IfmBytes + OfmBytes) <= Cap and MinMul_granX + 1 <= MaxMul_granX:
            MinMul_granX = MinMul_granX + 1
            log("MinMul_granX:", MinMul_granX)

        PaddedTensorDim = (N, X, C)
        SubVolumeDim = (1, MinMul_granX * granX, factorC*granC)

        log("TensorDim: ", TensorDim)  # ReShaped TensorDim
        log("PaddedTensorDim: ", PaddedTensorDim)
        log("SubVolumeDim: ", SubVolumeDim)

        return UnaryMapping(TensorDim, PaddedTensorDim, SubVolumeDim, SubVolumeDim, (1, 2, C), 'N1X12C1')

    def groupnormMapping(self) -> UnaryMapping:
        """ Tiler function for groupnorm op """
        splitmode = 'N1X4C3'
        # splitmode = 'N1X1C12'
        TensorDim = self.TensorDim
        (_, X, C) = TensorDim
        IfmBytes = self.IfmBytes
        NG = 32  # Currently fixed to 32 groups
        ncores = 12
        AieCols = 3
        # ncores_per_col = ncores // AieCols
        core_vec_len = 32
        wordlen = 4  # 4 byte word boundary
        assert C % NG == 0, "Groupnorm C should be multiple of 32"
        # If C is not multiple of 64 Pad it to multiple of 64
        if C % 64 != 0:
            C = iceil(C, 64)

        group_size = C // NG
        subvX = lcm(group_size, core_vec_len) // group_size
        log("subvX:", subvX)

        if splitmode == "N1X4C3":
            # Calculate number of groups to process per core
            groups_per_col = iceil(NG, AieCols) // AieCols
            subvC = groups_per_col * group_size
            assert (subvC*IfmBytes) % wordlen == 0, "SubvC size should be multiple of word len"
            SubVolumeDim = (1, subvX, subvC)
        elif splitmode == "N1X1C12":
            groups_per_core = iceil(NG, ncores) // ncores
            subvC = groups_per_core * group_size
            SubVolumeDim = (1, subvX, subvC)
        else:
            assert False, "Unsupported SpatialSplitMode for groupnorm"
        paddX = X  # iceil(X, subvX)
        PaddedTensorDim = (1, paddX, C)
        Npass = 2  # Currently fixed to 2 passes
        Ngroups = 32  # Currently fixed to 32 groups
        return UnaryMapping(TensorDim, PaddedTensorDim, SubVolumeDim, SubVolumeDim, SubVolumeDim, splitmode, Npass=Npass, Ngroups=Ngroups)


def get_uniop_mappings(shape: UnaryShape) -> UnaryMapping:
    """
    Function to get the unary op mappings
    """
    mapper = uniop_mappings(shape)
    if shape.function == "groupnorm":
        mapping = mapper.groupnormMapping()
    else:
        mapping = mapper.generateMapping()
    return mapping
