from calendar import c
import math
import struct
import ctypes
import os
import sys
from enum import IntEnum

CURRDIR = os.path.dirname(os.path.abspath(__file__))

from kernel.common.kernel_params_helper import (
    DimsHelper,
)

class GemmDims:
    def __init__(
        self,
        M, K, N,
        Msubv, Ksubv, Nsubv,
        aie_rows, aie_cols,
        ifm_bits, wgt_bits, c0_bits, 
        c1_bits, c2_bits, out_bits,
        K_gran,
        sign_A, sign_W, sign_O, shift_out, vector_coeff
    ):
        self.param_size = 1024
        self.param_bits = 8
        self.M = M
        self.K = K
        self.N = N
        self.Msubv = Msubv
        self.Ksubv = Ksubv
        self.Nsubv = Nsubv
        self.aie_rows = aie_rows
        self.aie_cols = aie_cols
        self.ifm_bits = ifm_bits
        self.wgt_bits = wgt_bits
        self.c0_bits = c0_bits
        self.c1_bits = c1_bits
        self.c2_bits = c2_bits
        self.out_bits = out_bits
        self.K_gran = K_gran
        assert(Ksubv % self.K_gran == 0)
        self.M_loop = self.M // self.Msubv
        self.N_loop = self.N // self.Nsubv
        self.K_loop = self.K // self.Ksubv
        self.sign_A = sign_A
        self.sign_W = sign_W
        self.sign_O = sign_O
        self.shift_out = shift_out
        self.vector_coeff = vector_coeff
        self.wgt_size = (Ksubv * Nsubv * wgt_bits) // 8
        self.c0_size = (Nsubv * c0_bits) // 8
        self.c1_size = (Nsubv * c1_bits) // 8
        self.c2_size = (Nsubv * c2_bits) // 8
        self.qdq_size = 128
        self.wgt_subv_bytes = self.wgt_size  + self.c0_size + self.c1_size + self.c2_size + self.qdq_size


def setup_gemm_params(
    dims: GemmDims
) -> bytes:
    granM = 32
    step_back_AM_int16 = -128
    step_back_B_int4 = -4096/2
    dimsClassA = DimsHelper(step_back_AM_int16 * granM)
    dimsClassB = DimsHelper( step_back_B_int4 )
    dimsClassQ = DimsHelper(0)
    
    outer_iters = dims.M_loop * dims.N_loop
    inner_iters = dims.K_loop
    inner_loop = dims.Ksubv // 64
    Y_g = dims.Msubv // 32
    X_g = dims.Nsubv // 64
    step_Xi = 64 * dims.Msubv * 2
    step_Yi = 64 * 2
    step_Kx = 64
    step_Ky = dims.Nsubv * (64 // 2)
    # shift_res = dims.shift_out
    # ctrl = 0
    dimsA = dimsClassA.from_steps(( inner_loop, Y_g ), ( step_Xi, step_Yi * granM, 0))
    dimsB = dimsClassB.from_steps(( inner_loop, Y_g ), ( step_Ky, 0, step_Kx * 64 ))
    dimsQ = dimsClassQ.from_steps((Y_g),(0,512))

    print(f"outer_iters: {outer_iters}")
    print(f"inner_iters: {inner_iters}")
    print(f"inner_loop: {inner_loop}")
    print(f"Y_g: {Y_g}")
    print(f"X_g: {X_g}")
    print(f"step_Xi: {step_Xi}")
    print(f"step_Yi: {step_Yi}")
    print(f"step_Kx: {step_Kx}")
    print(f"step_Ky: {step_Ky}")
    # print(f"shift_res: {shift_res}")
    # print(f"ctrl: {ctrl}")
    print(f"dimsA['num0']: {dimsA['num0']}")
    print(f"dimsA['num1']: {dimsA['num1']}")
    print(f"dimsA['inc0']: {dimsA['inc0']}")
    print(f"dimsA['inc1']: {dimsA['inc1']}")
    print(f"dimsA['inc2']: {dimsA['inc2']}")

    print(f"dimsB['num0']: {dimsB['num0']}")
    print(f"dimsB['num1']: {dimsB['num1']}")
    print(f"dimsB['inc0']: {dimsB['inc0']}")
    print(f"dimsB['inc1']: {dimsB['inc1']}")
    print(f"dimsB['inc2']: {dimsB['inc2']}")

    print(f"dimsQ['num0']: {dimsQ['num0']}")
    print(f"dimsQ['inc0']: {dimsQ['inc0']}")
    print(f"dimsQ['inc1']: {dimsQ['inc1']}")
    
    packed_params = struct.pack(
        '<3H2B4H2I3i2I3i1I2i',
        outer_iters,        # H
        inner_iters,        # H
        inner_loop,     # H
        Y_g,        # B
        X_g,    # B
        step_Xi,    # H
        step_Yi,    # H
        step_Kx,    # H
        step_Ky,    # H
        # shift_res,  # H
        # ctrl,       # B
        # 0,  # B (reserved) - There was a byte alignment issue in the original code, so we added a reserved byte
        dimsA['num0'],  # I
        dimsA['num1'],  # I
        dimsA['inc0'],  # i
        dimsA['inc1'],  # i
        dimsA['inc2'],  # i
        dimsB['num0'],  # I
        dimsB['num1'],  # I
        int(dimsB['inc0']),  # i
        int(dimsB['inc1']),  # i
        int(dimsB['inc2']),  # i
        dimsQ['num0'],  # I
        dimsQ['inc0'],  # i
        dimsQ['inc1'],   # i
    )
    return packed_params

def gemm_layer_params(
    core_spill_buf: int,
    core_ifm_tmp_buffer: int,
    core_coeff_tmp_buffer: int,
    wgt_size: int,
    coeff_size: int,
) -> bytes:
    '''Generate the layer parameters for the GEMM operation.'''
    return (
        core_spill_buf.to_bytes(length=4, byteorder='little', signed=False)
        + core_ifm_tmp_buffer.to_bytes(length=4, byteorder='little', signed=False)
        + core_coeff_tmp_buffer.to_bytes(length=4, byteorder='little', signed=False)
        + wgt_size.to_bytes(length=4, byteorder='little', signed=False)
        + coeff_size.to_bytes(length=4, byteorder='little', signed=False)
    )


def gen_aie4_gemm_params(
    dims: GemmDims,
    core_spill_buf: int,
    core_ifm_tmp_buffer: int,
    core_coeff_tmp_buffer: int,
) -> bytes:
    layer_params = gemm_layer_params(core_spill_buf, core_ifm_tmp_buffer, core_coeff_tmp_buffer, dims.wgt_size, (dims.c0_size + dims.c1_size + dims.c2_size))
    kernel_params = setup_gemm_params(dims)
    print("layer params + kernel params", len(layer_params + kernel_params))
    return (layer_params + kernel_params)
        

        
        