'''
Map conv shapes to the AIE-4 dataflow architecture.
External facing functions are documented below.

    generate_mappings - enumerate all possible ways to map a conv shape
    onto the compute array and sort them in order of descending projected
    latency (fastest mappings first)
'''
from dataclasses import dataclass
from utils.utils_common import ceildiv, log, ifloor

from dmacompiler import (
    DevGen, set_dev_gen, config
)
set_dev_gen(DevGen.Aie4)


@dataclass
class BinaryShape:
    """Dataclass for binary shape"""
    Ci: int
    Yi: int
    Xi: int
    Co: int
    Yo: int
    Xo: int
    ifm_bytes: int
    wgt_bytes: int = 1
    Ni: int = 1
    No: int = 1


class BinaryL2Tiler:
    """
    Tiler that first tries a 'equal' tiling (mode 0):
      - One per-core subvolume `sv` (multiple of min_subvolume, <= per-core cap)
      - Iterations 1..(n-1): all cores active, each does `sv`
      - Last iteration: some active cores, each also does `sv`
      - So: total = full_iters * C * sv + active_last * sv
    If no equal tiling is possible, falls back to mode 1 (original remainder logic).
    """
    def __init__(self, shape: BinaryShape,
                 num_cores: int = 12,
                 kernel_gran: int = 64,
                 kernel_loop_range: int = 8,
                 q_enable: int = 0,
                 dq_enable: int = 0,
                 core_bank_mem_size_software: int = 16384
                 ):

        self.shape = shape
        self.total_ifm_elements = self.shape.Ci * self.shape.Yi * self.shape.Xi
        self.num_cores = num_cores
        self.kernel_loop_range = kernel_loop_range

        self.kernel_subv_requirement = kernel_gran * kernel_loop_range
        self.memtile_subv_requirement = 4 // self.shape.ifm_bytes
        self.ifm_bytes = self.shape.ifm_bytes

        if q_enable == 0 and dq_enable == 0:
            self.max_subvolume = core_bank_mem_size_software // self.ifm_bytes
        else:
            qdq_kernel_gran = 32
            q_loop_range = 10
            dq_loop_range = 12
            self.max_subvolume = ifloor(core_bank_mem_size_software,
                                        (qdq_kernel_gran*max(q_loop_range, dq_loop_range)
                                         )) // self.ifm_bytes

    def compute_tiling(self) -> tuple[int, int, bool, int]:
        """
        Choose a per-core subvolume S (multiple of 4, <= self.max_subvolume, and S | self.total_ifm_elements) and schedule:
        - All cores active for all but possibly the last iteration
        - Each active core always receives the SAME subvolume S
        - Last iteration may have fewer than `cores` active cores
        - Core executes in 512-chunks; overcompute per active core per run is ceildiv(S,512)*512 - S
        Optimization priorities:
        1) Minimize TOTAL overcompute
        2) Maximize number of active cores in the LAST iteration
            (treat full last iteration as 'cores' active)
        3) Minimize total number of iterations
        4) Prefer larger S

        Returns:
        {
            'subvolume': S,
            'total_iterations': iters,
            'partial_last_iter': bool,
            'active_cores_last_iter': A  # A in [1..cores] (cores means full last iter)
        }
        """
        if self.total_ifm_elements < 0:
            raise ValueError("self.total_ifm_elements must be non-negative")
        if self.total_ifm_elements == 0:
            return (0, 0,  False, 0)

        best_score = None
        best_plan = None

        limit = min(self.max_subvolume, self.total_ifm_elements)

        for S in range(4, limit + 1, 4):
            if self.total_ifm_elements % S != 0:
                continue  # require exact coverage with uniform S
            # Per-active-core garbage due to 512 granularity
            garbage_per_active = ceildiv(S, 512) * 512 - S
            T = self.total_ifm_elements // S  # total active core-runs
            full_iters = T // self.num_cores
            A = T - full_iters * self.num_cores  # active cores in last iter (0..cores-1)

            if A == 0:
                # All iterations full; treat as core_effective = cores
                core_effective = self.num_cores
                iters = full_iters
                partial = False
                max_cores = self.num_cores
            else:
                core_effective = A
                iters = full_iters + 1
                partial = A < self.num_cores
                max_cores = A  # 1..cores-1

            total_garbage = T * garbage_per_active

            # Score tuple: lower is better except the negative terms which invert the preference
            score = (
                total_garbage,       # 1) minimize garbage
                -core_effective,        # 2) maximize active cores in last iter
                iters,               # 3) minimize number of iterations
                -S                   # 4) prefer larger S
            )

            if best_score is None or score < best_score:
                best_score = score
                best_plan = {
                    'subvolume': S,
                    'total_iterations': iters,
                    'partial_last_iter': partial,
                    'active_cores_last_iter': max_cores
                }

        if best_plan is None:
            raise ValueError(
                "No feasible subvolume S found: require S | self.total_ifm_elements, S % 4 == 0, and S <= self.max_subvolume."
            )

        log("best_plan['subvolume']", best_plan['subvolume'])
        log("best_plan['total_iterations']", best_plan['total_iterations'])
        log("best_plan['partial_last_iter']", best_plan['partial_last_iter'])
        log("best_plan['active_cores_last_iter']", best_plan['active_cores_last_iter'])
        return best_plan['subvolume'], best_plan['total_iterations'], best_plan['partial_last_iter'], best_plan['active_cores_last_iter']


