import os
import sys
import json
import shutil
import logging
from typing import List, Optional
import numpy as np
import json

CURRDIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append(CURRDIR)
sys.path.append(os.path.join(CURRDIR, '..'))
sys.path.append(os.path.join(CURRDIR, '..', '..', 'dmacompiler'))

from dataflow_common import clean_overlay, build_sim_overlay, iceil
from dmacompiler import (
    BackEnd,
    set_dev_gen, DevGen, config
)
set_dev_gen(DevGen.Aie2p)
config.ENABLE_BUSY_POLL = True

from bilinear_resize_common import (
    bilinear_resize_input_subvol_dims,
    bilinear_resize_wgt_subvol_dims,
    bilinear_resize_preproc_directives,
)

from kernels.bilinear_pixel_resize_bf16.bilinear_pixel_resize_bf16_kernel_params import (
    CoordinateTransfromationMode,
    genereate_bilinear_resize_kernel_params,
    mode_map,
    BilinearResizeShape,
)

from bilinear_resize_tiler import (
    bilinear_resize_tiler,
)

from bilinear_resize_scheduler import (
    gen_bilinear_resize_schedule,
)

from dmacompiler import (
    BackEnd,
)


def extract_fields(file_name):
	with open(file_name, 'r') as f:
		data = json.load(f)
	return data


def write_tiler_output_json(tiled_soln, output_path):
    """
    Write the tiled solution to a JSON file with fields similar to tiling.json.
    """
    # Prepare the dictionary in the required format
    output_dict = {
        "parameters": [
            {
                "subvolume": {
                    "Ho": getattr(tiled_soln, "Yos", None),
                    "Wo": getattr(tiled_soln, "Xos", None),
                    "C": getattr(tiled_soln, "Cos", None)
                },
                "tensor": {
                    "Ho": getattr(tiled_soln, "Yo", None),
                    "Wo": getattr(tiled_soln, "Xo", None),
                    "C": getattr(tiled_soln, "Co", None),
                    "Hi": getattr(tiled_soln, "Yi", None),
                    "Wi": getattr(tiled_soln, "Xi", None)
                },
                "coordinate_transformation_mode": getattr(tiled_soln, "mode", None).name.lower() \
                    if hasattr(getattr(tiled_soln, "mode", None), "name") \
                    else str(getattr(tiled_soln, "mode", "")).lower(),
                "H_outer": 0  # or set as needed
            }
        ]
    }
    with open(output_path, "w") as f:
        json.dump(output_dict, f, indent=2)


def update_json_file(json_file, dims: BilinearResizeShape):
    '''
    Read the json file, add new fields from the tiled solution
    '''
    with open(json_file, 'r') as f:
        data = json.load(f)
    # Add "core_tile_params" field and subvols field
    data["core_tile_params"] = {
        "subvols": {
            "ifm": {
                 "N": getattr(dims, "N", None),
                 "Yis": getattr(dims, "Yis", None),
                 "Xis": getattr(dims, "Xis", None),
                 "Cis": getattr(dims, "Cis", None),
            },
            "wgt": {
                "wgt_subvol_dims": getattr(dims, "wgt_subvol_dims", None),
                "Yis_step": getattr(dims, "Yis_step", None),
                "Xis_step": getattr(dims, "Xis_step", None),
                "Yis_offset": getattr(dims, "Yis_offset", None),
                "Xis_offset": getattr(dims, "Xis_offset", None),
            },
            "ofm": {
                "N": getattr(dims, "N", None),
                "Yos": getattr(dims, "Yos", None),
                "Xos": getattr(dims, "Xos", None),
                "Cos": getattr(dims, "Cos", None),
            }
        }
    }
    data["dram_params"] = {
        "shapes": {
            "ifm": {
                "N": getattr(dims, "N", None),
                "Yi": getattr(dims, "Yi", None),
                "Xi": getattr(dims, "Xi", None),
                "Ci": getattr(dims, "Ci", None),
            },
            "wgt": {
               "Y_iters": getattr(dims, "Y_split", None) * getattr(dims, "Y_loop", None),
               "X_iters": getattr(dims, "X_split", None) * getattr(dims, "X_loop", None),
               "wgt_subvol_dims": getattr(dims, "wgt_subvol_dims", None),
            },
            "ofm": {
                "N": getattr(dims, "N", None),
                "Yo": getattr(dims, "Yo", None),
                "Xo": getattr(dims, "Xo", None),
                "Co": getattr(dims, "Co", None),
            }
        }
    }
    # write the updated data back to the json file
    with open(json_file, 'w') as f:
        json.dump(data, f, indent=2)


