import math
import struct
import ctypes
import os
import sys
from enum import IntEnum

CURRDIR = os.path.dirname(os.path.abspath(__file__))

from kerneltest.helpers import \
    ceildiv, \
    iceil, \
    round_up_to_multiple, \
    is_power_of_two

class Ctrl(ctypes.Structure):
    _fields_= [
        ("zero_init", ctypes.c_uint32, 1),
        ("sign_N", ctypes.c_uint32, 1),
        ("sign_O", ctypes.c_uint32, 1),
        ("reserved3", ctypes.c_uint32, 3),
        ("skip_casc_in", ctypes.c_uint32, 1),
        ("skip_casc_out", ctypes.c_uint32, 1),
        ("sign_W", ctypes.c_uint32, 1),
        ("sign_A", ctypes.c_uint32, 1),
        ("reserved10", ctypes.c_uint32, 14),
        ("norm_ch_g", ctypes.c_uint32, 8)
    ]

class MaxpoolDims:
    __slots__ = (
                'Param_size',
                'mem_align',
                'N',
                'Yi', 'Xi', 'C',
                'Yo', 'Xo',
                'Yis', 'Xis', 'Cs', 'Cs_pad',
                'Yos', 'Xos',
                'Ky', 'Kx',
                'Sy', 'Sx',
                'Py', 'Px',
                'aie_rows', 'aie_cols',
                'act_bits', 'out_bits', 'param_bits',
                'Y_gran', 'X_gran', 'C_gran',
                'C_loop', 'Y_loop', 'X_loop',
                'act_subv_bytes', 'wgt_subv_bytes', 'out_subv_bytes',
                'size_bytes',
                'Yis_pad', 'Xis_pad', 'Cs_pad',
                'Yos_pad', 'Xos_pad',
                'act_fmt',
                 ) 
    def __init__(
        self,
        N: int,
        Yi: int,
        Xi: int,
        C: int,
        Yo: int,
        Xo: int,
        Yis: int,
        Xis: int,
        Cs: int,
        Cs_pad: int,
        Yos: int,
        Xos: int,
        Ky: int,
        Kx: int,
        Py: int,
        Px: int,
        Sy: int,
        Sx: int,
        aie_rows: int,
        aie_cols: int,
        act_bits: int,
        out_bits: int,
        param_bits: int,
        Y_gran: int,
        X_gran: int,
        C_gran: int,
    ):
        self.wgt_subv_bytes = 64
        self.Param_size = 1024
        self.mem_align = 128
        self.size_bytes = 1
        bits_per_byte = 8
        assert(Cs % C_gran == 0)
        assert(C % Cs == 0)
        self.N = N
        self.Yi = Yi
        self.Xi = Xi
        self.C = C
        self.Yo = Yo
        self.Xo = Xo
        self.Yis =Yis
        self.Xis =Xis
        self.Cs = Cs
        self.Cs_pad = Cs_pad
        self.Yos = Yos
        self.Xos = Xos
        self.Ky = Ky
        self.Kx = Kx
        self.Sy = Sy
        self.Sx = Sx
        self.Py = Py
        self.Px = Px
        self.aie_rows = aie_rows
        self.aie_cols = aie_cols
        self.act_bits = act_bits
        self.out_bits = out_bits
        self.param_bits = param_bits
        self.Y_gran = Y_gran
        self.X_gran = X_gran
        self.C_gran = C_gran
        self.act_subv_bytes = (Yis * Xis * Cs_pad * act_bits) // bits_per_byte
        self.out_subv_bytes = (Yos * Xos * Cs_pad * out_bits) // bits_per_byte
        self.C_loop = math.ceil(C / Cs)
        self.Y_loop = math.ceil(Yi / Yis)
        self.X_loop = math.ceil(Xi / Xis)


