import struct


from utils.utils_common import log
from kernel.common.kernel_params_helper import DimsHelper

"""
C++ Kernel Parameter Definition required by kernel :

struct KernelSoftmax_fp16x16Param {
    uint8_t Co_g;
    uint8_t X_g;
    uint16_t outer_g;
    uint16_t step_Ci;
    uint16_t step_Co;
    uint16_t step_Yi;
    uint16_t step_Yo;
    dims_2d_param dimsI_il;
    dims_2d_param dimsI_ol;
    dims_2d_param dimsO_il;
    dims_2d_param dimsO_ol;
    dims_2d_param dimsM_il;
    dims_2d_param dimsM_ol;
};
"""


class SoftmaxDims:
    def __init__(
        self,
        # M, N,
        Y,
        Msubv,
        Nsubv,
        aie_rows=4,
        aie_cols=3,
    ):
        self.granN = 1
        self.granY = 1
        self.granX = 4
        self.granC = 32

        self.param_size = 1024
        self.param_bits = 8
        # self.M = M                 # Problem Dimension : Number of rows in Input Tensor
        # self.N = N                 # Problem Dimension : Number of cols in Input Tensor
        self.Msubv = Msubv  # Kernel  Dimension : Number of rows in Input Subvolume
        self.Nsubv = Nsubv  # Kernel  Dimension : Number of rows in Input Subvolume
        self.Y = Y

        self.aie_rows = aie_rows  # Number of rows in Aie Array
        self.aie_cols = aie_cols  # Number of cols in Aie Array
        # self.qdq_size = 128

        self.inner_loop_num_iterations = self.Nsubv // self.granC
        self.outer_loop_num_iterations = self.Msubv // (
            self.granN * self.granY * self.granX
        )


def assert_equal_dims(dim2d_A, dim2d_B, scale=1):
    assert dim2d_A["num0"] == dim2d_B["num0"]
    assert dim2d_A["inc0"] == dim2d_B["inc0"] * scale
    assert dim2d_A["inc1"] == dim2d_B["inc1"] * scale


