/*  (c) Copyright 2019 - 2021 Xilinx, Inc. All rights reserved.

    This file contains confidential and proprietary information
    of Xilinx, Inc. and is protected under U.S. and
    international copyright and other intellectual property
    laws.

    DISCLAIMER
    This disclaimer is not a license and does not grant any
    rights to the materials distributed herewith. Except as
    otherwise provided in a valid license issued to you by
    Xilinx, and to the maximum extent permitted by applicable
    law: (1) THESE MATERIALS ARE MADE AVAILABLE "AS IS" AND
    WITH ALL FAULTS, AND XILINX HEREBY DISCLAIMS ALL WARRANTIES
    AND CONDITIONS, EXPRESS, IMPLIED, OR STATUTORY, INCLUDING
    BUT NOT LIMITED TO WARRANTIES OF MERCHANTABILITY, NON-
    INFRINGEMENT, OR FITNESS FOR ANY PARTICULAR PURPOSE; and
    (2) Xilinx shall not be liable (whether in contract or tort,
    including negligence, or under any other theory of
    liability) for any loss or damage of any kind or nature
    related to, arising under or in connection with these
    materials, including for any direct, or any indirect,
    special, incidental, or consequential loss or damage
    (including loss of data, profits, goodwill, or any type of
    loss or damage suffered as a result of any action brought
    by a third party) even if such damage or loss was
    reasonably foreseeable or Xilinx had been advised of the
    possibility of the same.

    CRITICAL APPLICATIONS
    Xilinx products are not designed or intended to be fail-
    safe, or for use in any application requiring fail-safe
    performance, such as life-support or safety devices or
    systems, Class III medical devices, nuclear facilities,
    applications related to the deployment of airbags, or any
    other applications that could lead to death, personal
    injury, or severe property or environmental damage
    (individually and collectively, "Critical
    Applications"). Customer assumes the sole risk and
    liability of any use of Xilinx products in Critical
    Applications, subject only to applicable laws and
    regulations governing limitations on product liability.

    THIS COPYRIGHT NOTICE AND DISCLAIMER MUST BE RETAINED AS
    PART OF THIS FILE AT ALL TIMES.                       */

#ifndef __MMULT_QDQ_BLOCKED_INT16X8_IMPL_HPP__
#define __MMULT_QDQ_BLOCKED_INT16X8_IMPL_HPP__

#include <stdint.h>
#include "aie_api/aie.hpp"
#include "access_helpers.hpp"
#include "kernel_helpers.h"
#include "ml_params.h"
// #include "mmult_qdq_blocked_int16x2.hpp"


__aie_inline void direct_conv_int16x8_generic_gemm
(
        int16 * input,
        int8  * weights,
        int32 * tdm1,
        int32 * tdm2,
        int shift_tdm,
        bool zero_init,
        const GemmInt16x2Blocked &params,
        dims_3d_param dimsAO,
        dims_3d_param dimsW,
        int incT_0,
        int incT_1
) {
    using mm_t = aie::mmul<4,8,8,int16,int8>;
    using acc_t = mm_t::accum_type;
    constexpr unsigned Va = 32;

    int16 * pA = ( int16 * ) input;
    int8  * pW = weights;

    int32 * restrict pOut1 = chess_copy( tdm1 );
    int32 * restrict pOut2 = chess_copy( tdm2 );

    fifo_state_t fA;
    fA.pos = 0;

    dims_2d_t dimsAI = params.dimsA.instantiate( );
    dims_3d_t dimsAO_i = dimsAO.instantiate( );
    dims_3d_t dimsW_i  = dimsW.instantiate( );

    // GemmInt16x2Blocked::Control ctrl = params.ctrl;

    int incTi = incT_0;
    int incT_flag = 0;
    
    for (unsigned j=0; j<params.outer_g; j++)
        chess_prepare_for_pipelining
        chess_loop_range(3,)
    {
        int zero_acc = 1;
        mm_t m[2];

        int incTo = chess_copy( incTi );
        incT_flag ^= 1;
        incTi = incT_flag ? incT_1 : incT_0;

        int bound = params.inner_g;

        aie::vector<int16,32> x0 = aie::load_v<32>( pA );
        aie::vector<int16,32> x1 = aie::load_v<32>( pA+32 ); pA=add_2d_byte(pA, dimsAI);
        aie::vector<int8, 64> y = aie::load_v<64>( pW );
        pW = add_3d_byte( pW, dimsW_i );
        m[0] = acc_t( mul_4x8_8x8_conf( x0, params.sign_A, y, 1, 0));
        m[1] = acc_t( mul_4x8_8x8_conf( x1, params.sign_A, y, 1, 0));

        for (int i=1; i<bound; i++)
            chess_prepare_for_pipelining
            //chess_peel_pipelined_loop(2)
            //chess_pipeline_adjust_preamble(-3)
            chess_loop_range(7,)
        {
            aie::vector<int16,32> x0 = aie::load_v<32>( pA );
            aie::vector<int16,32> x1 = aie::load_v<32>( pA+32 ); pA=add_2d_byte(pA, dimsAI);

            aie::vector<int8, 64> y = aie::load_v<64>( pW );
            pW = add_3d_byte( pW, dimsW_i );

            m[0] = acc_t( mac_4x8_8x8_conf( x0, params.sign_A, y, 1, m[0].to_accum( ), 0, 0, 0, 0 ));
            m[1] = acc_t( mac_4x8_8x8_conf( x1, params.sign_A, y, 1, m[1].to_accum( ), 0, 0, 0, 0 ));
        }
        pA = add_3d_byte( pA, dimsAO_i );
        aie::store_v( pOut1 + Va / 2, m[0].to_accum( ).extract<Va/2>( 1 ).to_vector<int32>( shift_tdm ));
        aie::store_v( pOut1,          m[0].to_accum( ).extract<Va/2>( 0 ).to_vector<int32>( shift_tdm ));
        aie::store_v( pOut2 + Va / 2, m[1].to_accum( ).extract<Va/2>( 1 ).to_vector<int32>( shift_tdm ));
        aie::store_v( pOut2,          m[1].to_accum( ).extract<Va/2>( 0 ).to_vector<int32>( shift_tdm ));
        pOut1 = byte_incr( pOut1, incTo );
        pOut2 = byte_incr( pOut2, incTo );
    }
}


