import numpy as np
from named_list import *
from named_list import type_decoder, TypedNamedList, NamedList, values
import struct
from dataclasses import dataclass
import numpy as np

def take( arr, idx ):
    # print( idx )
    return np.take( arr, idx )

def print_variable_states(
    dims_in, permute, dims_out, dims_mem_in, dims_mem_out,
    step_in, step_out, idx_first_padded_dim_out,
    pad_left_first_padded_dim, block_length
):
    # print("\nIntermediate Variables:")
    # print("{:<30} | {:<40} | {}".format("Variable", "Value", "Type"))
    # print("-" * 95)

    entries = [
        ("dims_in", dims_in),
        ("param.permute", permute),
        ("dims_out", dims_out),
        ("dims_mem_in", dims_mem_in),
        ("dims_mem_out", dims_mem_out),
        ("step_in", step_in),
        ("step_out", step_out),
        ("idx_first_padded_dim_out", idx_first_padded_dim_out),
        ("pad_left_first_padded_dim", pad_left_first_padded_dim),
        ("block_length", block_length),
    ]

    for name, value in entries:
        val_str = str(value)
        if len(val_str) > 40:
            val_str = val_str[:37] + "..."
        # print("{:<30} | {:<40} | {}".format(name, val_str, type(value)))

def translate_variables_with_model(param_dict, input_shape, perm_for_kernel, output_shape):
    in0, in1, in2, in3 = input_shape
    Do0, Do1, Do2, Do3 = perm_for_kernel
    out0, out1, out2, out3 = output_shape
    permute = (Do0, Do1, Do2, Do3)

    dims_in = array(in0, in1, in2, in3)
    dims_out = array(out0, out1, out2, out3)

    # Padding arrays
    pad_di = array(*param_dict["padding"]["Di"].values())
    pad_do = array(*param_dict["padding"]["Do"].values())

    # Memory dims
    dims_mem_in = dims_in + sum(pad_di, axis=1)
    dims_mem_out = dims_out + sum(pad_do, axis=1)

    # Step sizes
    step_in = (cumprod(dims_mem_in) // dims_mem_in) * sizeof(param_dict["dtype"])
    step_out = (cumprod(dims_mem_out) // dims_mem_out) * sizeof(param_dict["dtype"])

    # First padded dimension
    pad_sum = sum(pad_do, axis=1)
    first_padded_mask = (pad_sum == 0)
    idx_first_padded_dim_out = sum(cumprod(first_padded_mask))

    pad_left_first_padded_dim = pad_do[min(3, idx_first_padded_dim_out), 0]
    block_length = prod(dims_out[:idx_first_padded_dim_out + 1])

    # print_variable_states(
    #     dims_in, permute, dims_out,
    #     dims_mem_in, dims_mem_out,
    #     step_in, step_out,
    #     idx_first_padded_dim_out,
    #     pad_left_first_padded_dim,
    #     block_length
    # )

    # --- Requirments ---

    # Minimum block length check
    assert block_length >= 5, "Kernel cannot support shape: N < 5"

    return (
        dims_in,
        permute,
        dims_out,
        dims_mem_in,
        dims_mem_out,
        step_in,
        step_out,
        idx_first_padded_dim_out,
        pad_left_first_padded_dim,
        block_length
        )

def ceildiv(x: int, d: int) -> int:
    return -(x // -d)

def iceil(x: int, d: int) -> int:
    '''Integer ceiling function'''
    return ceildiv(x, d) * d

class TransposeSubvDims:
    def __init__(
        self,
        N : int,
        Y : int,
        X : int,
        C : int,
        permute: list,
        act_bits: int,
        scratch_buf: int,
        is_int16: bool = True,
        N_innermost: bool = False
    ):
        self.N = N
        self.Y = Y
        self.X = X
        self.C = C
        self.permute = permute
        self.perm_string = ["Di_outer", "Di_mid_o", "Di_mid_i", "Di_inner"]
        self.act_bits = act_bits
        self.scratch_buf = scratch_buf
        self.is_int16 = is_int16
        self.N_innermost = N_innermost

        if self.act_bits == 8:
            self.dtype = "int8"
        elif self.act_bits == 16:
            self.dtype = "int16"
        elif self.act_bits == 32:
            self.dtype = "int32"
        elif self.act_bits == 64:
            self.dtype = "int64"
        else:
            assert False, "Input datatype is not supported by kernel"

        self.input_shape = [C, X, Y, N]

        self.Do0, self.Do1, self.Do2, self.Do3 = get_permute_indices(self.permute)
        self.perm_for_kernel = [self.Do0, self.Do1, self.Do2, self.Do3]

        self.output_shape = [
                self.input_shape[self.perm_for_kernel[0]],
                self.input_shape[self.perm_for_kernel[1]],
                self.input_shape[self.perm_for_kernel[2]],
                self.input_shape[self.perm_for_kernel[3]]]

        self.Di0, self.Di1, self.Di2, self.Di3 = self.C, self.X, self.Y, self.N

        # print("self.input_shape_kernel", self.input_shape)
        # print("self.perm_for_kernel", self.perm_for_kernel)
        # print("perm_string__kernel", self.perm_string[self.permute[3]],
        #                              self.perm_string[self.permute[2]],
        #                              self.perm_string[self.permute[1]],
        #                              self.perm_string[self.permute[0]])
        # print("self.output_shape_kernel", self.output_shape)

        if self.N_innermost:
            self.param = {
                "subvolume": {
                    "Di_inner": self.C,
                    "Di_mid_i": self.X,
                    "Di_mid_o": self.Y,
                    "Di_outer": self.N
                },
                "order": {
                    "Do_inner": self.perm_string[self.permute[3]],
                    "Do_mid_i": self.perm_string[self.permute[2]],
                    "Do_mid_o": self.perm_string[self.permute[1]],
                    "Do_outer": self.perm_string[self.permute[0]]
                },
                "padding": {
                    "Di": {
                        "inner": [0, 0],
                        "mid_i": [0, 0],
                        "mid_o": [0, 0],
                        "outer": [0, 0]
                    },
                    "Do": {
                        "inner": [0, (iceil(self.N, 8) - self.N)],
                        "mid_i": [0, 0],
                        "mid_o": [0, 0],
                        "outer": [0, 0]
                    }
                },
                "dtype": self.dtype
            }
            self.input_subv_elem_qdq = iceil(self.N * self.Y * self.X * self.C, 64)
            self.output_subv_elem_qdq = iceil(iceil(self.N, 8) * self.Y * self.X * self.C, 64)

        else:
            self.param = {
                "subvolume": {
                    "Di_inner": self.C,
                    "Di_mid_i": self.X,
                    "Di_mid_o": self.Y,
                    "Di_outer": self.N
                },
                "order": {
                    "Do_inner": self.perm_string[self.permute[3]],
                    "Do_mid_i": self.perm_string[self.permute[2]],
                    "Do_mid_o": self.perm_string[self.permute[1]],
                    "Do_outer": self.perm_string[self.permute[0]]
                },
                "padding": {
                    "Di": {
                        "inner": [0, 0],
                        "mid_i": [0, 0],
                        "mid_o": [0, 0],
                        "outer": [0, 0]
                    },
                    "Do": {
                        "inner": [0, 0],
                        "mid_i": [0, 0],
                        "mid_o": [0, 0],
                        "outer": [0, 0]
                    }
                },
                "dtype": self.dtype
            }
            self.input_subv_elem_qdq = iceil(self.N * self.Y * self.X * self.C, 64)
            self.output_subv_elem_qdq = iceil(self.N * self.Y * self.X * self.C, 64)

        (self.dims_in,
        self.perm,
        self.dims_out,
        self.dims_mem_in,
        self.dims_mem_out,
        self.step_in,
        self.step_out,
        self.idx_first_padded_dim_out,
        self.pad_left_first_padded_dim,
        self.block_length) = translate_variables_with_model(self.param, self.input_shape, self.perm_for_kernel, self.output_shape)

def get_permute_indices(lst):
    if len(lst) != 4:
        raise ValueError("Input list must have exactly 4 elements")
    updated = [3 - x for x in lst]
    return updated[::-1]

log2 = np.log2
# array = lambda *x: np.array( x )
all = lambda *x: np.all( x )
any = lambda *x: np.any( x )

def ceil( n, d=1 ):
    if d == 1:
        return np.ceil( n )
    else:
        return d * np.ceil( n / d )

def sign( val ):
    if type( val ) == str:
        return type_decoder( val )[1]
    else:
        return val < 0

def sizeof( val ):
    return type_decoder( val )[0] / 8

class DimsHelper:
    def __init__( self, reset=0, bits=32 ):
        self.reset = reset
        self.bits = bits

    def __getitem__( self, key ):
        return getattr( self, key )

    def add_dimension( self, num, step ):
        inc = self.reset + step
        self.reset -= num * step
        return inc

    def from_steps( self, wraps, steps, next_loop_level=False ):
        wraps = make_tuple( wraps )
        steps = make_tuple( steps )
        assert len( steps ) in [1,2,3,4,5], "Only 1d to 5d address increments supported"
        assert len( wraps ) >= len( steps ) - 1, "Wrap dimesions passed are not sufficient"

        nums = []
        incs = []
        for i,s in enumerate( steps ):
            if i == len( wraps ):
                incs.append( self.reset + s )
                self.reset = 0
            else:
                if i < len( steps ) - 1:
                    if i % 3 == 2:
                        num = wraps[i]
                        nums.append( np.prod( wraps[:i+1] ) - 1 )
                    else:
                        num = wraps[i] - 1
                        nums.append( num )
                else:
                    num = wraps[i] - 1
                incs.append( self.add_dimension( num, s ))

                if ( next_loop_level and i == len( steps ) - 1 ) or ( i < len( steps ) - 1 and i % 3 == 2 ):
                    self.reset = -wraps[i] * s

        if len( incs ) == 1:
            return incs[0]
        else:
            return TypedNamedList([ f"uint{self.bits} num{n}" for n in range( len( nums ))] + [ f"int{self.bits} inc{n}" for n in range( len( incs ))], nums + incs )

def array( *x ):
    def unpack_list( l ):
        return [ unpack_list( i._values( )) if isinstance( i, NamedList ) else i for i in l ]
    x = unpack_list( x )
    return np.array( x ) if len( x ) > 1 else np.array( *x )

def sum( vals, axis=None ):
    if isinstance( vals, ( NamedList, dict )):
        vals = tuple( values( vals ))
    # print( vals )
    return np.sum( vals, axis=axis )

def prod( vals, axis=None ):
    if isinstance( vals, ( NamedList, dict )):
        vals = tuple( values( vals ))
    return np.prod( vals, axis=axis )

def cumsum( vals, axis=None ):
    if isinstance( vals, ( NamedList, dict )):
        vals = tuple( values( vals ))
    return np.cumsum( vals, axis=axis )

def cumprod( vals, axis=None ):
    if isinstance( vals, ( NamedList, dict )):
        vals = tuple( values( vals ))
    return np.cumprod( vals, axis=axis )

known_math_ops = ( "min", "max", "ceil", "sign", "sizeof", "DimsHelper", "from_steps", "range", "for", "in", "if", "else", "int", "log2", "prod", "array", "all", "any" )

def make_tuple( val ):
    if hasattr( val, "__len__" ):
        return tuple( val )
    else:
        return ( val, )

@dataclass
class dims_3d_param:
    num0: int   # 32-bit int
    num1: int   # 32-bit int
    inc0: int   # 32-bit int
    inc1: int   # 32-bit int
    inc2: int   # 32-bit int

@dataclass
class dims_5d_param:
    num0: int # 32-bit unsigned int
    num1: int # 32-bit unsigned int
    num2: int # 32-bit unsigned int
    num3: int # 32-bit unsigned int
    inc0: int # 32-bit int
    inc1: int # 32-bit int
    inc2: int # 32-bit int
    inc3: int # 32-bit int
    inc4: int # 32-bit int

def setup_kernel_params(subvdims: TransposeSubvDims):
    dims = DimsHelper()
    inner_loop = subvdims.block_length - 1
    outer_loop = prod( subvdims.dims_out ) // subvdims.block_length
    shift = 8 * sizeof( subvdims.dtype )
    shift_fin = shift * ( ceil( subvdims.block_length + subvdims.pad_left_first_padded_dim, max( 1, 4 // sizeof( subvdims.dtype ))) - subvdims.block_length - subvdims.pad_left_first_padded_dim + 1 )
    size_out = ceil( prod( subvdims.dims_mem_out ) * sizeof( subvdims.dtype ) / 64 )
    offset_in = sum( array( list(subvdims.param["padding"]["Di"].values()) )[ :, 0 ] * subvdims.step_in )
    offset_out = sum( array( list(subvdims.param["padding"]["Do"].values()) )[ :, 0 ] * subvdims.step_out )
    addr_in: dims_5d_param  = dims.from_steps( subvdims.dims_out, ( subvdims.step_in [ subvdims.perm [ 0 ] ], subvdims.step_in [ subvdims.perm [ 1 ] ], subvdims.step_in [ subvdims.perm [ 2 ] ], subvdims.step_in [ subvdims.perm [ 3 ] ], 0 ))
    inc_out = dims.from_steps( subvdims.block_length, sizeof( subvdims.dtype ))
    addr_out: dims_3d_param = dims.from_steps( subvdims.dims_out [ 1: ] * ( 1 - cumprod( sum( list(subvdims.param["padding"]["Do"].values()), 1 ) == 0 )[ :-1 ]) + cumprod( sum( list(subvdims.param["padding"]["Do"].values()), 1 ) == 0 )[ :-1 ], subvdims.step_out [ 1: ])

    # Collect all values and cast to int
    fields = [
        int(inner_loop),
        int(outer_loop),
        int(shift),
        int(shift_fin),
        int(size_out),
        int(offset_in),
        int(offset_out),
        int(addr_in.num0),
        int(addr_in.num1),
        int(addr_in.num2),
        int(addr_in.num3),
        int(addr_in.inc0),
        int(addr_in.inc1),
        int(addr_in.inc2),
        int(addr_in.inc3),
        int(addr_in.inc4),
        int(inc_out),
        int(addr_out.num0),
        int(addr_out.num1),
        int(addr_out.inc0),
        int(addr_out.inc1),
        int(addr_out.inc2),
    ]

    # Corresponding labels
    labels = [
        "inner_loop", "outer_loop", "shift", "shift_fin", "size_out",
        "offset_in", "offset_out",
        "addr_in.num0", "addr_in.num1", "addr_in.num2", "addr_in.num3",
        "addr_in.inc0", "addr_in.inc1", "addr_in.inc2", "addr_in.inc3", "addr_in.inc4",
        "inc_out",
        "addr_out.num0", "addr_out.num1", "addr_out.inc0", "addr_out.inc1", "addr_out.inc2"
    ]

    # Print all fields
    # print("\nPacked Struct Fields:")
    # print("{:<20} | {:>10}".format("Field", "Value"))
    # print("-" * 35)
    # for label, value in zip(labels, fields):
    #     print("{:<20} | {:>10}".format(label, value))

    # Perform struct packing
    struct_fields = struct.pack("<2H2b1H2I4I5i1i5i", *fields)

    layer_params = (
        subvdims.act_bits.to_bytes(length=4, byteorder='little', signed=False)
        + subvdims.input_subv_elem_qdq.to_bytes(length=4, byteorder='little', signed=False)
        + subvdims.output_subv_elem_qdq.to_bytes(length=4, byteorder='little', signed=False)
        + subvdims.is_int16.to_bytes(length=4, byteorder='little', signed=False)
        + subvdims.scratch_buf.to_bytes(length=4, byteorder='little', signed=False)

    )
    kernel_params = struct_fields + layer_params

    return kernel_params

if __name__ == '__main__':
    setup_kernel_params(TransposeSubvDims(1, 14, 14, 768, [1, 2, 3, 0], 0, True))
