''' Updates blocks in alloc_json by adding dtype fields based on matching 'inputs' '''
import json
import argparse
from typing import Any, Callable, Dict, List, Optional, Protocol, Union

# ----------------------------- Helpers -----------------------------


def _get_ir_input(ir_block: Dict[str, Any], param_name: str) -> Optional[Dict[str, Any]]:
    """Return the input dict whose param_name equals `param_name`, else None."""
    for inp in ir_block.get("inputs", []):
        if inp.get("param_name") == param_name:
            return inp
    return None


def _first_ir_input_where(ir_block: Dict[str, Any],
                          pred: Callable[[Dict[str, Any]], bool]) -> Optional[Dict[str, Any]]:
    """Return first input dict in ir_block['inputs'] that satisfies `pred`, else None."""
    for inp in ir_block.get("inputs", []):
        if pred(inp):
            return inp
    return None


def _shape_kind(shape_val: Any) -> str:
    """
    Classify a 'shape' value:
      - 'scalar'  : [] or None
      - 'vector1' : [int]
      - 'other'   : anything else
    """
    if shape_val is None:
        return "scalar"
    if isinstance(shape_val, list):
        if len(shape_val) == 0:
            return "scalar"
        if len(shape_val) == 1 and isinstance(shape_val[0], int):
            return "vector1"
    return "other"


def _get_op_type(ir_block: Dict[str, Any]) -> str:
    """Return operation type string from an IR block (tries 'op_type', then 'op')."""
    return str(ir_block.get("op_type") or ir_block.get("op") or "")

# --------------------------- Rule framework -----------------------------


class Rule(Protocol):
    """Interface for a rule that may write TOP-LEVEL fields on an alloc block."""
    def matches(self, ir_block: Dict[str, Any]) -> bool:
        """Return True if this rule should apply to the given IR block."""

    def apply(self, alloc_block: Dict[str, Any], ir_block: Dict[str, Any]) -> None:
        """Apply the rule to modify the alloc block based on the IR block."""


class OpStartsWithRule:
    """
    Base class for rules gated by an op prefix (single prefix or multiple prefixes).
    Subclasses implement `transform()` to write top-level fields.
    """
    def __init__(self, prefix: Union[str, list[str]]) -> None:
        # Normalize to a tuple for immutability
        if isinstance(prefix, str):
            self.prefixes = (prefix,)
        else:
            self.prefixes = tuple(prefix)

    def matches(self, ir_block: Dict[str, Any]) -> bool:
        """Return True if the IR block's op type starts with any of this rule's prefixes."""
        op_type = _get_op_type(ir_block)
        return any(op_type.startswith(p) for p in self.prefixes)

    def apply(self, alloc_block: Dict[str, Any], ir_block: Dict[str, Any]) -> None:
        """Invoke the subclass transform() to modify the alloc block."""
        if "is_compilable" in alloc_block and not alloc_block["is_compilable"]:
            return
        self.transform(alloc_block, ir_block)

    def transform(self, alloc_block: Dict[str, Any], ir_block: Dict[str, Any]) -> None:
        """Subclasses must implement logic to write derived top-level fields."""
        raise NotImplementedError

# -------------------------- Op-wise Rules -----------------------------


class MatMulVectorCoeffRule(OpStartsWithRule):
    """
    For MatMul* ops, infer top-level `vector_coeff` from B_* quant param shapes:

      - If shape == [] or [1]  -> vector_coeff = 0  (scalar)
      - If shape == [int]      -> vector_coeff = 2  (per-channel)
      - Otherwise / missing    -> do nothing

    Probes `B_scale` first, then `B_zero_point`, else any input with
    param_name starting 'B' and ending in '_scale' or '_zero_point'.
    """
    def __init__(self) -> None:
        super().__init__("MatMul")

    def transform(self, alloc_block: Dict[str, Any], ir_block: Dict[str, Any]) -> None:
        probe = _get_ir_input(ir_block, "B_scale") or _get_ir_input(ir_block, "B_zero_point")
        if probe is None:
            probe = _first_ir_input_where(
                ir_block,
                lambda x: isinstance(x.get("param_name"), str)
                and x["param_name"].startswith("B")
                and (x["param_name"].endswith("_scale") or x["param_name"].endswith("_zero_point"))
            )
        shape_val = probe.get("shape") if isinstance(probe, dict) else None
        kind = _shape_kind(shape_val)

        # treat [] and [1] both as scalar
        if kind == "scalar" or (isinstance(shape_val, list) and shape_val == [1]):
            alloc_block["vector_coeff"] = 0
        elif kind == "vector1":
            alloc_block["vector_coeff"] = 2


