from itertools import product
import os
import sys
import json
import shutil
import logging
import argparse
from typing import List, Optional
from functools import reduce
from operator import mul

CURRDIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.join(CURRDIR, '..'))
sys.path.append(os.path.join(CURRDIR, '..', '..'))
sys.path.append(os.path.join(CURRDIR, '..', '..', 'dmacompiler'))
sys.path.append(CURRDIR)

from dmacompiler import (
    BackEnd,
    set_dev_gen, DevGen, config
)
from dataflow_common import clean_overlay, build_sim_overlay, tiling_json_gen
from transpose_common import TransposeDims, TransposeKernelDims, transpose_preproc_directives, padding
from transpose_tiler import run_tiler
import transpose_dataflow, transpose_kernel_dataflow

set_dev_gen(DevGen.Aie2p)
config.ENABLE_BUSY_POLL = True
from OGOAT.src.L1_fusion.kernel_func_list import kernel_func_list

def build_transpose(
    back_end: BackEnd,
    kernel_names: List[str],
    kernel_includes: List[str],
    aie_cols: int,
    aie_rows: int,
    enable_batch: bool,
    batch_size: int,
    input: List[int],
    perm: List[int],
    # act_bits: int,
    is_int16: bool,
    is_signed: bool,
    qdq_mode: bool,
    kernel_dataflow: bool,
    frontend_only: bool = False,
    out_folder: Optional[str] = None,
    save_bins: Optional[bool] = False,
):
    assert (aie_cols, aie_rows) == (8, 4)

    if kernel_dataflow:
        dims = run_tiler(
            aie_cols,aie_rows,
            # act_bits, act_bits,
            # 16, 8,
            is_int16,
            is_signed,
            perm,
            enable_batch,
            batch_size,
            input[0], input[1], input[2], input[3],
            None, None, None, None,
            qdq_mode=int(qdq_mode),

        )
    else:
        assert False, f"kernel_dataflow should be enabled, place holder for non-kernel dataflow"

    tiling = {}
    tiling["op_type"] = "transpose"
    tiling["orig_input"] = input
    tiling["orig_output"] = [input[i] for i in perm]
    tiling_json_gen(tiling, os.path.join(os.getcwd(), 'tiling.json'))

    run_scheduler(dims, back_end, kernel_names, kernel_includes, kernel_dataflow, frontend_only)

    if save_bins and back_end == BackEnd.TxnHostPatch and out_folder is not None:
        if not os.path.exists(out_folder):
            os.makedirs(out_folder)
        in_folder = CURRDIR
        files = ('ifm.bin', 'wgt.bin', 'ofm.bin', 'dma.hpp',
                'tiling.json', 'txn.bin', 'param.bin', 'ctrl.bin', 'patch.json')
        for file in files:
            src = os.path.join(in_folder, file)
            dst = os.path.join(out_folder, file)
            shutil.move(src, dst)


def extract_fields(file_name):
    with open(file_name, 'r') as f:
        data = json.load(f)
    return data

def update_permute(permute, original_len):
    n_pad = 4 - original_len
    return list(range(n_pad)) + [i + n_pad for i in permute]

def update_input(data):
    while len(data) < 4:
        data.insert(0, 1)
    return data

def get_nested_value(d, keys, default=None):
    for key in keys:
        if isinstance(d, dict):
            d = d.get(key)
        else:
            return default
    return d if d is not None else default

def run_scheduler(dims: TransposeDims,
                back_end: BackEnd,
                kernel_names: List[str],
                kernel_includes: List[str],
                kernel_dataflow: bool = True,
                frontend_only: bool = False
                ):
    clean_overlay()
    if kernel_dataflow:
        transpose_kernel_dataflow.compile_dataflow(dims, back_end, kernel_names, kernel_includes)
    else:
        transpose_dataflow.compile_dataflow(dims, back_end, kernel_names, kernel_includes)
    if not frontend_only:
        host_cpp = os.path.join(os.getcwd(), 'transpose_main.cpp')
        build_sim_overlay(back_end, host_cpp, transpose_preproc_directives(dims, back_end))


