"""Remove L1 fused nodes (Relu/LeakyRelu/Q/DQ) and fold/remove a specific YOLO size pattern:

Conv -> X
  Shape(X) splits to:
    A) Gather -> Cast -> Slice -> Mul -> Cast -> \
                                                   Concat -> (as sizes) -> Resize(X, ..)
    B) Slice ------------------------------------/

We:
 - Wire Conv's output X directly to Resize input[0];
 - Replace Resize.sizes (input[3]) by a constant INT64 initializer from inferred output shape;
 - Delete ONLY the matched nodes (Shape, Gather, Casts, Slices, Mul, Concat);
 - Prune dangling graph members.

"""
import os
import json
import argparse
from collections import defaultdict
from typing import Dict, List, Optional

import onnx
from onnx import shape_inference, numpy_helper
import numpy as np
from tabulate import tabulate


# -----------------------------
# Helpers
# -----------------------------
PATTERN_TYPES = {"Shape", "Gather", "Cast", "Slice", "Mul", "Concat"}


def _consumers_map(graph: onnx.GraphProto) -> Dict[str, List[onnx.NodeProto]]:
    m = defaultdict(list)
    for node in graph.node:
        for i in node.input:
            if i:
                m[i].append(node)
    return m


def _producers_map(graph: onnx.GraphProto) -> Dict[str, onnx.NodeProto]:
    p = {}
    for node in graph.node:
        for o in node.output:
            if o:
                p[o] = node
    return p


def _prune_graph_members(graph: onnx.GraphProto):
    """Remove unused initializers / inputs / value_info."""
    referenced = set()
    for n in graph.node:
        referenced.update(x for x in n.input if x)
        referenced.update(x for x in n.output if x)
    for o in graph.output:
        if o.name:
            referenced.add(o.name)

    keep = [init for init in list(graph.initializer) if init.name in referenced]
    del graph.initializer[:]
    graph.initializer.extend(keep)

    keep = [gi for gi in list(graph.input) if gi.name in referenced]
    del graph.input[:]
    graph.input.extend(keep)

    keep = [vi for vi in list(graph.value_info) if vi.name in referenced]
    del graph.value_info[:]
    graph.value_info.extend(keep)


def _value_info_shape_map(model: onnx.ModelProto) -> Dict[str, List[Optional[int]]]:
    m: Dict[str, List[Optional[int]]] = {}

    def _grab(tvi):
        tt = tvi.type.tensor_type
        if not tt.HasField("shape"):
            return
        dims = []
        for d in tt.shape.dim:
            dims.append(int(d.dim_value) if d.HasField("dim_value") else None)
        m[tvi.name] = dims

    for vi in model.graph.value_info:
        _grab(vi)
    for i in model.graph.input:
        _grab(i)
    for o in model.graph.output:
        _grab(o)
    return m


def _ensure_len(lst: List[str], n: int):
    while len(lst) < n:
        lst.append("")


def force_dtype_int8(model: onnx.ModelProto):
    """
    Change all graph inputs, outputs and intermediate value_infos to INT8.
    NOTE: This only changes type metadata, not the actual underlying data.
    """
    INT8 = onnx.TensorProto.INT8

    def _fix_tvi(tvi):
        tt = tvi.type.tensor_type
        # 0 == UNDEFINED; skip those, override any concrete type
        if tt.elem_type in [0, INT8]:
            tt.elem_type = INT8

    # Graph inputs/outputs
    for inp in model.graph.input:
        _fix_tvi(inp)
    for out in model.graph.output:
        _fix_tvi(out)

    # All intermediate tensors that Netron shows
    for vi in model.graph.value_info:
        _fix_tvi(vi)

# -----------------------------
# L1 removals (Relu/LeakyRelu/Q/DQ) with DFS map
# -----------------------------


def get_chained_removal_map(graph, ops_to_remove):
    '''Dependency map for Chained Removal'''
    input_consumers = defaultdict(list)
    for node in graph.node:
        for input_name in node.input:
            input_consumers[input_name].append(node)

    visited = set()
    removal_chains = defaultdict(list)

    def _safe_name(n):
        return n if n else "unnamed"

    def _node_display_name(n):
        return n.name if n.name else (n.output[0] if n.output else "unnamed")

    def _visit_key(n):
        return (n.name, n.output[0] if n.output else "")

    def dfs(start_node, from_node):
        key = from_node if isinstance(from_node, str) else _node_display_name(from_node)
        vkey = _visit_key(start_node)
        if vkey in visited:
            return
        visited.add(vkey)

        if start_node.op_type in ops_to_remove:
            removal_chains[_safe_name(key)].append(_node_display_name(start_node))
            for out in start_node.output:
                for consumer in input_consumers.get(out, []):
                    dfs(consumer, from_node)

    for node in graph.node:
        if node.op_type not in ops_to_remove:
            for out in node.output:
                for consumer in input_consumers.get(out, []):
                    dfs(consumer, node)

    for input_info in graph.input:
        for consumer in input_consumers.get(input_info.name, []):
            dfs(consumer, input_info.name)

    return dict(removal_chains)


