import math
import os
from enum import IntEnum

from kerneltest.helpers import iceil

CURRDIR = os.path.dirname(os.path.abspath(__file__))

def calc_in_pixel_dim(
       X: int,
       K: int, 
       P: int,
       S: int,
) -> int:
    return ((X -1)* S) + K - (2 * P)

class DwcDims:
    __slots__ = (
                'Param_size',
                'mem_align',
                'N',
                'Yi', 'Xi', 'Ci',
                'Yo', 'Xo', 'Co',
                'Yis', 'Xis', 'Cis',
                'Yos', 'Xos', 'Cos',
                'Ky', 'Kx',
                'Sy', 'Sx',
                'Py', 'Px',
                'aie_rows', 'aie_cols',
                'act_bits', 'wgt_bits', 'bias_bits', 'out_bits', 'param_bits',
                'Ci_gran',
                'Co_loop', 'Y_loop', 'X_loop', 'Ci_loop',
                'act_subv_bytes', 'wgt_subv_bytes', 'bias_subv_bytes', 'out_subv_bytes',
                'sign_act', 'sign_wgt', 'sign_out', 'shift_out',
                'size_bytes',
                 ) 
    def __init__(
        self,
        N: int,
        Yi: int,
        Xi: int,
        Ci: int,
        Yo: int,
        Xo: int,
        Co: int,
        Yis: int,
        Xis: int,
        Cis: int,
        Yos: int,
        Xos: int,
        Cos: int,
        Ky: int,
        Kx: int,
        Py: int,
        Px: int,
        Sy: int,
        Sx: int,
        aie_rows: int,
        aie_cols: int,
        act_bits: int,
        wgt_bits: int,
        bias_bits: int,
        out_bits: int,
        param_bits: int,
        Ci_gran: int,
        sign_act: int,
        sign_wgt: int,
        sign_out: int,
        shift_out: int,
    ):
        self.Param_size = 1024
        self.mem_align = 128
        self.size_bytes = 1
        qdq_params_size = 128
        bits_per_byte = 8
        # NOTE: qdq size is hardcoded to 128 bytes
        assert(Ci % Ci_gran == 0)
        assert(Ci % Cis == 0)
        assert(Co % Cos == 0)
        assert(Xos % 2 == 0)
        assert(Yos % 4 == 0)
        assert(Cos % 64 == 0)
        '''
        Supported subvol shapes:
        Yos=1, Xos=64 or
        Yos=2, Xos=32 or
        Yos=4, Xos=16 or
        Yos=8, Xos=8
        '''
        Ky_gran = 3
        Kx_gran = 4
        self.N = N
        self.Yi = Yi 
        self.Xi = Xi 
        self.Ci = Ci 
        self.Yo = Yo 
        self.Xo = Xo 
        self.Co = Co 
        self.Yis =Yis 
        self.Xis =Xis 
        self.Cis =Cis 
        self.Yos =Yos 
        self.Xos =Xos 
        self.Cos =Cos 
        self.Ky = Ky 
        self.Kx = Kx 
        self.Sy = Sy 
        self.Sx = Sx 
        self.Py = Py 
        self.Px = Px 
        self.aie_rows = aie_rows 
        self.aie_cols = aie_cols 
        self.act_bits = act_bits 
        self.wgt_bits = wgt_bits 
        self.bias_bits = bias_bits 
        self.out_bits = out_bits 
        self.param_bits = param_bits 
        self.Ci_gran = Ci_gran
        self.sign_act = sign_act
        self.sign_wgt = sign_wgt
        self.sign_out = sign_out
        self.shift_out = shift_out
        self.bias_subv_bytes = (Cos * bias_bits) // bits_per_byte
        print(f'bias_subv_bytes: {self.bias_subv_bytes}')
        raw_wgt_subv_bytes = iceil( (Cos * max(Ky, Ky_gran) * max(Kx, Kx_gran) * wgt_bits) // bits_per_byte, self.mem_align)
        print(f'raw_wgt_subv_bytes: {raw_wgt_subv_bytes}')
        qdq_terms = 2
        coeff_size = iceil(qdq_terms * self.bias_subv_bytes, self.mem_align)
        self.wgt_subv_bytes = raw_wgt_subv_bytes + coeff_size + qdq_params_size 
        self.out_subv_bytes = (Yos * Xos * Cos * out_bits) // bits_per_byte
        self.Co_loop = math.ceil(Co / Cos)
        self.Y_loop = math.ceil(Yi / Yis)
        self.X_loop = math.ceil(Xi / Xis)
        self.Ci_loop = math.ceil(Ci / Cis)
        self.act_subv_bytes = (Yis * Xis * Cis * act_bits) // bits_per_byte
    
    def __str__(self):
        return (
            f"ConvDims(N={self.N}, Yi={self.Yi}, Xi={self.Xi}, Ci={self.Ci}, "
            f"Yo={self.Yo}, Xo={self.Xo}, Co={self.Co}, "
            f"Yis={self.Yis}, Xis={self.Xis}, Cis={self.Cis}, "
            f"Yos={self.Yos}, Xos={self.Xos}, Cos={self.Cos}, "
            f"Ky={self.Ky}, Kx={self.Kx}, "
            f"Py={self.Py}, Px={self.Px}, "
            f"Sy={self.Sy}, Sx={self.Sx}, "
            f"aie_rows={self.aie_rows}, aie_cols={self.aie_cols}, "
            f"act_bits={self.act_bits}, wgt_bits={self.wgt_bits}, "
            f"bias_bits={self.bias_bits}, out_bits={self.out_bits}, "
            f"param_bits={self.param_bits}, "
            f"Ci_gran={self.Ci_gran}, "
            f"act_subv_bytes={self.act_subv_bytes}, "
            f"bias_subv_bytes={self.bias_subv_bytes}, "
            f"wgt_subv_bytes={self.wgt_subv_bytes}, "
            f"out_subv_bytes={self.out_subv_bytes}, "
            f"Co_loop={self.Co_loop}, Y_loop={self.Y_loop}, "
            f"X_loop={self.X_loop}, Ci_loop={self.Ci_loop}, )"
        )