def build_bilinear_resize(
    back_end: BackEnd,
    kernel_names: dict,
    kernel_include: List[str],
    aie_cols: int,
    aie_rows: int,
    input_shape: List[int],
    output_shape: List[int],
    act_bytes: int = 2,
    dq_enable: bool = True,
    q_enable: bool = True,
    coordinate_transformation_mode: str = 'align_corners',
    frontend_only: bool = False,
    output_dir: Optional[str] = None,
):
    """
    Build the bilinear resize kernel for the given backend and parameters.
    """
    N, Yi, Xi, Ci = input_shape
    N, Yo, Xo, Co = output_shape
    mode = mode_map.get(coordinate_transformation_mode, CoordinateTransfromationMode.NONE)

    assert (aie_cols, aie_rows) == (8, 4), "Only 8x4 tiling is supported"
    op_name = "bilinear_pixel_resize_bf16"
    tiler_output = os.path.join(output_dir, 'tiling.json')
    # Align Co: minimum 64, and if above 64, must be multiple of 8
    aligned_Co = max(iceil(Co, 8), 64)
    aligned_Ci = max(iceil(Ci, 8), 64)
    ShapeDims = BilinearResizeShape(
        mode=mode,
        N=N,
        Yo=Yo,
        Xo=Xo,
        Co=aligned_Co,
        Yi=Yi,
        Xi=Xi,
        Ci=aligned_Ci,
        act_bits=act_bytes*8,
        dq_enable=dq_enable,
        q_enable=q_enable,
        aie_cols=aie_cols,
        aie_rows=aie_rows,
    ) 
    # Call tiler, the output list is a sorted solution pick the top one
    clean_overlay()
    tiled_soln = bilinear_resize_tiler(ShapeDims=ShapeDims, verbose=False)[0]
    print(f"tiled_soln: {tiled_soln}")
    '''
    NOTE: The following function is used to generate the test data and
    the kernel params metadata. For this we are leveraging the infrastructure
    from the https://gitenterprise.xilinx.com/smunz/kernel_lib_example.git repo.
    Refer to kernels/bilinear_pixel_resize_bf16/*.json and
    kernels/python/gen_data.py for more details.
    ''' 
    write_tiler_output_json(tiled_soln, tiler_output)
    # Generate the DMA transfers
    dims = gen_bilinear_resize_schedule(
        dims=tiled_soln, back_end=back_end,
        kernel_names=kernel_names, kernel_include=kernel_include, verbose=True,
    )
    # Compile and simulate
    if not frontend_only:
        host_cpp = 'bilinear_resize_main.cpp'
        build_sim_overlay(back_end, host_cpp, bilinear_resize_preproc_directives(dims, back_end), )
    update_json_file(tiler_output, dims)


def run_bilinear_resize_op(json_file: str, path: str, txn_mode: int, kernel_d, frontend_only: bool = False):
    """
    Run the bilinear resize operation.
    """
    os.chdir(path)   #because build system only work on current dir (subprocess)
    _data = extract_fields(json_file)
    mode = _data['layer_info']['attributes']['mode'][0]
    assert mode == 'linear', "Only linear mode is supported for bilinear resize"
    data = {}
    data['back_end'] = BackEnd.Adf if txn_mode == 0 else BackEnd.TxnHostPatch
    if not kernel_d:
        data['kernel_names'] = {'run_bilinear_resize_bf16': 25} # Refer to OGOAT/src/L1_fusion/kernel_func_list.py
        data['kernel_include'] = ['super.hh', 'bilinear_pixel_resize_bf16/bilinear_pixel_resize_bf16_wrapper.cc']   
    else:
        data['kernel_names'] = kernel_d['kernel_list']
        data['kernel_include'] = kernel_d['kernel_include']
    data['aie_rows'] = _data['overlay_info']['shape']['row']
    data['aie_cols'] = _data['overlay_info']['shape']['col']
    data['input'] = _data['layer_info']['in_act_shape']
    data['output'] = _data['layer_info']['out_act_shape']
    data['mode'] = _data['layer_info']
    dq_enable = not bool(_data['layer_info']['attributes']['disable_dq0'][0])
    q_enable = not bool(_data['layer_info']['attributes']['disable_q'][0])
    coordinate_transformation_mode = _data['layer_info']['attributes']['coordinate_transformation_mode'][0]
    act_bytes = _data['layer_info']['in_bytes']
    output_dir = os.path.dirname(os.path.realpath(json_file))
    build_bilinear_resize(
        data['back_end'],
        data['kernel_names'],
        data['kernel_include'],
        data['aie_cols'],
        data['aie_rows'],
        data['input'],
        data['output'],
        act_bytes=act_bytes,
        dq_enable=dq_enable,
        q_enable=q_enable,
        coordinate_transformation_mode=coordinate_transformation_mode,
        frontend_only=frontend_only,
        output_dir=output_dir,
    )


