import numpy as np
import ast
import copy
from dataflow.conv.conv_common import iceil

class UnsupportedDataTypeError(Exception): pass

class Layer:

    @staticmethod
    def get_permutation(shape, attrs, attr_key, perm_prefix):
        perm = np.arange(len(shape)).tolist()  # Default perm order
        if attrs is not None and attr_key in attrs and attrs[attr_key][0] == 1:
            return attrs[perm_prefix]
        return perm

    @staticmethod
    def apply_permutation(shape, attrs, attr_key, perm_prefix):
        perm = np.arange(len(shape)).tolist()  # Default perm order
        if attrs is not None and attr_key in attrs and attrs[attr_key][0] == 1:
            perm = Layer.get_permutation(shape, attrs, attr_key, perm_prefix)
            return np.array(shape)[perm].tolist(), perm
        return shape, perm

    @staticmethod
    def reverse_permutation(shape, attrs, attr_key, perm_prefix):
        perm = np.arange(len(shape)).tolist()  # Default perm order
        if attrs is not None and attr_key in attrs and attrs[attr_key][0] == 1:
            perm = Layer.get_permutation(shape, attrs, attr_key, perm_prefix)
            reverse_perm = np.argsort(perm).tolist()
            reverse_shape = np.array(shape)[reverse_perm].tolist()
            return reverse_shape, reverse_perm
        return shape, perm

    # Process permutation vectors and make them 3D
    @staticmethod
    def process_perm_vector(perm_vec):
        if len(perm_vec) == 2:
            return [0] + (np.array(perm_vec) + 1).tolist()
        elif len(perm_vec) >= 3 and all(perm_vec[0:-3] == np.arange(len(perm_vec) - 3)):
            return (np.array(perm_vec)[-3:] - (len(perm_vec) - 3)).tolist()
        else:
            raise ValueError("Unexpected perm order or perm too short")

    @staticmethod
    def reshape_shape(shape, last_dims=1):
        prod = (
            int(np.prod(shape[:-last_dims]))
            if len(shape) > last_dims
            else 1
        )
        return [prod] + shape[-last_dims:]

    @staticmethod
    def squeeze_leading_dims(shape: list[int]) -> list[int]:
        """
        Remove all leadings dimension equal to 1 and return the resulting shape
        """
        i = 0
        while i < len(shape) - 1:
            if shape[i] != 1:
                break
            i += 1

        if i == 0:
            return shape

        return shape[i:]

    @staticmethod
    def reshape_into_3D_tensor(shape: list[int], batch_size: int) -> list[int]:
        assert batch_size >= 1
        new_shape = shape.copy()

        # Make sure that we have at least a shape of rank 4 so we can fold the batch dimension or the
        # second to second last dimensions if needed.
        while len(new_shape) < 4:
            new_shape.insert(0, 1)

        # fold the outer most dimension into a single batch dimension matching the batch size requested
        outer_dim = new_shape.pop(0)
        while outer_dim != batch_size and len(new_shape):
            outer_dim *= new_shape.pop(0)

        new_shape.insert(0, outer_dim)
        if len(new_shape) < 3 or new_shape[0] != batch_size:
             raise RuntimeError(f"Could not reshape {shape} into a 3d shape with requested batch size '{batch_size}'")

        # Now that the outer most dimension is equal to the batch size,
        # reshape the second dimension and keep the last one intact.
        if len(new_shape) > 3:
            new_shape = new_shape[0:1] + [np.prod(new_shape[1:-1])] + new_shape[-1:]

        return new_shape

    def MHA_init(self, json_dict):
        assert "InTransposeK" in json_dict["attributes"], "Information about whether K input is transposed or not is required"
        assert "num_heads" in json_dict["attributes"], "number of heads should always be presented"

        # Make sure that we do not use the old attributes
        self.__delattr__("in_wgt_shape")
        self.__delattr__("in_wgt1_shape")
        self.__delattr__("in_act_shape")
        self.__delattr__("out_act_shape")

        # Extract the number of heads, groups
        num_heads = json_dict["attributes"]["num_heads"][0]
        num_groups = json_dict["attributes"].get("num_groups", [num_heads])[0]
        num_bias = json_dict["attributes"].get("num_bias", [num_heads])[0]
        num_mask = json_dict["attributes"].get("num_mask", [num_heads])[0]

        # Convert inputs and outputs infos, only needed because saved as a string in the json ..
        inputs = ast.literal_eval(json_dict["inputs"])
        outputs = ast.literal_eval(json_dict["outputs"])
        assert len(outputs) == 1, "Expecting only one output for the MHA node"

        # Extract the activation inputs and outputs shape, apply tranpose if needed and reshape them
        # into 3D tensors
        self.activations_shapes: dict[str, list[int]] = dict()
        self.permK_3d = None
        for input_info in inputs:
            if input_info["type"] != "act":
                continue

            input_name = input_info["param_name"].lower()
            shape = input_info["shape"]

            # add input datatype and data bytes with the old key.
            # FIXME: Should not be needed if inputs dict is used
            self.__setattr__("in_" + input_name + "_datatype", input_info["dtype"])
            self.__setattr__("in_" + input_name + "_bytes", input_info["dtype_bytes"])

            # Apply the permutation if a transpose op was fused with the K input
            if json_dict["attributes"]["InTransposeK"][0] and input_name == "k":
                shape, permK = Layer.apply_permutation(shape, json_dict["attributes"], "InTransposeK", "permK")
                self.permK_3d = Layer.process_perm_vector(permK)

            # Reshape the tensor shape into 3d tensors, using either:
            # - number of groups for K and V:
            # - number of heads for Q
            # - num_bias for B and num_mask for M
            if input_name in ["k", "v"]:
                shape = Layer.reshape_into_3D_tensor(shape, num_groups)
            elif input_name == "b":
                shape = Layer.reshape_into_3D_tensor(shape, num_bias)
            elif input_name == "m":
                shape = Layer.reshape_into_3D_tensor(shape, num_mask)
            else:
                shape = Layer.reshape_into_3D_tensor(shape, num_heads)

            self.activations_shapes["ifm_" + input_name] = shape

        self.activations_shapes["ofm"] = Layer.reshape_into_3D_tensor(outputs[0]["shape"], num_heads)

        # Add the output tensor informations
        # FIXME: inconsistency between the reader and the consumer of the layer infos.
        # Hence the need to change the value of the attribute (change the key in the final json file).
        self.out_ofm_datatype = self.out_datatype
        self.__delattr__("out_datatype")
        self.out_ofm_bytes = self.out_bytes
        self.__delattr__("out_bytes")

        # Check that the inner dimensions are the same for the inputs tensors
        if self.activations_shapes["ifm_q"][2] != self.activations_shapes["ifm_k"][1]:
            raise ValueError(
                "K dimension of input and weight must be same for Q/K MatMul"
            )
        if "ifm_v" in self.activations_shapes and self.activations_shapes["ifm_v"][1] != self.activations_shapes["ifm_k"][2]:
            raise ValueError(
                "N dimension of input and weight must be same for Sfm/V MatMul"
                    )

    def __init__(self, json_dict):
        self.in_act_shape  = []
        self.out_act_shape = []
        self.in_wgt_shape  = []

        json_dict.setdefault('conv_padding_enable', True)

        for key, value in json_dict.items():
            setattr(self, key, value)

            # if '_shape' in key:
            #     setattr(self, key+'_padding',[0]*len(value))

        # handle custom operator names
        if '_' in self.op_type:
            self.orig_op_type = self.op_type.split('_')[0]
        else:
            self.orig_op_type = self.op_type
        # add padding dims
        if self.orig_op_type == 'PWLA':
            ## W8 alignment padding for LUT ops
            pad_to_mul8 = lambda x: iceil(x, 8)
            self.in_act_shape[-1] = pad_to_mul8(self.in_act_shape[-1])
            self.out_act_shape[-1] = pad_to_mul8(self.out_act_shape[-1])
            if self.in_wgt_shape and len(self.in_wgt_shape)>1:
                self.in_wgt_shape[-1] = pad_to_mul8(self.in_wgt_shape[-1])
            
            self.in_act_shape = np.array([np.prod(self.in_act_shape)])
            self.out_act_shape = np.array([np.prod(self.out_act_shape)])
            
        elif self.orig_op_type in ['Add','Mul']:
            self.attributes = json_dict.get('attributes', None)
            self.orig_act_shape = copy.deepcopy(self.in_act_shape)
            self.orig_wgt_shape = copy.deepcopy(self.in_wgt_shape)
            self.orig_ofm_shape = copy.deepcopy(self.out_act_shape)
            ## W8 alignment padding for Add and Mul
            if 'BroadCast' in self.op_type and self.in_wgt_shape[-1] == 1 and self.in_act_shape[-1] != self.in_wgt_shape[-1]:
                self.in_wgt_shape[-1] = self.in_wgt_shape[-1]
            else:
                self.in_wgt_shape[-1] = np.ceil(self.in_wgt_shape[-1] / 8) * 8
            self.in_act_shape[-1] = np.ceil(self.in_act_shape[-1] / 8) * 8
            self.out_act_shape[-1] = np.ceil(self.out_act_shape[-1] / 8) * 8
            #Padded shapes
            self.in_act_shape = np.array((self.in_act_shape)).astype(int).tolist()
            self.in_wgt_shape = np.array((self.in_wgt_shape)).astype(int).tolist()
            self.out_act_shape = np.array((self.out_act_shape)).astype(int).tolist()
        elif self.orig_op_type == "LayerNormalization":
            self.in_wgt_shape = Layer.squeeze_leading_dims(self.in_wgt_shape)
            self.in_wgt1_shape = Layer.squeeze_leading_dims(self.in_wgt1_shape)
        elif self.orig_op_type == 'MHA':
            self.MHA_init(json_dict)
        elif self.orig_op_type == 'MatMul':
            self.bias_bytes = 8 * (self.coeff_shape[0]//self.in_wgt_shape[-1]) if 'qdq' in self.op_type else self.wgt1_bytes
            attrs = json_dict.get('attributes', None)
            if attrs is not None:
                if 'Unsqueeze' in attrs and attrs['Unsqueeze'][0] == 1:
                    axes_list=[attrs[x][0] for x in attrs.keys() if 'axes' in x ]
                    original_dims = [x for x in range(len(self.out_act_shape)) if x not in axes_list]
                    self.out_act_shape = np.array(self.out_act_shape)[axes_list].tolist() + np.array(self.out_act_shape)[original_dims].tolist()
            
            #NOTE: Preserving the original shape from unique_nodes.json
            self.in_act_shape_orig        = self.in_act_shape
            self.in_wgt_shape_orig        = self.in_wgt_shape
            self.out_act_shape_orig       = self.out_act_shape

            self.in_act_shape, self.permA = Layer.apply_permutation(
                self.in_act_shape, attrs, "InTransposeA", "permA"
            )
            self.in_wgt_shape, self.permB = Layer.apply_permutation(
                self.in_wgt_shape, attrs, "InTransposeB", "permB"
            )
            self.permY = Layer.get_permutation(
                self.out_act_shape, attrs, "OutTranspose", "permY"
            )
            self.out_act_shape, self.rev_permY = Layer.reverse_permutation(
                self.out_act_shape, attrs, "OutTranspose", "permY"
            )

            num_batches = json_dict["attributes"].get("num_batches", [1])
            # Reshape 2d matmuls to 2d shape
            if num_batches is None or num_batches[0] == 1:
                self.in_act_shape = Layer.reshape_shape(self.in_act_shape)
                self.in_wgt_shape = Layer.reshape_shape(self.in_wgt_shape)
                self.out_act_shape = Layer.reshape_shape(self.out_act_shape)

            # Reshape all tensors to 3D, for 2d matmul 1 will be appended.
            self.in_act_shape = Layer.reshape_shape(self.in_act_shape, last_dims=2)
            self.in_wgt_shape = Layer.reshape_shape(self.in_wgt_shape, last_dims=2)
            self.out_act_shape = Layer.reshape_shape(self.out_act_shape, last_dims=2)
            
            # Check B, M, K, N dimensions
            if num_batches[0] != self.out_act_shape[0]:
                raise ValueError('Inconsistent B dimensions')
            if self.in_act_shape[-2] != self.out_act_shape[-2]:
                raise ValueError('Inconsistent M dimensions')
            if self.in_act_shape[-1] != self.in_wgt_shape[-2]:
                raise ValueError('Inconsistent K dimensions')
            if self.in_wgt_shape[-1] != self.out_act_shape[-1]:
                raise ValueError('Inconsistent N dimensions')

            # Make batch dims consistent across ifm/wgt (PSMU_ST0 MatMul_4)
            self.in_act_shape[0] = self.in_wgt_shape[0] = num_batches[0]

            self.permA = Layer.process_perm_vector(self.permA)
            self.permB = Layer.process_perm_vector(self.permB)
            self.permY = Layer.process_perm_vector(self.permY)

            # Compute reverse permutations
            self.rev_permA = np.argsort(self.permA).tolist()
            self.rev_permB = np.argsort(self.permB).tolist()

        elif self.orig_op_type == 'RoPE':
            self.orig_act_shape =  [int(np.prod(self.in_act_shape[:-1])), self.in_act_shape[-1]]
            self.in_act_shape = np.array([np.prod(self.in_act_shape)]).astype(int).tolist()
            self.in_wgt_shape = np.array([np.prod(self.in_wgt_shape)]).astype(int).tolist() #sin
            self.in_wgt1_shape = np.array([np.prod(self.in_wgt1_shape)]).astype(int).tolist() #cos
            self.out_act_shape = np.array([np.prod(self.out_act_shape)]).astype(int).tolist()
        elif self.orig_op_type == 'Conv':
            if self.in_datatype in ["float32"]:
                raise UnsupportedDataTypeError(f"'{self.orig_op_type}' op does not support in_datatype as '{self.in_datatype}'")
            self.kernel_shape = self.attributes['kernel_shape']
            self.strides = self.attributes['strides']
            # TODO: Currently assuming pads are symmetric
            self.padding = [self.attributes['pads'][0], self.attributes['pads'][1], self.attributes['pads'][2], self.attributes['pads'][3]]
            
            # AIESW-18486: Fold-based mapping for Conv 7x7 stride 4x4.
            # Host runtime folds X into C with:
            #   - fold_factor_Ci = 3  => Kx: 7 -> ceil(7/3)=3, Ci: Ci -> Ci*3
            #   - fold_factor_Xi = 2  => Xi: Xi -> Xi/2, Sx: 4 -> 2
            # and adjusts X padding to (Px_b=1, Px_a=0) while keeping Y padding unchanged.
            if (self.kernel_shape[0] == 7 and self.kernel_shape[1] == 7 and
                self.strides[0] == 4 and self.strides[1] == 4):
                fold_factor_Ci = 3
                fold_factor_Xi = 2

                if self.in_act_shape[2] % fold_factor_Xi != 0:
                    raise ValueError(
                        f"Conv7x7 special format requires Xi divisible by {fold_factor_Xi}, got Xi={self.in_act_shape[2]}"
                    )

                # Update kernel/stride params used by tiler + downstream
                self.kernel_shape = [7, 3]
                self.strides = [4, 2]
                self.attributes['kernel_shape'] = self.kernel_shape
                self.attributes['strides'] = self.strides

                # Adjust pads: [Py_b, Px_b, Py_a, Px_a]
                pads = self.attributes.get('pads', None)
                if pads is not None and len(pads) >= 4:
                    pads = list(pads)
                    pads[1] = 1
                    pads[3] = 0
                    self.attributes['pads'] = pads
                    self.padding = [pads[0], pads[1], pads[2], pads[3]]

                self.attributes['conv7x7_special_format'] = [True]

            """for WAIC flow, we will enable the padding
                1. for the input, will treat the Ci = max(64, W8) despite the real Ci is 
                2. for the output, will do padding the Co = max(64, W8) despite the real Co is
                3. will annotate this in the tiling.json
            """
            # NOTE: check to replace the Ci, Co if needed
            pad_dim = lambda x: iceil(x, 8)
            if self.conv_padding_enable:
                self.in_act_shape[3], self.out_act_shape[3], self.in_wgt_shape[2], self.in_wgt1_shape[0] = \
                    [pad_dim(var) for var in (self.in_act_shape[3], self.out_act_shape[3], self.in_wgt_shape[2], self.in_wgt1_shape[0])]
                if len(self.in_wgt_shape) == 4:
                    self.in_wgt_shape[3] = pad_dim(self.in_wgt_shape[3])
            
            # Create aligned shapes for hardware requirements (Cin%8=0)
            # Check for Cin mismatch between input and weight
            if self.in_act_shape[3] != self.in_wgt_shape[2]:
                print('in_act_shape', self.in_act_shape, 'in_wgt_shape', self.in_wgt_shape)
                raise ValueError('Cin mismatch between IFM and WGT')
            
            # Create aligned versions of shapes
            self.aligned_in_act_shape = self.in_act_shape.copy()
            self.aligned_in_wgt_shape = self.in_wgt_shape.copy()
            self.aligned_out_act_shape = self.out_act_shape.copy()
            self.aligned_in_wgt1_shape = self.in_wgt1_shape.copy()
            
            # Align input channels to multiple of 8
            if self.aligned_in_act_shape[3] % 8 != 0:
                self.aligned_in_act_shape[3] = (self.aligned_in_act_shape[3] // 8 + 1) * 8
                self.aligned_in_wgt_shape[2] = (self.aligned_in_wgt_shape[2] // 8 + 1) * 8
            
            # Align output channels to multiple of 8
            if self.aligned_out_act_shape[3] % 8 != 0:
                self.aligned_out_act_shape[3] = (self.aligned_out_act_shape[3] // 8 + 1) * 8
                self.aligned_in_wgt_shape[3] = (self.aligned_in_wgt_shape[3] // 8 + 1) * 8
                self.aligned_in_wgt1_shape[0] = (self.aligned_in_wgt1_shape[0] // 8 + 1) * 8
            if not hasattr(self, 'is_standalone_dwc'):
                """for temporary solution to hard covert to DWC"""
                self.is_standalone_dwc = (self.attributes['group'][0] == self.in_act_shape[3])

# if __name__=='__main__':
#     import json
#     with open('c:/Users/sourabhd/Downloads/PSI_conv.json') as f:
#         mdict = json.loads(f.read())

#     # md = mdict['/up_blocks.1/attentions.2/transformer_blocks.1/attn2/MatMul_1'].copy()
#     cd = mdict['/conv_down/Conv1'].copy()

#     # md['weight_datatype']='bfp16'
#     # cd['weight_datatype']='bfp16'

#     # ml = Layer(md)
#     cl = Layer(cd)