def remove_nodes_by_op_type_chained_base(graph: onnx.GraphProto, ops_to_remove: List[str]) -> Dict[str, List[str]]:
    """Remove specified ops and rewire (forward only inputs[0])."""
    chained_removal_map = get_chained_removal_map(graph, ops_to_remove)

    input_consumers = _consumers_map(graph)
    graph_output_by_name = {o.name: o for o in graph.output}

    removable = [n for n in list(graph.node) if n.op_type in ops_to_remove]
    for node in removable:
        ins = list(node.input)
        outs = list(node.output)
        if not ins:
            graph.node.remove(node)
            continue

        primary = ins[0]

        for out in outs:
            if out in graph_output_by_name:
                graph_output_by_name[out].name = primary
                graph_output_by_name[primary] = graph_output_by_name.pop(out)
            for consumer in input_consumers.get(out, []):
                consumer.input[:] = [primary if i == out else i for i in consumer.input]

        graph.node.remove(node)

    return chained_removal_map


# -----------------------------
# Pattern matcher (order-agnostic for Concat inputs)
# -----------------------------
def _backtrack_branchA(producers: Dict[str, onnx.NodeProto], tip_tensor: str):
    """Expect Cast <- Mul <- Slice <- Cast <- Gather <- Shape; returns (ok, [A_nodes in order], shape_node)"""
    nodes = []

    n = producers.get(tip_tensor)
    if not n or n.op_type != "Cast":
        return False, [], None
    nodes.append(n)

    n = producers.get(n.input[0] if n.input else "")
    if not n or n.op_type != "Mul":
        return False, [], None
    nodes.append(n)

    n = producers.get(n.input[0] if n.input else "")
    if not n or n.op_type != "Slice":
        return False, [], None
    nodes.append(n)

    n = producers.get(n.input[0] if n.input else "")
    if not n or n.op_type != "Cast":
        return False, [], None
    nodes.append(n)

    n = producers.get(n.input[0] if n.input else "")
    if not n or n.op_type != "Gather":
        return False, [], None
    nodes.append(n)

    n = producers.get(n.input[0] if n.input else "")
    if not n or n.op_type != "Shape":
        return False, [], None
    shape_node = n
    nodes.append(n)

    return True, list(reversed(nodes[:-1])) + [nodes[-1]], shape_node


def _backtrack_branchB(producers: Dict[str, onnx.NodeProto], tip_tensor: str):
    """Expect Slice <- Shape; returns (ok, [Slice node], shape_node)"""
    n = producers.get(tip_tensor)
    if not n or n.op_type != "Slice":
        return False, [], None
    shape_node = producers.get(n.input[0] if n.input else "")
    if not shape_node or shape_node.op_type != "Shape":
        return False, [], None
    return True, [n], shape_node


def _pattern_fold_once(model: onnx.ModelProto) -> int:
    """
    Find one instance of the pattern and fold it. Returns 1 if a fold happened, else 0.
    """
    g = model.graph
    producers = _producers_map(g)

    # Iterate over Resize nodes
    for resize in g.node:
        if resize.op_type != "Resize":
            continue
        if len(resize.input) < 4 or not resize.input[3]:
            continue

        sizes_t = resize.input[3]
        concat = producers.get(sizes_t)
        if not concat or concat.op_type != "Concat" or len(concat.input) < 2:
            continue

        # Try both permutations of Concat inputs mapped to BranchA/BranchB
        a_tip, b_tip = concat.input[0], concat.input[1]

        okA1, chainA1, shapeA1 = _backtrack_branchA(producers, a_tip)
        okB1, chainB1, shapeB1 = _backtrack_branchB(producers, b_tip)

        okA2, chainA2, shapeA2 = _backtrack_branchA(producers, b_tip)
        okB2, chainB2, shapeB2 = _backtrack_branchB(producers, a_tip)

        matched = None
        if okA1 and okB1 and (shapeA1 is shapeB1):
            matched = ("AB", chainA1, chainB1, shapeA1)
        elif okA2 and okB2 and (shapeA2 is shapeB2):
            matched = ("BA", chainA2, chainB2, shapeA2)

        if not matched:
            continue

        _, chainA, chainB, shape_node = matched

        # Verify Shape input equals Resize.data and comes from Conv
        data_t = resize.input[0]
        if not shape_node.input or shape_node.input[0] != data_t:
            continue
        conv_prod = producers.get(data_t)
        if not conv_prod or conv_prod.op_type != "Conv":
            continue

        # Compute constant sizes from Resize output shape
        model = shape_inference.infer_shapes(model)
        vi = _value_info_shape_map(model)
        out_name = resize.output[0] if resize.output else None
        if not out_name or out_name not in vi:
            continue
        out_shape = vi[out_name]
        if not out_shape or any(d is None for d in out_shape):
            continue

        sizes_name = (resize.name or out_name) + "_sizes_const"
        if not any(init.name == sizes_name for init in g.initializer):
            g.initializer.append(numpy_helper.from_array(np.asarray(out_shape, dtype=np.int64), sizes_name))

        ins = list(resize.input)
        _ensure_len(ins, 4)
        ins[0] = data_t       # Conv output directly
        ins[3] = sizes_name   # constant sizes
        resize.input[:] = ins

        # Delete matched nodes: branchA nodes, branchB nodes, concat, shape
        to_delete = []
        to_delete.extend([n for n in chainA if n in g.node])
        to_delete.extend([n for n in chainB if n in g.node])
        if concat in g.node:
            to_delete.append(concat)
        if shape_node in g.node:
            to_delete.append(shape_node)
        for n in to_delete:
            if n in g.node:
                g.node.remove(n)

        return 1  # folded one instance

    return 0  # no match


