import struct 

from kernel.common.kernel_params_helper import (
    DimsHelper,
)

from utils.utils_common import (
    log,
    iceil
)


def setup_dwc_qdq_a16w8_s1_params(
    Yos: int, Xos: int, Cos: int,
    Yis: int, Xis: int,
    Ky: int, Kx: int,
    Sy: int, Sx: int,
    Y_loop: int, X_loop: int, Co_loop: int,
    ifm_ch_num: int, 
    wgt_size: int,
    coeff_size: int,
    sign_A: int, sign_W: int, sign_O: int,
) -> bytes:
    """
    Setup parameters for DWC QDQ A16W8 kernel.
    Supports both stride-1 and stride-2 configurations.
    """
    log(f"Ky: {Ky}")
    log(f"Y_loop: {Y_loop}")
    log(f"X_loop: {X_loop}")
    log(f"Co_loop: {Co_loop}")
    # NOTE: We make a design choice of CHWW2C64
    # For IFM subvol format for this format H_outer = 0 according to kernel spec
    # https://gitenterprise.xilinx.com/IPSP/AIE_SOL/blob/main/AIE4/kernel_lib/kernels/activated_dwc_qdq_int16x8_s1/activated_dwc_qdq_int16x8_s1.json#L78
    H_outer = 0
    has_actv_sum = 1  # Unused - by default actv sum always enabled in S2 kernel
    ifm_bytes = 2
    ofm_bytes = 2
    log(f"ofm_bytes: {ofm_bytes}")
    H_gran = 4
    W_gran = 2
    Co_gran = 32
    Kh_gran = 3
    log(f"Kh_gran: {Kh_gran}")
    Kw_gran = 4
    log(f"Kw_gran: {Kw_gran}")
    # Common calculated values
    Outer_loop = Yos // H_gran * Xos // W_gran * Cos // Co_gran 
    has_actv_sum_var = has_actv_sum
    log(f"has_actv_sum_var: {has_actv_sum_var}")
    folded_Kw = Kx
    log(f"folded_Kw: {folded_Kw}")
    H_g = Yos // H_gran
    W_g = Xos // W_gran
    log(f"W_g: {W_g}")
    Co_g = Cos // Co_gran
    Ci_ilb = min(64, Cos)
    Wo_ilb = min(64, Xos)
    log(f"Stride: {Sy}")
    log(f"Ci_ilb: {Ci_ilb}")
    log(f"Wo_ilb: {Wo_ilb}")
    log(f"ifm_bytes: {ifm_bytes}")
    log(f"Sy: {Sy}")
    log(f"Xis: {Xis}")
    # Common step calculations
    step_Hi = ifm_bytes * Sy * Xis * (Cos if H_outer else Ci_ilb)
    step_Wi = ifm_bytes * Ci_ilb * Sx
    step_Ci = ifm_bytes * Xis * (1 if H_outer else Yis) * Ci_ilb
    step_Kh = ifm_bytes * Xis * (Cos if H_outer else Ci_ilb) - 6 * 64
    reset = ifm_bytes * Xis * (Cos if H_outer else Ci_ilb)
    step_Kw = ifm_bytes * Ci_ilb * Sx
    log(f"step_Kw: {step_Kw}")
    step_Ho = ifm_bytes * Xos * (Cos if H_outer else Ci_ilb)
    step_Co = ifm_bytes * Xos * (1 if H_outer else Yos) * Ci_ilb
    qdq_terms = 2
    log(f"qdq_terms: {qdq_terms}")
    # Stride-dependent calculations
    step_Wo = ifm_bytes * Ci_ilb * Sx
    dims_reset = -5 * reset
    dims3_reset = -3 * step_Ho
    dims_A2_inc1_factor = 2
    dims_O2_inc1_factor = 2
    # Initialize dimension helpers
    dims = DimsHelper(dims_reset)
    dims2 = DimsHelper(0)
    dims3 = DimsHelper(dims3_reset)
    dims4 = DimsHelper(0)
    dims5 = DimsHelper(0)
    dims6 = DimsHelper(0)
    # Generate kernel parameters
    outer_loop = Outer_loop
    incS_0 = step_Ho
    incA_0 = step_Kh 
    log(f"Co_g: {Co_g}")
    log(f"step_Ci: {step_Ci}")
    log(f"step_Hi: {step_Hi}")
    log(f"sign_W: {sign_W}")
    # Dimension calculations
    dims_A3 = dims.from_steps((2, Co_g // 2), (64, step_Ci, step_Hi * 4))
    dims_A2 = dims2.from_steps((H_g * Co_g), (0, (-H_g * 4 * step_Hi) + dims_A2_inc1_factor * step_Wi))
    dims_O3 = dims3.from_steps((2, Co_g // 2), (64, step_Co, 4 * step_Ho))
    dims_O2 = dims4.from_steps((H_g * Co_g), (0, (-H_g * 4 * step_Ho) + dims_O2_inc1_factor * step_Wo))
    dims_W2 = dims5.from_steps((Co_g), (3 * 64 * ifm_bytes, 0))
    dims_C2 = dims6.from_steps((Co_g), (64 * 4, 0))
    sign_byte = (sign_A & 0x1) | ((sign_W & 0x1) << 1) | ((sign_O & 0x1) << 2)

    # Debug output
    log(f"outer_loop: {outer_loop}")
    log(f"incS_0: {incS_0}")
    log(f"incA_0: {incA_0}")
    log(f"dims_A3['num0']: {dims_A3['num0']}")
    log(f"dims_A3['num1']: {dims_A3['num1']}")
    log(f"dims_A3['inc0']: {dims_A3['inc0']}")
    log(f"dims_A3['inc1']: {dims_A3['inc1']}")
    log(f"dims_A3['inc2']: {dims_A3['inc2']}")
    log(f"dims_A2['num0']: {dims_A2['num0']}")
    log(f"dims_A2['inc0']: {dims_A2['inc0']}")
    log(f"dims_A2['inc1']: {dims_A2['inc1']}")
    log(f"dims_O3['num0']: {dims_O3['num0']}")
    log(f"dims_O3['num1']: {dims_O3['num1']}")
    log(f"dims_O3['inc0']: {dims_O3['inc0']}")
    log(f"dims_O3['inc1']: {dims_O3['inc1']}")
    log(f"dims_O3['inc2']: {dims_O3['inc2']}")
    log(f"dims_O2['num0']: {dims_O2['num0']}")
    log(f"dims_O2['inc0']: {dims_O2['inc0']}")
    log(f"dims_O2['inc1']: {dims_O2['inc1']}")
    log(f"dims_W2['num0']: {dims_W2['num0']}")
    log(f"dims_W2['inc0']: {dims_W2['inc0']}")
    log(f"dims_W2['inc1']: {dims_W2['inc1']}")
    log(f"dims_C2['num0']: {dims_C2['num0']}")
    log(f"dims_C2['inc0']: {dims_C2['inc0']}")
    log(f"dims_C2['inc1']: {dims_C2['inc1']}")
    log(f"sign_byte: {sign_byte}")
    # NOTE: Make sure both S1 and S2 kernel params packing is identical
    # As both share the kernel params struct on the wrapper/kernel side
    packed_params = struct.pack(
        '<4I 4H 2I 3i 1I 2i 2I 3i 1I 2i 1I 2i 1I 2i 1B',
        Sy,                      # I - stride
        ifm_ch_num,              # I
        wgt_size,                # I
        coeff_size,              # I

        outer_loop,              # H
        0,                       # H (reserved)
        incS_0,                  # H
        incA_0,                  # H

        int(dims_A3['num0']),    # I
        int(dims_A3['num1']),    # I

        int(dims_A3['inc0']),    # i
        int(dims_A3['inc1']),    # i
        int(dims_A3['inc2']),    # i

        int(dims_A2['num0']),    # I

        int(dims_A2['inc0']),    # i
        int(dims_A2['inc1']),    # i

        int(dims_O3['num0']),    # I
        int(dims_O3['num1']),    # I

        int(dims_O3['inc0']),    # i
        int(dims_O3['inc1']),    # i
        int(dims_O3['inc2']),    # i

        int(dims_O2['num0']),    # I

        int(dims_O2['inc0']),    # i
        int(dims_O2['inc1']),    # i

        int(dims_W2['num0']),    # I

        int(dims_W2['inc0']),    # i
        int(dims_W2['inc1']),    # i

        int(dims_C2['num0']),    # I

        int(dims_C2['inc0']),    # i
        int(dims_C2['inc1']),    # i

        sign_byte,               # B
    )
    return packed_params


def setup_dwc_qdq_a16w8_s2_params(
    Yos: int, Xos: int, Cos: int,
    Yis: int, Xis: int,
    Ky: int, Kx: int,
    Sy: int, Sx: int,
    Y_loop: int, X_loop: int, Co_loop: int,
    ifm_ch_num: int, 
    wgt_size: int,
    coeff_size: int,
    sign_A: int, sign_W: int, sign_O: int,
) -> bytes:
    """
    Setup parameters for DWC QDQ A16W8 kernel.
    Supports both stride-1 and stride-2 configurations.
    """
    log(f"Ky: {Ky}")
    log(f"Y_loop: {Y_loop}")
    log(f"X_loop: {X_loop}")
    log(f"Co_loop: {Co_loop}")
    # NOTE: We make a design choice of CHWW2C64
    # For IFM subvol format for this format H_outer = 0 according to kernel spec
    # https://gitenterprise.xilinx.com/IPSP/AIE_SOL/blob/main/AIE4/kernel_lib/kernels/activated_dwc_qdq_int16x8_s1/activated_dwc_qdq_int16x8_s1.json#L78
    H_outer = 0
    has_actv_sum = 1  # Unused - by default actv sum always enabled in S2 kernel
    ifm_bytes = 2
    ofm_bytes = 2
    log(f"ofm_bytes: {ofm_bytes}")
    H_gran = 4
    W_gran = 2
    Co_gran = 32
    Kh_gran = 3
    log(f"Kh_gran: {Kh_gran}")
    Kw_gran = 4
    log(f"Kw_gran: {Kw_gran}")
    # Common calculated values
    has_actv_sum_var = has_actv_sum
    log(f"has_actv_sum_var: {has_actv_sum_var}")
    folded_Kw = Kx
    log(f"folded_Kw: {folded_Kw}")
    H_g = Yos // H_gran
    log(f"H_g: {H_g}")
    W_g = Xos // W_gran
    log(f"W_g: {W_g}")
    Co_g = Cos // Co_gran
    log(f"Co_g: {Co_g}")
    Outer_loop = H_g * Xos * Co_g 
    Ci_ilb = min(64, Cos)
    Wo_ilb = min(64, Xos)
    log(f"Stride: {Sy}")
    log(f"Ci_ilb: {Ci_ilb}")
    log(f"Wo_ilb: {Wo_ilb}")
    log(f"ifm_bytes: {ifm_bytes}")
    log(f"Sy: {Sy}")
    log(f"Xis: {Xis}")
    # Common step calculations
    step_Hi = ifm_bytes * Sy * Xis * (Cos if H_outer else Ci_ilb)
    log(f"step_Hi: {step_Hi}")
    step_Wi = ifm_bytes * Ci_ilb * Sx
    log(f"step_Wi: {step_Wi}")
    step_Ci = ifm_bytes * Xis * (1 if H_outer else Yis) * Ci_ilb
    log(f"step_Ci: {step_Ci}")
    step_Kh = ifm_bytes * Xis * (Cos if H_outer else Ci_ilb) - 6 * 64
    reset = ifm_bytes * Xis * (Cos if H_outer else Ci_ilb)
    step_Kw = ifm_bytes * Ci_ilb * Sx
    log(f"step_Kw: {step_Kw}")
    step_Ho = ifm_bytes * Xos * (Cos if H_outer else Ci_ilb)
    step_Co = ifm_bytes * Xos * (1 if H_outer else Yos) * Ci_ilb
    qdq_terms = 2
    log(f"qdq_terms: {qdq_terms}")
    # Stride-dependent calculations
    step_Wo = ifm_bytes * Ci_ilb
    dims_reset = -8 * reset
    dims3_reset = -3 * step_Ho
    dims_A2_inc1_factor = 1
    dims_O2_inc1_factor = 1
    # Initialize dimension helpers
    dims = DimsHelper(dims_reset)
    dims2 = DimsHelper(0)
    dims3 = DimsHelper(dims3_reset)
    dims4 = DimsHelper(0)
    dims5 = DimsHelper(0)
    dims6 = DimsHelper(0)
    # Generate kernel parameters
    outer_loop = Outer_loop
    incS_0 = step_Ho
    incA_0 = step_Kh 
    log(f"Co_g: {Co_g}")
    log(f"sign_W: {sign_W}")
    # Dimension calculations
    dims_A3 = dims.from_steps((2, Co_g // 2), (64, step_Ci, step_Hi * 4))
    dims_A2 = dims2.from_steps((H_g * Co_g), (0, (-H_g * 4 * step_Hi) + dims_A2_inc1_factor * step_Wi))
    dims_O3 = dims3.from_steps((2, Co_g // 2), (64, step_Co, 4 * step_Ho))
    dims_O2 = dims4.from_steps((H_g * Co_g), (0, (-H_g * 4 * step_Ho) + dims_O2_inc1_factor * step_Wo))
    dims_W2 = dims5.from_steps((Co_g), (3 * 64 * ifm_bytes, 0))
    dims_C2 = dims6.from_steps((Co_g), (64 * 4, 0))
    sign_byte = (sign_A & 0x1) | ((sign_W & 0x1) << 1) | ((sign_O & 0x1) << 2)

    # Debug output
    log(f"outer_loop: {outer_loop}")
    log(f"incS_0: {incS_0}")
    log(f"incA_0: {incA_0}")
    log(f"dims_A3['num0']: {dims_A3['num0']}")
    log(f"dims_A3['num1']: {dims_A3['num1']}")
    log(f"dims_A3['inc0']: {dims_A3['inc0']}")
    log(f"dims_A3['inc1']: {dims_A3['inc1']}")
    log(f"dims_A3['inc2']: {dims_A3['inc2']}")
    log(f"dims_A2['num0']: {dims_A2['num0']}")
    log(f"dims_A2['inc0']: {dims_A2['inc0']}")
    log(f"dims_A2['inc1']: {dims_A2['inc1']}")
    log(f"dims_O3['num0']: {dims_O3['num0']}")
    log(f"dims_O3['num1']: {dims_O3['num1']}")
    log(f"dims_O3['inc0']: {dims_O3['inc0']}")
    log(f"dims_O3['inc1']: {dims_O3['inc1']}")
    log(f"dims_O3['inc2']: {dims_O3['inc2']}")
    log(f"dims_O2['num0']: {dims_O2['num0']}")
    log(f"dims_O2['inc0']: {dims_O2['inc0']}")
    log(f"dims_O2['inc1']: {dims_O2['inc1']}")
    log(f"dims_W2['num0']: {dims_W2['num0']}")
    log(f"dims_W2['inc0']: {dims_W2['inc0']}")
    log(f"dims_W2['inc1']: {dims_W2['inc1']}")
    log(f"dims_C2['num0']: {dims_C2['num0']}")
    log(f"dims_C2['inc0']: {dims_C2['inc0']}")
    log(f"dims_C2['inc1']: {dims_C2['inc1']}")
    log(f"sign_byte: {sign_byte}")
    # NOTE: Make sure both S1 and S2 kernel params packing is identical
    # As both share the kernel params struct on the wrapper/kernel side
    packed_params = struct.pack(
        '<4I 4H 2I 3i 1I 2i 2I 3i 1I 2i 1I 2i 1I 2i 1B',
        Sy,                      # I - stride
        ifm_ch_num,              # I
        wgt_size,                # I
        coeff_size,              # I

        outer_loop,              # H
        0,                       # H (reserved)
        incS_0,                  # H
        incA_0,                  # H

        int(dims_A3['num0']),    # I
        int(dims_A3['num1']),    # I

        int(dims_A3['inc0']),    # i
        int(dims_A3['inc1']),    # i
        int(dims_A3['inc2']),    # i

        int(dims_A2['num0']),    # I

        int(dims_A2['inc0']),    # i
        int(dims_A2['inc1']),    # i

        int(dims_O3['num0']),    # I
        int(dims_O3['num1']),    # I

        int(dims_O3['inc0']),    # i
        int(dims_O3['inc1']),    # i
        int(dims_O3['inc2']),    # i

        int(dims_O2['num0']),    # I

        int(dims_O2['inc0']),    # i
        int(dims_O2['inc1']),    # i

        int(dims_W2['num0']),    # I

        int(dims_W2['inc0']),    # i
        int(dims_W2['inc1']),    # i

        int(dims_C2['num0']),    # I

        int(dims_C2['inc0']),    # i
        int(dims_C2['inc1']),    # i

        sign_byte,               # B
    )
    return packed_params


def generate_dwc_qdq_a16w8_params(
    Yos: int, Xos: int, Cos: int,
    Yis: int, Xis: int,
    Ky: int, Kx: int,
    Sy: int, Sx: int,
    Y_loop: int, X_loop: int, Co_loop: int,
    ifm_ch_num: int, 
    sign_A: int, sign_W: int, sign_O: int,
) -> bytes:
    """Generate DWC QDQ A16W8 parameters for supported stride configurations."""
    if Sy != Sx or Sy not in [1, 2]:
        raise ValueError(f"Unsupported stride configuration: Sy={Sy}, Sx={Sx}. Only stride 1x1 and 2x2 are supported.")
    Ky_gran = 3
    Kx_gran = 4
    mem_align = 128
    wgt_size = iceil(max(Ky, Ky_gran) * max(Kx, Kx_gran) * Cos, mem_align)  # Weights are 1 byte for A16W8 DWC kernel
    coeff_size = iceil(2 * Cos * 4, mem_align)  # Co+C2 - each is a vector Cos length with datatype float
    if Sy == Sx == 1:
        return setup_dwc_qdq_a16w8_s1_params(
            Yos, Xos, Cos,
            Yis, Xis,
            Ky, Kx,
            Sy, Sx,
            Y_loop, X_loop, Co_loop,
            ifm_ch_num,
            wgt_size,
            coeff_size,
            sign_A, sign_W, sign_O,
        )
    else:
        return setup_dwc_qdq_a16w8_s2_params(
            Yos, Xos, Cos,
            Yis, Xis,
            Ky, Kx,
            Sy, Sx,
            Y_loop, X_loop, Co_loop,
            ifm_ch_num,
            wgt_size,
            coeff_size,
            sign_A, sign_W, sign_O,
        )
