import struct
import math
from collections import namedtuple
from kernel.common.kernel_params_helper import (
    DimsHelper,
    sizeof,
)

from utils.utils_common import log, iceil, BaseDims, ceildiv


Subvolume = namedtuple("subvolume", ["R", "C"])


def generate_broadcast_params(
    ifm_bytes: int,
    ofm_bytes: int,
    ifms: list[tuple[int, int, int]],
    ofms: tuple[int, int, int],
    gran: int,
    core_qbuf_offset: int,
    core_dqbuf_offset: int,
    has_scalar_broadcast: bool,
    is_sub: int,
    sign_A: int,
    sign_W: int,
    sign_O: int,
) -> bytes:
    """Pack and transfer runtime params to matadd kernel wrapper"""
    ifm_a_L1_elements = math.prod(ifms[0])
    ifm_b_L1_elements = math.prod(ifms[1])
    # assert all(not (ifm[0] > 1 and ifm[1] > 1 and ifm[2] > 1) for ifm in ifms)

    subvolume = Subvolume(R=ifms[0][1], C=ifms[0][2])
    subvolume_in1 = Subvolume(R=ifms[1][1], C=ifms[1][2])
    log("Subvolume 0:", subvolume)
    log("Subvolume 1:", subvolume_in1)
    R, C = ofms[1], ofms[2]

    E_a = math.prod(ifms[0])  # num elements in A, used for dq
    E_b = math.prod(ifms[1])  # num elements in B, used for dq
    E_c = math.prod(ofms)  # num elements in output, used for q
    C_g = C // gran  # channel granularity in output, used for bdcast
    dims_x, dims_y = DimsHelper(), DimsHelper()
    outer_loop = R * C_g
    el_size = 2  # bdcast is always bfloat16 input 

    dims_2d_t_x = dims_x.from_steps(
        C_g,
        (
            (subvolume.C > gran) * gran * el_size,
            (subvolume.R > 1) * subvolume.C * el_size,
        ),
    )
    dims_2d_t_y = dims_y.from_steps(
        C_g,
        (
            (subvolume_in1.C > gran) * gran * el_size,
            (subvolume_in1.R > 1) * subvolume_in1.C * el_size,
        ),
    )
    dq_gran_e = 32  # from dq.json
    q_gran_e = 32  # from q.json
    dq_a_inner_g = ceildiv(E_a, dq_gran_e)
    dq_b_inner_g = ceildiv(E_b, dq_gran_e)
    q_inner_g = ceildiv(E_c, q_gran_e)
    log("el_size:", int(el_size))
    log("ifm_a_L1_elements:", int(ifm_a_L1_elements))
    log("ifm_b_L1_elements:", int(ifm_b_L1_elements))
    log("outer_loop:", outer_loop)
    log("dq_a_inner_g:", int(dq_a_inner_g))
    log("dq_b_inner_g:", int(dq_b_inner_g))
    log("q_inner_g", int(q_inner_g))
    log("dims_2d_t_x['num0']:", int(dims_2d_t_x["num0"]))
    log("dims_2d_t_x['inc0']:", int(dims_2d_t_x["inc0"]))
    log("dims_2d_t_x['inc1']:", int(dims_2d_t_x["inc1"]))
    log("dims_2d_t_y['num0']:", int(dims_2d_t_y["num0"]))
    log("dims_2d_t_y['inc0']:", int(dims_2d_t_y["inc0"]))
    log("dims_2d_t_y['inc1']:", int(dims_2d_t_y["inc1"]))
    log("core_qbuf_offset", core_qbuf_offset)
    log("core_dqbuf_offset", core_dqbuf_offset)
    log("has_scalar_broadcast", int(has_scalar_broadcast))
    log("is_sub", int(is_sub))

    fmt_all = "<15IiiIiiI"

    kernel_params = struct.pack(
        fmt_all,
        int(ifm_a_L1_elements),  # I
        int(ifm_b_L1_elements),  # I
        core_qbuf_offset,  # I
        core_dqbuf_offset,  # I
        dq_a_inner_g,  # I
        dq_b_inner_g,  # I
        q_inner_g,  # I
        int(has_scalar_broadcast),  # I
        int(sign_A),
        int(sign_W),
        int(sign_O),
        int(ifm_bytes == 2),
        int(ofm_bytes == 2),
        outer_loop,  # I
        int(dims_2d_t_x["num0"]),  # I
        int(dims_2d_t_x["inc0"]),  # i
        int(dims_2d_t_x["inc1"]),  # i
        int(dims_2d_t_y["num0"]),  # I
        int(dims_2d_t_y["inc0"]),  # i
        int(dims_2d_t_y["inc1"]),  # i
        int(is_sub),
    )

    return kernel_params

    # otherwise, assume add2d_int8x8 kernel
    def make_ctrl(sign_A: int, sign_W: int, sign_O: int, sign_srs: int) -> int:
        """
        sign_A, sign_W, sign_O are 0/1.
        Returns an integer with bits packed into positions:
        bit0 = sign_A, bit1 = sign_W, bit2 = sign_O
        """
        return (
            (sign_A & 1)
            | ((sign_W & 1) << 1)
            | ((sign_O & 1) << 2)
            | ((sign_srs & 1) << 3)
        )

    shift_in = 0
    shift_in1 = 0
    shift_res = 0
    # Set all 3 sign bits to 1
    ctrl = make_ctrl(1, 1, 1, 1)  # == 7
    max_value = 127
    # Print all values before packing
    log("shift_in:", int(shift_in))
    log("shift_in1:", int(shift_in1))
    log("shift_res:", int(shift_res))
    log("ctrl (packed bits):", bin(ctrl), f"= {ctrl}")
    log("max_value:", int(max_value))

    fmt_all = "<IIIIiiIii3bBb"

    kernel_params = struct.pack(
        fmt_all,
        int(offset_bytes),  # num_elements
        int(has_scalar_broadcast),  # I
        int(outer_loop),  # I
        int(dims_2d_t_x["num0"]),  # I
        int(dims_2d_t_x["inc0"]),  # i
        int(dims_2d_t_x["inc1"]),  # i
        int(dims_2d_t_y["num0"]),  # I
        int(dims_2d_t_y["inc0"]),  # i
        int(dims_2d_t_y["inc1"]),  # i
        int(shift_in),
        int(shift_in1),
        int(shift_res),
        int(ctrl) & 0xFF,  # B
        int(max_value),
    )
    # print bytes of kernel_params for debugging
    log("Kernel params bytes:", kernel_params)

    return kernel_params
