import os
import struct
import numpy as np
import copy
import math
import sys
from typing import List, Union
from enum import IntEnum
from dataclasses import dataclass
from named_list import *
from named_list import TypedNamedList

from dataflow.dataflow_utils import (
    CommonDims,
)

def make_tuple( val ):
    if hasattr( val, "__len__" ):
        return tuple( val )
    else:
        return ( val, )


class DimsHelper:
    def __init__( self, reset=0, bits=32 ):
        self.reset = reset
        self.bits = bits

    def __getitem__( self, key ):
        return getattr( self, key )

    def add_dimension( self, num, step ):
        inc = self.reset + step
        self.reset -= num * step
        return inc

    def from_steps( self, wraps, steps, next_loop_level=False ):
        wraps = make_tuple( wraps )
        steps = make_tuple( steps )
        assert len( steps ) in [1,2,3,4,5], "Only 1d to 5d address increments supported"
        assert len( wraps ) >= len( steps ) - 1, "Wrap dimesions passed are not sufficient"

        nums = []
        incs = []
        for i,s in enumerate( steps ):
            if i == len( wraps ):
                incs.append( self.reset + s )
                self.reset = 0
            else:
                if i < len( steps ) - 1:
                    if i % 3 == 2:
                        num = wraps[i]
                        nums.append( np.prod( wraps[:i+1] ) - 1 )
                    else:
                        num = wraps[i] - 1
                        nums.append( num )
                else:
                    num = wraps[i] - 1
                incs.append( self.add_dimension( num, s ))

                if ( next_loop_level and i == len( steps ) - 1 ) or ( i < len( steps ) - 1 and i % 3 == 2 ):
                    self.reset = -wraps[i] * s

        if len( incs ) == 1:
            return incs[0]
        else:
            return TypedNamedList([ f"uint{self.bits} num{n}" for n in range( len( nums ))] + [ f"int{self.bits} inc{n}" for n in range( len( incs ))], nums + incs )


@dataclass
class dims_3d_param:
    num0: int   # 32-bit int 
    num1: int   # 32-bit int
    inc0: int   # 32-bit int
    inc1: int   # 32-bit int
    inc2: int   # 32-bit int

@dataclass
class dims_2d_param:
    num0: int   # 32-bit int
    inc0: int   # 32-bit int
    inc1: int   # 32-bit int

@dataclass
class BilinearPixelResizeBf16Params:
    time_iters: int  # 16-bit int
    H: int           # 16-bit int
    W: int           # 16-bit int
    inner_loop: int  # 16-bit int
    step_Ci: int     # 32-bit int
    step_Hi: int     # 32-bit int
    dimsO: dims_3d_param
    dimsW: dims_2d_param

'''
 The follwing enumeration is used to define the transformation mode
 of the coordinates in the bilinear resize operation.
'''
class CoordinateTransfromationMode(IntEnum):
    NONE = 0
    HALF_PIXEL = 1
    ALIGN_CORNERS = 2
    ASYMMETRIC = 3

mode_map = {
    "half_pixel": CoordinateTransfromationMode.HALF_PIXEL,
    "align_corners": CoordinateTransfromationMode.ALIGN_CORNERS,
    "asymmetric": CoordinateTransfromationMode.ASYMMETRIC,
    "none": CoordinateTransfromationMode.NONE,
}


