import argparse
from itertools import product
import os
import sys
import json
import shutil
import logging
from typing import List, Optional

CURRDIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append(CURRDIR)
sys.path.append(os.path.join(CURRDIR, '..'))
sys.path.append(os.path.join(CURRDIR, '..', '..', 'dmacompiler'))
sys.path.append(os.path.join(CURRDIR, '..', '..', 'kernels'))

from q_dq_tiler import run_tiler
from dmacompiler import (
    BackEnd,
    set_dev_gen, DevGen, config
)
from dataflow_common import (
    clean_overlay, build_sim_overlay, 
    sizeof, overlay_stack_addr, iceil
)
from q_dq_common import QDQDims, q_dq_preproc_directives
import q_dq_dataflow
set_dev_gen(DevGen.Aie2p)
config.ENABLE_BUSY_POLL = True
from OGOAT.src.L1_fusion.kernel_func_list import kernel_func_list

def build_q_dq(
    back_end: BackEnd,
    kernel_names: List[str],
    kernel_includes: List[str],
    aie_cols: int, aie_rows: int,
    h_in: int, w_in: int, c_in: int,
    ifm_bits: int,
    ofm_bits: int,
    fixed_point_bits: int,
    op_type: str,
    frontend_only: bool = False,
    enable_Cin_pad: bool = True,
    save_bins: bool=False
):
    '''Build function for Q/ DQ op'''
    
    c_in_pad = iceil(c_in, 8) if enable_Cin_pad else c_in

    dims = run_tiler(aie_rows, aie_cols, 1, h_in, w_in, c_in, c_in_pad, ifm_bits, ofm_bits, fixed_point_bits, 64, op_type)
    run_scheduler(dims, back_end, kernel_names, kernel_includes, frontend_only)


    if save_bins and back_end == BackEnd.TxnHostPatch:
        out_folder = f'qdq_{h_in}_{w_in}_{c_in}_{ifm_bits}_{ofm_bits}_{op_type}'
        if not os.path.exists(out_folder):
            os.makedirs(out_folder)
        in_folder = CURRDIR
        files = ('ifm.bin', 'wgt.bin', 'ofm.bin', 'dma.hpp',
                'tiling.json', 'txn.bin', 'param.bin', 'ctrl.bin', 'patch.json')
        for file in files:
            src = os.path.join(in_folder, file)
            dst = os.path.join(out_folder, file)
            shutil.move(src, dst)


def extract_fields(file_name):
    '''Function to extra JSON'''
    with open(file_name, 'r', encoding="utf-8") as f:
        data = json.load(f)
    return data


def run_scheduler(
        dims: QDQDims, 
        back_end: BackEnd, 
        kernel_names: List[str], 
        kernel_includes: List[str], 
        frontend_only: bool,
    ):
    clean_overlay()
    q_dq_dataflow.compile_dataflow(
        dims,
        back_end,
        kernel_names,
        kernel_includes
    )
    if not frontend_only:
        host_cpp = os.path.join(os.getcwd(), 'q_dq_main.cpp')
        build_sim_overlay(back_end, host_cpp, q_dq_preproc_directives(dims, back_end))


def run_q_dq_op(json_file, path, txn_mode, kernel_d, frontend_only):
    '''Build API exposed to WAIC'''
    os.chdir(path)
    enable_Cin_pad = True
    _data = extract_fields(json_file)
    data = {}
    data['back_end'] = BackEnd.Adf if txn_mode == 0 else BackEnd.TxnHostPatch
    data['op_type'] = _data['layer_info']['op_type']
    data['aie_rows'] = _data['overlay_info']['shape']['row']
    data['aie_cols'] = _data['overlay_info']['shape']['col']

    if data['op_type'] == "Quant_float32xuint16":
        data['ifm_bits'] = 16
        data['ofm_bits'] = 16
        op_type = "Quant"
    elif data['op_type'] == "Dequant_uint16xfloat32":
        data['ifm_bits'] = 16
        data['ofm_bits'] = 16
        op_type = "Dequant"
    elif data['op_type'] == "Quant_float32xuint8":
        data['ifm_bits'] = 16
        data['ofm_bits'] = 8
        op_type = "Quant"
    elif data['op_type'] == "Dequant_uint8xfloat32":
        data['ifm_bits'] = 8
        data['ofm_bits'] = 16
        op_type = "Dequant"
    else:
        op_type = "Both"
        data['ifm_bits'] = 16
        data['ofm_bits'] = 16

    if not kernel_d:
        data['kernel_names'] = {}
        if data['ifm_bits'] == 8 or data['ofm_bits'] == 8:
            data['kernel_names']['run_combined_qdq_a8'] = kernel_func_list.index('run_combined_qdq_a8')
        else:
            data['kernel_names']['run_combined_qdq'] = kernel_func_list.index('run_combined_qdq')
        data['kernel_includes'] = ['super.hh', 'qdq/wrapper_qdq.cc']
    else:
        data['kernel_names'] = kernel_d['kernel_list']
        data['kernel_includes'] = kernel_d['kernel_include']
    
    data['fixed_point_bits'] = 8 if data['ifm_bits'] == 8 or data['ofm_bits'] == 8 else 16
    in_act_shape = _data['layer_info']['in_act_shape']
    len_in_act_shape = len(in_act_shape)

    data['c_in'] = in_act_shape[-1] if len_in_act_shape >= 1 else 1
    data['w_in'] = in_act_shape[-2] if len_in_act_shape >= 2 else 1
    data['h_in'] = in_act_shape[-3] if len_in_act_shape >= 3 else 1


    if len_in_act_shape == 0:
        raise NotImplementedError(f"Empty Q_DQ input shape.")

    output_dir = os.path.dirname(os.path.realpath(json_file)) if \
        data['back_end'] != BackEnd.Adf else None
    logging.info(f" Q_DQ input args: {data}")
    build_q_dq(data['back_end'],
               data['kernel_names'], data['kernel_includes'],
               data['aie_cols'], data['aie_rows'],
               data['h_in'], data['w_in'], data['c_in'],
               data['ifm_bits'],
               data['ofm_bits'],
               data['fixed_point_bits'],
               op_type,
               frontend_only,
               enable_Cin_pad)