template<bool has_sum, bool has_vector_coeffs, typename To> __aie_inline
void qdq( int32_t * tdm1, int32_t * tdm2, int32_t * ifm_sum, int64_t * coeff, int32_t c1, int32_t c2, int8_t shift_res, To * __restrict ofm, GemmInt16x2::QDQParams &params )
{
    dims_3d_t dims_in1;
    int32_t step0 =  params.split_mode ? params.dims_in1_step * 32 : ( long )tdm2 - ( long )tdm1;
    int32_t step1 = !params.split_mode ? params.dims_in1_step * 32 : ( long )tdm2 - ( long )tdm1;
    //Change to reference custom change here compared to reference qdq
    AddByte<dims_3d_t> add_3d_in1( dims_3d_from_steps( params.dims_in1_wrap0, step0, params.dims_in1_wrap1, 128, 128 ));
    AddByte<dims_2d_t> add_2d_sum( params.dims_sum.instantiate( ));
    AddByte<dims_2d_t> add_2d_cf0( params.dims_qnt.instantiate( ));
    AddByte<dims_2d_t> add_2d_cf1( params.dims_qnt.instantiate( ));
    //Change to reference custom change here compared to reference qdq
    AddByte<dims_3d_t> add_3d_out( params.dims_out.instantiate( ));
    int bypass_value = params.loop;
    if ( params.vector_coeffs <= 0 )
        add_2d_cf1.num0 = bypass_value;

    auto pI = aie::begin_vector<32, aie_dm_resource::a>( tdm1 );
    auto pS = aie::begin_vector<4, aie_dm_resource::b>( ifm_sum );
    int64_t __aie_dm_resource_b * pC0 = (int64_t __aie_dm_resource_b *) coeff;
    auto pO = aie::begin_restrict_vector<32, aie_dm_resource::c>( ofm );

    alignas( 32 ) int32_t cf_spill[16];
    aie::store_v( cf_spill, aie::broadcast<int32_t,8>( c1 ));
    aie::store_v( cf_spill + 8 * has_sum, aie::broadcast<int32_t,8>( c2 ));
    int32_t * pC1 = params.vector_coeffs > 0 ? ( int32_t * )( pC0 + 8 ) : cf_spill;

    [[ using chess: prepare_for_pipelining, min_loop_count( 4 )]]
    for ( unsigned o=0; o<params.loop; o++ )
    {
        v32acc64 a = set_v32acc64( 0, *( v8acc64* )pC0 );
        a = insert( a, 1, *( v8acc64* )pC0 );
        aie::accum<acc64, 32> acc = insert( a, 1, extract_v16acc64( a, 0 ));
        
        if constexpr( has_vector_coeffs ) {
            aie::vector c2v = aie::load_v<8, aie_dm_resource::b>( pC1 + 8 * has_sum );
            if constexpr( has_sum ) {
                pC1 = chess_copy( pC1 );
                acc = mac_outer_prod( acc, *pS, aie::load_v<8, aie_dm_resource::b>( pC1 ));
            }
            acc = aie::mac( acc, *pI, c2v.grow_replicate<32>( ));
        } else {
            if constexpr( has_sum ) {
                acc = mac_outer_prod( acc, *pS, aie::broadcast<int32_t, 8>( c1 ));
            }

            acc = aie::mac( acc, *pI, c2 );
        }

        *pO = acc.template to_vector_sign<To>( params.sign_out, shift_res );

        pI = add_3d_in1( pI );
        pS = add_2d_sum( pS );
        pC0 = add_2d_cf0( pC0 );
        pC1 = add_2d_cf1( pC1 );
        pO = add_3d_out( pO );
    }
}