'''
 The following class is used to define the dimensions of the
 bilinear resize operation.
'''
@dataclass
class BilinearResizeShape(CommonDims):
    def __init__(
        self,
        # transformation mode
        mode: CoordinateTransfromationMode,
        # Shape info
        N: int,
        Co: int,
        Yo: int,
        Xo: int,
        Ci: int,
        Yi: int,
        Xi: int,
        # Datatype info
        act_bits: int,
        # QDQ info
        dq_enable: bool = False,
        q_enable: bool = False,
        aie_cols: int = 0,
        aie_rows: int = 0,
    ):
        self.N = N
        self.Co = Co
        self.Yo = Yo
        self.Xo = Xo
        self.Ci = Ci
        self.Yi = Yi
        self.Xi = Xi
        self.Y_gran = 1
        self.X_gran = 1
        self.C_gran = 8
        self.Yos = 0
        self.Xos = 0
        self.Cos = 0
        self.Yis = 0
        self.Xis = 0
        self.Cis = 0
        self.aie_cols = aie_cols
        self.aie_rows = aie_rows
        self.act_bits = act_bits
        super().__init__(
            aie_cols=self.aie_cols,
            aie_rows=self.aie_rows,
            ifm_bits=self.act_bits,
            ofm_bits=self.act_bits,
            Ni=self.N,
            Yi=self.Yi,
            Xi=self.Xi,
            Ci=self.Ci,
            No=self.N,
            Yo=self.Yo,
            Xo=self.Xo,
            Co=self.Co,
            Nip=self.N,
            Yip=self.Yi,
            Xip=self.Xi,
            Cip=self.Ci,
            Nop=self.N,
            Yop=self.Yo,
            Xop=self.Xo,
            Cop=self.Co,
            Nis=1,
            Yis=self.Yis,
            Xis=self.Xis,
            Cis=self.Cis,
            Nos=1,
            Yos=self.Yos,
            Xos=self.Xos,
            Cos=self.Cos,
            Ni_gran=1,
            Yi_gran=1,
            Xi_gran=1,
            Ci_gran=8,
            No_gran=1,
            Yo_gran=1,
            Xo_gran=1,
            Co_gran=8,
            Ky=1,
            Kx=1,
            Sy=1,
            Sx=1,
            Py_b=0,
            Px_b=0,
            Py_a=0,
            Px_a=0,
            is_Y8_split=False,
            is_X8_split=False,
        )
        self.mode = mode
        self.dq_enable = dq_enable
        self.q_enable = q_enable
        self.wgt_subvol_dims = 0
        self.Y_split = 0
        self.X_split = 0
        self.C_split = 0
        self.feasible = True
        self.ifm_l1_no_buff = 2 # ping-ping: 2, no ping-ping: 1
        self.ofm_l1_no_buff = 2 # ping-ping: 2, no ping-ping: 1
        self.Y_loop = 0
        self.X_loop = 0
        self.C_loop = 0
        self.qdq_bytes = 64
        self.param_subv_size = 1024
        self.Y_overcompute = 0
        self.X_overcompute = 0
        self.C_overcompute = 0
        self.Yis_step = 0
        self.Xis_step = 0
        self.Yis_offset = 0
        self.Xis_offset = 0

    def copy(self):
        return copy.deepcopy(self)


def genereate_bilinear_resize_kernel_params(
    dims: BilinearResizeShape,
    verbose: bool = False,
) -> bytes:
    '''
    Function to generate the kernel parameters for the bilinear resize operation.
    Refer to kernels/bilinear_pixel_resize_bf16/bilinear_pixel_resize_bf16.json
    for more details.
    '''
    step_Ci = 16 * dims.Xis * dims.Yis
    step_Hi = 2 * dims.Xis * 8
    step_Co = 16 * dims.Xos * dims.Yos
    step_Ho = 2 * dims.Xos * 8
    time_iters = 1 # NOTE: Unused in the kernel
    H = dims.Yos
    W = dims.Xos
    inner_loop = dims.Xos * dims.Cos // dims.C_gran
    dims_helper = DimsHelper()
    dimsO = dims_helper.from_steps(
        (dims.Xos, dims.Cos // 8),
        ( 16, step_Co, step_Ho )
    )
    dimsW = dims_helper.from_steps(
        (dims.Xos),
        (8,0)
    )
    if verbose:
        print(f"kernel_params.time_iters: {time_iters}")
        print(f"kernel_params.H: {H}")
        print(f"kernel_params.W: {W}")
        print(f"kernel_params.inner_loop: {inner_loop}")
        print(f"kernel_params.step_Ci: {step_Ci}")
        print(f"kernel_params.step_Hi: {step_Hi}")
        print(f"kernel_params.dimsO.num0: {dimsO.num0}")
        print(f"kernel_params.dimsO.num1: {dimsO.num1}")
        print(f"kernel_params.dimsO.inc0: {dimsO.inc0}")
        print(f"kernel_params.dimsO.inc1: {dimsO.inc1}")
        print(f"kernel_params.dimsO.inc2: {dimsO.inc2}")
        print(f"kernel_params.dimsW.num0: {dimsW.num0}")
        print(f"kernel_params.dimsW.inc0: {dimsW.inc0}")
        print(f"kernel_params.dimsW.inc1: {dimsW.inc1}")
    packed_params = struct.pack(
        "<4H2i5i3i",
        time_iters,
        H,
        W,
        inner_loop,
        step_Ci,
        step_Hi,
        dimsO.num0,
        dimsO.num1,
        dimsO.inc0,
        dimsO.inc1,
        dimsO.inc2,
        dimsW.num0,
        dimsW.inc0,
        dimsW.inc1
    )
    return packed_params