def run_transpose_op(json_file, path, txn_mode, kernel_d, frontend_only):
    os.chdir(path)
    _data = extract_fields(json_file)
    data = {}
    data['op_type'] = _data['layer_info']['op_type']
    data['back_end'] = BackEnd.Adf if txn_mode == 0 else BackEnd.TxnHostPatch
    input_shape = _data['layer_info']['in_act_shape']
    if not len(input_shape) > 1:
        assert False, "This layer is not a Transpose because input shape is less than 2"

    data['input'] = _data['layer_info']['in_act_shape']
    data['perm'] = _data['layer_info']['attributes']['perm']
    data['input'] = update_input(data['input'])
    data['perm'] = update_permute(data['perm'], len(data['perm']))

    kernel_dataflow = True

    data['aie_rows'] = _data['overlay_info']['shape']['row']
    data['aie_cols'] = _data['overlay_info']['shape']['col']
    # data['act_bits'] = _data['layer_info']['in_bytes'] * 8
    data['act_bits'] = 16   # This is a temporary fix until the input datatype issue is not resolved by Akshay's team
    output_dir = os.path.dirname(os.path.realpath(json_file)) if data['back_end'] != BackEnd.Adf else None

    # disable_qdq / is_not_qdq
    disable_qdq_list = get_nested_value(_data, ['layer_info', 'attributes', 'disable_qdq'], [])
    is_not_qdq = disable_qdq_list[0] if isinstance(disable_qdq_list, list) and disable_qdq_list else 1
    is_qdq = 1 - is_not_qdq
    qdq_mode = 2 if is_qdq else 3
    is_int16 = True if data['op_type'] == "Transpose_qdq_uint16xuint16" else False

    if not kernel_d:
        if kernel_dataflow:
            data['kernel_names'] = {}
            if is_int16:
                data['kernel_names']['run_transpose'] = kernel_func_list.index('run_transpose')
            else:
                data['kernel_names']['run_transpose_a8'] = kernel_func_list.index('run_transpose_a8')
            data['kernel_includes'] = ['super.hh', 'transpose/wrapper_transpose.cc']
        else:
            data['kernel_names'] = []
            data['kernel_includes'] = ['super.hh']
    else:
        data['kernel_names'] = kernel_d['kernel_list']
        data['kernel_includes'] = kernel_d['kernel_include']

    logging.info(f" Transpose input args: {data}")
    build_transpose(data['back_end'],
                data['kernel_names'], data['kernel_includes'],
                data['aie_cols'], data['aie_rows'],
                data['input'], data['perm'],
                # data['act_bits'],
                is_int16,
                qdq_mode,
                kernel_dataflow,
                frontend_only,
                out_folder=output_dir)




def generate_all_shapes_and_qdq_modes_bins(combination: dict,
                                      qdq_modes: List[int] = [0, 1, 2, 3],
                                      aie_cols: int = 8,
                                      aie_rows: int = 4
                                    ):
    all_combinations = product(combination, qdq_modes)
    for qdq_combination in all_combinations:
        curr_shape = combination[qdq_combination[0]][0]
        curr_permutation = combination[qdq_combination[0]][1]
        curr_ifm_bits = combination[qdq_combination[0]][2]
        curr_is_kernel_dataflow = combination[qdq_combination[0]][3]
        curr_qdq_mode = qdq_combination[1]
        is_int16 = True if curr_ifm_bits == 16 else False

        if curr_is_kernel_dataflow:
            kernel_names = {'run_transpose': kernel_func_list.index('run_transpose'),
                            'run_transpose_a8': kernel_func_list.index('run_transpose_a8')}
            kernel_includes = ['super.hh', 'transpose/wrapper_transpose.cc']
        else:
            kernel_names = []
            kernel_includes = ['super.hh']
        build_transpose(
            BackEnd.TxnHostPatch,
            kernel_names,
            kernel_includes,
            aie_cols,
            aie_rows,
            update_input(curr_shape),
            update_permute(curr_permutation, len(curr_permutation)),
            is_int16,
            # curr_act_bits,
            curr_qdq_mode,
            curr_is_kernel_dataflow,
            frontend_only=False,
            out_folder=f"input_{'_'.join(map(str, curr_shape + curr_permutation + [curr_ifm_bits] + [curr_qdq_mode]))}"
        )