def setup_maxpool_params(
    N: int,
    Sy: int,
    Sx: int,
    Ky: int,
    Kx: int,
    Yis: int,
    Xis: int,
    Cs: int,
    Y_gran: int,
    X_gran: int,
    C_gran: int,
    Yos: int,
    Xos: int,
    size_bytes: int,
) -> bytes:
    '''
    NOTE: The smallest output subvol granularity for the maxpool kernel is 1x1x64 
    '''
    Kx_g = Kx
    Ky_g = Ky
    Ci_g = 1    # NOTE: Channel iterations is controlled by Co_g 
    S_g = Sx
    N_g = N
    X_g = ceildiv(Xos, X_gran)
    Y_g = ceildiv(Yos, Y_gran) 
    Co_g = ceildiv(Cs, C_gran)
    inner_g = Kx_g * Ky_g
    outer_g = X_g * Y_g * Co_g

    step_Kx = C_gran     # NOTE: Unused param for maxpool kernel
    step_Ky = Xis * size_bytes * Cs
    step_Xi = C_gran     # NOTE: Unused param for maxpool kernel
    step_Yi = Xis * size_bytes * Cs * S_g
    step_Ci = (Xis * C_gran * size_bytes) if (Cs > C_gran) else 1
    step_Xo = C_gran * size_bytes
    step_Yo = Xos * Cs * size_bytes
    step_Co = (Xos * C_gran * size_bytes) if (Cs > C_gran) else 1
    param_value = 0
    ctrl = Ctrl()
    shift_tdm = 0
    shift_norm = 0
    shift_bias = 0
    shift_res = 0
    ctrl.sign_A = 0
    ctrl.sign_W = 0
    ctrl.sign_O = 0
    struct_fields = (
        Kx_g,
        Ky_g,
        Ci_g,
        S_g,
        N_g,
        X_g,
        Y_g,
        Co_g,
        inner_g,
        outer_g,
        shift_tdm,
        shift_res,
        shift_norm,
        shift_bias,
        step_Kx,
        step_Ky,
        step_Ci,
        step_Xi,
        step_Yi,
        step_Xo,
        step_Yo,
        step_Co,
        param_value,
        ctypes.string_at(ctypes.addressof(ctrl), ctypes.sizeof(ctrl)),
    )
    print(f"Kx_g,       ", Kx_g)
    print(f"Ky_g,       ", Ky_g)
    print(f"Ci_g,       ", Ci_g)
    print(f"S_g,        ", S_g)
    print(f"N_g,        ", N_g)
    print(f"X_g,        ", X_g)
    print(f"Y_g,        ", Y_g)
    print(f"Co_g,       ", Co_g)
    print(f"inner_g,    ", inner_g)
    print(f"outer_g,    ", outer_g)
    print(f"shift_tdm,  ", shift_tdm)
    print(f"shift_res,  ", shift_res)
    print(f"shift_norm, ", shift_norm)
    print(f"shift_bias, ", shift_bias)
    print(f"step_Kx,    ", step_Kx)
    print(f"step_Ky,    ", step_Ky)
    print(f"step_Ci,    ", step_Ci)
    print(f"step_Xi,    ", step_Xi)
    print(f"step_Yi,    ", step_Yi)
    print(f"step_Xo,    ", step_Xo)
    print(f"step_Yo,    ", step_Yo)
    print(f"step_Co,    ", step_Co)
    print(f"param_value,", param_value)
    format_string = 'BBBbBBBBHHbbbbHHHHHHHHi4s'
    kernel_params = struct.pack(format_string, *struct_fields)
    return kernel_params

def gen_aie4_maxpool_params(
    dims: MaxpoolDims,
    mode: int,
) -> bytes:
    kernel_params = setup_maxpool_params(
        N=dims.N,
        Sy=dims.Sy,
        Sx=dims.Sx,
        Ky=dims.Ky,
        Kx=dims.Kx,
        Yis=dims.Yis,
        Xis=dims.Xis,
        Cs=dims.Cs_pad,
        Y_gran=dims.Y_gran,
        X_gran=dims.X_gran,
        C_gran=dims.C_gran,
        Yos=dims.Yos,
        Xos=dims.Xos,
        size_bytes=dims.size_bytes,
    )
    layer_params = (
        mode.to_bytes(length=2, byteorder='little', signed=False)
    )
    return (kernel_params + layer_params)