
from dataclasses import dataclass
import os

from dataflow_utils import CommonDims
CURRDIR = os.path.dirname(os.path.abspath(__file__))
import sys
sys.path.append(os.path.join(CURRDIR, '..', '..', 'dmacompiler'))
from typing import List

from dmacompiler import \
    BackEnd, \
    set_dev_gen, DevGen, config
    
set_dev_gen(DevGen.Aie2p)
config.ENABLE_BUSY_POLL = True

@dataclass
class ResizeDims(CommonDims):
    def __init__(
        self,
        aie_rows: int, 
        aie_cols: int,
        Ni: int,
        Yi: int, 
        Xi: int, 
        Ci: int,
        Nis: int,
        Yis: int,
        Xis: int,
        Cis: int,
        N_loop: int,
        Y_loop: int,
        X_loop: int,
        C_loop: int,
        num_interpolations: int,
        ifm_bits: int,
        int_16: int, 
        bfloat_16:  int
    ):
        
        self.param_subv_size = config.MAX_CORE_LAYER_PARAM_SIZE

        self.int_16 = int_16
        self.bfloat_16 = bfloat_16

        self.num_interpolations = num_interpolations
        
        self.N_loop = N_loop 
        self.Y_loop = Y_loop 
        self.X_loop = X_loop
        self.C_loop = C_loop

        super().__init__(
            aie_cols=aie_cols,
            aie_rows=aie_rows,
            Ni=Ni,
            Yi=Yi,
            Xi=Xi,
            Ci=Ci,
            No=Ni,
            Yo=Yi * num_interpolations,
            Xo=Xi * num_interpolations,
            Co=Ci,
            Nis=Nis,
            Yis=Yis,
            Xis=Xis,
            Cis=Cis,
            Nos=Nis,
            Yos=Yis,
            Xos=Xis,
            Cos=Cis,
            ifm_bits=ifm_bits,
            ofm_bits=ifm_bits
        )


        
        
def resize_preproc_directives(
    dims: ResizeDims,
    back_end: BackEnd,
) -> List[str]:      
    def directive(ident: str, val: int) -> str:
        if back_end == BackEnd.Adf:
            return f'--Xpreproc="-D{ident}={val}"'
        return f"-D{ident}={val}"
    txn_mode = int(back_end != BackEnd.Adf)
    return [
        directive('AIE_COLS', dims.aie_cols),
        directive('AIE_ROWS', dims.aie_rows),
        directive('H_IN', dims.Yi),
        directive('W_IN', dims.Xi),
        directive('C_IN', dims.Ci),
        directive('NUM_INTERPOLATIONS', dims.num_interpolations),
        directive('H_OUT', dims.Yo),
        directive('W_OUT', dims.Xo),
        directive('C_OUT', dims.Co),
        directive('INT_16', dims.int_16),
        directive('BFLOAT_16', dims.bfloat_16),
        directive('TXN_MODE', txn_mode),
    ]