class MatMulL3TransposeRule(OpStartsWithRule):
    """
    For MatMul_qdq_actxac* ops, infer top-level `transpose_wgts` from input0 and input1 shapes:

      - If input0[-1] == input1[-1] -> transpose_wgts = 1
      - Otherwise                   -> transpose_wgts = 0
    """
    def __init__(self) -> None:
        super().__init__("MatMul_qdq_actxac")

    def transform(self, alloc_block: Dict[str, Any], ir_block: Dict[str, Any]) -> None:
        input0 = alloc_block["input0"]
        input1 = alloc_block["input1"]
        if input0 and input1:
            if input0[-1] == input1[-1]:
                alloc_block["transpose_wgts"] = 1
            else:
                alloc_block["transpose_wgts"] = 0


class LayerNormGammaBetaRule(OpStartsWithRule):
    """
    For LayerNormalization* ops, map input dtypes:
      - param_name == "Scale" -> alloc["in_dtype_gamma"]
      - param_name == "B"     -> alloc["in_dtype_beta"]
    """
    def __init__(self) -> None:
        super().__init__("LayerNormalization")

    def transform(self, alloc_block: Dict[str, Any], ir_block: Dict[str, Any]) -> None:
        gamma = _get_ir_input(ir_block, "Scale")
        if isinstance(gamma, dict):
            d_type = gamma.get("dtype")
            if d_type:
                alloc_block["in_dtype_gamma"] = d_type

        beta = _get_ir_input(ir_block, "B")
        if isinstance(beta, dict):
            d_type = beta.get("dtype")
            if d_type:
                alloc_block["in_dtype_beta"] = d_type


class BinaryQdqInputRule(OpStartsWithRule):
    """
    For *_qdq_BroadCast* ops, copy the broadcast input shape only if:
      - alloc_block does NOT already contain "input1"
      - param_name == "B" and its 'type' == "const"
      - shape is a List[int]
    """
    def __init__(self) -> None:
        super().__init__([
            "Add_qdq_BroadCast",
            "Mul_qdq_BroadCast",
            "Div_qdq_BroadCast",
            "Sub_qdq_BroadCast",
            "Add_qdq_EleWise",
            "Mul_qdq_EleWise",
            "Div_qdq_EleWise",
            "Sub_qdq_EleWise",
        ])

    def transform(self, alloc_block: Dict[str, Any], ir_block: Dict[str, Any]) -> None:
        from graph.tensor_types import get_padded_shape_rev  # pylint: disable=import-outside-toplevel

        alloc_block["input_types"] = {"A": "act", "B": "act"}
        const_idx = -1

        for idx, param_name in enumerate(("A", "B")):
            inp = _get_ir_input(ir_block, param_name)
            if inp is None:
                continue
            alloc_block["input_types"][param_name] = inp.get("type")

            if inp.get("type") == "const":
                const_idx = idx
                shape = inp.get("shape")
                tensor_name = inp.get("name")
                if isinstance(shape, list) and all(isinstance(x, int) for x in shape):
                    alloc_block[f"input{idx}"] = shape
                    alloc_block[f"padded_input{idx}"] = get_padded_shape_rev(shape)
                    if tensor_name:
                        alloc_block[f"input{idx}_name"] = tensor_name
                break

        # Move generic activation fields to the activation input's index
        if const_idx != -1:
            alloc_block["const_idx"] = const_idx
            act_idx = 1 - const_idx
            for key, indexed_key in [
                ("input", f"input{act_idx}"),
                ("padded_input", f"padded_input{act_idx}"),
                ("input_name", f"input{act_idx}_name"),
            ]:
                if key in alloc_block:
                    alloc_block[indexed_key] = alloc_block.pop(key)


class MatMulFlattenShape(OpStartsWithRule):
    """
    For MatMul_qdq_actxac* ops, flatten leading batch dimensions while keeping the last 2 dims.

    The goal is to collapse multiple batch dimensions into a single batch dimension:
      - (1, 1, 48, 45, 47) -> (1, 1, 48, 45, 47)    # leading 1s preserved
      - (1, 12, 48, 45, 47) -> (1, 1, 576, 45, 47)  # 12*48 = 576
      - (1, 48, 45, 47) -> (1, 48, 45, 47)          # leading 1 preserved
      - (12, 48, 45, 47) -> (1, 576, 45, 47)        # 12*48 = 576, prepend 1
      - (48, 45, 47) -> (48, 45, 47)                # 3D unchanged
      - (45, 47) -> (45, 47)                        # 2D unchanged
    """
    def __init__(self) -> None:
        super().__init__("MatMul_qdq_actxac")

    def transform(self, alloc_block: Dict[str, Any], ir_block: Dict[str, Any]) -> None:
        from graph.tensor_types import flatten_batch_dims, get_padded_shape_rev  # pylint: disable=import-outside-toplevel

        for key in ("input0", "input1", "output"):
            if key in alloc_block and isinstance(alloc_block[key], list):
                original = alloc_block[key]
                flattened = flatten_batch_dims(original)
                if flattened != original:
                    alloc_block[key] = flattened
                    padded_key = f"padded_{key}"
                    if padded_key in alloc_block:
                        alloc_block[padded_key] = get_padded_shape_rev(flattened)


