#! /usr/bin/env python3
# Copyright (C) 2019 - 2022 Xilinx, Inc. All rights reserved.
# Copyright (C) 2022 - 2025 Advanced Micro Devices, Inc. All rights reserved.
#
# This file contains confidential and proprietary information
# of Xilinx, Inc. and is protected under U.S. and
# international copyright and other intellectual property
# laws.
#
# DISCLAIMER
# This disclaimer is not a license and does not grant any
# rights to the materials distributed herewith. Except as
# otherwise provided in a valid license issued to you by
# Xilinx, and to the maximum extent permitted by applicable
# law: (1) THESE MATERIALS ARE MADE AVAILABLE "AS IS" AND
# WITH ALL FAULTS, AND XILINX HEREBY DISCLAIMS ALL WARRANTIES
# AND CONDITIONS, EXPRESS, IMPLIED, OR STATUTORY, INCLUDING
# BUT NOT LIMITED TO WARRANTIES OF MERCHANTABILITY, NON-
# INFRINGEMENT, OR FITNESS FOR ANY PARTICULAR PURPOSE; and
# (2) Xilinx shall not be liable (whether in contract or tort,
# including negligence, or under any other theory of
# liability) for any loss or damage of any kind or nature
# related to, arising under or in connection with these
# materials, including for any direct, or any indirect,
# special, incidental, or consequential loss or damage
# (including loss of data, profits, goodwill, or any type of
# loss or damage suffered as a result of any action brought
# by a third party) even if such damage or loss was
# reasonably foreseeable or Xilinx had been advised of the
# possibility of the same.
#
# CRITICAL APPLICATIONS
# Xilinx products are not designed or intended to be fail-
# safe, or for use in any application requiring fail-safe
# performance, such as life-support or safety devices or
# systems, Class III medical devices, nuclear facilities,
# applications related to the deployment of airbags, or any
# other applications that could lead to death, personal
# injury, or severe property or environmental damage
# (individually and collectively, "Critical
# Applications"). Customer assumes the sole risk and
# liability of any use of Xilinx products in Critical
# Applications, subject only to applicable laws and
# regulations governing limitations on product liability.
#
# THIS COPYRIGHT NOTICE AND DISCLAIMER MUST BE RETAINED AS
# PART OF THIS FILE AT ALL TIMES.
import me_regression as me
import os
from kernels.direct_conv.common.direct_conv_test_method import DirectConvTestMethod
from kernels.direct_conv.common.generate_direct_conv_data_step import GenerateDirectConvDataStep
from kernels.ml.common.generate_ml_params import GenerateMLParams, TypedNamedList
import numpy as np

opts = {
    "MODE"        : 'x86', #'mesimulator', #"sa_iss",
    "IN_MODE"     : 'zero',    #valid are zero,    casc, tdm16, tdm32, tdm16_casc, tdm32_casc
    "OUT_MODE"    : 'result16', #valid are result8, casc, tdm16, tdm32
}


class GenerateParams( GenerateMLParams ):
    def __init__( self, *args, **kwargs ):
    
        print( kwargs )
        GenerateMLParams.__init__( self, *args, **kwargs )

        MLKernelParams_fields = (
                    'uint8_t Kx_g',
                    'uint8_t Ky_g',
                    'uint8_t Ci_g',
                    'int8_t  S_g',
                    'uint8_t N_g',
                    'uint8_t X_g',
                    'uint8_t Y_g',
                    'uint8_t Co_g',
                    'uint16_t inner_g',
                    'uint16_t outer_g',
                    'int8_t shift_tdm',
                    'int8_t shift_res',
                    'int8_t zp_wght',
                    'int8_t op_mode',

                    'uint16_t step_Kx',
                    'uint16_t step_Ky',
                    'uint16_t step_Ci',
                    'uint16_t step_Xi',
                    'uint16_t step_Yi',
                    'uint16_t step_Xo',
                    'uint16_t step_Yo',
                    'uint16_t step_Co',
                    'int param_value',
                    'DirectConvKernelConfig config',
                )

        self.kernel_param = TypedNamedList(MLKernelParams_fields)
        config = self.kernel_param.config
        self.layer_param.kernel  = self.kernel_param
        self.kernel_param.config = config




