# pylint: skip-file
import sys
import os

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..')))

from python.named_list import TypedNamedList, values, copy
from python.known_functions import *
from kernels.conv.direct_conv_int8x8_generic.low.direct_conv_int8x8_generic_params import params as low_params
from kernels.conv.direct_conv_int8x8_generic.low.direct_conv_int8x8_generic_params import sum_params
import json


def pack_op_mode( op_mode, fm=None, struct=False ):
    om = TypedNamedList([
        "uint32_t is_conv:1",
        "uint32_t is_sum:1",
        "uint32_t is_dwc:1",
        "uint32_t is_qdq:1",
        "uint32_t is_sum_2:1",
      ], [
        op_mode in ( "conv", "sym", "asym", "AxA" ),
        op_mode in ( "sum", "asym", "AxA", "dwc_asym" ),
        op_mode in ( "dwc", "dwc_sym", "dwc_asym" ),
        op_mode in ( "sym", "asym", "AxA", "dwc_sym", "dwc_asym", "qdq" ),
        op_mode in ( "AxA", ),
    ])
    return om if struct else om._get_stream( 32 )



def params( templates, parameters, Hi, Wi ):
    prm = TypedNamedList([])
    om = pack_op_mode( parameters["op_mode"], struct=True )
    H,W,Co,Ci,Kh,Kw,Sh,Sw,Dh,Dw = [ parameters["subvolume"][k] for k in "H,W,Co,Ci,Kh,Kw,Sh,Sw,Dh,Dw".split( ',' )]

    if om.is_conv:
        prm._append( "struct conv", low_params( "conv", templates, parameters, Hi, Wi ))

    if om.is_dwc:
        prm._append( "struct dwc", low_params( "dwc", templates, parameters, Hi, Wi ))
    elif om.is_sum:
        prm._append( "struct sum", low_params( "sum", templates, parameters, Hi, Wi ))
    if om.is_sum_2:
        param_2 = copy( parameters )
        param_2["transpose"] = copy( param_2["transpose"] )
        param_2["subvolume"] = copy( param_2["subvolume"] )
        param_2["transpose"]["I0"] = not param_2["transpose"]["I1"]
        param_2["subvolume"]["W"]  = Co
        #param_2["subvolume"]["Ci"] = Co
        param_2["op_mode"] = "sum"
        prm._append( "struct sum_2", low_params( "sum", templates, param_2, 1, Co ))

        SumToC0Params_fields = (
            "uint8_t N_g",
            "uint8_t coeff_step",
            "uint16_t offset",
        )
        sum2c0 = TypedNamedList( SumToC0Params_fields )
        prm._append( "struct sum2c0", sum2c0 )
        sum2c0.N_g = Co // 8
        sum2c0.coeff_step = 192 if parameters["quantization_coeffs"]["vector_coeffs"] > 0 else 64
        sum2c0.offset = 256 * prm.sum.outer_g
            

    if om.is_qdq:
        QDQParams_fields = (
            "int16_t loop",
            "uint8_t split_mode:1",
            "uint8_t sign_out:1",
            "int8_t vector_coeffs",
            "int16_t dims_in1_wrap0",
            "int16_t dims_in1_wrap1",
            "int16_t dims_in1_step",
            "dims_2d_param_s16 dims_sum",
            "dims_2d_param_s16 dims_qnt",
            "dims_3d_param_s16 dims_out",
          )

        qdq = TypedNamedList( QDQParams_fields )
        prm._append( "struct qdq", qdq )
        vc = parameters["quantization_coeffs"]["vector_coeffs"]

        qdq.loop           = H * W * Co // 32
        qdq.split_mode     = 0 #Sw != 1
        qdq.vector_coeffs  = vc
        qdq.sign_out       = sign( parameters["dtype"]["O0"] )
        qdq.dims_in1_wrap0 = H * W // 4 if qdq.split_mode else 2
        qdq.dims_in1_wrap1 = H * W // 8 if not qdq.split_mode else 2
        qdq.dims_in1_step  = ( 2 if qdq.split_mode else 4 ) * Co // 8

        dims = DimsHelper( bits=16 )
        qdq.dims_sum = dims.from_steps( H * W // 4, ( 16, 0 ))
        qdq.dims_qnt = dims.from_steps( H * W // 4, ( 0, (( 192 if om.is_sum_2 else 128 ) if vc > 0 else 64 )))
        qdq.dims_out = dims.from_steps(( W // 4, H ), ( 4, W * Co // 8, W ))
    
    print(prm)
    blob = prm._to_byte_array()
    return blob

def main():
    templates = {
        "has_dwc": 0,
        "has_conv": 0, #Since GEMM, set this to 1
        "has_sum": 0, #Asym = 1
        "has_vector_coeffs": 0 # 1 if channel wise vectors
    }
    parameters = {
        "subvolume": {
            "H": 1,
            "W": 64,
            "Co": 64,
            "Ci": 64,
            "Kh": 1,
            "Kw": 1,
            "Sh": 1,
            "Sw": 1,
            "Dh": 1,
            "Dw": 1
        },
        "op_mode": "sym",
        "dtype": {
            "I0": "int8",
            "I1": "int8",
            "O0": "int8"
        },
        "transpose": {
            "I0": 0,
            "I1": 0
        },
        "quantization_coeffs": {
            "shift_res": 16,
            "zp_wght" : 0,
            "vector_coeffs": 0,
            "qdq_c0": 0,
            "qdq_c1": 0,
            "qdq_c2": 1,
            "qdq_c3": 0
        }
    }
    
    Hi = 1
    Wi = 64
    prm = params(templates, parameters, Hi, Wi)


if __name__ == '__main__':
    main()