template<typename To> __aie_inline
void qdq_tdm64( int64_t * tdm1, int64_t * tdm2, int64_t * coeff, int8_t shift_res, To * __restrict ofm, GemmInt16x2::QDQParams &params )
{
    dims_3d_t dims_in1;
    int32_t step0 =  params.split_mode ? params.dims_in1_step * 32 : ( long )tdm2 - ( long )tdm1;
    int32_t step1 = !params.split_mode ? params.dims_in1_step * 32 : ( long )tdm2 - ( long )tdm1;

    dims_3d_t tdm_steps = dims_3d_from_steps( params.dims_in1_wrap0, ( long )tdm2 - ( long )tdm1, params.dims_in1_wrap1, 256, 256 );
    AddByte<dims_2d_t> add_2d_cf0( params.dims_qnt.instantiate( ));
    AddByte<dims_3d_t> add_3d_out( params.dims_out.instantiate( ));

    int64_t __aie_dm_resource_ab * pI = (int64_t __aie_dm_resource_ab *) tdm1;
    int64_t __aie_dm_resource_b * pC0 = (int64_t __aie_dm_resource_b *) coeff;
    auto pO = aie::begin_restrict_vector<32, aie_dm_resource::a>( ofm );

    [[ using chess: prepare_for_pipelining, min_loop_count( 4 )]]
    for ( unsigned o=0; o<params.loop; o++ )
    {
        v32acc64 a = set_v32acc64( 0, *( v8acc64* )pC0 );
        a = insert( a, 1, *( v8acc64* )pC0 );
        aie::accum<acc64, 32> acc = insert( a, 1, extract_v16acc64( a, 0 ));

        acc = aie::add(acc, aie::accum<acc64,32>(*((v32acc64 __aie_dm_resource_ab*)pI))); pI = add_3d_byte( pI, tdm_steps );
        *pO = acc.template to_vector_sign<To>( params.sign_out, shift_res );

        pC0 = add_2d_cf0( pC0 );
        pO = add_3d_out( pO );
    }
}