def parse_args():
    parser = argparse.ArgumentParser(description="Build QDQ for some predefined shapes")
    parser.add_argument(
        "--backend", type=int, default=0,
        help="Backend type (default: 0 for Adf)"
    )
    parser.add_argument(
        "--qdq_mode", type=int,
        help="QDQ mode"
    )
    parser.add_argument(
        "--shape_index", type=int,
        help="Index of the shape from the input set to run"
    )
    parser.add_argument(
        "--save_bins", type=bool, default=False,
        help="Save generated bin files and dma.hpp"
    )
    parser.add_argument(
        "--dtype", type=int,
        help="Size of the fixed Point Datatype involved in the operation"
    )
    return parser.parse_args()


def main():
    '''Standalone flow for build function'''
    args = parse_args()
    back_end = BackEnd(args.backend)
    qdq_mode = args.qdq_mode
    save_bins = args.save_bins
    dtype = args.dtype

    input_shapes = {
        0: [1, 1, 64], 
        1: [1, 64, 1024], 
        2: [1, 64, 32],  
        3: [1, 64, 64], 
        4: [1, 1, 1024], 
        5: [1, 1, 32128], 
        6: [1, 1, 6], 
        7: [32, 32, 5120], 
        8: [32, 32, 2560], 
        9: [1, 32, 5120], 
        10: [1, 256, 160],
        11: [1, 256, 640], 
        12: [1, 1, 1280], 
        14: [1, 64, 3072],
        15: [5, 64, 64],
        16: [5, 63, 64],
        17: [64, 140, 64],
        18: [1024, 1024, 512]
    }

    possible_op_types = ["Quant", "Dequant", "Both"]

    
    aie_cols, aie_rows = 8, 4


    

    frontend_only = False
    enable_Cin_pad = True


    if args.shape_index is None or args.qdq_mode is None:
        shapes = product(input_shapes.values(), possible_op_types)
    else:
        shapes =  [input_shapes[args.shape_index]]
        if qdq_mode == 0:
            qdq_op_type = "Dequant"
        elif qdq_mode == 1:
            qdq_op_type = "Quant"
        elif qdq_mode == 2:
            qdq_op_type = "Both"
        else:
            exit("This is a no-op")
        shapes = [(input_shapes[args.shape_index], qdq_op_type)]

    if dtype == 8:
        if qdq_mode == 0:
            input_bits = 8
            output_bits = 16
        elif qdq_mode == 1:
            input_bits = 16
            output_bits = 8
        else:
            input_bits = 8
            output_bits = 8
    elif dtype == 16:
        input_bits = 16
        output_bits = 16
    else:
        exit("The current dataflow does not expect such a dtype")

    
    if dtype == 8:
        kernel_names = {
            'run_combined_qdq_a8' : kernel_func_list.index("run_combined_qdq_a8"),
        }
    else:
        kernel_names = {
            'run_combined_qdq' : kernel_func_list.index("run_combined_qdq"),
        }

        
    kernel_includes = ['super.hh', 'qdq/wrapper_qdq.cc']



    for input_shape, qdq_op_type in shapes:
        build_q_dq(
            back_end,
            kernel_names,
            kernel_includes,
            aie_cols, aie_rows,
            input_shape[0], input_shape[1], input_shape[2],
            input_bits,
            output_bits,
            dtype,
            qdq_op_type,
            frontend_only,
            enable_Cin_pad,
            save_bins
        )


if __name__ == '__main__':
    main()
