"""Gen Input/ Output Data"""
import sys
import os
import shutil
import numpy as np
import ml_dtypes
from scipy.special import erf
from utils.utils_common import log


def copy_and_rename(src_path, old_name, dest_path, new_name):
    "Copy file from src_path/old_name to dest_path/new_name"
    # Copy the file
    shutil.copy(f"{src_path}/{old_name}", dest_path)

    # Rename the copied file
    new_path = f"{dest_path}/{new_name}"
    shutil.move(f"{dest_path}/{old_name}", new_path)


def load_binary_array(file_path, dtype=np.uint16, shape=None):
    """
    Load a NumPy array from a raw binary file.

    Parameters:
        file_path (str): Path to the binary file.
        dtype (np.dtype): Data type of the stored elements.
        shape (tuple or None): Shape of the array. If None, returns 1D array.

    Returns:
        np.ndarray: Loaded NumPy array.
    """
    # Validate file existence
    if not os.path.isfile(file_path):
        raise FileNotFoundError(f"File not found: {file_path}")

    try:
        # Read binary data into NumPy array
        arr = np.fromfile(file_path, dtype=dtype)

        # Reshape if shape is provided
        if shape is not None:
            expected_size = np.prod(shape)
            if arr.size != expected_size:
                raise ValueError(f"File size does not match expected shape {shape}")
            arr = arr.reshape(shape)

        return arr

    except Exception as e:
        raise RuntimeError(f"Error loading binary file: {e}") from e