template<bool has_vector_coeffs, typename tdm_type> __aie_inline
void blocked_c2k_qdq(
        int32 * tdm1,
        int32 * tdm2,
        bool zero_init,
        int32_t* c2v_ptr,
        int32_t c2,
        int8_t shift_res,
        tdm_type * __restrict tdm1s,
        tdm_type * __restrict tdm2s,
        const GemmInt16x2Blocked &params
){
    tdm_type __aie_dm_resource_a * restrict pO0 = (tdm_type __aie_dm_resource_a* restrict)(tdm1s);
    tdm_type __aie_dm_resource_b * restrict pO1 = (tdm_type __aie_dm_resource_b* restrict)(tdm2s);
    tdm_type __aie_dm_resource_a * pTDM0 = (tdm_type __aie_dm_resource_a*) tdm1s;
    tdm_type __aie_dm_resource_b * pTDM1 = (tdm_type __aie_dm_resource_b*) tdm2s;
    int32 __aie_dm_resource_a* pI0 = (int32 __aie_dm_resource_a*) tdm1;
    int32 __aie_dm_resource_b* pI1 = (int32 __aie_dm_resource_b*) tdm2;
    int32* c2_ptr= c2v_ptr;
    uint16_t tdm_step;
    if constexpr(std::is_same_v<tdm_type,int64>)
        tdm_step=2*128;
    else
        tdm_step=128;
    int32_t step0 =  ( long )tdm2 - ( long )tdm1;
    int32_t step1 =  ( long )tdm2s - ( long )tdm1s;
    dims_2d_t inp_addr = dims_2d_t(1, step0, -step0+128);
    dims_2d_t tdmo_addr = dims_2d_t(1, step1, -step1+tdm_step);
    dims_2d_t tdmi_addr = dims_2d_t(1, step1, -step1+tdm_step);
    dims_2d_t c2_addr = dims_2d_t(params.sgemm_c2_wrap, 0, params.sgemm_c2_step);

    [[ using chess: prepare_for_pipelining, min_loop_count( 8 )]]
    for ( unsigned o=0; o<params.loop_blocked; o++ )
    {
        aie::accum<acc64,32> acc,acc0,acc1;
        aie::vector<int32,32> loaded_tdm;
        loaded_tdm.insert( 0, aie::load_v<16>(pI0));
        loaded_tdm.insert( 1, aie::load_v<16>(pI1));
        if constexpr( has_vector_coeffs ) {
            aie::vector c2v = aie::load_v<8, aie_dm_resource::b>( c2_ptr ); c2_ptr=add_2d_byte(c2_ptr,c2_addr);
            acc = aie::mul( loaded_tdm, c2v.grow_replicate<32>( ));
        } else {
            acc = aie::mul( loaded_tdm, c2 );
        }
        if constexpr(std::is_same_v<tdm_type,int64>){
            
            aie::accum<acc64, 16>  add_sum_l(*((v16acc64 __aie_dm_resource_a*)pTDM0));
            aie::accum<acc64, 16>  add_sum_h(*((v16acc64 __aie_dm_resource_b*)pTDM1));
            aie::accum<acc64,32>  add_sum = aie::concat(add_sum_l,add_sum_h);
            aie::accum<acc64, 32>  scaled_gemm_tdm = add_conf((v32acc64)add_sum,(v32acc64)acc,zero_init,0,0,0);
            *((v8acc64 __aie_dm_resource_a*) pO0)= (v8acc64) scaled_gemm_tdm.extract<8>(0);
            *((v8acc64 __aie_dm_resource_a*) pO0+1)= (v8acc64) scaled_gemm_tdm.extract<8>(1);
            *((v8acc64 __aie_dm_resource_b*) pO1)= (v8acc64) scaled_gemm_tdm.extract<8>(2);
            *((v8acc64 __aie_dm_resource_b*) pO1+1)= (v8acc64) scaled_gemm_tdm.extract<8>(3);
        }else{
            aie::accum<acc64, 16>  add_sum_l = aie::accum<acc64, 16>(aie::load_v<16>(pTDM0));
            aie::accum<acc64, 16>  add_sum_h = aie::accum<acc64, 16>(aie::load_v<16>(pTDM1));
            aie::accum<acc64,32> add_sum = aie::concat(add_sum_l,add_sum_h);
            aie::accum<acc64, 32> scaled_gemm( acc.template to_vector_sign<int32>( 1, shift_res ) );
            aie::accum<acc64, 32> scaled_gemm_tdm = add_conf((v32acc64)add_sum,(v32acc64)scaled_gemm,zero_init,0,0,0);
            aie::accum<acc64, 16> scaled_gemm_tdm_l = scaled_gemm_tdm.extract<16>(0);
            aie::accum<acc64, 16> scaled_gemm_tdm_h = scaled_gemm_tdm.extract<16>(1);
            *((v16int32 __aie_dm_resource_a*) pO0)=scaled_gemm_tdm_l.template to_vector_sign<int32>( 1, 0 );
            *((v16int32 __aie_dm_resource_b*) pO1)=scaled_gemm_tdm_h.template to_vector_sign<int32>( 1, 0 );
        }

        pI0+=16;
        pI1+=16;
        pTDM0+=16;
        pTDM1+=16;
        pO0+=16;
        pO1+=16;
    }
}