def shape_fuse_expand(shape, perm, ifm_bits, enable_batch):
    # ======= Recursive fuse & expand function =======
    def is_continuous_ascending(perm):
        return all(perm[i]+1==perm[i+1] for i in range(len(perm)-1))
    # ======= Find the continurous fuse blocks ========
    # 1. any continurous ascending pattern;
    # 2. the inner-most % 8 == 0 , for padding
    # 3. the fused dim limitation might be voilate the DMA dim limition and L1 sise after split
    def find_fusable_blocks(shape, perm):
        """
        Find blocks of continuous ascending perm values that can be fused.
        Returns list of tuples (start_val, end_val) in **perm values**, not indices.
        Last axis is excluded unless divisible by 8.
        """
        blocks = []
        i = 0
        while i < len(perm) - 1:
            start_val = perm[i]
            start_idx = i
            while i + 1 < len(perm) and perm[i] + 1 == perm[i + 1]:
                i += 1
            end_val = perm[i]
            end_idx = i
            # exclude last axis unless divisible by 8
            if end_idx == len(perm) - 1 and shape[perm[end_idx]] % 8 != 0:
                end_val = perm[end_idx - 1]
                end_idx -= 1
            if end_val > start_val:
                blocks.append((start_val, end_val))
            i += 1
        return blocks
    def fuse_once(shape, perm):
        blocks = find_fusable_blocks(shape, perm)
        if not blocks:
            return shape, perm, False

        start, end = blocks[0]
        product = reduce(mul, shape[start:end+1], 1)
        if product > 1024 * 8:
            return shape, perm, False

        # fuse shape
        new_shape = shape[:start] + [product] + shape[end+1:]

        # adjust perm
        fused_len = end-start+1
        new_perm = []
        for p in perm:
            if start <= p <= end:
                if p == start:
                    new_perm.append(start)
            elif p > end:
                new_perm.append(p - (fused_len-1))
            else:
                new_perm.append(p)

        if len(shape) == 4 and len(new_shape) == 3:
            if ((new_perm[2] == 0 and new_shape[0] % (32 / ifm_bits) != 0 and new_shape[1] < 4) or
                (new_perm[2] == 1 and new_shape[1] % (32 / ifm_bits) != 0 and new_shape[0] < 4)):
                return shape, perm, False


        return new_shape, new_perm, True

    """start from here """
    # no-op check
    if is_continuous_ascending(perm):
        return [], [], [], "no-op (continuous ascending)"

    fused = True
    while fused:
        shape, perm, fused = fuse_once(shape, perm)

    # expand to 4d
    while len(shape) <4:
        shape = [1] + shape
        perm = [0] + [p+1 for p in perm]
        # perm = list(range(4 - len(perm))) + perm

    if len(shape) > 4 and  not enable_batch:
        return [], [], [], "not able to fused to 4d"
    #regenerate the shape and perm
    batch_size =  shape[0] if enable_batch and len(shape) == 5 else 1
    output = [shape[i] for i in perm]

    final_input  = shape[1:] if enable_batch and len(shape) == 5 else shape
    fianl_output = output[1:] if enable_batch and len(shape) == 5 else output
    final_perm = [i - 1 for i in perm[1:]] if enable_batch and len(shape) == 5 else perm

    return batch_size, final_input, final_perm, fianl_output, ""