class GenerateDataStep(GenerateDirectConvDataStep):
    def __init__(self, tc, path, param_gen, *args, **kwargs):
        #self.orderCasc_S2 = kwargs['orders']['Casc_S2']
        self.orderTdm_S2  = kwargs['orders'].get('Tdm{}_S2'.format(param_gen.bitsTdm), kwargs['orders'].get('Tdm_S2', kwargs['orders']['Tdm']))
        return GenerateDirectConvDataStep.__init__(self, tc, path, param_gen, *args, **kwargs)

    op_modes = ( "none", "conv", "sum", "sum_T", "dwc", "dwc_sum", "dummy1", "dummy2", "conv_sym", "conv_asym", "dwc_sym", "dwc_asym" )


    def get_sizes(self):
        sz = GenerateDirectConvDataStep.get_sizes(self)
        if self.depthwise:
            sz.W = (( sz.W + 7 ) & ~7 ) * 9
        return sz


    def get_parameters(self, ps, **kwargs):
        op_mode = self.op_modes[ ps.op_mode ]
        self.depthwise = op_mode.startswith( "dwc" )
        print( ps )
        self.use_S2_order = ps.Sx == 2
        accum_ifms = 1 if self.depthwise else ps.Ci
        inner = accum_ifms * ps.Ky * ps.Kx

        if self.depthwise:
            self.wght.shape = list( self.wght.shape )
            self.wght.shape[ self.wght.defOrder.index( 'I' )] = 1
            self.wght.order = self.wght.order.replace( 'I8', 'I1' )
        elif op_mode.startswith( "sum" ):
            self.wght.shape = list( self.wght.shape )
            self.wght.shape[ self.wght.defOrder.index( 'O' )] = 2**max( 0, int( np.ceil( 6 - np.log2( ps.Ky * ps.Kx ))))
            #self.gold.shape = list( self.wght.shape )
            #self.gold.shape[ self.gold.defOrder.index( 'C' )] = 1
            self.gold.order = self.gold.order.replace( 'C8', 'C1' )

        if self.gold.bits == self.bitsCasc:
            self.gold.sgn = True

        shift_res  = self.srs_shift(inner, self.gold.bits, self.gold.sgn)
        shift_tdm  = self.srs_shift(inner, self.bitsTdm)
        ps._append( 'zp_wght', -53 )

        offsets = self.calculate_offsets()

        self.param_gen.set_kernel_params(
                ps,
                shift_res   = shift_res,
                shift_tdm   = shift_tdm,
                sign_config = self.pack_sign(),
                offset_actv = offsets['A'],
                offset_wght = offsets['W'],
                offset_out  = offsets['O'],
                offset_interm = offsets['T'],
                **kwargs
            )
        if self.depthwise:
            self.param_gen.kernel_param.inner_g = inner
            self.param_gen.kernel_param.Ci_g = 1
        print( self.param_gen.kernel_param )
        return (self.param_gen.get_parameters() + self.append_param, (shift_res,), shift_tdm)


    def compute_reference( self, activations, weights, shapes ):
        ( batch, oyps, oxps, ofms, ifms, ky, kx, sy, sx ) = shapes[:9]
        r = np.zeros( shape=( batch, oyps, oxps, ofms ), dtype=np.long )
        actv = activations
        op_mode = self.op_modes[ self.param_gen.kernel_param.op_mode ]
        if op_mode == "dwc_sum":
            weights = - self.param_gen.kernel_param.zp_wght * np.ones(( 1, ky, kx, ofms ), dtype=np.long )
        if op_mode.startswith( "sum" ):
            if "T" in op_mode:
                r = np.sum( actv, axis=( 0, 1, 2 ), keepdims=True )
                assert np.prod(( ky, kx, sy, sx )) == 1, "Gemm required for transposed sum"
            else:
                ch_sum = np.sum( actv, axis=3 )
                for       b in range( batch ):
                  for     y in range( oyps ):
                    for   x in range( oxps ):
                        r[b, y, x] = np.sum( ch_sum[b, y*sy:y*sy+ky, x*sx:x*sx+kx] )
                print( ch_sum )
                print( r )
        else:
            depthwise = op_mode.startswith( "dwc" )
            print( "is dwc: ", depthwise, op_mode)
            for       b in range( batch ):
              for     y in range( oyps ):
                for   x in range( oxps ):
                  for o in range( ofms ):
                    if depthwise:
                      a = actv[b, y*sy:y*sy+ky, x*sx:x*sx+kx, o].reshape( -1 )
                    else:
                      a = actv[b, y*sy:y*sy+ky, x*sx:x*sx+kx, :].reshape( -1 )
                    w = weights[:, 0:kx, :, o].reshape( -1 )
                    r[b, y, x, o] = np.sum( np.multiply( a, w ))
        return r

    def create_casc_in(self, order):
        if self.use_S2_order: order = self.orderCasc_S2
        return GenerateDirectConvDataStep.create_casc_in(self, order)
    def create_tdm_in(self, order, shift_tdm):
        if self.use_S2_order: order = self.orderTdm_S2
        return GenerateDirectConvDataStep.create_tdm_in(self, order, shift_tdm)

    def run_casc(self, order):
        if self.use_S2_order: order = self.orderCasc_S2
        return GenerateDirectConvDataStep.run_casc(self, order)
    def run_tdm(self, order, shift_tdm):
        if self.use_S2_order: order = self.orderTdm_S2
        op_mode = self.op_modes[ self.param_gen.kernel_param.op_mode ]
        if op_mode.startswith( "sum" ):
            order = 'YXN(XN)8C'
        return GenerateDirectConvDataStep.run_tdm(self, order, shift_tdm)




