# fmt: on
from dataclasses import dataclass, field
from dataclasses import asdict
from typing import ClassVar, Optional, Any


# contains fusion related command line arguments
@dataclass
class FusionArguments:
    debug: bool
    model_path: str
    model_name: str
    external_data: bool
    inits_int4_to_int8: bool
    fusion_seq_path: Optional[str]
    target: Optional[str]
    opt_level: int
    out_dir_path: str
    fast_pm_enable: bool
    qdq_optimization: bool
    qdq_int16_cleanup: bool = True
    out_model_suffix: str = ""
    old_fusion_flow: bool = True
    shape_inference_outputs: int = 3000
    prebuilt_mladf_mha: bool = False
    no_dtype_freeze: bool = False
    assign_pmid_before_partition: bool = False

    @property
    def out_model_path(self) -> str:
        return self.model_path.removesuffix(".onnx") + self.out_model_suffix + ".onnx"

    def asdict(self) -> dict[str, Any]:
        return asdict(self)

    def copy(self) -> "FusionArguments":
        return FusionArguments(**self.asdict())


@dataclass
class FusionConfigs:
    extend_qdq: bool = False
    keep_border_qdq: bool = False
    batch_by_out_tensor: bool = False
    enable_batch_operator: list[str] = field(default_factory=list)
    MMT_configs: dict = field(default_factory=dict)

    # Class-level singleton instance
    _instance: ClassVar[Optional["FusionConfigs"]] = None

    @classmethod
    def save_fusion_configs(cls, config: "FusionConfigs") -> None:
        cls._instance = config

    @classmethod
    def get_fusion_configs(cls) -> "FusionConfigs":
        if cls._instance is None:
            cls._instance = FusionConfigs()
        return cls._instance
