import os
import sys
from typing import List, Union, Dict
from dataclasses import dataclass

CURRDIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.join(CURRDIR, '..', '..', 'dmacompiler'))

from dmacompiler import (
    BackEnd,
    set_dev_gen, DevGen, config
)

from dataflow_common import iceil, ceildiv
from dataflow_utils import CommonDims

set_dev_gen(DevGen.Aie2p)
config.ENABLE_BUSY_POLL = True


@dataclass
class PadDims(CommonDims):
    def __init__(
        self,
        aie_cols: int,
        aie_rows: int,
        input: List[int],
        Cip: int,
        output: List[int],
        Cop: int,
        pad_dims: List[int],
        in_gran: List[int],
        out_gran: List[int],
        pad_limit: List[int],
        ifm_bits: int,
        ofm_bits: int,
        wgt_subv_size: int,
        qdq_mode: int,
        fix_point_bits: int,
        is_signed: bool,
        param_subv_size: int,
        ifm_memtile_size: int,
        ofm_memtile_size: int,
        ifm_core_size: int,
        ofm_core_size: int,
        scratch_buf_size: int,


        #from cost function
        core_subv: List[int],
        mt_subv: List[int],
        loop: List[int],
        ping_pong: bool,
        spatial_split_mode: List[int],
        row_split_mode: List[int],

    ):
        # Initialize CommonDims fields via super
        super().__init__(
            aie_cols=aie_cols,
            aie_rows=aie_rows,
            ifm_bits=ifm_bits,
            ofm_bits=ofm_bits,
        )

        # PadDims-specific fields
        self.input = input
        self.output = output
        self.Ni = input[0]
        self.Yi = input[1]
        self.Xi = input[2]
        self.Ci = input[3]
        self.No = output[0]
        self.Yo = output[1]
        self.Xo = output[2]
        self.Co = output[3]
        self.Cip = Cip
        self.Cop = Cop
        self.pad_dims = pad_dims
        self.pad_N = pad_dims[0]
        self.pad_Y = pad_dims[1]
        self.pad_X = pad_dims[2]
        self.pad_C = pad_dims[3]
        self.in_gran = in_gran
        self.Nis_gran = in_gran[0]
        self.Yis_gran = in_gran[1]
        self.Xis_gran = in_gran[2]
        self.Cis_gran = in_gran[3]
        self.out_gran = out_gran
        self.Nos_gran = out_gran[0]
        self.Yos_gran = out_gran[1]
        self.Xos_gran = out_gran[2]
        self.Cos_gran = out_gran[3]
        self.pad_limit = pad_limit
        self.core_subv = core_subv
        self.Nis = core_subv[0]
        self.Yis = core_subv[1]
        self.Xis = core_subv[2]
        self.Cis = core_subv[3]
        self.mt_subv = mt_subv
        self.Nim = mt_subv[0]
        self.Yim = mt_subv[1]
        self.Xim = mt_subv[2]
        self.Cim = mt_subv[3]
        self.loop = loop
        self.N_loop = loop[0]
        self.Y_loop = loop[1]
        self.X_loop = loop[2]
        self.C_loop = loop[3]
        self.ping_pong = ping_pong
        self.spatial_split_mode = spatial_split_mode
        self.row_split_mode = row_split_mode

        self.is_qdq = qdq_mode != 3
        self.qdq_mode = qdq_mode
        self.fix_point_bits = fix_point_bits
        self.is_signed = is_signed
        self.param_subv_size = param_subv_size
        self.wgt_subv_size = wgt_subv_size
        self.ifm_memtile_size = ifm_memtile_size
        self.ofm_memtile_size = ofm_memtile_size
        self.ifm_core_size = ifm_core_size
        self.ofm_core_size = ofm_core_size
        self.scratch_buf_size = scratch_buf_size



def pad_preproc_directives(
    dims: PadDims,
    back_end: BackEnd,
) -> List[str]:
    def directive(ident: str, val: int) -> str:
        if back_end == BackEnd.Adf:
            return f'--Xpreproc="-D{ident}={val}"'
        return f"-D{ident}={val}"
    txn_mode = int(back_end != BackEnd.Adf)
    return [
        directive('AIE_COLS', dims.aie_cols),
        directive('AIE_ROWS', dims.aie_rows),
        directive('N_IN', dims.Ni),
        directive('Y_IN', dims.Yi),
        directive('X_IN', dims.Xi),
        directive('C_IN', dims.Cip),
        directive('N_OUT', dims.No),
        directive('Y_OUT', dims.Yo),
        directive('X_OUT', dims.Xo),
        directive('C_OUT', dims.Cop),
        directive('WGT_SIZE', dims.wgt_subv_size),
        directive('QDQ_MODE', dims.qdq_mode),
        directive("FIX_POINT_BITS", dims.fix_point_bits),
        directive("IS_SIGNED", int(dims.is_signed)),
        directive('TXN_MODE', txn_mode),
    ]