import struct
from tiler.binary_tiler import BinaryL2Dims
from kernel.common.kernel_params_helper import (
    DimsHelper, sizeof,
)

from utils.utils_common import log, iceil

def generate_binary_params(binary_dims: BinaryL2Dims, offset_bytes: int, minimum_core_subv: int, core_qbuf_offset: int, core_dqbuf_offset: int ) -> bytes:
    '''Pack and transfer runtime params to matadd kernel wrapper'''

    R = 1
    log(tuple(binary_dims.__dict__.items()))
    log("minimum_core_subv:", minimum_core_subv)
    C = iceil(binary_dims.max_subvolume, minimum_core_subv)
    gran = 64
    C_g = C // gran
    el_size = sizeof("int16") if binary_dims.shape.ifm_bytes == 2 else sizeof("int8")
    sign_A, sign_O = 0, 0
    dims_x, dims_y = DimsHelper(), DimsHelper()
    dims_2d_t_x = dims_x.from_steps(
        C_g,
        (
            (C > gran) * gran * el_size,
            (R > 1) * C * el_size,
        ),
    )
    dims_2d_t_y = dims_y.from_steps(
        C_g,
        (
            (C > gran) * gran * el_size,
            (R > 1) * C * el_size,
        ),
    )
    outer_loop = R * C_g
    qdq_kernel_gran = 32
    q_loop_range = 10
    dq_loop_range = 12
    minimum_subv_qdq = qdq_kernel_gran * max(q_loop_range, dq_loop_range)
    qdq_inner_g = int(iceil(C, minimum_subv_qdq) / qdq_kernel_gran)
    log("C:", int(C))
    log("el_size:", int(el_size))
    log("offset_bytes:", int(offset_bytes))
    log("outer_loop:", outer_loop)
    log("dims_2d_t_x['num0']:", int(dims_2d_t_x["num0"]))
    log("dims_2d_t_x['inc0']:", int(dims_2d_t_x["inc0"]))
    log("dims_2d_t_x['inc1']:", int(dims_2d_t_x["inc1"]))
    log("dims_2d_t_y['num0']:", int(dims_2d_t_y["num0"]))
    log("dims_2d_t_y['inc0']:", int(dims_2d_t_y["inc0"]))
    log("dims_2d_t_y['inc1']:", int(dims_2d_t_y["inc1"]))

    log("qdq_inner_g", int(qdq_inner_g))
    log("core_qbuf_offset", core_qbuf_offset)
    log("core_dqbuf_offset", core_dqbuf_offset)

    if binary_dims.shape.ifm_bytes == 2:
        fmt_all = '<10IiiIii'

        kernel_params = struct.pack(
            fmt_all,
            int(offset_bytes),                 # I
            core_qbuf_offset,                    # I
            core_dqbuf_offset,                   # I
            qdq_inner_g,                       # I
            int(el_size == 2),
            int(el_size == 2),
            sign_A,
            sign_O,
            outer_loop,                        # I
            int(dims_2d_t_x["num0"]),  # I
            int(dims_2d_t_x["inc0"]),  # i
            int(dims_2d_t_x["inc1"]),  # i
            int(dims_2d_t_y["num0"]),  # I
            int(dims_2d_t_y["inc0"]),  # i
            int(dims_2d_t_y["inc1"]),  # i
        )

        return kernel_params
    # otherwise, assume add2d_int8x8 kernel
    def make_ctrl(sign_A: int, sign_W: int, sign_O: int, sign_srs: int) -> int:
        """
        sign_A, sign_W, sign_O are 0/1.
        Returns an integer with bits packed into positions:
        bit0 = sign_A, bit1 = sign_W, bit2 = sign_O
        """
        return ((sign_A & 1) | ((sign_W & 1) << 1) | ((sign_O & 1) << 2) | ((sign_srs & 1) << 3))
    shift_in = 0
    shift_in1 = 0
    shift_res = 0
    # Set all 3 sign bits to 1
    ctrl = make_ctrl(1, 1, 1, 1)  # == 7
    max_value = 127
    # Print all values before packing
    log("shift_in:", int(shift_in))
    log("shift_in1:", int(shift_in1))
    log("shift_res:", int(shift_res))
    log("ctrl (packed bits):", bin(ctrl), f"= {ctrl}")
    log("max_value:", int(max_value))
    dims_t = DimsHelper()
    dims_2d_t = dims_t.from_steps(
        C_g,
        (
            (C > gran) * gran * el_size,
            (R > 1) * C * el_size,
        ),
    )

    fmt_all = '<H'
    fmt_all += 'HIii3bBb'

    kernel_params = struct.pack(
        fmt_all,
        int(offset_bytes),                 # num_elements
        outer_loop,                        # H
        int(dims_2d_t['num0']),            # I
        int(dims_2d_t['inc0']),            # i
        int(dims_2d_t['inc1']),            # i
        int(shift_in),
        int(shift_in1),
        int(shift_res),
        int(ctrl) & 0xFF,                  # B
        int(max_value)
    )

    return kernel_params


