"""Common utils for Uniop"""
from typing import Tuple
from dataclasses import dataclass


@dataclass(frozen=True)
class UnaryShape:
    '''Define the shape of a unary operation'''
    function: str
    TensorDim: Tuple[int, int, int]
    ifmbytes: int
    ofmbytes: int
    ifmSign: int
    ofmSign: int
    SpatialSplitMode: str


@dataclass(frozen=True)
class UnaryMapping:
    '''Define subvolume size, spatial split'''
    TensorDim: Tuple[int, int, int]
    PaddedTDim: Tuple[int, int, int]
    ifm_subv: Tuple[int, int, int]
    ofm_subv: Tuple[int, int, int]
    wgt_subv: Tuple[int, int, int]
    spatial_split: str
    Npass: int = 1
    Ngroups: int = 1


op_mapping: dict[str, list[str]] = {
    "l2norm": ["LpNormalization_qdq_uint16xuint16"],
    "quant": ["Quant_bfloat16xuint16", "Quant_float32xuint16"],
    "dequant": ["Dequant_uint16xbfloat16", "Dequant_uint16xfloat32"],
    "softmax": ["Softmax_qdq_uint16xuint16"],
    "silu": ["Silu_qdq_uint16xuint16"],
    "copy": ["Copy_uint16xuint16"],
    "layernorm": ["LayerNorm_qdq_uint16xuint16", "LayerNormalization_qdq_uint16xuint8xuint16"],
    "groupnorm": ["GroupNorm_qdq_uint16xuint16", "GroupNorm_qdq_uint16xuint16xuint16", "GroupNorm_qdq_uint16xint16xuint16",
                  "GroupNormalization_qdq_uint16xuint16xuint16", "GroupNormalization_qdq_uint16xint16xuint16"],
    "gelu": ["Gelu_qdq_uint16xuint16"],
    "swish": ["Swish_qdq_uint16xuint16"],
    "tanh": ["Tanh_qdq_uint16xuint16"],
    "sigmoid": ["Sigmoid_qdq_uint16xuint16"],
    "elu": ["Elu_qdq_uint16xuint16"]
 }


# auto-generate reverse mapping for symmetry
waic_mapping: dict[str, str] = {}
for logical_op, waic_names in op_mapping.items():
    for w in waic_names:
        waic_mapping[w] = logical_op


def map_op_name(name: str) -> str:
    """
    Bi-directional mapper between normalized op names and WAIC names.

    If `name` is a logical op (e.g., "dequant") → return its first WAIC name.
    If `name` is a WAIC op → return its logical name.
    Else return name unchanged.
    """
    # If input is logical op, return first mapped WAIC name
    if name in op_mapping:
        return op_mapping[name][0]

    # If input is WAIC op, return canonical logical op
    return waic_mapping.get(name, name)
