import ast
import json
from typing import Dict, List, Optional
from dataclasses import asdict, dataclass, field
from dataclass_wizard import JSONWizard
import os

from OGOAT.src.Tiler.layer import Layer
from OGOAT.src.Tiler.op_dims import OpDims
from OGOAT.src.Tiler.overlay import Overlay


@dataclass
class SubVols:
    ifm: List[int] = field(default_factory=list)
    wgt: List[int] = field(default_factory=list)
    ofm: List[int] = field(default_factory=list)


@dataclass
class Iters:
    ifm: List[int] = field(default_factory=list)
    wgt: List[int] = field(default_factory=list)
    ofm: List[int] = field(default_factory=list)


@dataclass
class Sizes:
    ifm: int = 0
    wgt: int = 0
    ofm: int = 0


@dataclass
class TileParams:
    subvols: SubVols = field(default_factory=SubVols)
    iters: Iters = field(default_factory=Iters)
    sizes: Sizes = field(default_factory=Sizes)


@dataclass
class KernelInfo:
    placement_constraints: Dict[str, str] = field(default_factory=dict)
    kernel_names:  List[str] = field(default_factory=list)
    kernel_includes: List[str] = field(default_factory=list)


@dataclass
class Shape:
    row: int = 0
    col: int = 0


@dataclass
class OverlayInfo:
    overlay: str = ""
    shape: Shape = field(default_factory=Shape)
    mode: str = ""


@dataclass
class IOType:
    type: str = ""
    shape: Shape = field(default_factory=Shape)
    dtype: str = ""
    dtype_bytes: int = 0


@dataclass
class LayerAttr:
    axis: List[int] = field(default_factory=list)


@dataclass
class LayerInfo:
    op_type: str = ""
    inputs: List[IOType] = field(default_factory=list)
    outputs: IOType = field(default_factory=IOType)
    in_act_shape: List[int] = field(default_factory=list)
    in_wgt_shape: List[int] = field(default_factory=list)
    in_wgt1_shape: List[int] = field(default_factory=list)
    out_act_shape: List[int] = field(default_factory=list)
    in_datatype: str = ""
    wgt_datatype: str = ""
    wgt1_datatype: str = ""
    out_datatype: str = ""
    in_bytes: int = 0
    wgt_bytes: int = 0
    wgt1_bytes: int = 0
    out_bytes: int = 0
    attributes: LayerAttr = field(default_factory=LayerAttr)
    qdq_symmetry: str = ""
    coeff_shape: str = ""
    in_act_residency: str = ""
    out_act_residency: str = ""
    Frequency: int = 0
    nodenames: List[str] = field(default_factory=list)
    orig_op_type: str = ""
    
    def update_from(self, other, condition=lambda k, v: True):
        for key, value in other.__dict__.items():
            if hasattr(self, key) and condition(key, value):
                setattr(self, key, value)


@dataclass
class TilingResult(JSONWizard):
    op_dims: Optional[OpDims] = field(default=None)
    core_tile_params: TileParams = field(default_factory=TileParams)
    mem_tile_params: TileParams = field(default_factory=TileParams)
    shim_tile_params: TileParams = field(default_factory=TileParams)
    scheduling: Dict[str, str] = field(default_factory=dict)
    layer_padding: TileParams = field(default_factory=TileParams)
    layer_info: LayerInfo = field(default_factory=LayerInfo)
    overlay_info: OverlayInfo = field(default_factory=OverlayInfo)
    kernel_info: KernelInfo = field(default_factory=KernelInfo)

    def dump(self,sub_dir:str, output_file: str) -> None:
        assert os.path.exists(
            sub_dir
        ), f"Directory provided does not exist:{sub_dir}"
        with open(output_file, "w") as fd:
            json.dump(asdict(self), fd, indent=4)

    @staticmethod
    def load(input_dir: str, input_path: str):
        assert os.path.exists(
            input_dir
        ), f"Directory provided does not exist: {input_dir}"

        assert os.path.exists(
            input_path
        ), f"Input path does not exist: {input_path}"

        with open(input_path, "r") as fd:
            json_str = fd.read()
        return TilingResult.from_json(json_str)

    def add_mem_tile_params(
        self, subV_ifm, subV_ofm, iters_ifm, iters_ofm, sizes_ifm, sizes_ofm
    ):
        self.mem_tile_params.subvols = SubVols(ifm=subV_ifm, ofm=subV_ofm)
        self.mem_tile_params.iters = Iters(ifm=iters_ifm, ofm=iters_ofm)
        self.mem_tile_params.sizes = Sizes(ifm=sizes_ifm, ofm=sizes_ofm)

    def add_overlay_info(self, overlay: Overlay, mode: str):
        self.overlay_info.overlay = overlay.overlay_name
        self.overlay_info.shape = Shape(row=overlay.rows, col=overlay.cols)
        self.overlay_info.mode = mode

    def add_kernel_info(
        self,
        placement_constraint: Dict[str, any] = None,
        kernel_names: List[str] = None,
        kernel_includes: List[str] = None,
    ):
        self.kernel_info.placement_constraints = placement_constraint if placement_constraint else {}
        self.kernel_info.kernel_names = kernel_names if kernel_names else []
        self.kernel_info.kernel_includes = kernel_includes if kernel_includes else []

    def add_layer_info(self, layer: Layer):       
        # Define a condition to exclude "inputs" and "onputs"
        condition = lambda k, v: k != "inputs" or k != "outputs"
        # Use custom update method with condition
        self.layer_info.update_from(layer, condition)
        layer_dict = vars(layer)
        inputs = layer_dict["inputs"]
        outputs = layer_dict["outputs"]
        validInputs = ast.literal_eval(inputs)
        validOutputs = ast.literal_eval(outputs)
        for i in range(len(validInputs)):
            validInputs[i] = IOType(**validInputs[i])
        for i in range(len(validOutputs)):
            validOutputs[i] = IOType(**validOutputs[i])
        self.layer_info.inputs = validInputs
        self.layer_info.outputs = validOutputs
