import numpy as np
from named_list import *
import struct

# from pooling_common import PoolingDims

class PoolingSubvDims:
    def __init__(
        self,
        Y : int,
        X : int,
        C : int,
        Ky : int,
        Kx : int,
        Sy : int,
        Sx : int,
        X_gran : int,
        Co_gran: int,
        ifm_bits: int,
        subv_elem: int,
        max_or_avg: int = 0,
        is_signed: int = 0,
        scratch: int = 0,
    ):
        self.Y = Y      # Yis with Y_gran = 1
        self.X = X      # Xis with X_gran = 1
        self.C = C      # Cos = Cis with C_gran = 8
        self.Ky = Ky
        self.Kx = Kx
        self.Sy = Sy
        self.Sx = Sx
        self.subv_elem = subv_elem
        self.C_gran = Co_gran
        self.X_gran = X_gran
        self.max_or_avg = max_or_avg

        self.ifm_bits = ifm_bits
        self.is_signed = is_signed
        self.scratch = scratch

        if self.ifm_bits == 8:
            self.dtype = "int8"
        elif self.ifm_bits == 16:
            self.dtype = "int16" # or bfloat16
        elif self.ifm_bits == 32:
            self.dtype = "int32"
        elif self.ifm_bits == 64:
            self.dtype = "int64"
        else:
            assert False, "Input datatype is not supported by kernel"



# log2 = np.log2
# array = lambda *x: np.array( x )
# all = lambda *x: np.all( x )
# any = lambda *x: np.any( x )

def ceil( n, d=1 ):
    if d == 1:
        return np.ceil( n )
    else:
        return d * np.ceil( n / d )

def sign( val ):
    if type( val ) == str:
        return type_decoder( val )[1]
    else:
        return val < 0

def sizeof( val ):
    return type_decoder( val )[0] / 8

class DimsHelper:
    def __init__( self, reset=0 ):
        self.reset = reset

    def __getitem__( self, key ):
        return getattr( self, key )

    def add_dimension( self, num, step ):
        inc = self.reset + step
        self.reset -= num * step
        return inc

    def from_steps( self, wraps, steps ):
        wraps = make_tuple( wraps )
        steps = make_tuple( steps )
        assert len( steps ) in [1,2,3,4,5], "Only 1d to 5d address increments supported"
        assert len( wraps ) >= len( steps ) - 1, "Wrap dimesions passed are not sufficient"

        nums = []
        incs = []
        for i,s in enumerate( steps ):
            if i == len( wraps ):
                incs.append( self.reset + s )
                self.reset = 0
            else:
                if i < len( steps ) - 1:
                    if i % 3 == 2:
                        num = wraps[i]
                        nums.append( np.prod( wraps[:i+1] ) - 1 )
                    else:
                        num = wraps[i] - 1
                        nums.append( num )
                else:
                    num = wraps[i]
                incs.append( self.add_dimension( num, s ))

                if i < len( steps ) - 1 and i % 3 == 2:
                    self.reset = -wraps[i] * s

        if len( incs ) == 1:
            return incs[0]
        else:
            return TypedNamedList([ f"uint32 num{n}" for n in range( len( nums ))] + [ f"uint32 inc{n}" for n in range( len( incs ))], nums + incs )



# known_math_ops = ( "min", "max", "ceil", "sign", "sizeof", "DimsHelper", "from_steps", "range", "for", "in", "if", "else", "int", "log2", "prod", "array", "all", "any" )

def make_tuple( val ):
    if hasattr( val, "__len__" ):
        return tuple( val )
    else:
        return ( val, )


def setup_kernel_params(
    dims: PoolingSubvDims,
    is_int16, is_signed,
):
    dims_helper = DimsHelper()
    Y = dims.Y
    X = dims.X
    C = dims.C
    # subv_elem = Y * X * C
    subv_elem = dims.subv_elem
    if(is_int16):
        assert X % (4 if not dims.max_or_avg else 8) == 0
    else:
        assert X % 8 == 0
    assert C % 8 == 0

    size_of_dtype = 2 if is_int16 else 1

    X_g = X // dims.X_gran
    C_g = C // dims.C_gran
    Xi  = (X - 1) * dims.Sx + dims.Kx
    step_Ky = ceil(Xi, dims.X_gran) * C * (2 if is_int16 else 1)
    reset = -(dims.Ky * step_Ky)
    dims_helper.reset = reset
    step_Xi = dims.X_gran * dims.C_gran * dims.Sx * size_of_dtype
    step_Ci = ceil(Xi, dims.X_gran) * dims.C_gran * size_of_dtype
    step_Yi = dims.Sy * ceil(Xi, dims.X_gran) * C * size_of_dtype
    T128or64_4or8x2_lo = 8 if is_int16 else 6
    T128or64_4or8x2_hi = 9 if is_int16 else 7
    T512_1x2_lo = 22

    outer_loop = Y * (X // dims.X_gran) * (C // dims.C_gran) * 1
    inner_loop = dims.Ky

    dimsA = dims_helper.from_steps((X_g, C_g), (step_Xi, step_Ci, step_Yi))
    step_ky = step_Ky
    shfl_0 = T128or64_4or8x2_lo if dims.Sx == 2 else T512_1x2_lo
    shfl_1 = T128or64_4or8x2_hi if dims.Sx == 2 else T512_1x2_lo
    shft_0 = 8 * (dims.Sy == 1 ) * (dims.Sx == 1)
    shft_1 = 8 * (dims.Kx >2 ) * (dims.Sx == 2)
    shft_2 = 2 * 8 * (dims.Kx >2 ) * (dims.Sx == 1)
    min_val = 0 if is_signed == False else (-128 if is_int16 == False else -32768 )
    ctrl_sign = is_signed

    dummy = 0
    struct_fields = (
        int(outer_loop),    # uint16_t
        int(inner_loop),    # uint16_t
        int(dimsA['num0']), # int32_t
        int(dimsA['num1']), # int32_t
        int(dimsA['inc0']), # int32_t
        int(dimsA['inc1']), # int32_t
        int(dimsA['inc2']), # int32_t
        int(step_ky),       # uint16_t
        int(shfl_0),        # uint8_t
        int(shfl_1),        # uint8_t
        int(shft_0),        # uint8_t
        int(shft_1),        # uint8_t
        int(shft_2),        # uint8_t
        int(min_val),       # int16_t
        int(ctrl_sign),     # uint8_t
    )
    layer_params = (
        dims.max_or_avg.to_bytes(length=2, byteorder='little', signed=False) +
        dims.ifm_bits.to_bytes(length=2, byteorder='little', signed=False) +
        subv_elem.to_bytes(length=2, byteorder='little', signed=False) +
        dims.is_signed.to_bytes(length=2, byteorder='little', signed=False) +
        dims.scratch.to_bytes(length=2, byteorder='little', signed=False) +
        dummy.to_bytes(length=2, byteorder='little', signed=False)
    )
    format_string = 'HHiiiiiHBBBBBhB'
    kernel_params = layer_params + struct.pack(format_string, *struct_fields)
    return kernel_params

if __name__ == '__main__':
    #PoolingSubvDims(Yis, Xis, Cis, Ky, Kx, Sy, Sx, X_gran, Co_gran, ifm_bits, max_or_avg)
    setup_kernel_params(PoolingSubvDims(1, 16, 64, 3, 3, 2, 2, 4, 8, 16, 1*16*64, 0))