def main():
    # check if args length is 2
    if len(sys.argv) != 2:
        print("Usage: python bilinear_resize_build.py <sim | hw>")
        sys.exit(1)
    back_end = BackEnd.Adf
    if sys.argv[1] == 'sim':
        back_end = BackEnd.Adf  # BackEnd.Adf or BackEnd.TxnHostPatch 
    elif sys.argv[1] == 'hw':
        back_end = BackEnd.TxnHostPatch
    else:
        print("Invalid argument. Use 'sim' or 'hw'.")
        sys.exit(1)
    aie_cols, aie_rows = 8, 4
    frontend_only = False
    kernel_names = {'run_bilinear_resize_bf16': 25} # Refer to OGOAT/src/L1_fusion/kernel_func_list.py
    kernel_include = ['super.hh', 'bilinear_pixel_resize_bf16/bilinear_pixel_resize_bf16_wrapper.cc']
    shape_list = [
        [
            [1, 33, 57, 384], # N, Yi, Xi, Ci
            [1, 66, 114, 384], # N, Yo, Xo, Co
            [1, 1], # DQ_enable, Q_enable
            2, # act_bytes
            'align_corners', # coordinate_transformation_mode
        ],
        [
            [1, 33, 57, 48], # N, Yi, Xi, Ci
            [1, 66, 114, 48], # N, Yo, Xo, Co
            [1, 1], # DQ_enable, Q_enable
            2, # act_bytes
            'align_corners', # coordinate_transformation_mode
        ],
        [
            [1, 58, 100, 48], # N, Yi, Xi, Ci
            [1, 66, 114, 48], # N, Yo, Xo, Co
            [1, 1], # DQ_enable, Q_enable
            2, # act_bytes
            'align_corners', # coordinate_transformation_mode
        ],
        [
            [1, 66, 114, 192], # N, Yi, Xi, Ci
            [1, 115, 199, 192], # N, Yo, Xo, Co
            [1, 1], # DQ_enable, Q_enable
            2, # act_bytes
            'align_corners', # coordinate_transformation_mode
        ],
        [
            [1, 66, 114, 48], # N, Yi, Xi, Ci
            [1, 115, 199, 48], # N, Yo, Xo, Co
            [1, 1], # DQ_enable, Q_enable
            2, # act_bytes
            'align_corners', # coordinate_transformation_mode
        ],
        [
            [1, 256, 256, 8], # N, Yi, Xi, Ci
            [1, 1024, 1024, 8], # N, Yo, Xo, Co
            [1, 1], # DQ_enable, Q_enable
            1, # act_bytes
            'half_pixel', # coordinate_transformation_mode
        ],
    ]

    for shape in shape_list:
        build_bilinear_resize(
            back_end,
            kernel_names,
            kernel_include,
            aie_cols,
            aie_rows,
            shape[0],  # input shape
            shape[1],  # output shape
            act_bytes=shape[3], # act_bytes
            dq_enable=shape[2][0],  # DQ enable
            q_enable=shape[2][1],  # Q enable
            coordinate_transformation_mode=shape[4],  # coordinate transformation mode
            frontend_only=frontend_only,
            output_dir=CURRDIR,
        )
        if back_end == BackEnd.TxnHostPatch:
            # copy ifm.bin, ofm.bin, wgt.bin, param.bin, txn.bin and patch.json to hw_package/N_Yi_Xi_Ci_Yo_Xo_Co_coordinate_transformation_mode dir
            output_dir = os.path.join(CURRDIR, f"psp1_hw_package/{shape[0][0]}_{shape[0][1]}_{shape[0][2]}_{shape[0][3]}_{shape[1][1]}_{shape[1][2]}_{shape[1][3]}_{shape[4]}")
            if not os.path.exists(output_dir):
                os.makedirs(output_dir)
            shutil.copy(os.path.join(CURRDIR, 'ifm.bin'), os.path.join(output_dir, 'ifm.bin'))
            shutil.copy(os.path.join(CURRDIR, 'ofm.bin'), os.path.join(output_dir, 'ofm.bin'))
            shutil.copy(os.path.join(CURRDIR, 'wgt.bin'), os.path.join(output_dir, 'wgt.bin'))
            shutil.copy(os.path.join(CURRDIR, 'param.bin'), os.path.join(output_dir, 'param.bin'))
            shutil.copy(os.path.join(CURRDIR, 'txn.bin'), os.path.join(output_dir, 'txn.bin'))
            shutil.copy(os.path.join(CURRDIR, 'ctrl.bin'), os.path.join(output_dir, 'ctrl.bin'))
            shutil.copy(os.path.join(CURRDIR, 'patch.json'), os.path.join(output_dir, 'patch.json'))
            shutil.copy(os.path.join(CURRDIR, 'tiling.json'), os.path.join(output_dir, 'tiling.json'))

if __name__ == '__main__':
    main()
