
from dataclasses import dataclass
import os

CURRDIR = os.path.dirname(os.path.abspath(__file__))
import sys
sys.path.append(os.path.join(CURRDIR, '..', '..', 'dmacompiler'))
sys.path.append(os.path.join(CURRDIR, '..'))
sys.path.append(os.path.join(CURRDIR, '..', '..'))
from typing import List

from dmacompiler import \
    BackEnd, \
    set_dev_gen, DevGen, config

from dataflow_utils import CommonDims
set_dev_gen(DevGen.Aie2p)
config.ENABLE_BUSY_POLL = True

@dataclass
class QDQDims(CommonDims):
    def __init__(
        self,
        Ni : int,
        Yi : int,
        Xi : int,
        Ci : int,
        Cip : int,
        No : int,
        Yo : int,
        Xo : int,
        Co : int,
        Cop : int,
        ifm_bits : int,
        ofm_bits : int,
        fixed_point_bits: int,
        total_Y : int,
        total_X : int,
        CoreqdqPrmSize: int,
        param_subv_size: int,
        wgt_subv_size: int,
        subv_elem: int,
        subv_size_input: int,
        subv_size_output: int,
        Yis: int,
        Yos: int,
        Y_loop: int,
        op_type: str,
        qdq_mode: int,
        aie_cols: int,
        aie_rows: int
    ):
        self.CoreqdqPrmSize = CoreqdqPrmSize
        self.param_subv_size = param_subv_size
        self.wgt_subv_size = wgt_subv_size
        self.total_Y = total_Y
        self.total_X = total_X
        self.subv_elem = subv_elem
        self.subv_size_input = subv_size_input
        self.subv_size_output = subv_size_output
        self.Y_loop = Y_loop
        self.op_type = op_type
        self.qdq_mode = qdq_mode
        self.fixed_point_bits = fixed_point_bits
        super().__init__(
            aie_cols=aie_cols,
            aie_rows=aie_rows,
            Ni=Ni,
            Yi=Yi,
            Xi=Xi,
            Ci=Ci,
            Cip=Cip,
            No=No,
            Yo=Yo,
            Xo=Xo,
            Co=Co,
            Cop=Cop,
            Yis=Yis,
            Yos=Yos,
            ifm_bits=ifm_bits,
            ofm_bits=ofm_bits,
        )


def q_dq_preproc_directives(
    dims: QDQDims,
    back_end: BackEnd,
) -> List[str]:      
    def directive(ident: str, val: int) -> str:
        if back_end == BackEnd.Adf:
            return f'--Xpreproc="-D{ident}={val}"'
        return f"-D{ident}={val}"
    txn_mode = int(back_end != BackEnd.Adf)
    return [
        directive('AIE_COLS', dims.aie_cols),
        directive('AIE_ROWS', dims.aie_rows),
        directive('H_IN', dims.Yi),
        directive('W_IN', dims.Xi),
        directive('C_IN', dims.Cip),
        directive('IFM_BYTES', dims.ifm_bits // 8),
        directive('WGT_SIZE', dims.wgt_subv_size),
        directive('QDQ_MODE', dims.qdq_mode),
        directive('TXN_MODE', txn_mode),
        directive('FIXED_POINT_BIT_SIZE', dims.fixed_point_bits)
    ]