class BinaryL2Dims:
    '''Accessors for Binary dimensions'''
    def __init__(
        self,
        shape: BinaryShape,
        q_enable: int = 0,
        dq_enable: int = 0,
        broadcast: bool = False,
        call_kernel: str = "call_kernel",
        b_on_wgt: int = 0,
    ):
        self.shape = shape
        # Constants
        self.wgt_size = 128
        self.q_buf_size = 512
        self.dq_buf_size = 512
        self.wgt_bits = shape.wgt_bytes * 8
        self.param_subv_size = config.MAX_CORE_LAYER_PARAM_SIZE
        self.q_enable = q_enable
        self.dq_enable = dq_enable
        self.broadcast = broadcast
        self.call_kernel = call_kernel
        self.b_on_wgt = b_on_wgt

        # Overlay
        self.aie_cols: int = 3  # config.NUM_AIE_COLS
        self.aie_rows: int = 4  # config.NUM_AIE_ROWS
        self.num_cores = self.aie_cols * self.aie_rows

        # IFM (might include overcompute)
        self.ifm_bits = shape.ifm_bytes * 8
        self.ifm_size = shape.Xi * shape.Yi * shape.Ci * shape.ifm_bytes

        # OFM
        self.Co = shape.Co
        self.Xo = shape.Xo
        self.Yo = shape.Yo
        self.ofm_bytes = shape.ifm_bytes
        self.ofm_bits = self.ofm_bytes * 8
        self.ofm_size = shape.Xo * shape.Yo * shape.Co * self.ofm_bytes

        # Constant
        self.core_bank_mem_size_software = 16384

        # Tiling
        tiler = BinaryL2Tiler(
            shape,
            q_enable=self.q_enable,
            dq_enable=self.dq_enable,
            num_cores=self.num_cores,
            core_bank_mem_size_software=self.core_bank_mem_size_software
        )

        (self.max_subvolume,
         self.total_iterations,
         self.partial_last_iter,
         self.active_cores_last_iter,) = tiler.compute_tiling()

        log("========Tiler Output========")
        log("self.max_subvolume", self.max_subvolume)
        log("self.total_iterations", self.total_iterations)
        log("self.partial_last_iter", self.partial_last_iter)
        log("self.active_cores_last_iter", self.active_cores_last_iter)
        log("============================")

        self.full_subvol_iterations = self.total_iterations - 1 if self.partial_last_iter else self.total_iterations

        # Buffer Offset
        self.buffer_offset = self.ifm_size