def parse_args():
    parser = argparse.ArgumentParser(description="Build transpose for various shapes and permutations.")
    parser.add_argument(
        "--backend", type=int, default=0,
        help="Backend type (default: 0 for Adf)"
    )
    parser.add_argument(
        "--qdq_mode", type=int, default=2,
        help="QDQ mode (default: 3)"
    )
    parser.add_argument(
        "--ifm_bits_override", type=int, default=0,
        help="ifm bits override (default: 0, not override)"
    )
    parser.add_argument(
        "--shape_index", type=int,
        help="Index of the shape from the input set to run (if not provided, runs all)"
    )
    parser.add_argument(
        "--save_bins", type=bool, default=False,
        help="Save generated bin files and dma.hpp"
    )
    return parser.parse_args()


def main():
    args = parse_args()

    back_end = BackEnd(args.backend)
    qdq_mode = args.qdq_mode
    save_bins = args.save_bins
    aie_cols, aie_rows = 8, 4
    ifm_bits_override = args.ifm_bits_override

    INPUT_SHAPES = {
        0 : [[1, 12, 50, 64],        [0, 2, 1, 3],      16,   True],    # PASS
        1 : [[1, 224, 224, 3],       [0, 3, 1, 2],      16,   True ],    # PASS
        2 : [[1, 224, 224, 3],       [0, 3, 2, 1],      16,   True ],    # PASS
        3 : [[1, 50, 12, 64],        [0, 2, 1, 3],      16,   True],    # PASS
        4 : [[10, 77, 8, 64],        [0, 2, 1, 3],      16,   True],    # PASS
        5 : [[12, 50, 64],           [0, 2, 1],         16,   True ],    # PASS
        6 : [[80, 77, 64],           [0, 2, 1],         16,   True ],    # PASS
        7 : [[1, 12, 197, 64],       [0, 2, 1, 3],      16,   True],    # PASS
        8 : [[1, 197, 12, 64],       [0, 2, 1, 3],      16,   True],    # PASS
        9 : [[1, 768, 14, 14],       [0, 3, 2, 1],      16,   True ],    # PASS
        10: [[1, 768, 14, 14],       [0, 2, 3, 1],      16,   True ],    # PASS
        11: [[12, 197, 64],          [0, 2, 1],         16,   True ],    # PASS
        12: [[1, 197, 12, 64],       [0, 2, 3, 1],      16,   True ],    # PASS
        13: [[1, 768, 196],          [0, 2, 1],         16,   True ],    # PASS
        14: [[1, 12, 128, 64],       [0, 2, 1, 3],      16,   True],    # PASS
        15: [[1, 128, 12, 64],       [0, 2, 1, 3],      16,   True],    # PASS
        16: [[1, 128, 12, 64],       [0, 2, 3, 1],      16,   True ],    # PASS
        17: [[10, 8, 77, 64],        [0, 2, 1, 3],      16,   True],    # PASS
        18: [[1, 14, 14, 768],       [0, 3, 1, 2],      16,   True ],    # PASS
        19: [[1, 14, 14, 768],       [0, 3, 2, 1],      16,   True ],    # PASS
        20: [[1, 1886, 6, 64],       [0, 2, 3, 1],      16,   True ],    # PASS
        21: [[1, 6, 1886, 64],       [0, 1, 3, 2],      16,   True ],    # PASS
        22: [[1, 14, 1886, 34],      [0, 1, 3, 2],      16,   True ],    # PASS
        23: [[1, 4, 4, 4],           [0, 2, 1, 3],      16,   True ],    # PASS
        24: [[8, 128, 64, 576],      [0, 3, 2, 1],      16,    True],    # PASS
    }
    # generate_all_shapes_and_qdq_modes_bins(INPUT_SHAPES)


    PSP1_SHAPES = {
        0 : [[1, 115, 199, 144],     [0, 3, 1, 2],      16,   True ],    # PASS
        1 : [[1, 115, 199, 54],      [0, 3, 1, 2],      16,   True ],    # PASS
        2 : [[1, 3, 462, 768],       [0, 2, 3, 1],      16,   True ],    # PASS
        3 : [[1, 33, 57, 384],       [0, 3, 1, 2],      16,   True ],    # PASS
        4 : [[1, 384, 1881],         [0, 2, 1],         16,   True ],    # PASS
        5 : [[1, 384, 33, 57],       [0, 2, 3, 1],      16,   True ],    # PASS
        6 : [[1, 4, 460, 796],       [0, 2, 3, 1],      16,   True ],    # PASS
        7 : [[1, 460, 796, 4],       [0, 3, 1, 2],      16,   True ],    # PASS
        8 : [[1, 462, 768, 3],       [0, 3, 1, 2],      16,   True ],    # PASS
    }

    PERM_3_0 = {
        0 : [[8, 64, 4, 150],        [3, 1, 2, 0],      16,   True ],    # PASS
        1 : [[7, 45, 9, 57],         [3, 1, 0, 2],      16,   True ],    # PASS
        2 : [[6, 46, 99, 56],        [3, 0, 1, 2],      16,   True ],    # PASS
        3 : [[5, 16, 47, 64],        [3, 0, 2, 1],      16,   True ],    # PASS
        4 : [[13, 45, 74, 64],       [3, 2, 1, 0],      16,   True ],    # PASS
        5 : [[5, 16, 99, 64],        [3, 2, 0, 1],      16,   True ],    # PASS
    }

    debug = {
        # 'Transpose_2_sam2_en_a8': [ [1, 1024, 1024, 3],  [0, 3, 1, 2], 8, True ],
        # 'Transpose_6_sam2_en_a8': [ [1, 32, 32, 2, 2, 384],  [0, 1, 3, 2, 4, 5], 8, True ],
        'Transpose_8_sam2_de_a8': [ [1, 32, 2, 32, 2, 256],  [0, 1, 3, 5, 2, 4], 8, True ],  # output = [1, 32, 32, 256, 2, 2]
        'Transpose_13_sam2_de_a8': [ [1, 64, 2, 64, 2, 128],  [0, 1, 3, 5, 2, 4], 8, True ], # output = [1, 64, 64, 128, 2, 2]
        'Transpose_17_sam2_de_a8': [ [128, 2, 128, 2, 1],  [0, 2, 1, 3, 4], 8, True ],
        'Transpose_18_sam2_de_a8': [ [64, 2, 64, 2, 4],  [0, 2, 1, 3, 4], 8, True ],
    }

    target_shape = debug
    shapes = [target_shape[args.shape_index]] if args.shape_index is not None else target_shape.values()
    frontend_only = False
    is_signed = False

    for idx, (input, perm, act_bits, kernel_dataflow) in enumerate(shapes):
        print("input", input)
        print("perm", perm)
        print("idx", idx)
        W8_input = input[:-1] + [padding(input[-1])] if perm[-1] == len(perm)-1 else input
        enable_batch = len(perm) == 6 and (perm[0]==0 and perm[1]==1)
        batch_size, new_input, new_perm, _, reason = shape_fuse_expand(W8_input, perm, act_bits, enable_batch)
        if reason:
            assert False, f"the shape: {input} with perm: {perm} failed: {reason}"

        if kernel_dataflow:
            kernel_names = {'run_transpose': kernel_func_list.index('run_transpose'),
                            'run_transpose_a8': kernel_func_list.index('run_transpose_a8')}
            kernel_includes = ['super.hh', 'transpose/wrapper_transpose.cc']
        else:
            kernel_names = []
            kernel_includes = ['super.hh']
        ifm_bits = ifm_bits_override if (ifm_bits_override == 8 or ifm_bits_override == 16) else act_bits
        is_int16 = True if ifm_bits == 16 else False
        build_transpose(
            back_end,
            kernel_names,
            kernel_includes,
            aie_cols,
            aie_rows,
            # update_input(input),
            # update_permute(perm, len(perm)),
            enable_batch,
            batch_size,
            new_input,
            new_perm,
            is_int16,
            is_signed,
            qdq_mode,
            kernel_dataflow,
            frontend_only=frontend_only,
            out_folder=f"input_{'_'.join(map(str, input + perm + [ifm_bits] + [qdq_mode]))}",
            save_bins=save_bins,
        )
        print("=" * 90)

if __name__ == '__main__':
    main()