import struct
import os
from utils.utils_common import log
from kernel.common.kernel_params_helper import DimsHelper

'''
struct KernelLayerNorm_fp16x16_Param {
    uint16_t order_64;
    uint16_t inner_g;
    uint16_t X_g;
    uint16_t outer_g;
    dims_2d_param dimsI_ol;
    dims_2d_param dimsO_ol;
    dims_2d_param dimsI_il;
    dims_2d_param dimsO_il;
};
'''

class LayernormDims:
    def __init__(
        self,
        # M, N,
        Nsubv,
        Ysubv,
        Xsubv,
        Csubv,
        order_select,
        aie_rows=4,
        aie_cols=3,
    ):
        self.granN = 1
        self.granY = 1
        self.granX = 2
        self.granC = 32 if order_select == 0 else 64

        self.Nsubv = Nsubv
        self.Ysubv = Ysubv  # Kernel  Dimension : Number of rows in Input Subvolume
        self.Xsubv = Xsubv  # Kernel  Dimension : Number of rows in Input Subvolume
        self.Csubv = Csubv
        self.order_select = order_select

        self.aie_rows = aie_rows  # Number of rows in Aie Array
        self.aie_cols = aie_cols  # Number of cols in Aie Array

        # self.inner_loop_num_iterations = self.Nsubv // self.granC
        # self.outer_loop_num_iterations = self.Msubv // (self.granN * self.granY * self.granX)

