"""Module for Global Average Pooling (GAP) implementation for AIE4."""

import math
import struct
from typing import Optional, Tuple

from dmacompiler import DevGen, set_dev_gen
from utils.utils_common import L2Alloc

set_dev_gen(DevGen.Aie4)


def conv_to_local_ptr(addr: int) -> int:
    """Constant offset."""
    core_local_offset = 0xE0000
    return core_local_offset + addr


def compute_optimal_param_value_shift_gap(
    acc: int, s_min: Optional[int] = None, s_max: Optional[int] = None
) -> Tuple[int, int]:
    """
    Compute optimal parameter value and shift for GAP operation.

    Args:
        acc: Accumulation value
        s_min: Minimum shift value
        s_max: Maximum shift value

    Returns:
        Tuple of best parameter value and best shift
    """
    if s_min is None:
        s_min = math.ceil(math.log2(abs(acc)))
    if s_max is None:
        s_max = s_min + 5
    best_error = float("inf")
    best_param_value = None
    best_shift = None
    target = 1 / acc
    for s in range(s_min, s_max + 1):
        M = round(2**s / acc)
        approx = M / 2**s
        error = abs(approx - target)
        if error < best_error:
            best_error = error
            best_param_value = M
            best_shift = s
    return best_param_value, best_shift


class GAPDims:
    """Class for Global Average Pooling dimensions and parameters."""

    def __init__(
        self,
        Yi: int,
        Xi: int,
        Ci: int,
        Yo: int,
        Xo: int,
        Co: int,
        act_bits: int,
        out_bits: int,
        sign_act: int,
        sign_out: int,
        prm_bits: int,
        bits_per_byte: int,
        aie_rows: int,
        aie_cols: int,
        Yis: int,
        Xis: int,
        Cis: int,
        Yos: int,
        Xos: int,
        Cos: int,
        shift_res: Optional[int] = None,
        param_value: Optional[int] = None,
        fusion_param: Optional[L2Alloc] = None,
    ):
        self.act_bits = act_bits
        self.out_bits = out_bits
        self.sign_act = sign_act
        self.sign_out = sign_out
        self.prm_bits = prm_bits
        self.bits_per_byte = bits_per_byte
        self.fusion_param = fusion_param

        self.C_g = 64
        self.il_unroll_range = 2
        self.jl_unroll_range = 4
        self.align_l1 = 128
        self.prm_size = 1024

        self.Yis = Yis
        self.Xis = Xis
        self.Cis = Cis
        self.Yos = Yos
        self.Xos = Xos
        self.Cos = Cos
        self.aie_rows = aie_rows
        self.aie_cols = aie_cols

        self.Yi = Yi
        self.Xi = Xi
        self.Ci = Ci
        self.Yo = Yo
        self.Xo = Xo
        self.Co = Co
        self.Y = self.Yi // self.Yis

        self.Pad = False
        self.C_outer_g = self.C_g * self.il_unroll_range
        if self.Cis % self.C_outer_g != 0:
            Cis = Cis + (self.C_outer_g - Cis % self.C_outer_g)
            self.Pad = True
            self.Pad_Cis = Cis

        self.X_g = Xis
        self.Y_g = Yis
        self.Co_g = Cis // self.C_g
        self.inner_g = self.X_g * self.Y_g
        self.outer_g = Cis // self.C_g
        self.step_Xi = self.C_g
        self.step_Ci = self.X_g * self.step_Xi
        self.step_Yi = self.Co_g * self.step_Ci

        if self.X_g == 1:
            self.step_Ci = self.step_Xi if self.Y_g % 2 == 1 else 1
            self.step_Yi = self.Co_g * self.step_Xi

        if param_value is None and shift_res is None:
            self.param_value, self.shift_res = compute_optimal_param_value_shift_gap(
                self.inner_g * self.Y
            )
        else:
            self.param_value = param_value
            self.shift_res = shift_res

        self.aie_cols_used = self.aie_cols
        if self.aie_cols * self.aie_rows * self.Cis == self.Ci:
            pass
        elif self.aie_cols * self.aie_rows * self.Cis < self.Ci:
            pass
        elif self.aie_cols * self.aie_rows * self.Cis > self.Ci:
            for _i in range(self.aie_cols, 0, -1):
                if self.Ci % (_i * self.aie_rows * self.Cis) == 0:
                    self.aie_cols_used = _i
                    break
            if self.aie_cols_used == self.aie_cols:
                raise ValueError("Invalid GAP Shape")


def gen_aie4_gap_params(dims: GAPDims, iter_cnt: int, tdm_cnt: int, offset_interm: int) -> bytes:
    """
    Generate AIE4 GAP parameters as a byte string.

    Args:
        dims: GAP dimensions

    Returns:
        Byte string containing packed parameters
    """
    offset_interm = conv_to_local_ptr(offset_interm)
    config = {
        "zero_init": False,
        "sign_N": 0,
        "sign_O": dims.sign_out,
        "reserved3": 0,
        "skip_casc_in": False,
        "skip_casc_out": False,
        "sign_W": 1,
        "sign_A": dims.sign_act,
        "reserved10": 0,
        "norm_ch_g": 0,
    }
    ctrl = (
        (config["zero_init"] << 0)
        | (config["sign_N"] << 1)
        | (config["sign_O"] << 2)
        | (config["reserved3"] << 3)
        | (config["skip_casc_in"] << 6)
        | (config["skip_casc_out"] << 7)
        | (config["sign_W"] << 8)
        | (config["sign_A"] << 9)
        | (config["reserved10"] << 10)
        | (config["norm_ch_g"] << 24)
    )
    kernel_fields = (
        0,
        0,
        0,
        0,
        0,
        dims.X_g,
        dims.Y_g,
        0,
        dims.inner_g,
        dims.outer_g,
        0,
        dims.shift_res,
        0,
        0,
        0,
        0,
        dims.step_Ci,
        dims.step_Xi,
        dims.step_Yi,
        0,
        0,
        0,
        dims.param_value,
        ctrl,
    )
    layer_fields = (
        iter_cnt,
        tdm_cnt,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        offset_interm,
        0,
    )
    format_string_kernel = "BBBbBBBBHHbbbbHHHHHHHHii"
    format_string_layer = "HBBBBBBiiiii"
    kernel_params = struct.pack(format_string_kernel, *kernel_fields)
    layer_params = struct.pack(format_string_layer, *layer_fields)
    return kernel_params + layer_params
