
import os
CURRDIR = os.path.dirname(os.path.abspath(__file__))
import sys
sys.path.append(os.path.join(CURRDIR, '..', '..', 'dmacompiler'))
from typing import List

from dmacompiler import \
    BackEnd, \
    set_dev_gen, DevGen, config
    
from dataflow_common import ceildiv, calculate_row_split, overlay_stack_addr
set_dev_gen(DevGen.Aie2p)
config.ENABLE_BUSY_POLL = True

class SliceDims:
    def __init__(
        self,
        aie_rows: int, aie_cols: int,
        h_in: int, w_in: int,
        w_out_start: int, w_out_stop: int,
        ifm_bits: int,
    ):
        self.aie_rows = aie_rows
        self.aie_cols = aie_cols
        self.h_in = h_in
        self.w_in = w_in
        self.c_in = 1
        self.w_out_start = w_out_start
        self.w_out_stop = w_out_stop
        self.w_out = w_out_stop - w_out_start
        self.ifm_bits = ifm_bits
        self.num_splits = calculate_row_split(False, self.h_in, self.w_out, self.c_in, self.ifm_bits, 0, overlay_stack_addr(), 8, 4, True, True)
        self.CoreqdqPrmSize = 64
        self.ifm_subv_elem = h_in * (w_out_stop - w_out_start) // aie_cols // self.num_splits // aie_rows

        self.h_out = h_in
        
        self.param_subv_size = config.MAX_CORE_LAYER_PARAM_SIZE
        self.wgt_subv_size   = 0
        
        self.input_rows_split = self.h_in // self.aie_cols // self.num_splits
        self.input_cols_split = self.w_in     
        
        self.Yi = self.h_in
        self.Yis = self.input_rows_split 
        
        self.Yo = self.h_out 
        self.Yos = self.input_rows_split
        self.Y_loop = ceildiv(self.Yo, (self.aie_cols * self.Yos))
        
        self.Xi = self.w_out
        self.Xis = self.input_cols_split 
        
        self.Xo = self.w_out
        self.Xos = self.w_in 
        
        self.X_split = 1
        self.output_cols_Split = 1
        

def slice_neg_preproc_directives(
    dims: SliceDims,
    back_end: BackEnd,
) -> List[str]:      
    def directive(ident: str, val: int) -> str:
        if back_end == BackEnd.Adf:
            return f'--Xpreproc="-D{ident}={val}"'
        return f"-D{ident}={val}"
    txn_mode = int(back_end != BackEnd.Adf)
    return [
        directive('AIE_COLS', dims.aie_cols),
        directive('AIE_ROWS', dims.aie_rows),
        directive('H_IN', dims.h_in),
        directive('W_IN', dims.w_in),
        directive('W_OUT_START', dims.w_out_start),
        directive('W_OUT_STOP', dims.w_out_stop),
        directive('H_OUT', dims.h_out),
        directive('W_OUT', dims.w_out),
        directive('WGT_SIZE', dims.wgt_subv_size),
        directive('TXN_MODE', txn_mode),
    ]