def setup_layernorm_kernel_params(LRNdims: LayernormDims, order_select) -> bytes:

    # raw :
    N = LRNdims.Nsubv
    Y = LRNdims.Ysubv
    X = LRNdims.Xsubv
    C = LRNdims.Csubv

    # L1 : NYCXC64
    granN = LRNdims.granN
    granY = LRNdims.granY
    granX = LRNdims.granX
    granC = LRNdims.granC
    granXC = granX * granC

    sizeof_dtype_O0 = 2
    sizeof_dtype_I0 = 2

    # kernel_parameter_setup:
    co_g = C // granC  # uint8_t
    C_steps = 2 if order_select == 1 else 1
    
    Step_Ci = N * X * granC * sizeof_dtype_I0
    Step_Co = N * X * granC * sizeof_dtype_O0
    Step_Yi = N * X * C  * sizeof_dtype_I0
    Step_Yo = N * X * C * sizeof_dtype_O0

    x_g = N * X // granX
    X_g = x_g

    dims1 = DimsHelper(-co_g * Step_Ci)
    dims2 = DimsHelper(-co_g * Step_Co)
    dims3 = DimsHelper(0)
    dims5 = DimsHelper(0)

    # inner_g(co_g, granC)
    inner_g = co_g * (granC // 32)

    # outer_g(N,Y,X,C,granN, granY, granX, granC)
    outer_g = (N * Y * X * C) // (granN * granY * granX * granC) * (granC // 32)

    # dimsI_il(C_steps, step_Ci)
    dimsI_il = dims3.from_steps(C_steps, (32 * sizeof_dtype_I0, Step_Ci))

    dimsI_il["num0"] = 1
    dimsI_il["inc0"] = 32 * sizeof_dtype_I0  # jumping over 32 elemnts
    dimsI_il["inc1"] = (
        LRNdims.Xsubv * 64 - 32
    ) * sizeof_dtype_I0  # 64 from W64 requirement

    # dims3.from_steps( C_steps, (32 * sizeof(dtype.I0), Step_Ci ) )
    # dimsO_il(C_steps, step_Co)
    dimsO_il = dims5.from_steps(C_steps, (32 * sizeof_dtype_O0, Step_Co))
    dimsO_il["num0"] = dimsI_il["num0"]
    dimsO_il["inc0"] = dimsI_il["inc0"]
    dimsO_il["inc1"] = dimsI_il["inc1"]

    dimsI_ol = dims1.from_steps((X_g), (64 * int(sizeof_dtype_I0) * granX, Step_Yi))
    dimsI_ol["num0"] = LRNdims.Xsubv // granX - 1
    dimsI_ol["inc0"] = (granX * 64 * sizeof_dtype_I0) - (
        dimsI_il["inc0"] + dimsI_il["inc1"]
    ) * (LRNdims.Csubv // 64)
    dimsI_ol["inc1"] = 0 - (LRNdims.Xsubv - granX) * (64 * sizeof_dtype_I0)

    dimsO_ol = dims2.from_steps((X_g), (64 * int(sizeof_dtype_O0) * granX, Step_Yo))
    dimsO_ol["num0"] = dimsI_ol["num0"]
    dimsO_ol["inc0"] = dimsI_ol["inc0"]
    dimsO_ol["inc1"] = dimsI_ol["inc1"]

    # requirements:
    assert C % granC == 0 and X % granX == 0

    log("--------------------------------")
    log(f"dims.Nsubv:{LRNdims.Nsubv}")
    log(f"dims.Ysubv:{LRNdims.Ysubv}")
    log(f"dims.Xsubv:{LRNdims.Xsubv}")
    log(f"dims.Csubv:{LRNdims.Csubv}")
    log("--------------------------------")
    log(f"Co_g: {co_g}")
    log(f"X_g: {x_g}")
    log(f"inner_g: {inner_g}")
    log(f"outer_g: {outer_g}")
    log(f"step_Ci: {Step_Ci}")
    log(f"step_Co: {Step_Co}")
    log(f"step_Yi: {Step_Yi}")
    log(f"step_Yo: {Step_Yo}")

    log("dimsI_il:", dimsI_il)
    log("dimsI_ol:", dimsI_ol)

    log("dimsO_il:", dimsO_il)
    log("dimsO_ol:", dimsO_ol)

    # assert_equal_dims(dimsI_il,dimsO_il)
    # assert_equal_dims(dimsI_ol,dimsO_ol)

    # first v32 elements in W64, then 2nd v32 elements in W64, alternating between the 2, hence the num0 shall be 1 (as 2 minus 1)
    """
    assert(dimsI_il['num0']==1)
    assert(dimsI_il['inc0']==granC * sizeof_dtype_I0) ## jumping over 32 elemnts
    assert(dimsI_il['inc1']==(LRNdims.Msubv * 64  - granC) * sizeof_dtype_I0)  ## 64 from W64 requirement

    assert(dimsI_ol['num0']==LRNdims.Msubv // granX - 1)
    assert(dimsI_ol['inc0']==((granX*64*sizeof_dtype_I0)-(dimsI_il['inc0']+dimsI_il['inc1'])*(LRNdims.Nsubv//64)))
    assert(dimsI_ol['inc1']==(0-(LRNdims.Msubv - granX)*(64*sizeof_dtype_I0)))

    assert_equal_dims(dimsI_il, dimsM_il, scale=16)
    assert_equal_dims(dimsI_ol, dimsM_ol, scale=16)
    """
    packed_params = struct.pack(
        "<4H1I2i1I2i1I2i1I2i",
        order_select,  # H
        inner_g,  # H
        X_g,  # H
        outer_g,  # H
        dimsI_ol["num0"],
        dimsI_ol["inc0"],
        dimsI_ol["inc1"],  # 1I2i
        dimsO_ol["num0"],
        dimsO_ol["inc0"],
        dimsO_ol["inc1"],  # 1I2i
        dimsI_il["num0"],
        dimsI_il["inc0"],
        dimsI_il["inc1"],  # 1I2i --> 4x3 == 12 bytes
        dimsO_il["num0"],
        dimsO_il["inc0"],
        dimsO_il["inc1"],  # 1I2i
    )
    return packed_params


def layernorm_layer_params(
    CoreInputAddr: int,
    CoreWeightAddr: int,
    CoreOutputAddr: int,
    QdqParamAddr: int,
    dqBufferAddr: int,
    qBufferAddr: int,
    TrueNumCols: int,
    Msubv: int,
    Nsubv: int,
    num_elem_subv: int,
    sign_A: int, 
    sign_O: int,
    layernorm_dims: LayernormDims,
) -> bytes:

    CoreInputAddr += 0xE0000
    CoreWeightAddr += 0xE0000
    CoreOutputAddr += 0xE0000
    QdqParamAddr += 0xE0000
    dqBufferAddr += 0xE0000
    qBufferAddr += 0xE0000

    bytes = (
        CoreInputAddr.to_bytes(length=4, byteorder="little", signed=False)
        + CoreWeightAddr.to_bytes(length=4, byteorder="little", signed=False)
        + CoreOutputAddr.to_bytes(length=4, byteorder="little", signed=False)
        + QdqParamAddr.to_bytes(length=4, byteorder="little", signed=False)
        + dqBufferAddr.to_bytes(length=4, byteorder="little", signed=False)
        + qBufferAddr.to_bytes(length=4, byteorder="little", signed=False)
        + TrueNumCols.to_bytes(length=4, byteorder="little", signed=False)
        + Msubv.to_bytes(length=4, byteorder="little", signed=False)
        + Nsubv.to_bytes(length=4, byteorder="little", signed=False)
        + num_elem_subv.to_bytes(length=4, byteorder="little", signed=False)
        + sign_A.to_bytes(length=4, byteorder="little", signed=False)
        + sign_O.to_bytes(length=4, byteorder="little", signed=False)
        + setup_layernorm_kernel_params(layernorm_dims, order_select=layernorm_dims.order_select)
    )

    log("Number of bytes layernorm layer params:", len(bytes))
    return bytes

if __name__ == "__main__":
    # ( batchSize, Msubv, Nsubv ) == (Y,X,C)
    os.environ["LOG_ENABLED"]="true"

    test_subvolume_shapes = [(2, 2, 8, 256)]
    order_select = 1
    for test_id, subvolume_shape in enumerate(test_subvolume_shapes):

        Nsubv, Ysubv, Xsubv, Csubv = subvolume_shape

        layernorm_dims = LayernormDims(
            Nsubv, Ysubv, Xsubv, Csubv, order_select
        )  # --> Msubv, Nsubv

        packed_params = setup_layernorm_kernel_params(layernorm_dims, order_select)

        with open("krn_param_toArray" + str(test_id) + ".bin", "wb") as file:
            file.write(packed_params)

        unpacked_values = struct.unpack("<4H1I2i1I2i1I2i1I2i", packed_params)
        print(unpacked_values)  # Tuple Typed

        expected_size = struct.calcsize("<4H1I2i1I2i1I2i1I2i")
        print("expected kernel param size:", expected_size)

        print("layer param size:", len(layernorm_layer_params(0, 0, 0, 0, 0, 256, 0, 0, 0, 0, layernorm_dims)))