# Define the non-linear functions
class nonlinear_function:
    """
    Parameters:
    x (np.ndarray): Input array of rank up to 4.
    axis (int): Axis along which to normalize. Default is -1 (last axis).
    epsilon (float): Small value to avoid division by zero.

    Returns:
    np.ndarray: output array
    """

    def __init__(self, func: str):
        self.func = func
        self.gamma_beta = None
        self.lut = None

    def call(self, x, bitmask=None, axis=-1, epsilon=1e-12):
        """Function to select operator"""

        if self.func == "l2norm":
            x = np.asarray(x)
            norm = np.linalg.norm(x, ord=2, axis=axis, keepdims=True)
            y = x / (norm + epsilon)
            # v_invsqrt_sumx2 = 1 / (norm + epsilon) #log(v_invsqrt_sumx2) #log(x / (norm + epsilon))

        elif self.func == "silu":
            y = x / (1 + np.exp(-x))

        elif self.func == "gelu":
            y = 0.5 * x * (1 + erf(x / np.sqrt(2)))

        elif self.func == "softmax":
            (N, C) = x.shape

            # Begin compositing mask
            mask = np.zeros((N, C)).astype(np.uint16)
            if bitmask is not None:
                for r in range(N):
                    for c in range(C):
                        if ((1 << c % 16) & bitmask[r, c//16]):
                            mask[r, c] = 1
            # End of mask composition

            if bitmask is not None:
                x[mask == 0] = -np.inf

            e_x = np.exp(x - np.max(x, axis=axis, keepdims=True))
            y = e_x / e_x.sum(axis=axis, keepdims=True)

        elif self.func == "layernorm":
            sigma = 3.5
            mu = 0
            self.gamma_beta = sigma * np.random.randn(2, x.shape[-1]) + mu
            mean = np.mean(x, axis=-1, keepdims=True)
            std = np.std(x, axis=-1, keepdims=True)

            y = ((x - mean) / (std + epsilon)) * self.gamma_beta[1, :] + self.gamma_beta[0, :]  # Normalize the input
        elif self.func == "groupnorm":
            num_groups = 32  # Currently fixed to 32 groups
            epsilon = 0.000009999999747378752  # 1e-5
            # Sigma and mu are the params of the normal distribution can be any value in the TB
            sigma = 3.5
            mu = 0
            self.gamma_beta = sigma * np.random.randn(2, x.shape[-1]) + mu
            # To blfoat and back to float32 to simulate the aie accuracy
            self.gamma_beta = (self.gamma_beta.astype(x.dtype)).astype(np.float32)
            # Assume channel is first dim
            assert x.shape[-1] % num_groups == 0
            group_size = x.shape[-1] // num_groups
            # Reshape to [N, group_size, C/group_size, H, W, ...]
            new_shape = [x.shape[0], num_groups, group_size]
            x_grouped = x.reshape(new_shape)
            x_reshaped = x_grouped.transpose(1, 0, 2).reshape(num_groups, x.shape[0] * group_size)
            mean = np.mean(x_reshaped, axis=-1, keepdims=True)
            var = np.var(x_reshaped, axis=-1, keepdims=True)
            mean_rep = np.repeat(mean, group_size)
            var_rep = np.repeat(var, group_size)
            op1 = 1/np.sqrt(var_rep + epsilon)
            op2 = mean_rep * op1

            beta = self.gamma_beta[0, :]
            gamma = self.gamma_beta[1, :]
            y = x*gamma*op1 - op2*gamma + beta

        elif self.func == "copy":
            y = x.copy()

        elif self.func == "swish":
            y = x / (1.0 + np.exp(-1.70*x))

        elif self.func == "sigmoid":
            y = 1.0 / (1.0 + np.exp(-x))

        elif self.func == "tanh":
            y = np.tanh(x)

        elif self.func == "elu":
            y = np.where(x > 0, x, 1.7 * (np.exp(x) - 1))

        else:
            log("Function Not supported")
            assert False

        return y


def dequantize(x, scale, zero_point, outputType):
    """Function to dequantize"""
    assert outputType in {ml_dtypes.bfloat16, np.float16, np.float32}

    x_s32 = x.astype(np.int32)
    shift = x_s32 - zero_point
    y = shift * scale

    return y.astype(outputType)


def quantize(x, scale, zero_point, inputType):
    """Function to quantize"""
    assert inputType in {ml_dtypes.bfloat16, np.float16, np.float32}

    y = np.clip(np.round(np.float64(x) * (1 / scale) + zero_point, decimals=0), 0, 65535).astype(np.uint16)
    return y


class ConfigParam:
    """Create Layer Param"""
    def __init__(
        self,
    ):
        self.zp_i = 0
        self.sc_i = 0
        self.dq_enable = False
        self.zp_o = 0
        self.sc_o = 0
        self.q_enable = False
        self.nlf_enable = True

    def enable_dQ(self):
        """Enable DQ"""
        self.dq_enable = True

    def enable_Q(self):
        """Enable Q"""
        self.q_enable = True

    def enable_nlf(self):
        """Enable NLF"""
        self.nlf_enable = True

    def set_dQ_zero_point(self, zp: float):
        """Set DQ Zero-point"""
        self.zp_i = zp

    def set_Q_zero_point(self, zp: float):
        """Set Q Zero-point"""
        self.zp_o = zp

    def set_dQ_scale(self, sc: float):
        """Set DQ Scale"""
        self.sc_i = sc

    def set_Q_scale(self, sc: float):
        """Set Q Scale"""
        self.sc_o = sc


def YC_to_CYC(x, num_cols_per_blk):
    """Dimension Unfolding"""
    y = np.copy(x.flatten())

    H = x.shape[0]
    W = x.shape[1]   # Total number of columns

    # num_cols_per_blk = 64
    total_blks = W // num_cols_per_blk

    for blk_id in range(0, total_blks):
        for row in range(H):
            for col_in_blk in range(num_cols_per_blk):
                col = blk_id * num_cols_per_blk + col_in_blk
                src_index = row * W + col
                dst_index = blk_id * (H * num_cols_per_blk) + (row * num_cols_per_blk + col_in_blk)
                y[dst_index] = x.flatten()[src_index]

    return y


def generate_and_save_io_pairs(
    num_samples,
    input_dim,      # input_dim is true dimension of the model data
    padded_dim,
    config: ConfigParam,
    test_data_dir: str = "./",
    function: str = "l2norm",
    overlay_is_1x1: bool = False,
    qdq_input_fp32: bool = False,
    qdq_output_fp32: bool = False,
    qdq_floating_is_fp16: bool = False
):
    """Routine to generate and save input-output pairs in proper format"""
    log("qdq_input_fp32:", qdq_input_fp32)
    log("padded_dim:", padded_dim)
    log("input_dim:", input_dim)

    # if not (function in {"silu", "gelu", "dequant", "quant"}):   # non-element-wise Unary op
    assert padded_dim >= input_dim
    output_dim = input_dim    # basic assumption of Unary op

    # in case of function == dQ or Q, use the following and early exit
    float_16bit_type = np.float16 if qdq_floating_is_fp16 else ml_dtypes.bfloat16
    dQ_OutType = np.float32 if qdq_output_fp32 else float_16bit_type
    Q_InType = np.float32 if qdq_input_fp32 else float_16bit_type

    log("qdq_input_fp32:", qdq_input_fp32)
    log("qdq_output_fp32:", qdq_output_fp32)

    if function == "dequant":
        input_data = np.random.randint(5, 15, size=(num_samples, input_dim)).astype(np.uint16)
        dqOut = dequantize(input_data, config.sc_i, config.zp_i, dQ_OutType) if config.dq_enable else input_data

        log("config.dq_enable:", config.dq_enable)
        log("dqOut.dtype:", dqOut.dtype)

        padded_input = np.zeros((num_samples, padded_dim)).astype(input_data.dtype)
        padded_input[:, :input_dim] = input_data
        padded_input.tofile(os.path.join(test_data_dir, "input_0.bin"))

        padded_output = np.zeros((num_samples, padded_dim)).astype(dqOut.dtype)
        padded_output[:, :output_dim] = dqOut
        padded_output.tofile(os.path.join(test_data_dir, "output_0.bin"))

        return

    if function == "quant":
        sigma = 3.5
        mu = 0

        input_data = (sigma * np.random.randn(num_samples, input_dim) + mu).astype(Q_InType)
        qOut = quantize(input_data, config.sc_o, config.zp_o, Q_InType) if config.q_enable else input_data

        log("config.q_enable:", config.q_enable)
        log("input_data.dtype:", input_data.dtype)
        log("qOut.dtype:", qOut.dtype)

        padded_input = np.zeros((num_samples, padded_dim)).astype(input_data.dtype)
        padded_input[:, :input_dim] = input_data
        padded_input.tofile(os.path.join(test_data_dir, "input_0.bin"))

        padded_output = np.zeros((num_samples, padded_dim)).astype(qOut.dtype)
        padded_output[:, :output_dim] = qOut
        padded_output.tofile(os.path.join(test_data_dir, "output_0.bin"))
        return

    nlf = nonlinear_function(function)
    if nlf.func == "copy":
        InType = np.float32 if qdq_input_fp32 else np.uint16
        OutType = np.float32 if qdq_output_fp32 else np.uint16
        input_data = np.random.randint(5, 15, size=(num_samples, input_dim)).astype(InType)

        if OutType == InType:
            y1 = input_data.copy().astype(InType)

        elif InType == np.float32 and OutType == np.uint16:
            input_data = (3.5 * np.random.randn(num_samples * input_dim) + 0.0).astype(np.float32)
            log("input_data:", input_data)
            log("input_data.shape:", input_data.shape)
            log("input_data.dtype:", input_data.dtype)

            input_1d = input_data.reshape(input_data.size)
            arr_2 = input_1d.view(np.uint16)      # as uint16 element, dimension doubled

            log("arr_2.shape:", arr_2.shape)
            log("arr_2.dtype:", arr_2.dtype)

            y1 = arr_2[::2].reshape(input_data.shape)  # extract every other element, then reshape to input_data shape

        elif InType == np.uint16 and OutType == np.float32:
            input_data = np.random.randint(5, 65531, size=(num_samples, input_dim)).astype(InType)  # u16 random input_data
            y1 = np.zeros(input_data.size * 2).astype(np.uint16)  # double the num elements, but shrink dataType
            y1[::2] = input_data.reshape(input_data.size).astype(np.uint16)   # lower  16bits
            y1[1::2] = 0x43bb                                       # higher 16bits

        else:
            log("input_data/output type pairs not supported!")
            sys.exit(1)

        log("y1.shape:", y1.shape)
        log("y1.dtype:", y1.dtype)
        log("y1:", y1)

        input_data.tofile(os.path.join(test_data_dir, "input_0.bin"))
        y1.tofile(os.path.join(test_data_dir, "output_0.bin"))
        return

    # ###############################################################################################
    # ##  Generate random inputs in : bfloat16 if dQ_enabled , in : uint16 if not dQ_enabled
    # #################################################################################################
    if function not in {"swish", "sigmoid", "tanh", "elu"}:
        sigma = 3.5
    else:
        if function == "swish":
            sigma = 1.3
        elif function == "tanh":
            sigma = 0.2
        elif function == "sigmoid":
            sigma = 1.0
        elif function == "elu":
            sigma = 1.0
    mu = 0

    if not config.dq_enable:
        input_data = (sigma * np.random.randn(num_samples, input_dim) + mu).astype(float_16bit_type)
        # input_data = (np.ones((num_samples, input_dim)) * (-5)).astype(float_16bit_type)
        dqOut = None
    else:
        if function == "groupnorm":
            in_min = 2450
            in_max = 63400
        else:
            in_min = 5
            in_max = 15
        # input_data = np.random.randint(5, 15, size=(num_samples, input_dim)).astype(np.uint16)
        # input_data = np.random.randint(52500, 57500, size=(num_samples, input_dim)).astype(np.uint16)
        input_data = np.random.randint(in_min, in_max, size=(num_samples, input_dim)).astype(np.uint16)
        dqOut = dequantize(input_data, config.sc_i, config.zp_i, dQ_OutType)

    # for debug : overwrite the numpy array with data coming from .bin file:
    # input_data = load_binary_array(test_data_dir + "../../../../scheduler/uniop/activation.bin", dtype=np.uint16, shape=(num_samples, input_dim))
    # dqOut = dequantize(input_data, config.sc_i, config.zp_i)
    # nlf.gamma_beta = load_binary_array(test_data_dir + "../../../../scheduler/uniop/beta_gamma.bin", dtype=ml_dtypes.bfloat16, shape=(2, input_dim))
    # ###############################################################################################
    #   Save input to binary files
    # ###############################################################################################

    # ###############################################################################################
    #   Save input to binary files
    # ###############################################################################################

    if overlay_is_1x1:
        input_CYC64 = YC_to_CYC(input_data, 64)
        input_CYC64.tofile(os.path.join(test_data_dir, "input_0.bin"))
    else:
        padded_input = np.zeros((num_samples, padded_dim)).astype(input_data.dtype)
        padded_input[:, :input_dim] = input_data
        padded_input.tofile(os.path.join(test_data_dir, "input_0.bin"))

    log("input_data:", input_data)
    log("dqOut:", dqOut)
    log("num_samples:", num_samples)
    log("input_dim:", input_dim)

    if overlay_is_1x1:
        bitmask = np.random.randint(0, 65535, size=(num_samples, input_dim // 16)).astype(np.uint16)
        bitmask_CYC64 = YC_to_CYC(bitmask, 4)
        bitmask_CYC64.tofile(os.path.join(test_data_dir, "input_1.bin"))
    else:
        bitmask = None

    # ###############################################################################################
    #    Generate outputs in : uint16 if Q_enabled , in : bfloat16 if not Q_enabled
    # ###############################################################################################
    if not config.q_enable:
        if not config.dq_enable:
            y1 = nlf.call(input_data.astype(np.float32), bitmask=bitmask).astype(float_16bit_type)
        else:
            y1 = nlf.call(dqOut.astype(np.float32), bitmask=bitmask).astype(float_16bit_type)
    else:
        if not config.dq_enable:
            y0 = nlf.call(input_data.astype(np.float32), bitmask=bitmask)
        else:
            y0 = nlf.call(dqOut.astype(np.float32), bitmask=bitmask)
        y1 = quantize(y0, config.sc_o, config.zp_o, Q_InType).astype(np.uint16)

    log("config.sc_o:", config.sc_o)
    log("config.zp_o:", config.zp_o)
    output = y1  # dqOut#y1 #input_data.copy() #y1 # dqOut.copy() #y1
    # output = load_binary_array(test_data_dir + "../../../../scheduler/uniop/ref_output.bin", dtype=np.uint16, shape=(num_samples, input_dim))

    # ###############################################################################################
    #   Save Gamma/Beta Weights to binary files
    # ###############################################################################################
    if function in ["layernorm", "groupnorm"]:
        (nlf.gamma_beta.astype(float_16bit_type)).tofile(os.path.join(test_data_dir, "input_1.bin"))

    if function in {"swish", "sigmoid", "tanh", "elu", "silu", "gelu"}:
        os.environ["ENABLED_LOG"] = "true"
        current_directory = os.getcwd()
        aie4_model_repo_root = os.environ.get("AIE4_ROOT_DIR")
        log("current_directory:", current_directory)
        # copy_and_rename("../../scheduler/uniop/", "swish_lut_1_0.bin", test_data_dir, "input_1.bin")
        if function == "swish":
            copy_and_rename(os.path.join(aie4_model_repo_root, "scheduler", "uniop"), "swish_lut_1_7.bin", test_data_dir, "input_1.bin")
        elif function == "sigmoid":
            copy_and_rename(os.path.join(aie4_model_repo_root, "scheduler", "uniop"), "sigmoid_lut.bin", test_data_dir, "input_1.bin")
        elif function == "tanh":
            copy_and_rename(os.path.join(aie4_model_repo_root, "scheduler", "uniop"), "tanh_lut.bin", test_data_dir, "input_1.bin")
        elif function == "elu":
            copy_and_rename(os.path.join(aie4_model_repo_root, "scheduler", "uniop"), "elu_lut.bin", test_data_dir, "input_1.bin")
        elif function == "silu":
            copy_and_rename(os.path.join(aie4_model_repo_root, "scheduler", "uniop"), "silu_lut.bin", test_data_dir, "input_1.bin")
        elif function == "gelu":
            copy_and_rename(os.path.join(aie4_model_repo_root, "scheduler", "uniop"), "gelu_lut.bin", test_data_dir, "input_1.bin")

    # ###############################################################################################
    #   Save output to binary files
    # ###############################################################################################

    if overlay_is_1x1:
        output_CYC64 = YC_to_CYC(output, 64)
        output_CYC64.tofile(os.path.join(test_data_dir, "output_0.bin"))
    else:
        padded_output = np.zeros((num_samples, padded_dim)).astype(output.dtype)
        padded_output[:, :output_dim] = output
        padded_output.tofile(os.path.join(test_data_dir, "output_0.bin"))
    log("padded_output:", padded_output)
    log("test_data_dir:", test_data_dir)
    log("Saved input and output pairs to binary files.")
