from curses import raw
import math
import struct
import ctypes
import os
from enum import IntEnum

CURRDIR = os.path.dirname(os.path.abspath(__file__))

from kerneltest.helpers import \
    ceildiv, \
    iceil, \
    round_up_to_multiple, \
    is_power_of_two

from kernel.common.kernel_params_helper import (
    DimsHelper,
)


class ActFmt(IntEnum):
    YCXC = 0
    CYXC = 1

class ActMode(IntEnum):
    AC_SRS = 0
    AC_RELU = 1
    AC_RELU6 = 2
    AC_LRELU = 3
    AC_HSWISH = 4

def calc_in_pixel_dim(
       X: int,
       K: int, 
       P: int,
       S: int,
) -> int:
    return ((X -1)* S) + K - (2 * P)

def align_Xis(
    Xos: int,
    Cin: int,
    Kx: int,
    Sx: int,
) -> int:
    '''
    This function is used to align the Xis dimension in case of small Cin
    Where the Kx is folded with Cin on the weights
    '''
    return iceil( (Xos - 1) * Sx + Kx, max(1, 64 // Cin) )
    
class Ctrl(ctypes.Structure):
    _fields_= [
        ("zero_init", ctypes.c_uint32, 1),
        ("sign_N", ctypes.c_uint32, 1),
        ("sign_O", ctypes.c_uint32, 1),
        ("reserved3", ctypes.c_uint32, 3),
        ("skip_casc_in", ctypes.c_uint32, 1),
        ("skip_casc_out", ctypes.c_uint32, 1),
        ("sign_W", ctypes.c_uint32, 1),
        ("sign_A", ctypes.c_uint32, 1),
        ("reserved10", ctypes.c_uint32, 14),
        ("norm_ch_g", ctypes.c_uint32, 8)
    ]

class ConvDims:
    __slots__ = (
                'Param_size',
                'mem_align',
                'N',
                'Yi', 'Xi', 'Ci',
                'Yo', 'Xo', 'Co',
                'Yis', 'Xis', 'Cis',
                'Yos', 'Xos', 'Cos',
                'Ky', 'Kx',
                'Sy', 'Sx',
                'Py', 'Px',
                'aie_rows', 'aie_cols',
                'act_bits', 'wgt_bits', 'bias_bits', 'out_bits', 'param_bits',
                'Ci_gran',
                'Co_loop', 'Y_loop', 'X_loop', 'Ci_loop',
                'act_subv_bytes', 'wgt_subv_bytes', 'bias_subv_bytes', 'out_subv_bytes',
                'size_bytes', 'stride_efficiency',
                'Yis_pad', 'Xis_pad', 'Cis_pad',
                'Yos_pad', 'Xos_pad', 'Cos_pad',
                'act_fmt', "small_Cin_mode",
                 ) 
    def __init__(
        self,
        N: int,
        Yi: int,
        Xi: int,
        Ci: int,
        Yo: int,
        Xo: int,
        Co: int,
        Yis: int,
        Xis: int,
        Cis: int,
        Yos: int,
        Xos: int,
        Cos: int,
        Ky: int,
        Kx: int,
        Py: int,
        Px: int,
        Sy: int,
        Sx: int,
        aie_rows: int,
        aie_cols: int,
        act_bits: int,
        wgt_bits: int,
        bias_bits: int,
        out_bits: int,
        param_bits: int,
        Ci_gran: int,
        act_fmt: ActFmt,
    ):
        self.Param_size = 1024
        self.mem_align = 128
        self.size_bytes = 1
        bits_per_byte = 8
        # NOTE: qdq size is hardcoded to 128 bytes
        qdq_size_bytes = 128
        assert(Ci % Ci_gran == 0)
        assert(Ci % Cis == 0)
        assert(Co % Cos == 0)
        assert(Xos % 8 == 0)
        assert(Xos * Yos == 64)
        assert(Cos  == 64)
        '''
        Supported subvol shapes:
        Yos=1, Xos=64 or
        Yos=2, Xos=32 or
        Yos=4, Xos=16 or
        Yos=8, Xos=8
        '''
        self.N = N
        self.Yi = Yi 
        self.Xi = Xi 
        self.Ci = Ci 
        self.Yo = Yo 
        self.Xo = Xo 
        self.Co = Co 
        self.Yis =Yis 
        self.Xis =Xis 
        self.Cis =Cis 
        self.Yos =Yos 
        self.Xos =Xos 
        self.Cos =Cos 
        self.Ky = Ky 
        self.Kx = Kx 
        self.Sy = Sy 
        self.Sx = Sx 
        self.Py = Py 
        self.Px = Px 
        self.aie_rows = aie_rows 
        self.aie_cols = aie_cols 
        self.act_bits = act_bits 
        self.wgt_bits = wgt_bits 
        self.bias_bits = bias_bits 
        self.out_bits = out_bits 
        self.param_bits = param_bits 
        self.Ci_gran = Ci_gran
        self.small_Cin_mode = True if Cis < 64 else False
        self.bias_subv_bytes = (Cos * bias_bits) // bits_per_byte
        if self.small_Cin_mode:
            self.wgt_subv_bytes = ((Cos * Ky * 64 * wgt_bits) // bits_per_byte) + \
                                self.bias_subv_bytes + \
                                qdq_size_bytes
        else:
            self.wgt_subv_bytes = ((Cos * Ky * Kx * Cis * wgt_bits) // bits_per_byte) + \
                                self.bias_subv_bytes + \
                                qdq_size_bytes
        self.out_subv_bytes = (Yos * Xos * Cos * out_bits) // bits_per_byte
        self.Co_loop = math.ceil(Co / Cos)
        self.Y_loop = math.ceil(Yi / Yis)
        self.X_loop = math.ceil(Xi / Xis)
        self.Ci_loop = math.ceil(Ci / Cis)
        self.act_fmt = act_fmt
        self.Xis_pad = align_Xis(Xos, Cis, Kx, Sx)
        self.act_subv_bytes = (Yis * self.Xis_pad * Cis * act_bits) // bits_per_byte
    
    def __str__(self):
        return (
            f"ConvDims(N={self.N}, Yi={self.Yi}, Xi={self.Xi}, Ci={self.Ci}, "
            f"Yo={self.Yo}, Xo={self.Xo}, Co={self.Co}, "
            f"Yis={self.Yis}, Xis={self.Xis}, Cis={self.Cis}, "
            f"Yos={self.Yos}, Xos={self.Xos}, Cos={self.Cos}, "
            f"Ky={self.Ky}, Kx={self.Kx}, "
            f"Py={self.Py}, Px={self.Px}, "
            f"Sy={self.Sy}, Sx={self.Sx}, "
            f"aie_rows={self.aie_rows}, aie_cols={self.aie_cols}, "
            f"act_bits={self.act_bits}, wgt_bits={self.wgt_bits}, "
            f"bias_bits={self.bias_bits}, out_bits={self.out_bits}, "
            f"param_bits={self.param_bits}, "
            f"Ci_gran={self.Ci_gran}, "
            f"act_fmt={self.act_fmt}, "
            f"act_subv_bytes={self.act_subv_bytes}, "
            f"bias_subv_bytes={self.bias_subv_bytes}, "
            f"wgt_subv_bytes={self.wgt_subv_bytes}, "
            f"out_subv_bytes={self.out_subv_bytes}, "
            f"Co_loop={self.Co_loop}, Y_loop={self.Y_loop}, "
            f"X_loop={self.X_loop}, Ci_loop={self.Ci_loop}, "
            f"size_bytes={self.size_bytes}, Xis_pad={self.Xis_pad}, )"
        )


def derive_hardened_loop(Cis: int, folded_Kx: int, Ky: int) -> int:
    """
    Derive hardened_loop based on the constraint:
    Cis * folded_Kx * Ky >= 64 * min(4, hardened_loop & 7)
    
    Args:
        Cis: Channel input size
        folded_Kx: Folded kernel width
        Ky: Kernel height
        
    Returns:
        Appropriate hardened_loop value
    """
    threshold = (Cis * folded_Kx * Ky) // 64
    hardened_loop = 0
    if threshold >= 4:
        hardened_loop = 0
    elif threshold >= 3:
        hardened_loop = 3  # hardened_loop & 7 = 3
    elif threshold >= 2:
        hardened_loop = 2  # hardened_loop & 7 = 2
    elif threshold >= 1:
        hardened_loop = 1  # hardened_loop & 7 = 1
    else:
        hardened_loop = 0  # threshold < 1
    
    return hardened_loop

def setup_conv_params_new_kernel(
    dims: ConvDims,
    mode: int,
    verbose: bool = True,
) -> bytes:
    '''
    NOTE: All the layer params calculations are with respect to the output Xos and Yos dimensions
    TODO: The kernel supports a Cos < 64 but it has to be folded into Xos
    '''
    assert dims.Cos == 64, "Cos must be 64"
    print(f"Cos: {dims.Cos}") if verbose else None
    folded_Xos = 64 // dims.Cos
    print(f"folded_Xos: {folded_Xos}") if verbose else None
    folded_Kx = dims.Kx + (folded_Xos - 1) * dims.Sx
    print(f"folded_Kx: {folded_Kx}") if verbose else None
    # Xis_aligned = iceil((dims.Xis - 1) * dims.Sx + dims.Kx, max(1, 64 // dims.Cis))
    Xis_aligned = dims.Xis_pad 
    print(f"Xis_aligned: {Xis_aligned}") if verbose else None
    Cis_ilb = min(64, dims.Cis)
    print(f"Cis_ilb: {Cis_ilb}") if verbose else None
    Xos_ilb = min(64, dims.Xos // folded_Xos)
    print(f"Xos_ilb: {Xos_ilb}") if verbose else None
    kernel_dims = DimsHelper(-64)
    print(f"kernel_dims: {kernel_dims}") if verbose else None
    assert (dims.Xos / folded_Xos) in [8, 16, 32, 64], "Xos must be 8, 16, 32 or 64, got {}".format(dims.Xos / folded_Xos)
    assert dims.Yos == (4096 / (dims.Xos * dims.Cos)), "Yos must be 8, 16, 32, 64, got {}".format(dims.Yos)
    # step_Ci = Xis_aligned * ( 1 if dims.act_fmt == ActFmt.YCXC else dims.Yis ) * Cis_ilb
    # step_Ky = Xis_aligned * ( dims.Cis if dims.act_fmt == ActFmt.YCXC else Cis_ilb)
    step_Ci = Xis_aligned * ( dims.Yis ) * Cis_ilb
    print(f"step_Ci: {step_Ci}") if verbose else None
    step_Ky = Xis_aligned * ( Cis_ilb)
    print(f"step_Ky: {step_Ky}") if verbose else None
    step_Xi = Cis_ilb * dims.Sx * folded_Xos
    print(f"step_Xi: {step_Xi}") if verbose else None
    step_Yi = dims.Sy * step_Ky
    print(f"step_Yi: {step_Yi}") if verbose else None
    incr_Xi = kernel_dims.from_steps( 1, Cis_ilb * dims.Sx * folded_Xos )
    print(f"incr_Xi: {incr_Xi}") if verbose else None
    step_align = 0 if incr_Xi == 0 else int( math.log2( incr_Xi ^ ( incr_Xi - 1 ))) - 3 
    print(f"step_align: {step_align}") if verbose else None
    norm_ch_g = 1
    dims_YXi = kernel_dims.from_steps(( Xos_ilb, 64 // Xos_ilb ), ( step_Xi, step_Yi ))
    print(f"dims_YXi: {dims_YXi}") if verbose else None
    dims_KCi = kernel_dims.from_steps(( math.ceil( folded_Kx * Cis_ilb / 64 ), dims.Ky ), ( 64, step_Ky, step_Ci ))
    print(f"dims_KCi: {dims_KCi}") if verbose else None
    outer_time_iters = dims.Co_loop * dims.Y_loop * dims.X_loop
    print(f"outer_time_iters: {outer_time_iters}") if verbose else None
    inner_time_iters = dims.Ci_loop
    print(f"inner_time_iters: {inner_time_iters}") if verbose else None
    inner_loop = dims.Ky * int( math.ceil( folded_Kx * dims.Cis / 64 ))
    print(f"inner_loop: {inner_loop}") if verbose else None
    if dims.small_Cin_mode:
        raw_wgt_subv_size = dims.Cos * dims.Ky * 64 * dims.wgt_bits // 8
    else:
        raw_wgt_subv_size = dims.Cos * dims.Ky * dims.Kx * dims.Cis * dims.wgt_bits // 8
    print(f"raw_wgt_subv_size: {raw_wgt_subv_size}") if verbose else None
    raw_bias_size = dims.Cos * dims.bias_bits // 8
    print(f"raw_bias_size: {raw_bias_size}") if verbose else None
    harndened_loop = derive_hardened_loop(dims.Cis, folded_Kx, dims.Ky)
    print(f"harndened_loop: {harndened_loop}") if verbose else None
    packed_params = struct.pack(
        '<4I3H2B1I2i2I3i',
        harndened_loop,  # I
        mode,                           # I
        raw_wgt_subv_size,            # I
        raw_bias_size,           # I
        outer_time_iters,       # H
        inner_time_iters,       # H
        inner_loop,      # H
        step_align,      # B
        norm_ch_g,      # B
        dims_YXi['num0'],       # I
        dims_YXi['inc0'],      # i
        dims_YXi['inc1'],      # i
        dims_KCi['num0'],       # I
        dims_KCi['num1'],       # I
        dims_KCi['inc0'],      # i
        dims_KCi['inc1'],      # i
        dims_KCi['inc2'],      # i
    )
    return packed_params