//20 cycles schedule, 512/20 = 25.6 elem/cycle
template<unsigned loop_range,unsigned use_shfl>
void zp_sub_vsub_zp_int2
(
    int8 * wts,
    int8 * zp,
    int8 * __restrict ofm,
    const GemmInt16x2::UnpackInt2x8Params &params
) {
    int4 __aie_dm_resource_a* in_ptr= (int4 __aie_dm_resource_a * ) wts;
    v4int4 __aie_dm_resource_a* zp_ptr= (v4int4 __aie_dm_resource_a * ) zp;
    v64int8 __aie_dm_resource_a* __restrict ofm_ptr= (v64int8 __aie_dm_resource_a* __restrict) ofm;
    dims_3d_t dimsZ = params.dimsZ.instantiate( );
    //increase reuse of zero_point_w unroll by 2 for 256 elem/iter
    for(int i = 0; i<params.inner_loop; i++)
    chess_prepare_for_pipelining
    chess_loop_range(loop_range,)
    {
        //broadcast 1x8 1xN to KxN to sub from weights
        //here as int2
        v4int4 z = *zp_ptr; zp_ptr = add_3d_byte( zp_ptr, dimsZ );
        v128int4 zp = broadcast_to_v128int4(z);
        v64int4 zp_64 = extract_v64int4(zp,0);
        v64int8 zp_64_8 = unpack(zp_64);
        v64acc32 zp_c = sups_conf( zp_64_8, 30, 0);
        v64int8 zp_64_8_l = ssrs_conf(zp_c, 30, 0, 0);
        zp_c = sups_conf( zp_64_8, 28, 0);
        v64int8 zp_64_8_h = ssrs_conf(zp_c, 30, 0, 0);
        v64int8 zp_s = shuffle(zp_64_8_l,zp_64_8_h,T8_2x64_lo);
        v64int8 out0, out1;
        //increase reuse of zero_point_w unroll by 2 for 256 elem/iter
        #pragma unroll
        for ( unsigned l = 0; l < 4; l++ ) {
            //convert int2 weights to int8
            v64int8 b = unpack(  aie::load_v<64>( in_ptr) ); in_ptr+=32;
            v64acc32 c = sups_conf( b, 30, 0);
            out0 = ssrs_conf(c, 30, 0, 0);
            c = sups_conf( b, 28, 0);
            out1 = ssrs_conf( c, 30, 0, 0);
            //for easy testing without making a special weight data order
            if constexpr(use_shfl){            
                v64int8 chess_storage(x0) h0 = shuffle(out0,out1,T8_2x64_lo);
                v64int8 chess_storage(x1) h1 = shuffle(out0,out1,T8_2x64_hi);
                out0=h0;
                out1=h1;
            }
            //subtract loaded zero point
            v64int8 subed_out = sub(out0, zp_s);
            *ofm_ptr++ = subed_out;
            subed_out = sub(out1, zp_s);
            *ofm_ptr++ = subed_out;
        }
    }
}


//20 cycles schedule, 512/20 = 25.6 elem/cycle
template<unsigned loop_range=2,unsigned use_shfl=0>
void zp_sub_vsub_zp_uint2
(
    uint8 * wts,
    uint8 * zp,
    uint8 * __restrict ofm,
    const GemmInt16x2::UnpackInt2x8Params &params
) {
    uint4 __aie_dm_resource_a* in_ptr= (uint4 __aie_dm_resource_a * ) wts;
    v4uint4 __aie_dm_resource_a* zp_ptr= (v4uint4 __aie_dm_resource_a * ) zp;
    v64uint8 __aie_dm_resource_a* __restrict ofm_ptr= (v64uint8 __aie_dm_resource_a* __restrict) ofm;
    dims_3d_t dimsZ = params.dimsZ.instantiate( );
    //increase reuse of zero_point_w unroll by 2 for 256 elem/iter
    for(int i = 0; i<params.inner_loop; i++)
    chess_prepare_for_pipelining
    chess_loop_range(loop_range,)
    {
        //broadcast 1x8 1xN to KxN to sub from weights
        //here as int2
        v4uint4 z = *zp_ptr; zp_ptr = add_3d_byte( zp_ptr, dimsZ );
        v128uint4 zp = broadcast_to_v128uint4(z);
        v64uint4 zp_64 = extract_v64uint4(zp,0);
        v64uint8 zp_64_8 = unpack(zp_64);
        //v64acc32 zp_c = sups_conf( zp_64_8, 30, 0, 0);
        v64uint8 zp_64_8_l = band( zp_64_8, broadcast_to_v64uint8( 3 ));
        v64acc32 zp_c = sups_conf( zp_64_8, 0, 0, 0);
        v64uint8 zp_64_8_h = ussrs_conf(zp_c, 2, 0, 0, 0);
        v64uint8 zp_s = shuffle(zp_64_8_l,zp_64_8_h,T8_2x64_lo);
        v64uint8 out0, out1;
        //increase reuse of zero_point_w unroll by 2 for 256 elem/iter
        #pragma unroll
        for ( unsigned l = 0; l < 4; l++ ) {
            //convert int2 weights to int8
            v64uint8 b = unpack(  aie::load_v<64>( in_ptr) ); in_ptr+=32;
            //v64acc32 c = sups_conf( b, 30, 0, 0);
            out0 = band( b, broadcast_to_v64uint8( 3 ));
            v64acc32 c = sups_conf( b, 0, 0, 0);
            out1 = ussrs_conf( c, 2, 0, 0);
            //for easy testing without making a special weight data order
            if constexpr(use_shfl){            
                v64uint8 chess_storage(x0) h0 = shuffle(out0,out1,T8_2x64_lo);
                v64uint8 chess_storage(x1) h1 = shuffle(out0,out1,T8_2x64_hi);
                out0=h0;
                out1=h1;
            }
            //subtract loaded zero point
            v64uint8 subed_out = sub(out0, zp_s);
            *ofm_ptr++ = subed_out;
            subed_out = sub(out1, zp_s);
            *ofm_ptr++ = subed_out;
        }
    }
}


