from curses import raw
from dataclasses import dataclass
import math
import struct
import ctypes
import os
from enum import IntEnum

import numpy as np

CURRDIR = os.path.dirname(os.path.abspath(__file__))

from kernel.common.kernel_params_helper import (
    DimsHelper,
)

from scheduler.common import (
    BaseDims,
    ceildiv,
    iceil,
)

@dataclass(frozen=True)
class CoordinateTransfromationMode(IntEnum):
    '''Activation function modes supported by conv kernel'''
    HALF_PIXEL = 0
    ALIGN_CORNERS = 1
    ASYMMETRIC = 2


def resize_nni_wgt_subvol_dims(
    output_subvol_dims: tuple[int, int],
    wgt_bits: int,
    verbose: bool = False,
) -> int:
    '''
    Function to calculate the weight subvol dimensions
    based on the output subvol dimensions.
    Refer to kernel spec for more details.
    '''
    Yos, Xos = output_subvol_dims
    wgt_subvol_dims = int((Yos + Xos)* 2 * (wgt_bits // 8))
    if verbose:
        print(f"output_subvol_dims: {output_subvol_dims}")
        print(f"wgt_bits: {wgt_bits}")
        print(f"wgt_subvol_dims: {wgt_subvol_dims}")
    return wgt_subvol_dims


def resize_nni_input_subvol_dims(
    mode: CoordinateTransfromationMode,
    output_tensor_dims: tuple[int, int],
    input_tensor_dims: tuple[int, int],
    output_subvol_dims: tuple[int, int],
    split: tuple[int, int],
    verbose: bool = False,
) -> list:
    '''
    Function to calculate the input subvol dimensions
    based on the mode and the output subvol dimensions.
    Refer to kernel spec for more details.
    '''
    if verbose:
        print(f"output_tensor_dims: {output_tensor_dims}")
        print(f"input_tensor_dims: {input_tensor_dims}")
        print(f"output_subvol_dims: {output_subvol_dims}")
        print(f"split: {split}")
        print(f"mode: {mode}")
    Yo, Xo = output_tensor_dims
    Yi, Xi = input_tensor_dims
    Yos, Xos = output_subvol_dims
    Y_split, X_split = split
    grid_offset = 0.0
    scale_adjust = 0
    scale = {}
    raw_indices = {}
    input_subvol_dims = {}
    sv_I_step = {}
    sv_I_offset = {}
    if mode.value == CoordinateTransfromationMode.HALF_PIXEL.value:
        grid_offset = 0.5
    elif mode.value == CoordinateTransfromationMode.ALIGN_CORNERS.value:
        scale_adjust = 1
    if verbose:
        print(f"grid_offset: {grid_offset}")
        print(f"scale_adjust: {scale_adjust}")
    # Add new key values to the dictionary scale
    scale["Y"] = (Yo - scale_adjust) / (Yi - scale_adjust)
    scale["X"] = (Xo - scale_adjust) / (Xi - scale_adjust)
    if verbose:
        print(f"scale: {scale}")
    raw_indices["Y"] = []
    raw_indices["X"] = []
    for pixel_index in range(math.ceil(iceil(Yo, Yos) + grid_offset)):
        raw_indices["Y"].append(min( max( 0, pixel_index / scale["Y"] - grid_offset ), Yi - 1 ))
    for pixel_index in range(math.ceil(iceil(Xo, Xos) + grid_offset)):
        raw_indices["X"].append(min( max( 0, pixel_index / scale["X"] - grid_offset ), Xi - 1 ))
    if verbose:
        print(f"raw_indices: {raw_indices}")
    sv_I_step["Y"] = int(round( Yos / scale["Y"] ))
    sv_I_step["X"] = int(round( Xos / scale["X"] ))
    if verbose:
        print(f"sv_I_step: {sv_I_step}")
    sv_I_offset["Y"] = int(min(
                        [np.floor(raw_indices["Y"][i * Yos] - sv_I_step["Y"] * i) \
                         for i in range(Y_split)]))
    sv_I_offset["X"] = int(min(
                        [np.floor(raw_indices["X"][i * Xos] - sv_I_step["X"] * i) \
                         for i in range(X_split)]))
    if verbose:
        print(f"sv_I_offset: {sv_I_offset}")
    input_subvol_dims["Y"] = int(
        max(
            np.floor(raw_indices["Y"][Yos-1::Yos][i] - sv_I_step["Y"] * i)
            for i in range(Y_split)
        )
        - sv_I_offset["Y"] + 2
    )
    input_subvol_dims["X"] = int(
        max(
            np.floor(raw_indices["X"][Xos-1::Xos][i] - sv_I_step["X"] * i)
            for i in range(X_split)
        )
        - sv_I_offset["X"] + 2
    )
    if verbose:
        print(f"input_subvol_dims: {input_subvol_dims}")

    return_obj = [
        (input_subvol_dims["Y"], input_subvol_dims["X"]),
        (sv_I_step["Y"], sv_I_step["X"]),
        (sv_I_offset["Y"], sv_I_offset["X"]),
    ]
    return return_obj


@dataclass
class ResizeNNIDims(BaseDims):
    def __init__(
        self,
        coord_mode: int,
        N: int,
        Co: int,
        Yo: int,
        Xo: int,
        Ci: int,
        Yi: int,
        Xi: int,
        Cos: int,
        Yos: int,
        Xos: int,
        Cis: int,
        Yis: int,
        Xis: int,
        N_split: int,
        Co_split: int,
        Y_split: int,
        X_split: int,
        N_loop: int,
        Co_loop: int,
        Y_loop: int,
        X_loop: int,
        Ci_loop: int,
        aie_cols: int,
        aie_rows: int,
        Yis_size: int,
        Xis_size: int,
        Yis_step: int,
        Xis_step: int,
        Yis_offset: int,
        Xis_offset: int,
        ifm_bits: int,
        ofm_bits: int,
    ):
        assert Co == Ci
        assert Cos == Cis
        self.Ci_gran = 64
        self.N = N
        self.Co = Co
        self.Yo = Yo
        self.Xo = Xo
        self.Ci = Ci
        self.Yi = Yi
        self.Xi = Xi
        self.Cos = Cos
        self.Yos = Yos
        self.Xos = Xos
        self.Cis = Cis
        self.Yis = Yis
        self.Xis = Xis
        self.N_split = N_split
        self.Co_split = Co_split
        self.Y_split = Y_split
        self.X_split = X_split
        self.N_loop = N_loop
        self.Co_loop = Co_loop
        self.Y_loop = Y_loop
        self.X_loop = X_loop
        self.Ci_loop = Ci_loop
        self.aie_cols = aie_cols
        self.aie_rows = aie_rows
        if coord_mode == 0:
            self.mode = CoordinateTransfromationMode.HALF_PIXEL
        elif coord_mode == 1:
            self.mode = CoordinateTransfromationMode.ALIGN_CORNERS
        elif coord_mode == 2:
            self.mode = CoordinateTransfromationMode.ASYMMETRIC
        else:
            raise ValueError(f"Invalid mode: {coord_mode}")
        self.Yis_size = Yis_size
        self.Xis_size = Xis_size
        self.Yis_step = Yis_step
        self.Xis_step = Xis_step
        self.Yis_offset = Yis_offset
        self.Xis_offset = Xis_offset
        self.ifm_bits = ifm_bits
        self.ofm_bits = ofm_bits
        self.wgt_bits = 8
        self.wgt_subv_bytes = resize_nni_wgt_subvol_dims(
            (self.Yos, self.Xos),
            self.wgt_bits,
        )
        self.ifm_subv_bytes = (self.Yis * self.Xis * self.Cis * (self.ifm_bits // 8))
        self.ofm_subv_bytes = (self.Yos * self.Xos * self.Cos * (self.ofm_bits // 8))

        super().__init__(
            N=self.N,
            Co=self.Co,
            Yo=self.Yo,
            Xo=self.Xo,
            Ci=self.Ci,
            Yi=self.Yi,
            Xi=self.Xi,
            Cos=self.Cos,
            Yos=self.Yos,
            Xos=self.Xos,
            Cis=self.Cis,
            Yis=self.Yis,
            Xis=self.Xis,
            Kx=1,
            Ky=1,
            Sy=1,
            Sx=1,
            Py=0,
            Px=0,
            N_split=self.N_split,
            Co_split=self.Co_split,
            Y_split=self.Y_split,
            X_split=self.X_split,
            N_loop=self.N_loop,
            Co_loop=self.Co_loop,
            Y_loop=self.Y_loop,
            X_loop=self.X_loop,
            Ci_loop=self.Ci_loop,
            aie_cols=self.aie_cols,
            aie_rows=self.aie_rows,
        )

    def __str__(self)->str:
        return f"{super().__str__()}, ResizeNNIDims(" \
                f"mode={self.mode}, " \
                f"Yis_size={self.Yis_size}, Xis_size={self.Xis_size}, " \
                f"Yis_step={self.Yis_step}, Xis_step={self.Xis_step}, " \
                f"Yis_offset={self.Yis_offset}, Xis_offset={self.Xis_offset}, " \
                f"ifm_bits={self.ifm_bits}, ofm_bits={self.ofm_bits}, " \
                f"wgt_bits={self.wgt_bits}, " \
                f"wgt_subv_bytes={self.wgt_subv_bytes}, " \
                f"ifm_subv_bytes={self.ifm_subv_bytes}, " \
                f"ofm_subv_bytes={self.ofm_subv_bytes}" \
                f")"


def generate_resize_nni_noqdq_a8_params(
    dims: ResizeNNIDims,
) -> bytes:
    '''Generates the parameters for the Resize NNI kernel without QDQ for AIE-8.'''
    # "key": "order_select",
    #                "values": {
    #                    "0": "CHW(WC)32",
    #                    "1": "HCW(WC)32",
    #                    "2": "CHW(WC)64",
    #                    "3": "HCW(WC)64"
    #                }
    order_select = 2
    kernel_dims = DimsHelper()
    outer_loop = dims.Yos
    inner_loop = dims.Xos * dims.Cos // 32
    time_iters = 1
    wts_x_offset = dims.Yos
    step_ci = dims.Cos * (dims.ofm_bits // 8) * 32 if order_select < 2 else dims.Cos * (dims.ofm_bits // 8) * 64
    Wo_c64 = 2 if order_select > 1 else 1
    num_wi_c64 = 0 if order_select < 2 else 1
    step_yi = dims.Xis_size * dims.Cos * (dims.ofm_bits // 8)
    dimsA = kernel_dims.from_steps((Wo_c64, dims.Xos), (32 * (dims.ofm_bits // 8), 0, step_ci))
    dimsW = kernel_dims.from_steps((dims.Xos), (2, 0))
    wgt_l1_offset = dims.ifm_subv_bytes
    print(f"order_select: {order_select}")
    print(f"outer_loop: {outer_loop}")
    print(f"inner_loop: {inner_loop}")
    print(f"time_iters: {time_iters}")
    print(f"wts_x_offset: {wts_x_offset}")
    print(f"step_ci: {step_ci}")
    print(f"Wo_c64: {Wo_c64}")
    print(f"num_wi_c64: {num_wi_c64}")
    print(f"step_yi: {step_yi}")
    print(f"dimsA['num0']: {dimsA['num0']}")
    print(f"dimsA['num1']: {dimsA['num1']}")
    print(f"dimsA['inc0']: {dimsA['inc0']}")
    print(f"dimsA['inc1']: {dimsA['inc1']}")
    print(f"dimsA['inc2']: {dimsA['inc2']}")
    print(f"dimsW['num0']: {dimsW['num0']}")
    print(f"dimsW['inc0']: {dimsW['inc0']}")
    print(f"dimsW['inc1']: {dimsW['inc1']}")
    paked_params = struct.pack(
        '<1I6H3I3i1I2i',
        wgt_l1_offset,  # I
        outer_loop,     # H
        inner_loop,     # H
        time_iters,     # H
        wts_x_offset,  # H
        num_wi_c64,     # H
        0,             # H
        step_yi,        # I
        dimsA['num0'],  # I
        dimsA['num1'],  # I
        dimsA['inc0'],  # i
        dimsA['inc1'],  # i
        dimsA['inc2'],  # i
        dimsW['num0'],  # I
        dimsW['inc0'],  # i
        dimsW['inc1'],  # i
    )
    return paked_params
