# pylint: skip-file
from kernels.python.named_list import TypedNamedList, values
from kernels.python.known_functions import *


def params( op_mode, templates, parameters, Hi, Wi ):
    #print( f"{templates}\n {parameters}" )
    Control_fields = [
        "uint8_t sign_A:1",
        "uint8_t sign_W:1",
        "uint8_t is_conv:1",
        "uint8_t is_sum:1",
        "uint8_t tdm_overwrite:1",
        "uint8_t is_dwc:1",
        "uint8_t is_in1_T:1",
    ]
    BaseParams_fields = [
        "uint16_t outer_g",
        "uint16_t inner_g",
        "Control ctrl",
        "uint8_t step_align",
        "uint8_t shfl_0",
        "uint8_t shfl_1",
        "uint8_t shfl_W",
        "int16_t incW",
        "dims_3d_param_s16 dimsW",
        "dims_3d_param_s16 dimsA",
    ]

    ConvAddParams_fields = [
        "dims_3d_param_s16 dimsAO",
    ]

    SumAddParams_fields = [
        "uint16_t loop",
        "dims_3d_param_s16 dimsSum",
    ]

    ConvSumAddParams_fields = [
        "SumAddParams sum",
        "uint16_t reserved",
        "dims_2d_param_s16 dimsKsumI",
        "dims_2d_param_s16 dimsKsumO",
        "uint8_t kernel_size",
        "uint8_t pixels_out_g",
        "uint8_t shfl_sum_2",
    ]

    DwcAddParams_fields = [
        "ConvAddParams conv",
        "int16_t weight_size",
        "int8_t zp_wght",
    ]


    #LowParams_fields = [ BaseParams {
    #    union {
    #        ConvAddParams conv",
    #        ConvSumAddParams sum",
    #        DwcAddParams dwc",
    #    };
    #};
    #
    #
    #struct ConvParams : BaseParams, ConvAddParams { };
    #struct SumParams : BaseParams, SumAddParams { };
    #struct ConvSumParams : BaseParams, ConvSumAddParams { };
    #struct DwcParams : BaseParams, DwcAddParams { };

    Control = TypedNamedList( Control_fields )
    BaseParams = TypedNamedList( BaseParams_fields )
    ConvAddParams = TypedNamedList( ConvAddParams_fields )
    SumAddParams = TypedNamedList( SumAddParams_fields )
    ConvSumAddParams = TypedNamedList( ConvSumAddParams_fields )
    DwcAddParams = TypedNamedList( DwcAddParams_fields )

    H,W,Co,Ci,Kh,Kw,Sh,Sw,Dh,Dw = [ parameters["subvolume"][k] for k in "H,W,Co,Ci,Kh,Kw,Sh,Sw,Dh,Dw".split( ',' )]

    Control.sign_A = sign( parameters["dtype"]["I0"] )
    Control.sign_W = sign( parameters["dtype"]["I1"] )
    Control.is_conv = any( array( H, Kw, Kh, Sw, Sh ) > 1 )
    Control.is_sum = op_mode in ('sum',)
    Control.is_dwc = op_mode.startswith( "dwc" )
    Control.is_in1_T = parameters["transpose"]["I1"]
    Co_iter = Co // 8 if op_mode.startswith( "dwc" ) else Co // 16
    Control.tdm_overwrite = Control.is_sum or op_mode.startswith( "dwc" )
    BaseParams.ctrl = Control

    is_dwc = Control.is_dwc

    from kernels.conv.direct_conv_int8x8_generic.direct_conv_int8x8_generic_params import pack_op_mode
    BaseParams.inner_g = Kw * Kh * ( Ci if not is_dwc else 8 ) // 8
    BaseParams.outer_g = W * H * Co_iter // 8

    stp = 8
    step_Kw = Dw * 8 // stp
    if parameters["transpose"]["I0"]:
        step_Wi = Ci * 8 // stp
        step_Ci = 64 // stp
    else:
        step_Wi = 64 * Sw // stp
        step_Ci = ceil( Wi, 8 ) * 8 // stp
    step_Hi = step_Ci * ( Ci // 8 ) * Sh
    step_Kh = step_Ci * ( Ci // 8 ) * Dh

    T512_1x2_lo = 22#0
    T512_1x2_hi = 23#1
    T128_4x2_lo = 8#20
    T64_8x2_lo = 6
    T32_4x8_lo = 32#52
    T32_8x4_lo = 30#38
    T32_16x2_lo = 4
    T16_8x8_lo = 52#54
    T8_8x8 = 35
    BaseParams.shfl_0 = T8_8x8 if parameters["transpose"]["I0"] else ( T512_1x2_lo if Sw == 1 else T64_8x2_lo )
    BaseParams.step_align = int( log2( step_Kw ^ ( step_Kw - 1 ))) + 0*1

    dims = DimsHelper( -128 // stp, bits=16 )
    if Control.is_sum:
        BaseParams.incW = 0
        BaseParams.shfl_0 = T512_1x2_lo
        Control.is_in1_T = not parameters["transpose"]["I0"]
        step_Ws = 64 // stp if Control.is_in1_T else step_Wi
        sp = sum_params( Hi, Wi, Ci )
        Wi_g = ceil( Wi, 8 )
        BaseParams.inner_g = sp.inner_g
        BaseParams.outer_g = sp.outer_g
        stepW2 = sp.step_sum_block - 8 if sp.small_pixel_opt else 0
        BaseParams.dimsA = dims.from_steps(( Ci // 8, sp.block_0 ), ( 0, 64 // stp, stepW2 * 64 // stp ))
        BaseParams.dimsW = dims.from_steps(( Ci * Hi // 8, Wi_g // 8 ), ( step_Ci, step_Ws, 0 ))
        SumAddParams.loop = 8 * sp.outer_g
        SumAddParams.incAO = 0
        SumAddParams.dimsSum = sp.dimsSum

        if Control.is_conv:
            ConvSumAddParams.sum = SumAddParams
            dims.reset = -64
            ConvSumAddParams.dimsKsumI = dims.from_steps(( Kw, Kh ), ( 4 * Dw, sp.step_sum_H * Dh ), next_loop_level=True )
            ConvSumAddParams.dimsKsumO = dims.from_steps(( W // 8, ), ( 32 * Sw, sp.step_sum_H * Sh ))
            ConvSumAddParams.kernel_size = Kw * Kh
            ConvSumAddParams.pixels_out_g = W * H // 8
            ConvSumAddParams.shfl_sum_2 = T32_16x2_lo if Sw == 2 else T512_1x2_lo

    else:
        if Control.is_conv:
            BaseParams.dimsA = dims.from_steps(( Kw, Kh, Ci // 8 ), ( step_Kw, step_Kh, step_Ci * ( not is_dwc )), next_loop_level=True )
            ConvAddParams.dimsAO = dims.from_steps(( Co_iter, W // 8 ), ( is_dwc * step_Ci, step_Wi, step_Hi ))
        else:
            BaseParams.dimsA = dims.from_steps(( Ci // 8, Co_iter ), ( step_Ci, 0, step_Wi ))

        weight_size = BaseParams.inner_g * Co // 8
        if op_mode.startswith( "dwc" ):
            BaseParams.incW = 0
            BaseParams.dimsW = dims.from_steps(( 1, weight_size ), ( 0, 8, 0 ));
        elif parameters["transpose"]["I1"]:
            BaseParams.incW = 64
            dims.reset = -BaseParams.incW // stp
            BaseParams.dimsW = dims.from_steps(( Ci // 8, Co_iter ), ( Co * 8 // stp, 128 // stp, 0 ));
        elif not Control.is_conv:
            BaseParams.incW = Ci * 8
            dims.reset = -BaseParams.incW // stp
            BaseParams.dimsW = dims.from_steps(( Ci // 8, Co_iter ), ( 64 // stp, 2 * BaseParams.incW // stp, 0 ));
        else:
            BaseParams.incW = 64
            dims.reset = -8  
            BaseParams.dimsW = dims.from_steps(( 1, weight_size // 2 ), ( 0, 16, 0 ));

    if "dwc" in op_mode:
        DwcAddParams.weight_size = weight_size
        DwcAddParams.zp_wght = parameters["quantization_coeffs"]["zp_wght"]
        DwcAddParams.conv = ConvAddParams
        return BaseParams | DwcAddParams
    elif op_mode.startswith( "sum" ):
        if Control.is_conv:
            return BaseParams | ConvSumAddParams
        else:
            return BaseParams | SumAddParams
    elif Control.is_conv:
        return BaseParams | ConvAddParams
    else:
        return BaseParams


def sum_params( Hi, Wi, Ci ):
    Wi_g = ceil( Wi, 8 )
    outer_g = max( 2, ceil( Wi_g * Hi / 64 ))
    if Wi_g == 8:
        block_0 = ceil( Hi / outer_g )
        if Ci // 8 * block_0 < 8:
            block_0 = ceil( 64 / Ci )
            Hi = 2 * block_0
    else:
        block_0 = Hi * ceil( Wi_g / ( 8 * outer_g ))
        if Ci // 8 * block_0 < 8:
            while Ci // 8 * block_0 < 8:
                Wi_g += 8
                block_0 = Hi * ceil( Wi_g / ( 8 * outer_g ))

    step_sum_block = ceil( 8, min( Hi, block_0 ))
    sp_valid_range = step_sum_block * ( outer_g - 1 ) + ( Hi * Wi_g // 8 - block_0 * ( outer_g - 1 ))
    small_pixel_opt = block_0 < 8 and sp_valid_range <= 8 * outer_g
    if not small_pixel_opt:
        block_0 = 8

    #print( "Small pixel", small_pixel_opt, block_0, ceil( 8, block_0 ), ( Hi * Wi_g // 8 - block_0 * ( outer_g - 1 )))
    dims = DimsHelper( bits=16 )
    inner_g = Ci // 8 * block_0
    assert inner_g >= 8, "This indicates a bug in the code above"
    Wi_all = Wi_g;
    while Wi_all * Hi < 64 * outer_g:
        Wi_all += 8
    #Wi_all = max( 16, Wi_g, ceil( 64 * BaseParams.outer_g / Hi ))
    #print( "Wi_all:", Wi_all, Wi, Wi_g, Hi, outer_g )
    #step_sum_H = Wi_all * 4;
    if small_pixel_opt:
        #print( "setup for small image" )
        Hi_inner = ceil( Hi / ( 1 + ( Wi_g == 8 )))
        Wi_inner = step_sum_block // Hi_inner
        step_sum_H = 32 * ( ceil( step_sum_block / Hi ) + block_0 // Hi *  (outer_g - 1 )) if Wi_g > 8 else Wi_g * 4
        step_sum_W = 32 if Wi_g > 8 else 32 * Hi
        step_sum_outer = 32 * ceil( block_0 / Hi ) if Wi_g > 8 else ceil( Hi / 2 ) * Wi_g * 4
        #Wi_inner = ceil(( 8 + 2**ceil( log2( Hi * Wi_g / 16 ))) / Hi_inner )
        #step_sum_H = 32 * Wi_inner if Hi > 2 or Wi_g > 8 else 32
        #step_sum_outer = step_sum_H * Hi_inner if Wi_g == 8 else 32 * 2**ceil( log2( Wi_g / 16 ))
        #print( "step_sum_outer", step_sum_outer, step_sum_H, Hi_inner, Wi_g )
        dimsSum = dims.from_steps(( Hi_inner, Wi_inner ), ( step_sum_H, step_sum_W, step_sum_outer ))

        img = np.arange( 1, 1 + Wi_g // 8 ).reshape( 1, Wi_g // 8 ) + 10 * np.arange( 1, 1 + Hi ).reshape( Hi, 1 )
        img_t = img.T.flatten( )
        img_s = np.zeros( outer_g * 8 )
        for o in range( outer_g ):
            i = img_t[ o * block_0 : ( o + 1 ) * block_0 ]
            img_s[ o * step_sum_block : o * step_sum_block + i.size ] = i
        #print( img_s.size )
        #print( img_s.reshape( -1, 8 ))
        img_r = np.zeros( Hi * Wi_g // 8 + 16 )
        p = 0
        c0 = 0
        c1 = 0
        for s in range( 8 * outer_g ):
            #print( f"[{p}] <= [{s}], with c0={c0}, c1={c1}, dims={dimsSum}" )
            img_r[p] = img_s.flatten( )[s]
            if c0 == dimsSum.num0:
                c0 = 0
                if c1 == dimsSum.num1:
                    c1 = 0
                    p += dimsSum.inc2 // 32
                else:
                    c1 += 1
                    p += dimsSum.inc1 // 32
            else:
                c0 += 1
                p += dimsSum.inc0 // 32

        #print( img_r[ :floor( img_r.size, Hi )].reshape( Hi, -1 ))
        img_x = img_r[ :max( step_sum_H // 32, Wi_g // 8 ) * Hi ].reshape( Hi, -1 )[ :, :Wi_g // 8 ]
        #print( img_x, img_r[ :max( step_H, w ) * h ].shape, img_r[ :max( step_H, w ) * h ].reshape( h, -1 ).shape )
        if not np.all( img == img_x ):
            print( img )
            print( img_x )
            p = 0
            c0 = 0
            c1 = 0
            for s in range( 8 * outer_g ):
                print( f"[{p}] <= [{s}], with c0={c0}, c1={c1}, dims={dimsSum}" )
                if c0 == dimsSum.num0:
                    c0 = 0
                    if c1 == dimsSum.num1:
                        c1 = 0
                        p += dimsSum.inc2 // 32
                    else:
                        c1 += 1
                        p += dimsSum.inc1 // 32
                else:
                    c0 += 1
                    p += SumAddParams.dimsSum.inc0 // 32
            raise ValueError( "Reorder failed" )
    else:
        #print( "setup for big image" )
        step_sum_H = Wi_all * 4
        dimsSum = dims.from_steps(( Hi, Wi_all ), ( step_sum_H, 32, Hi * step_sum_H ))

    return NamedList({
        "outer_g": outer_g,
        "inner_g": inner_g,
        "small_pixel_opt": small_pixel_opt,
        "block_0": block_0,
        "dimsSum": dimsSum,
        "step_sum_H": step_sum_H,
        "step_sum_block": step_sum_block,
    })