def setup_softmax_kernel_params(sfmx_dims: SoftmaxDims) -> bytes:

    # raw :
    N = 1
    Y = sfmx_dims.Y
    X = sfmx_dims.Msubv
    C = sfmx_dims.Nsubv

    # L1 : NYCXC64
    granN = sfmx_dims.granN
    granY = sfmx_dims.granY
    granX = sfmx_dims.granX
    granC = sfmx_dims.granC
    granXC = granX * granC
    x_g = N * X // granX
    co_g = C // granC
    Step_Ci = N * X * granC * granX
    Step_Co = N * X * granC * granX
    Step_M_C = (2 * granXC // 8) * x_g
    Step_M_Y = co_g // 2 * Step_M_C
    dims = DimsHelper()
    dims1 = DimsHelper(-co_g // 2 * Step_Ci)
    dims2 = DimsHelper(-co_g // 2 * Step_Co)
    dims3 = DimsHelper(-Step_M_Y)
    inner_loop_range = 2

    # requirements:
    assert C % granC == 0 and X % granX == 0
    assert co_g >= inner_loop_range  # --> least number of inner_loop_iteration is 8
    # --> least Nsubv is 8*64 == 512

    sizeof_dtype_O0 = 2
    sizeof_dtype_I0 = 2
    assert sfmx_dims.Msubv * sfmx_dims.Nsubv * sizeof_dtype_I0 < (128 * 1024)

    # kernel_parameter_setup:
    Co_g = co_g  # uint8_t
    X_g = x_g  # uint8_t
    outer_g = (N * Y * X * C) // (granN * granY * granX * granC)  # uint16_t
    step_Ci = Step_Ci  # uint16_t
    step_Co = Step_Co  # uint16_t
    step_Yi = 2 * X * C  # uint16_t
    step_Yo = 2 * X * C  # uint16_t

    # log("type(step_Yi):", type(step_Yi))
    # log("type(step_Yo):", type(step_Yo))

    # assert the value range for step_Yi and step_Yo
    assert 0 <= step_Yi and step_Yi <= 65535
    assert 0 <= step_Yo and step_Yo <= 65535

    dimsI_il = dims.from_steps((2), (32 * int(sizeof_dtype_O0), step_Ci))
    dimsI_ol = dims1.from_steps((X_g), (64 * int(sizeof_dtype_I0) * granX, step_Yi))
    dimsO_il = dims.from_steps((2), (32 * int(sizeof_dtype_O0), step_Ci))
    dimsO_ol = dims2.from_steps((X_g), (64 * int(sizeof_dtype_O0) * granX, step_Yo))
    dimsM_il = dims.from_steps((2), (4, Step_M_C))
    dimsM_ol = dims3.from_steps((X_g), (2 * granXC // 8, Step_M_Y))

    log("--------------------------------")
    log(f"dims.Msubv:{sfmx_dims.Msubv}")
    log(f"dims.Nsubv:{sfmx_dims.Nsubv}")
    log("--------------------------------")
    log(f"Co_g: {Co_g}")
    log(f"X_g: {x_g}")
    log(f"outer_g: {outer_g}")
    log(f"step_Ci: {step_Ci}")
    log(f"step_Co: {step_Co}")
    log(f"step_Yi: {step_Yi}")
    log(f"step_Yo: {step_Yo}")

    log("dimsI_il:", dimsI_il)
    log("dimsI_ol:", dimsI_ol)

    log("dimsO_il:", dimsO_il)
    log("dimsO_ol:", dimsO_ol)

    log("dimsM_il:", dimsM_il)
    log("dimsM_ol:", dimsM_ol)

    assert step_Ci == step_Co
    assert_equal_dims(dimsI_il, dimsO_il)
    assert_equal_dims(dimsI_ol, dimsO_ol)

    # first v32 elements in W64, then 2nd v32 elements in W64, alternating between the 2, hence the num0 shall be 1 (as 2 minus 1)
    assert dimsI_il["num0"] == 1
    assert dimsI_il["inc0"] == granC * sizeof_dtype_I0  # jumping over 32 elemnts
    assert (
        dimsI_il["inc1"] == (sfmx_dims.Msubv * 64 - granC) * sizeof_dtype_I0
    )  # 64 from W64 requirement

    assert dimsI_ol["num0"] == sfmx_dims.Msubv // granX - 1
    assert dimsI_ol["inc0"] == (
        (granX * 64 * sizeof_dtype_I0)
        - (dimsI_il["inc0"] + dimsI_il["inc1"]) * (sfmx_dims.Nsubv // 64)
    )
    assert dimsI_ol["inc1"] == (0 - (sfmx_dims.Msubv - granX) * (64 * sizeof_dtype_I0))

    assert_equal_dims(dimsI_il, dimsM_il, scale=16)
    assert_equal_dims(dimsI_ol, dimsM_ol, scale=16)

    packed_params = struct.pack(
        "<2B5H1I2i1I2i1I2i1I2i1I2i1I2i",
        Co_g,  # B
        X_g,  # B
        outer_g,  # H
        step_Ci,  # H
        step_Co,  # H
        step_Yi,  # H    (WARNING : Kernel not consume this, so set to 0)
        step_Yo,  # H    (WARNING : Kernel not consume this, so set to 0)
        # 0,  # B (reserved) - There was a byte alignment issue in the original code, so we added a reserved byte
        dimsI_il["num0"],
        dimsI_il["inc0"],
        dimsI_il["inc1"],  # 1I2i
        dimsI_ol["num0"],
        dimsI_ol["inc0"],
        dimsI_ol["inc1"],  # 1I2i
        dimsO_il["num0"],
        dimsO_il["inc0"],
        dimsO_il["inc1"],  # 1I2i
        dimsO_ol["num0"],
        dimsO_ol["inc0"],
        dimsO_ol["inc1"],  # 1I2i
        dimsM_il["num0"],
        dimsM_il["inc0"],
        dimsM_il["inc1"],  # 1I2i
        dimsM_ol["num0"],
        dimsM_ol["inc0"],
        dimsM_ol["inc1"],  # 1I2i
    )
    return packed_params


def softmax_layer_params(
    CoreInputAddr: int,
    CoreMaskAddr: int,
    CoreOutputAddr: int,
    QdqParamAddr: int,
    dqBufferAddr: int,
    qBufferAddr: int,
    TrueNumCols: int,
    Msubv: int,
    Nsubv: int,
    num_elem_subv: int,
    msk_num_bytes: int,
    sign_A: int,
    sign_O: int,
    sfmx_dims: SoftmaxDims,
) -> bytes:
    """Generate the layer parameters for the SOFTMAX operation."""
    CoreInputAddr += 0xE0000
    CoreMaskAddr += 0xE0000
    CoreOutputAddr += 0xE0000
    QdqParamAddr += 0xE0000
    dqBufferAddr += 0xE0000
    qBufferAddr += 0xE0000

    bytes = (
        CoreInputAddr.to_bytes(length=4, byteorder="little", signed=False)
        + CoreMaskAddr.to_bytes(length=4, byteorder="little", signed=False)
        + CoreOutputAddr.to_bytes(length=4, byteorder="little", signed=False)
        + QdqParamAddr.to_bytes(length=4, byteorder="little", signed=False)
        + dqBufferAddr.to_bytes(length=4, byteorder="little", signed=False)
        + qBufferAddr.to_bytes(length=4, byteorder="little", signed=False)
        + TrueNumCols.to_bytes(length=4, byteorder="little", signed=False)
        + Msubv.to_bytes(length=4, byteorder="little", signed=False)
        + Nsubv.to_bytes(length=4, byteorder="little", signed=False)
        + num_elem_subv.to_bytes(length=4, byteorder="little", signed=False)
        + msk_num_bytes.to_bytes(length=4, byteorder="little", signed=False)
        + sign_A.to_bytes(length=4, byteorder="little", signed=False)
        + sign_O.to_bytes(length=4, byteorder="little", signed=False)
        + setup_softmax_kernel_params(sfmx_dims)
    )

    log("Number of bytes softmax layer params:", len(bytes))
    return bytes


def copy_layer_params(
    CoreInputAddr: int,
    CoreOutputAddr: int,
    num_elem_subv: int,
    ifmbytes: int,
    ofmbytes: int
) -> bytes:
    """Generate the layer parameters for the SOFTMAX operation."""
    CoreInputAddr += 0xE0000
    CoreOutputAddr += 0xE0000
    
    bytes = (
        CoreInputAddr.to_bytes(length=4, byteorder="little", signed=False)
        + CoreOutputAddr.to_bytes(length=4, byteorder="little", signed=False)
        + num_elem_subv.to_bytes(length=4, byteorder="little", signed=False)
        + ifmbytes.to_bytes(length=4, byteorder="little", signed=False)
        + ofmbytes.to_bytes(length=4, byteorder="little", signed=False)
    )

    return bytes


# 2 B   5 H   1I2i * 6
# 2*1 + 5*2 + (3*4)*6 == 2 + 10 + 72 == 84


if __name__ == "__main__":
    # ( batchSize, Msubv, Nsubv ) == (Y,X,C)
    test_subvolume_shapes = [(1, 4, 256), (1, 8, 1024), (1, 4, 2048), (3, 4, 320)]

    for test_id, subvolume_shape in enumerate(test_subvolume_shapes):

        Y, Msubv, Nsubv = subvolume_shape
        sfmx_dims = SoftmaxDims(Y, Msubv, Nsubv)  # --> Msubv, Nsubv
        packed_params = setup_softmax_kernel_params(sfmx_dims)

        with open("krn_param_toArray" + str(test_id) + ".bin", "wb") as file:
            file.write(packed_params)

        unpacked_values = struct.unpack("<2B5H1I2i1I2i1I2i1I2i1I2i1I2i", packed_params)
        log(unpacked_values)  # Tuple Typed

        expected_size = struct.calcsize("<2B5H1I2i1I2i1I2i1I2i1I2i1I2i")
        assert expected_size == 84

        log("layer param size:", len(softmax_layer_params(0, 0, 0, sfmx_dims)))
