"""
This module provides utility functions for performing L2 fusion.
"""

import os
import logging
from functools import lru_cache
from typing import Any, List
import hashlib
import json
from dataclasses import is_dataclass, asdict
import onnx
from dotenv import load_dotenv


@lru_cache(maxsize=None)
def get_logger() -> logging.Logger:
    """Provides a logger instance"""
    logger_instance = logging.getLogger("aie4_models")

    # Prevent duplicate logs when pytest adds its own handlers by disabling propagation to root logger
    logger_instance.propagate = False

    # Only add handler if none exist
    if not logger_instance.handlers:
        handler = logging.StreamHandler()
        handler.setFormatter(logging.Formatter(
            "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
        ))
        logger_instance.addHandler(handler)
        logger_instance.setLevel(logging.INFO)

    return logger_instance


logger = get_logger()


def config_logger_from_env():
    """Configures the logger based on the environment variable"""
    load_dotenv()
    log_level = os.getenv("AIE4_LOG_LEVEL", "INFO")
    logger.setLevel(log_level)
    logger.info("Log level set to %s", log_level)


def tensor_dtype_to_size(dtype: str) -> float:
    """Number of bytes for each data type defined in ONNX TensorProto"""
    tensorproto_dtype_size = {
        "TensorProto.UNDEFINED": -1,
        "TensorProto.FLOAT": 4,
        "TensorProto.UINT8": 1,
        "TensorProto.INT8": 1,
        "TensorProto.UINT16": 2,
        "TensorProto.INT16": 2,
        "TensorProto.INT32": 4,
        "TensorProto.INT64": 8,
        "TensorProto.STRING": -1,  # Variable size
        "TensorProto.BOOL": 1,
        "TensorProto.FLOAT16": 2,
        "TensorProto.DOUBLE": 8,
        "TensorProto.UINT32": 4,
        "TensorProto.UINT64": 8,
        "TensorProto.COMPLEX64": 8,
        "TensorProto.COMPLEX128": 16,
        "TensorProto.BFLOAT16": 2,
        "TensorProto.FLOAT8E4M3FN": 1,
        "TensorProto.FLOAT8E4M3FNUZ": 1,
        "TensorProto.FLOAT8E5M2": 1,
        "TensorProto.FLOAT8E5M2FNUZ": 1,
        "TensorProto.UINT4": 0.5,
        "TensorProto.INT4": 0.5,
        "TensorProto.FLOAT4E2M1": 0.5,
    }
    if dtype not in tensorproto_dtype_size:
        logger.info("Unknown tensor dtype: %s", dtype)
    return float(tensorproto_dtype_size[dtype])


def subclasses(cls: Any, just_leaf: bool = False) -> List[Any]:
    """Finds subclasses of a class"""
    sc = cls.__subclasses__()
    ssc = [g for s in sc for g in subclasses(s, just_leaf)]
    return [s for s in sc if not just_leaf or not s.__subclasses__()] + ssc


def subclass_where(cls: Any, **kwargs: Any) -> Any:
    """Finds subclasses of a class matching the **kwargs"""
    k, v = next(iter(kwargs.items()))
    subcls = subclasses(cls)

    for s in subcls:
        if hasattr(s, k) and getattr(s, k) == v:
            return s

    raise KeyError(
        f"No subclasses of {cls.__name__} with cls.{k} == '{v}'. "
        f"Available subclasses {[cls.__name__ for cls in subcls]}"
    )


def make_json_serializable(obj: Any):
    """
    Recursively convert objects into JSON-serializable structures.
    Handles dataclasses, tuples, dicts, lists, etc.
    """
    if is_dataclass(obj):
        return {k: make_json_serializable(v) for k, v in asdict(obj).items()}
    if isinstance(obj, dict):
        return {str(k): make_json_serializable(v) for k, v in obj.items()}
    if isinstance(obj, (list, tuple)):
        return [make_json_serializable(i) for i in obj]
    return obj  # assume it's already JSON-serializable (int, str, float, etc.)


def obj_to_json(obj: Any) -> str:
    """Create a JSON representation of a general python object"""
    serializable_obj = make_json_serializable(obj)
    return json.dumps(serializable_obj, sort_keys=True, separators=(',', ':'))


def hash_obj(obj: Any) -> str:
    """Hash a python object"""
    return hashlib.md5(obj_to_json(obj).encode('utf-8')).hexdigest()


def validate_and_load_model(model_path: str) -> onnx.ModelProto:
    """Validate ONNX model and return True if valid"""
    try:
        # Check model structure
        onnx.checker.check_model(model_path, full_check=True)

        # load the graph
        model = onnx.load(model_path, load_external_data=False)

        # Check graph
        onnx.checker.check_graph(model.graph)

        # Verify inputs exist
        if not model.graph.input:
            raise RuntimeError("Model has no inputs")

        # Verify outputs exist
        if not model.graph.output:
            raise RuntimeError("Model has no inputs")

        return model

    except Exception as e:
        logger.exception("ONNX model error: %s", e)
        raise e


def load_model_without_data(model_path: str) -> onnx.ModelProto:
    """Validate ONNX model and return True if valid"""
    try:
        # load the graph
        model = onnx.load(model_path, load_external_data=False)

        # Verify inputs exist
        if not model.graph.input:
            raise RuntimeError("Model has no inputs")

        # Verify outputs exist
        if not model.graph.output:
            raise RuntimeError("Model has no inputs")

        return model

    except Exception as e:
        logger.exception("ONNX model error: %s", e)
        raise e


def find_reverse_map(input_lst: List[str], output_lst: List[str]) -> List[int]:
    """Find a mapping from input to output

    Examples:
        >>> input_lst = ['a', 'b']
        >>> output_lst = ['b', 'a']
        >>> find_reverse_map(input_lst, output_lst)
        [1, 0]
    """
    if not set(input_lst) == set(output_lst):
        raise RuntimeError("a unique mapping can't be determined")
    input_ids = {name: i for i, name in enumerate(input_lst)}
    return [input_ids[name] for name in output_lst]