template<bool scaled_mmult_tdm64, bool use_shfl, unsigned has_unpack>
void mmult_qdq_blocked_int16x2
(
        int16_t * input,
        int8_t * weights,
        int8_t * weight_unpack,
        int32_t * tdm1,
        int32_t * tdm2,
        int32_t * tdm1s,
        int32_t * tdm2s,
        int64_t * coeffs,
        int8_t * zp,
        int32_t * c2k,
        int16_t * restrict output,
        bool zero_init,
        bool final_tdm_iter,
        const GemmInt16x2Blocked &params,
        const GemmInt16x2_QDQ_Params &qdq_params
) {
    GemmInt16x2::UnpackInt2x8Params unpack_params;
    unpack_params.inner_loop = params.inner_loop;
    unpack_params.dimsZ = params.dimsZ;
    //convert all weights from int2 to  int8
    if constexpr(has_unpack==1)
        zp_sub_vsub_zp_int2<2,use_shfl>(weights, (int8*) zp, weight_unpack, unpack_params);
    else if constexpr(has_unpack==2)
        zp_sub_vsub_zp_uint2<2,use_shfl>((uint8*) weights, (uint8*) zp, (uint8*)weight_unpack, unpack_params);

    //loop over blocks of K/64, perform gemm and scale with c2k[k_block:N]
    for(int k=0; k<params.block_g; k++)
    chess_loop_range(1,)
    {
        //tdm buffer management parameters for gemm
        int delta_tdm = (long)tdm2 - (long)tdm1;
        int incT_0 = 128;
        int incT_1 = 128;

        direct_conv_int16x8_generic_gemm( input + (k*params.blocked_A_offset), weight_unpack + (k*params.blocked_B_offset), tdm1, tdm2, qdq_params.shift_tdm, zero_init,  params, params.dimsAO, params.dimsW, incT_0, incT_1 );
        
        //scale gemm output with c2k[k_block:N] and add on running scaled_gemm, can be either written out in int32 or int64
        if constexpr (scaled_mmult_tdm64)
            blocked_c2k_qdq<true,int64>(tdm1, tdm2, zero_init, (int32_t*) c2k+ (k*params.blocked_c2k_offset), qdq_params.inner_c2, qdq_params.shift_sgemm, (int64_t*) tdm1s, (int64_t*) tdm2s, params);
        else
            blocked_c2k_qdq<true,int32>(tdm1, tdm2, zero_init, (int32_t*) c2k+ (k*params.blocked_c2k_offset), qdq_params.inner_c2, qdq_params.shift_sgemm, (int32_t*) tdm1s, (int32_t*) tdm2s, params);
        zero_init=0;
    }
    //perform final qdq result = sw_scaled_gemm*c2_outer + c0
    //c2_outer is a scalar here as scale_a/scale_o is a scalar
    //for tdm64 c2_outer is not used otherwise that would be int64*int32 mul
    if ( final_tdm_iter ) {
        GemmInt16x2::QDQParams prm_qdq  = params.qdq_param;
        if constexpr (scaled_mmult_tdm64)
            qdq_tdm64( (int64_t*) tdm1s,(int64_t*) tdm2s, coeffs, qdq_params.shift_res, output, prm_qdq );
        else
            qdq<false, false>( tdm1s, tdm2s, (int32_t*) output, coeffs, 0, qdq_params.outer_c2, qdq_params.shift_res, output, prm_qdq );
    }
}
#endif // __MMULT_QDQ_BLOCKED_INT16X8_IMPL_HPP__