class DirectConvTest(DirectConvTestMethod):
    def __init__(self):
        DirectConvTestMethod.__init__(self, 16, 8, opts=opts, tdm_split=True, gen_step=GenerateDataStep, param_gen=GenerateParams, appendix="generic")

        self.data_order = {'A':'YCXN(XN)4C8',  'W':'OIYXI8O8',  'O':'YCXNC8',
                        'Tdm':  '(XN)2>1YX<1N<1C(XN)4C8', 'Tdm_S2': 'C2>1YXNC<1(XN)4C8',
                        #'Tdm32':'YC(XN)4>2(XN)<1(XN)2<1(XN)2C8', 'Tdm32_S2':'YC(XN)4>1X<1N<1(XN)2C8',
                        #'Casc': 'YCX(XN)4>1C2(XN)2<1(XN)2C8',    'Casc_S2': 'YCX(XN)4C2*2(XN)2C8'
                }

        self.size_order = ('N','Y','X','Co','Ci','Ky','Kx','Sy','Sx','op_mode')
        self.direct_conv_tests = (
                    (1, 2, 24, 24, 64, 1, 1, 1, 1, 1, 0, 7),
                    (1, 2, 24, 24, 16, 3, 3, 1, 1, 1, 1, 6),
                    (1, 2, 24, 24, 16, 3, 3, 1, 1, 1, 1, 6),
                    (1, 2, 24, 24, 16, 3, 3, 1, 1, 1, 1, 7),
                    (1, 2, 24, 16,  8, 5, 5, 1, 1, 1, 0, 3),
                    (1, 2, 24, 16, 16, 3, 3, 2, 2, 1, 0, 4),
                    (1, 4,  8, 16,  8, 5, 5, 2, 2, 1, 0, 2),
                    (1, 2, 16, 16,  8, 7, 4, 4, 2, 1, 0, 7),
                    (1, 4, 24, 16, 16, 3, 3, 1, 1, 4, 0, 7),
                    (1, 7,  8, 16, 16, 3, 3, 2, 2, 4, 0, 7),
                    (1, 1, 32,  1, 64, 1, 1, 1, 1, 2, 0, 7),
                    (1, 5, 40,  1, 16, 1, 1, 1, 1, 2, 0, 7),
                    (1, 1, 64,  1, 64, 1, 1, 1, 1, 3, 0, 7),
                    (1, 2, 32,  1, 64, 1, 1, 1, 1, 2, 0, 7),
                    (1, 2, 32,  1, 16, 3, 3, 1, 1, 2, 0, 7),
                    (1, 2, 24,  1, 16, 3, 3, 1, 1, 2, 0, 7),
                    (1, 4,  8,  1,  8, 5, 5, 2, 2, 2, 0, 2),
                    (1, 2, 32,  1, 16, 3, 3, 2, 2, 2, 0, 7),
                    (1, 2, 16,  1,  8, 7, 7, 4, 2, 2, 0, 7),
                    (1, 4, 32, 16, 16, 3, 3, 1, 1, 5, 0, 7),
                    (1, 4, 32, 16, 16, 3, 3, 1, 1, 4, 0, 7),
                    (1, 8,  8, 16, 16, 3, 3, 2, 2, 5, 0, 7),
                    #(1, 2, 32, 24, 48, 1, 1, 1, 2, 0, 0, 7),
                    #(1, 2, 32, 24, 48, 1, 1, 1, 2, 1, 0, 7),
                    #(1, 4, 32, 16, 16, 3, 3, 1, 1, 0, 0, 7),
                )


        self.granularity = (1, 1, 8, 8, 8, 1, 1)
        self.stride_efficiency = 0.5

        self.overhead = {'func': 270}


    def run(self, tc, args):
        if self.in_IOmode('tdm32', 'result16'):
            tc.set_expected_fail("Wrong results - likely testcase issue - needs to be debugged")
        self.direct_conv_test_method(tc, args)


RegressionTest = DirectConvTest

if __name__ == "__main__":
    exit(me.run_single_testcase(DirectConvTest))

