import os
import struct
import numpy as np
import copy
import math
CURRDIR = os.path.dirname(os.path.abspath(__file__))
import sys
sys.path.append(os.path.join(CURRDIR, '..', '..', 'dmacompiler'))
sys.path.append(os.path.join(CURRDIR, '..', '..', 'kernels', 'conv'))
sys.path.append(os.path.join(CURRDIR, '..', '..', 'kernels', 'bilinear_pixel_resize_bf16'))
from typing import List, Union
from enum import IntEnum

from dataflow_common import ceildiv, iceil, overlay_stack_addr
from kernels.bilinear_pixel_resize_bf16.bilinear_pixel_resize_bf16_kernel_params import (
    CoordinateTransfromationMode,
    genereate_bilinear_resize_kernel_params,
    mode_map,
    BilinearResizeShape,
)
from dmacompiler import \
    BackEnd, \
    set_dev_gen, DevGen, config
set_dev_gen(DevGen.Aie2p)
config.ENABLE_BUSY_POLL = True


def bilinear_resize_wgt_subvol_dims(
    output_subvol_dims: tuple[int, int],
    verbose: bool = False,
) -> int:
    '''
    Function to calculate the weight subvol dimensions
    based on the output subvol dimensions.
    Refer to kernels/bilinear_pixel_resize_bf16/bilinear_pixel_resize_bf16.json
    for more details.
    '''
    print(f"output_subvol_dims: {output_subvol_dims}") if verbose else None
    Yos, Xos = output_subvol_dims
    wgt_subvol_dims = int((Yos + Xos)*4)
    print(f"wgt_subvol_dims: {wgt_subvol_dims}") if verbose else None
    return wgt_subvol_dims


def bilinear_resize_input_subvol_dims(
    mode: CoordinateTransfromationMode,
    output_tensor_dims: tuple[int, int],
    input_tensor_dims: tuple[int, int],
    output_subvol_dims: tuple[int, int],
    split: tuple[int, int],
    verbose: bool = False,
) -> list:
    '''
    Function to calculate the input subvol dimensions
    based on the mode and the output subvol dimensions.
    Refer to kernels/bilinear_pixel_resize_bf16/bilinear_pixel_resize_bf16.json
    for more details.
    '''
    print(f"output_tensor_dims: {output_tensor_dims}") if verbose else None
    print(f"input_tensor_dims: {input_tensor_dims}") if verbose else None
    print(f"output_subvol_dims: {output_subvol_dims}") if verbose else None
    print(f"split: {split}") if verbose else None
    print(f"mode: {mode}") if verbose else None
    Yo, Xo = output_tensor_dims
    Yi, Xi = input_tensor_dims
    Yos, Xos = output_subvol_dims
    Y_split, X_split = split
    grid_offset = 0.0
    scale_adjust = 1
    scale = {}
    raw_indices = {}
    input_subvol_dims = {}
    sv_I_step = {} 
    sv_I_offset = {} 
    if mode == CoordinateTransfromationMode.HALF_PIXEL:
        grid_offset = 0.5
    elif mode == CoordinateTransfromationMode.ALIGN_CORNERS:
        scale_adjust = 1
    
    print(f"grid_offset: {grid_offset}") if verbose else None
    print(f"scale_adjust: {scale_adjust}") if verbose else None
    # Add new key values to the dictionary scale
    scale["Y"] = (Yo - scale_adjust) / (Yi - scale_adjust)
    scale["X"] = (Xo - scale_adjust) / (Xi - scale_adjust)
    print(f"scale: {scale}") if verbose else None
    raw_indices["Y"] = []
    raw_indices["X"] = []
    for pixel_index in range(math.ceil(iceil(Yo, Yos) + grid_offset)):
        raw_indices["Y"].append(min( max( 0, pixel_index / scale["Y"] - grid_offset ), Yi - 1 ))
    for pixel_index in range(math.ceil(iceil(Xo, Xos) + grid_offset)):
        raw_indices["X"].append(min( max( 0, pixel_index / scale["X"] - grid_offset ), Xi - 1 ))
    print(f"raw_indices: {raw_indices}") if verbose else None
    sv_I_step["Y"] = int(round( Yos / scale["Y"] ))
    sv_I_step["X"] = int(round( Xos / scale["X"] ))
    print(f"sv_I_step: {sv_I_step}") if verbose else None
    sv_I_offset["Y"] = int(min(
                        [np.floor(raw_indices["Y"][i * Yos] - sv_I_step["Y"] * i) \
                         for i in range(Y_split)]))
    sv_I_offset["X"] = int(min(
                        [np.floor(raw_indices["X"][i * Xos] - sv_I_step["X"] * i) \
                         for i in range(X_split)]))
    print(f"sv_I_offset: {sv_I_offset}") if verbose else None
    input_subvol_dims["Y"] = int(
        max(
            np.floor(raw_indices["Y"][Yos-1::Yos][i] - sv_I_step["Y"] * i)
            for i in range(Y_split)
        )
        - sv_I_offset["Y"] + 2
    )
    input_subvol_dims["X"] = int(
        max(
            np.floor(raw_indices["X"][Xos-1::Xos][i] - sv_I_step["X"] * i)
            for i in range(X_split)
        )
        - sv_I_offset["X"] + 2
    )
    print(f"input_subvol_dims: {input_subvol_dims}") if verbose else None
    return_obj = [
        (input_subvol_dims["Y"], input_subvol_dims["X"]),
        (sv_I_step["Y"], sv_I_step["X"]),
        (sv_I_offset["Y"], sv_I_offset["X"]),
    ]
    return return_obj 