# Internal registry of default rules (we extends this if we need more rules)
_DEFAULT_RULES: List[Rule] = [
    MatMulVectorCoeffRule(),
    MatMulFlattenShape(),
    LayerNormGammaBetaRule(),
    BinaryQdqInputRule(),
]


def update_alloc_json(alloc_json_path: str, ir_json_path: str) -> None:
    """
    Updates blocks in alloc_json by adding dtype fields based on matching 'inputs'
    and 'outputs' from the corresponding entry in ir_json.

    For each block in alloc_json:
      1) Copy input/output dtypes from IR onto TOP LEVEL of each alloc block:
         - e.g., in_dtype_A, in_dtype_B_scale, out_dtype_Y
      2) Merge raw IR 'attributes' dict under alloc_block['attributes'] (unchanged from your original).
      3) Apply an internal RULE PIPELINE that may write extra TOP-LEVEL fields
         derived from IR contents (e.g., MatMul* -> vector_coeff).

    Args:
        alloc_json_path (str): Path to the allocation JSON file to update.
        ir_json_path (str): Path to the IR JSON file providing dtype info.

    Returns:
        None — the alloc_json file is updated in-place.
    """
    # Load both JSONs
    with open(alloc_json_path, "r", encoding="utf-8") as f:
        alloc_data: Dict[str, Any] = json.load(f)

    with open(ir_json_path, "r", encoding="utf-8") as f:
        ir_data: Dict[str, Any] = json.load(f)

    updated_count = 0

    # Iterate through alloc_json blocks
    for _, block_val in alloc_data.items():
        node_name = block_val.get("name")
        if node_name and node_name in ir_data:
            ir_block = ir_data[node_name]

            # 1) Copy input dtypes
            for inp in ir_block.get("inputs", []):
                param_name = inp.get("param_name")
                dtype = inp.get("dtype")
                if param_name and dtype:
                    dtype_key = f"in_dtype_{param_name}"
                    block_val[dtype_key] = dtype

            # 2) Copy output dtypes
            for outp in ir_block.get("outputs", []):
                param_name = outp.get("param_name")
                dtype = outp.get("dtype")
                if param_name and dtype:
                    dtype_key = f"out_dtype_{param_name}"
                    block_val[dtype_key] = dtype

            # 3) Merge IR attributes
            attrs = ir_block.get("attributes")
            if isinstance(attrs, dict) and attrs:
                block_val["attributes"] = attrs

            # 4) Apply internal rules
            for rule in _DEFAULT_RULES:
                if rule.matches(ir_block):
                    rule.apply(block_val, ir_block)

            updated_count += 1

    # Save updates back
    with open(alloc_json_path, "w", encoding="utf-8") as f:
        json.dump(alloc_data, f, indent=4)

    print(f"[INFO] Updated {updated_count} blocks in {alloc_json_path}")


def copy_L3_to_ir(alloc_json_path: str, ir_json_path: str) -> None:
    """
    Copies 'L3' from each alloc_json block into the corresponding ir_json node:
      - Match on alloc_json block['name'] == ir_json key.
      - If the alloc block has 'L3', set ir_json[name]['L3'] = that value.
      - Leaves existing ir_json fields intact; overwrites 'L3' if already present.
    Writes changes back to ir_json in place.
    """
    with open(alloc_json_path, "r", encoding="utf-8") as f:
        alloc_data: Dict[str, Any] = json.load(f)

    with open(ir_json_path, "r", encoding="utf-8") as f:
        ir_data: Dict[str, Any] = json.load(f)

    copied = 0
    missing = 0

    for _, block in alloc_data.items():
        node_name = block.get("name")
        if not node_name:
            continue
        l3 = block.get("L3")
        if l3 is None:
            continue  # nothing to copy for this block

        ir_block = ir_data.get(node_name)
        if not isinstance(ir_block, dict):
            missing += 1
            continue

        ir_block["L3"] = l3
        copied += 1

    with open(ir_json_path, "w", encoding="utf-8") as f:
        json.dump(ir_data, f, indent=4)

    print(f"[INFO] Copied L3 to {copied} IR nodes in {ir_json_path} "
          f"(skipped {missing} unmatched names or non-dicts)")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Update alloc JSON with dtype info from IR JSON.")
    parser.add_argument("--alloc_json", required=True, help="Path to allocation JSON file to update.")
    parser.add_argument("--ir_json", required=True, help="Path to IR JSON file containing node details.")
    args = parser.parse_args()

    update_alloc_json(args.alloc_json, args.ir_json)
    copy_L3_to_ir(args.alloc_json, args.ir_json)
