
import os
CURRDIR = os.path.dirname(os.path.abspath(__file__))
import sys
sys.path.append(os.path.join(CURRDIR, '..', '..', 'dmacompiler'))
sys.path.append(os.path.join(CURRDIR, '..', '..', 'kernels', 'conv'))
from typing import List

from dmacompiler import \
    BackEnd, \
    set_dev_gen, DevGen, config
    
from dataflow_common import ceildiv, calculate_row_split, overlay_stack_addr
set_dev_gen(DevGen.Aie2p)
config.ENABLE_BUSY_POLL = True

def split_core_ifm_subv(
    aie_rows: int,
    ifm_H: int, 
    ifm_W: int, 
    ifm_C: int, 
    core_size_cutoff: int,
    ifm_bits: int
) -> int:
    
    
    for num_X_split in range(1, ifm_W + 1, 1):
        new_ifm_W = ifm_W // num_X_split 
        subv_size = (1 * new_ifm_W * ifm_C * ifm_bits) // 8
        mem_size = subv_size * 4 # ifm, ofm pingpong
        is_valid = (
            (mem_size <= core_size_cutoff) and
            ((num_X_split % aie_rows) == 0)
        )
        if is_valid:
            return num_X_split
    return -1

def dummy_node_ifm_split_cost(
    aie_cols: int, aie_rows: int,
    no_kernel_datapath: bool,
    ifm_H: int,
    ifm_W: int,
    ifm_C: int,
    ifm_bits: int,
    wgt_size_core: int, # not in pingpong
    wgt_size_ifm: int,  # not in pingpong
    ofm_ifm_subv_ratio: int = 1, # default ifm and ofm size is different 
):
    
    usable_mt_mem = 524288 - (config.MAX_CORE_LAYER_PARAM_SIZE * aie_rows) - wgt_size_ifm
    usable_core_mem = overlay_stack_addr() - wgt_size_core
    #the current limitation is the split from shim to memtile is only implemented on Height, so hte granulity is as below
    # at least there is a split available , it might not be effecieny effective 
    min_ifm_in_bytes = 1 * ifm_W * ifm_C * ifm_bits // 8 # input_rows = 1
    if no_kernel_datapath:
        assert min_ifm_in_bytes * (1 + ofm_ifm_subv_ratio) < usable_mt_mem, "N*H*W*C(N=1, H=1) has to be fitting in the memtile, with ifm and ofm not pingpong mode"
    else:
        # 3: ifm pingpong + ofm ping
        h_w_valid = (ifm_H * ifm_W > aie_rows) and (ifm_H * ifm_W % aie_rows == 0)
        if not h_w_valid:
            assert False, "Current shape can't be supported, future improvement needed for odd shapes"
        if  (min_ifm_in_bytes // aie_rows * 3) > usable_core_mem: #"N*H*W*C(N=1, H=1) has to be fitting in the (4) cores memory in the column, with ifm and ofm not pingpong mode"   
            num_X_split = split_core_ifm_subv(aie_rows, ifm_H, ifm_W, ifm_C, usable_core_mem, ifm_bits)
            if num_X_split == -1:
                assert False, "the Give shape is not able to be reshaped from depth to space"
            else:
                new_ifm_H = ifm_H * num_X_split
                new_ifm_W = ifm_W // num_X_split               
        else:
            new_ifm_H = ifm_H
            new_ifm_W = ifm_W
            
    max_splits = ceildiv(new_ifm_H, aie_cols)   
    #browse the splits to get cost function
    num_splits = []
    for split in range(1, max_splits + 1):
        row_in_splits = ceildiv(new_ifm_H, aie_cols * split)
        ifm_size_in_bytes =  row_in_splits * new_ifm_W * ifm_C * ifm_bits // 8
        ofm_size_in_bytes =  row_in_splits * new_ifm_W * ifm_C * ifm_bits // 8
        if (ifm_size_in_bytes + ofm_size_in_bytes) <= usable_mt_mem and (ifm_size_in_bytes + ofm_size_in_bytes) <= usable_core_mem * aie_rows:
            num_splits.append(split)
     
    return num_splits, new_ifm_H, new_ifm_W

class DepthToSpace_dims:
    def __init__(
        self,
        aie_rows: int, aie_cols: int,
        in_shape: List, blockSize: int, perm_mode: str,
        ifm_bits: int
    ):
        self.aie_rows = aie_rows
        self.aie_cols = aie_cols
        self.batch = in_shape[0]
        self.depth = in_shape[3]
        self.height = in_shape[1]
        self.width = in_shape[2]
        self.blockSize = blockSize
        self.perm_mode = 1 if perm_mode == "DCR" else 0
        self.ifm_bits = ifm_bits
        self.param_subv_size = config.MAX_CORE_LAYER_PARAM_SIZE
        
        #  Validate dimensions
        assert self.depth % (blockSize * blockSize) == 0, "Depth must be divisible by blockSize^2."
        

        # Compute the splits
        self.no_kernel_datapath = False
        # self.enable_ifm_pingpong = True
        # self.enable_ofm_pingpong = True
        self.wgt_subv_size = 0
        self.wgt_size_ifm  = 0
        
        num_splits, new_ifm_H, new_ifm_W = dummy_node_ifm_split_cost(aie_cols, aie_rows, self.no_kernel_datapath, 
                                         self.height, self.width, self.depth, self.ifm_bits,
                                         self.wgt_subv_size, self.wgt_size_ifm)
        
        # Compute the output dimensions
        self.height = new_ifm_H
        self.width = new_ifm_W
        
        self.outDepth  = self.depth // (blockSize * blockSize)
        self.outHeight = self.height * blockSize
        self.outWidth  = self.width * blockSize     
        
        
        """make judgement if core memory can be fit as pingpong
            1. if it is core involved, core memory decided, ignore memtile
               NOTE: becasue the memtile can be still in pingpong because it has large memory than (core mem *4),
                     will address this later
            2. if only memtile. will check indivudually. 
        
        """
        split_score_curr = 0
        split_score_past = 0
        split_curr = 0
        if not self.no_kernel_datapath: 
            for split in num_splits:
                ifm_minMem_per_core = (self.height // aie_cols // split) * self.width * self.depth * ifm_bits // 8 // aie_rows
                ofm_minMem_per_core = blockSize * (self.height // aie_cols // split) * self.outWidth * self.outDepth * ifm_bits // 8 // aie_rows
                avail_core_mem_size = overlay_stack_addr() - self.wgt_subv_size
                if ifm_minMem_per_core + ofm_minMem_per_core > avail_core_mem_size:
                    assert False, "the problem size(W and C ) might too big to fit in the core memory with current split(H only cross all core columns)"
                elif 2*ifm_minMem_per_core + 2*ofm_minMem_per_core <=avail_core_mem_size:
                    self.core_ifm_pingpong = True
                    self.core_ofm_pingpong = True
                    self.mt_ifm_pingpong = True
                    self.mt_ofm_pingpong = True   
                    if  (self.height % (self.aie_cols * split)  == 0):   
                        split_curr = split
                        break            
                elif 2*ifm_minMem_per_core + ofm_minMem_per_core <=avail_core_mem_size:
                    self.core_ifm_pingpong = True
                    self.core_ofm_pingpong = False
                    self.mt_ifm_pingpong = True
                    self.mt_ofm_pingpong = True
                    split_score_curr += 2
                else:
                    self.core_ifm_pingpong = False
                    self.core_ofm_pingpong = False   
                    self.mt_ifm_pingpong = False
                    self.mt_ofm_pingpong = False  
                    split_score_curr += 1       
                if split_score_curr > split_score_past:
                    split_curr = split
                    split_score_past = split_score_curr
        else: # only memtile datapath
            for split in num_splits:
                ifm_minMem_per_mt = (self.height // aie_cols // split) * self.width * self.depth * ifm_bits // 8 
                ofm_minMem_per_mt = blockSize * (self.height // aie_cols // split) * self.outWidth * self.outDepth * ifm_bits // 8 
                avail_mt_mem_size = 524588 - self.wgt_subv_size - (config.MAX_CORE_LAYER_PARAM_SIZE * aie_rows)
                if ifm_minMem_per_mt + ofm_minMem_per_mt > avail_mt_mem_size:
                    assert False, "the problem size(W and C ) might too big to fit in the memtile memory with current split(H only cross all  columns)"
                elif 2*ifm_minMem_per_mt + 2*ofm_minMem_per_mt <=avail_mt_mem_size:
                    self.mt_ifm_pingpong = True
                    self.mt_ofm_pingpong = True
                    split_curr = split
                    break                     
                elif 2*ifm_minMem_per_mt + ofm_minMem_per_mt <=avail_mt_mem_size:
                    self.mt_ifm_pingpong = True
                    self.mt_ofm_pingpong = False
                    split_score_curr += 2
                else:
                    self.mt_ifm_pingpong = False
                    self.mt_ofm_pingpong = False  
                    split_score_curr += 1
                if split_score_curr > split_score_past:
                    split_curr = split
                    split_score_past = split_score_curr                                   
        
        if split_curr == 0:
            assert False, "the Split caculation wrong!"
        self.num_splits = split_curr 
                       
          
        self.param_subv_size = config.MAX_CORE_LAYER_PARAM_SIZE
        self.input_rows_split = self.height // self.aie_cols // self.num_splits
        self.input_cols_split = self.width     
        
        self.Yi = self.height
        self.Yis = self.input_rows_split 
        self.ifm_subv_elem = self.Yis * self.width * self.depth // aie_rows  
        
        self.Yo = self.outHeight 
        self.Yos = self.input_rows_split * blockSize
        self.Y_loop = ceildiv(self.Yi, (self.aie_cols * self.Yis))
        
        self.Xi = self.width
        self.Xis = self.input_cols_split 

        self.Ci = self.depth
        self.Cis = self.depth 

        self.Xo = self.outWidth
        self.Xos = self.outWidth 

        self.Co = self.depth // (blockSize * blockSize)
        self.Cos = self.depth // (blockSize * blockSize)
        
        self.X_split = 1
        self.output_cols_Split = 1
        self.Co_split = 1
                

def depthtospace_preproc_directives(
    dims: DepthToSpace_dims,
    back_end: BackEnd,
) -> List[str]:
    def directive(ident: str, val: int) -> str:
        if back_end == BackEnd.Adf:
            return f'--Xpreproc="-D{ident}={val}"'
        return f"-D{ident}={val}"
    txn_mode = int(back_end != BackEnd.Adf)
    return [
        directive('AIE_COLS', dims.aie_cols),
        directive('AIE_ROWS', dims.aie_rows),
        directive('N_BATCH', dims.batch),
        directive('C_DEPTH', dims.depth),
        directive('Y_HEIGHT', dims.height),
        directive('X_WIDTH', dims.width),
        directive('PERMUTE_MODE', dims.perm_mode),
        directive('BLOCK_SIZE', dims.blockSize),
        directive('TXN_MODE', txn_mode),
    ]
    
