# pylint: skip-file
#######################################################################################
from dmacompiler import CascDir

from dataflow.dataflow_common import overlay_4x4_core_stream_bdcast, overlay_8x4_core_stream_bdcast
from math import ceil
from enum import Enum
import struct

class StreamType(Enum):
    BC = 1
    UC = 2

class CoreStreamChan(Enum):
    UC_channel = 0
    BC_channel = 1

class SplitType(Enum):
    RowSplit = 0
    ColSplit = 1

class NormTransferType(Enum):
    IfmMM2S  = 0
    IfmS2MM  = 1
    OfmS2MM  = 2
    OfmMM2S  = 3

class AccessFormat(Enum):
    W8 = 0
    Linear = 1

def largest_factor_pair(n):
    for d in range(floor(sqrt(n)), 1, -1):
        if (n % d) == 0:
            return (d, n // d)
    return (1, n)

def listit(t):
    if type(t) == list or type(t) == tuple:
        return [listit(i) for i in t]
    return t

def cumulative_sum_upto(array, idx):
    """
    Calculate the cumulative sum of elements in a 2D array or list up to a given index (i, j).
    
    Parameters:
        array (numpy.ndarray or list of lists): The input 2D array or list.
        idx (tuple): The index (i, j) up to which the sum is calculated.
    
    Returns:
        int: The cumulative sum of elements up to the given index.
    """
    i, j = idx
    total_sum = 0
    
    # Handle both numpy arrays and Python lists
    for row in range(i + 1):
        if row == i:
            # For the last row, sum only up to column j (excluding the element at (i, j))
            total_sum += sum(array[row][:j])
        else:
            # For all previous rows, sum all columns
            total_sum += sum(array[row])
    
    return total_sum

def access_w8_subvolume(
    subv_rows: int,
    subv_cols: int,
    elem_bytes: int
) -> tuple:
    assert subv_cols % 8 == 0
    step0 = 1
    wrap0 = (8 * elem_bytes) // 4
    step1 = (subv_rows * 8 * elem_bytes) // 4
    wrap1 = (subv_cols // 8)
    step2 = (8 * elem_bytes) // 4
    wrap2 = None
    step3 = None
    if wrap1 > config.MAX_MEMTILE_WRAP:
        v1, v2 = largest_factor_pair(wrap1)
        wrap1 = v1
        step2 = v1*subv_rows*8*elem_bytes//4
        wrap2 = v2
        step3 =  (8 * elem_bytes) // 4
    return (step0, step1, step2, step3, wrap0, wrap1, wrap2)

def access_w8_rd(
    subv_rows: int,
    subv_cols: int,
    elem_bytes: int
) -> tuple:
    assert subv_cols % 8 == 0
    step0 = 1
    wrap0 = (8 * elem_bytes) // 4
    step1 = (subv_cols * elem_bytes) // 4
    wrap1 = subv_rows
    step2 = (8 * elem_bytes) // 4
    return (step0, step1, step2, wrap0, wrap1)

def lrn_kernel_params(  Nlrn: int,
                        Nsubv: int,
                        split_type: int,
                        col_id: int,
                        bias_addr: int,
                        gamma_offset: int,
                        beta_offset: int,
                        qdq_addr: int,
                        ) -> bytes:

    kernel_params = Nlrn.to_bytes(length=2, byteorder='little', signed=False) + \
                    Nsubv.to_bytes(length=2, byteorder='little', signed=False) + \
                    split_type.to_bytes(length=2, byteorder='little', signed=False) + \
                    col_id.to_bytes(length=2, byteorder='little', signed=False) + \
                    bias_addr.to_bytes(length=2, byteorder='little', signed=False) + \
                    gamma_offset.to_bytes(length=2, byteorder='little', signed=False) + \
                    beta_offset.to_bytes(length=2, byteorder='little', signed=False) + \
                    qdq_addr.to_bytes(length=2, byteorder='little', signed=False)
                    
    return kernel_params

def lpnorm_kernel_params(Nlrn: int,
                         Nsubv: int,
                         col_id: int,
    			 split_type: int
                        ) -> bytes:

    kernel_params = Nlrn.to_bytes(length=2, byteorder='little', signed=False) + \
                    Nsubv.to_bytes(length=2, byteorder='little', signed=False) + \
                    col_id.to_bytes(length=2, byteorder='little', signed=False) + \
    					split_type.to_bytes(length=2, byteorder='little', signed=False) 
                    
    return kernel_params

def sfmx_params(
    Nlayer: int,
    Msubv: int,
    Nsubv: int,
    SplitType : int,
    col_id : int, 
    row_id : int,
    q_node_addr : int, 
    dq_node_addr : int,
    in_addr : int,
    out_addr : int,
    fuse_mode : int,
    scalefactor : float
):
    float_bytes = struct.pack('f', scalefactor)
    uint32_scalefactor = struct.unpack('I', float_bytes)[0]

    #print("uint32_scalefactor:", uint32_scalefactor)
    dummy = 0  ## for 4 bytes alignment

    return (
           Nlayer.to_bytes(length=2, byteorder='little', signed=False) + \
           Msubv.to_bytes(length=2, byteorder='little', signed=False) + \
           Nsubv.to_bytes(length=2, byteorder='little', signed=False) + \
           SplitType.to_bytes(length=2, byteorder='little', signed=False) + \
           col_id.to_bytes(length=2, byteorder='little', signed=False) + \
           row_id.to_bytes(length=2, byteorder='little', signed=False) + \
           q_node_addr.to_bytes(length=2, byteorder='little', signed=False) + \
           dq_node_addr.to_bytes(length=2, byteorder='little', signed=False) + \
           in_addr.to_bytes(length=2, byteorder='little', signed=False) + \
           out_addr.to_bytes(length=2, byteorder='little', signed=False) + \
           fuse_mode.to_bytes(length=2, byteorder='little', signed=False) + \
           dummy.to_bytes(length=2, byteorder='little', signed=False) + \
           uint32_scalefactor.to_bytes(length=4, byteorder='little', signed=False)
    )

def get_core_instrs(lrn, col, row):
    return lrn.get_core_instrs(col, row)

def get_shim_transfers(lrn):
    return lrn.get_shim_transfers()

def get_memtile_transfers(lrn):
    return lrn.get_memtile_transfers()


class layer_norm:
    pass
    
def shim_step_wrap( transferLength, ElemBytes, N):

    step_list = [1, (N * ElemBytes) // 4]
    wrap_list = [(transferLength * ElemBytes) // 4]
    if wrap_list[0] > config.MAX_SHIM_WRAP:
        step_list.append(step_list[1])
        w1, s1 = largest_factor_pair(wrap_list[0])
        step_list[1] = s1
        wrap_list.append(w1)
        wrap_list[0] = step_list[1]
     
    return step_list, wrap_list

def set_core_addr(actSize, biasSize, qdqSize, outSize, stackSize, totalSize, debugPrint):

    is_inPlace = False
    CoreInPingAddr = 0 
    CoreInPongAddr, CoreOutPingAddr, CoreOutPongAddr, CoreBiasPingAddr, CoreQdqPrmAddr = None, None, None, None, None
    
    # Act, Out double buffering
    alloc_0 = (2*actSize+biasSize+qdqSize+2*outSize+stackSize) < totalSize
    # Act alone double buffering
    alloc_1 = (2*actSize+biasSize+qdqSize+outSize+stackSize) < totalSize
    # Out alone double buffering
    alloc_2 = (actSize+biasSize+qdqSize+2*outSize+stackSize) < totalSize
    # Act, Out Single buffering
    alloc_3 = (actSize+biasSize+qdqSize+outSize+stackSize) < totalSize
    # Act, Out Same buffers - Inplace
    alloc_4 = (actSize+biasSize+qdqSize+stackSize) < totalSize
    
    # Assert if it does not fit within core DM
    assert alloc_4, "L1 memory allocation failed - check set_core_addr()"

    if alloc_0:
        if debugPrint: print("L1 Act, Out double buffering")
        CoreInPongAddr   = CoreInPingAddr + actSize
        CoreOutPingAddr  = CoreInPongAddr + actSize 
        CoreOutPongAddr  = CoreOutPingAddr + outSize
        CoreBiasPingAddr = CoreOutPongAddr + outSize
        CoreQdqPrmAddr   = CoreBiasPingAddr + biasSize 
    elif alloc_1:
        if debugPrint: print("L1 Act alone double buffering")
        CoreInPongAddr   = CoreInPingAddr + actSize
        CoreOutPingAddr  = CoreInPongAddr + actSize 
        CoreBiasPingAddr = CoreOutPingAddr + outSize
        CoreQdqPrmAddr   = CoreBiasPingAddr + biasSize 
    elif alloc_2:
        if debugPrint: print("L1 Out alone double buffering")
        CoreOutPingAddr  = CoreInPingAddr + actSize 
        CoreOutPongAddr  = CoreOutPingAddr + outSize
        CoreBiasPingAddr = CoreOutPongAddr + outSize
        CoreQdqPrmAddr   = CoreBiasPingAddr + biasSize 
    elif alloc_3:
        if debugPrint: print("L1 Act, Out single buffering")
        CoreOutPingAddr  = CoreInPingAddr + actSize 
        CoreBiasPingAddr = CoreOutPingAddr + outSize
        CoreQdqPrmAddr   = CoreBiasPingAddr + biasSize 
    elif alloc_4:
        if debugPrint: print("L1 Act, Out same buffers - inplace")
        is_inPlace = True
        CoreOutPingAddr  = CoreInPingAddr 
        CoreBiasPingAddr = CoreOutPingAddr + outSize
        CoreQdqPrmAddr   = CoreBiasPingAddr + biasSize 

    return CoreInPingAddr, CoreInPongAddr, CoreOutPingAddr, CoreOutPongAddr, CoreBiasPingAddr, CoreQdqPrmAddr, is_inPlace

def split_val(value, max_val):
    result = []
    while value > max_val:
        result.append(max_val)
        value -= max_val
    result.append(value)
    return result

def re_enqueue_mem(x, max_rep):
    max_enq = np.max(np.ceil(x/max_rep)).astype(int)
    y = np.zeros(np.shape(x)[:] + (int(max_enq),), int)

    for inst in range(0,np.shape(x)[0]):
        for col in range(0,np.shape(x)[1]):
            s = split_val(x[inst,col], max_rep)
            y[inst,col,0:len(s)] = s
    return y # (inst,col, n_enq)

def set_mem_addr( actSize, biasSize, qdqSize, outSize, coreParamSize, totalSize, debugPrint):

    MemPrmAddr = 0
    InPingAddr, InPongAddr, BiasPingAddr = None, None, None
    OutPingAddr, OutPongAddr, QdqPrmAddr = None, None, None
    
    # Act, Out double buffering
    alloc_0 = (2*actSize+biasSize+qdqSize+2*outSize+coreParamSize) < totalSize
    # Act alone double buffering
    alloc_1 = (2*actSize+biasSize+qdqSize+outSize+coreParamSize) < totalSize
    # Out alone double buffering
    alloc_2 = (actSize+biasSize+qdqSize+2*outSize+coreParamSize) < totalSize
    # Act, Out Single buffering
    alloc_3 = (actSize+biasSize+qdqSize+outSize+coreParamSize) < totalSize
    
    # Assert if it does not fit within mem tile
    assert alloc_3, "L2 memory allocation failed - check set_mem_addr()"

    if alloc_0:
        if debugPrint: print("L2 Act, Out double buffering")
        InPingAddr   = MemPrmAddr + coreParamSize
        InPongAddr   = InPingAddr + actSize
        OutPingAddr  = InPongAddr + actSize 
        OutPongAddr  = OutPingAddr + outSize
        BiasPingAddr = OutPongAddr + outSize
        QdqPrmAddr   = BiasPingAddr + biasSize 
    elif alloc_1:
        if debugPrint: print("L2 Act alone double buffering")
        InPingAddr   = MemPrmAddr + coreParamSize
        InPongAddr   = InPingAddr + actSize
        OutPingAddr  = InPongAddr + actSize 
        BiasPingAddr = OutPingAddr + outSize
        QdqPrmAddr   = BiasPingAddr + biasSize 
    elif alloc_2:
        if debugPrint: print("L2 Out alone double buffering")
        InPingAddr   = MemPrmAddr + coreParamSize
        OutPingAddr  = InPingAddr + actSize 
        OutPongAddr  = OutPingAddr + outSize
        BiasPingAddr = OutPongAddr + outSize
        QdqPrmAddr   = BiasPingAddr + biasSize 
    elif alloc_3:
        if debugPrint: print("L2 Act, Out single buffering")
        InPingAddr   = MemPrmAddr + coreParamSize
        OutPingAddr  = InPingAddr + actSize 
        BiasPingAddr = OutPingAddr + outSize
        QdqPrmAddr   = BiasPingAddr + biasSize 

    return MemPrmAddr, InPingAddr, InPongAddr, OutPingAddr, OutPongAddr, BiasPingAddr, QdqPrmAddr

def bytes_to_words(elem_bytes):
    word_size = 4
    assert elem_bytes % word_size == 0
    elem_words = elem_bytes // word_size
    return elem_words