def fold_all_patterns(model: onnx.ModelProto) -> int:
    """
    Repeatedly apply the pattern fold until no more matches.
    Returns number of folds performed.
    """
    total = 0
    while True:
        changed = _pattern_fold_once(model)
        if not changed:
            break
        total += changed
        _prune_graph_members(model.graph)
    return total


# -----------------------------
# Main pipeline
# -----------------------------
def process_model(input_model_path: str,
                  output_model_path: str,
                  ops_to_remove: List[str],
                  is_dtype_int8: bool = False) -> Dict[str, List[str]]:
    '''Callable API for the script'''
    model = onnx.load(input_model_path)

    # 1) Remove L1 nodes
    remove_nodes_by_op_type = remove_nodes_by_op_type_chained_base(model.graph, ops_to_remove)

    # 2) Prune light-weight members
    _prune_graph_members(model.graph)

    # 3) Pattern fold (Conv->Shape two-branch pattern feeding Resize.sizes)
    num_folded = fold_all_patterns(model)
    print(f"Pattern-folded {num_folded} Resize node(s).")

    # 4) Final shape inference
    model = shape_inference.infer_shapes(model)

    # 5) Optionally force IO types to INT8
    if is_dtype_int8:
        print("Forcing all graph inputs/outputs to INT8")
        force_dtype_int8(model)

    # 6) Save
    onnx.save(model, output_model_path)
    print(f"Saved updated model to {output_model_path}")
    return remove_nodes_by_op_type


def print_removed_table(removal_map_print):
    '''Print Table'''
    table = [(src, ", ".join(nodes)) for src, nodes in removal_map_print.items()]
    print("\n### Removed Nodes Table:\n")
    print(tabulate(table, headers=["Source Node", "Removed Nodes"], tablefmt="github"))


def main(input_model_path: str,
         output_model_path: str,
         is_dtype_int8: bool = False,
         ops_to_remove: list[str] = ["Relu", "LeakyRelu", "QuantizeLinear", "DequantizeLinear"]):
    '''Main function wrapper for command line usage'''
    return process_model(
        input_model_path=input_model_path,
        output_model_path=output_model_path,
        ops_to_remove=ops_to_remove,
        is_dtype_int8=is_dtype_int8,
    )


# -----------------------------
# Usage
# -----------------------------
if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Remove L1 ops, fold YOLO Resize size pattern, and tidy Netron layout."
    )
    parser.add_argument('--input', required=True, help="Path to input ONNX model")
    parser.add_argument('--output', required=False, help="Optional path for output ONNX model (overrides default)")
    parser.add_argument("-int8", "--dtype_int8", action="store_true", help="Force all graph input/output tensors to INT8")
    args = parser.parse_args()

    input_path = args.input
    input_dir = os.path.dirname(input_path)
    input_basename = os.path.splitext(os.path.basename(input_path))[0]

    if args.output:
        output_model = args.output
    else:
        output_model = os.path.join(input_dir, f"{input_basename}_cleaned_graph.onnx")

    json_filename = os.path.join(input_dir, f"{input_basename}_removed_nodes.json")

    removal_map = main(
        input_model_path=input_path,
        output_model_path=output_model,
        is_dtype_int8=args.dtype_int8,
    )

    print_removed_table(removal_map)

    with open(json_filename, "w", encoding="utf-8") as f:
        json.dump(removal_map, f, indent=2)

    print(f"\nRemoved map saved to {json_filename}")
    print(f"Output model saved to {output_model}")
