import os
CURRDIR = os.path.dirname(os.path.abspath(__file__))
import sys
sys.path.append(os.path.join(CURRDIR, '..', '..', 'dmacompiler'))
sys.path.append(os.path.join(CURRDIR, '..', '..', 'kernels', 'conv'))
from typing import List, Union
from dataclasses import dataclass

from dataflow_common import ceildiv, iceil, overlay_stack_addr

from dmacompiler import \
    BackEnd, \
    set_dev_gen, DevGen, config

from dataflow_utils import CommonDims

set_dev_gen(DevGen.Aie2p)
config.ENABLE_BUSY_POLL = True

def padding(dIn: input):
    # NOTE: This is W8 padding
    dOut = iceil(dIn, 8)
    return dOut

@dataclass
class TransposeKernelDims(CommonDims):
    is_kernel_transpose: bool = True

@dataclass
class TransposeDims:
    is_kernel_transpose: bool = False

def transpose_preproc_directives(
    dims: Union[TransposeDims, TransposeKernelDims],
    back_end: BackEnd,
) -> List[str]:
    with open("shapes.txt", "w") as file:
        file.write(",".join(map(str, dims.input)) + "\n")
        file.write(",".join(map(str, dims.perm)) + "\n")

    def directive(ident: str, val: int) -> str:
        if back_end == BackEnd.Adf:
            return f'--Xpreproc="-D{ident}={val}"'
        return f"-D{ident}={val}"
    txn_mode = int(back_end != BackEnd.Adf)
    return [
        directive('AIE_COLS', dims.aie_cols),
        directive('AIE_ROWS', dims.aie_rows),
        # directive('ACT_BITS', dims.ifm_bits),
        directive('IS_INT16', int(dims.is_int16)),
        directive('BATCH_SIZE', dims.batch_size),
        directive('N_OP', dims.Nop),
        directive('Y_OP', dims.Yop),
        directive('X_OP', dims.Xop),
        directive('C_OP', dims.Cop),
        directive('N_IP', dims.Nip),
        directive('Y_IP', dims.Yip),
        directive('X_IP', dims.Xip),
        directive('C_IP', dims.Cip),
        directive('TXN_MODE', txn_mode),
        directive('QDQ_MODE', dims.qdq_mode),
        directive('IS_SIGNED', int(dims.is_signed)),
    ]