def generate_bilinear_resize_layer_params(
    dims: BilinearResizeShape,
    wgt_subv_offset: int,
    col: int,
    row: int,
    verbose: bool = False,
) -> bytes:
    elems_align = 64  # Align to 64 elems for QDQ kernel
    ifm_elems = iceil(dims.Cis * dims.Yis * dims.Xis, elems_align)  # NOTE: Used for dq
    ofm_elems = iceil(dims.Cos * dims.Yos * dims.Xos, elems_align)  # NOTE: Used for q
    layer_params = (
        wgt_subv_offset.to_bytes(4, byteorder='little', signed=False)   # NOTE: used to derive the wgt pointer
        + ifm_elems.to_bytes(4, byteorder='little', signed=False)
        + ofm_elems.to_bytes(4, byteorder='little', signed=False)
    )
    kernel_params = genereate_bilinear_resize_kernel_params(dims)
    return (layer_params + kernel_params)


def bilinear_resize_preproc_directives(
    dims: BilinearResizeShape,
    back_end: BackEnd,
    verbose: bool = False,
) -> List[str]:
    # Generate the preprocessor directives for the kernel
    def directive(ident: str, val: int) -> str:
        if back_end == BackEnd.Adf:
            return f'--Xpreproc="-D{ident}={val}"'
        return f"-D{ident}={val}"
    txn_mode = int(back_end != BackEnd.Adf)
    return [
        directive("HALF_PIXEL", int(dims.mode == CoordinateTransfromationMode.HALF_PIXEL)),
        directive("ALIGN_CORNERS", int(dims.mode == CoordinateTransfromationMode.ALIGN_CORNERS)),
        directive("ASYMMETRIC", int(dims.mode == CoordinateTransfromationMode.ASYMMETRIC)),
        directive("COUT", int(dims.Co)),
        directive("YOUT", int(dims.Yo)),
        directive("XOUT", int(dims.Xo)),
        directive("CIN", int(dims.Ci)),
        directive("YIN", int(dims.Yi)),
        directive("XIN", int(dims.Xi)),
        directive("COS", int(dims.Cos)),
        directive("YOS", int(dims.Yos)),
        directive("XOS", int(dims.Xos)),
        directive("CIS", int(dims.Cis)),
        directive("YIS", int(dims.Yis)),
        directive("XIS", int(dims.Xis)),
        directive("YIS_STEP", int(dims.Yis_step)),
        directive("XIS_STEP", int(dims.Xis_step)),
        directive("YIS_OFFSET", int(dims.Yis_offset)),
        directive("XIS_OFFSET", int(dims.Xis_offset)),
        directive("Y_SPLIT", int(dims.Y_split)),
        directive("X_SPLIT", int(dims.X_split)),
        directive("DQ_ENABLE", int(dims.dq_enable)),
        directive("Q_ENABLE", int(dims.q_enable)),
        directive('TXN_MODE', txn_mode),
    ]