
import json
import os
import sys
from typing import List
from dataflow_utils import CommonDims
from q_dq_common import QDQDims
from dataflow_common import ceildiv
from dmacompiler import config
CURRDIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append(CURRDIR)
sys.path.append(os.path.join(CURRDIR, '..'))
sys.path.append(os.path.join(CURRDIR, '..', '..', 'dmacompiler'))
sys.path.append(os.path.join(CURRDIR, '..', '..', 'kernels'))

def Y_X_subv(
    aie_rows: int,
    input: List[int],
    ifm_bits: int,
    ofm_bits: int
):
    (Ci, Yi, Xi) = input

    ifm_bytes = (ifm_bits // 8)
    ofm_bytes = (ofm_bits // 8)
   
    num_elements = Ci * Yi * Xi
    core_mem_cutoff = 58000
    address_factor = 2 # pingpong
    scratch_buffer_bytes = 2
    min_subv = 64
   
    max_Y = ceildiv(num_elements, min_subv)
    subv_list = []
    X_list = [1, aie_rows]
   
    for y in range(max_Y , 0, -1):
        for x in X_list:
            subv = num_elements // (y *x)
            is_valid = ((subv * ifm_bytes) * address_factor + (subv * ofm_bytes) * address_factor + (subv * scratch_buffer_bytes)) <= core_mem_cutoff and (num_elements % (y*x) == 0) and (subv * ifm_bits % 32 == 0)
            if is_valid:
                subv_list.append([y, x, subv])
    sorted_subv_splits = sorted(subv_list)

    return sorted_subv_splits[0]

def run_tiler(
        aie_rows: int, 
        aie_cols: int,
        Ni: int,
        Yi: int, 
        Xi: int,
        Ci: int,
        Cip: int,
        ifm_bits: int,
        ofm_bits: int,
        fixed_point_bits: int,
        CoreqdqPrmSize: int,
        op_type: str = "Quant"
):
    

    ifm_dims =[Cip, Xi, Yi]

    total_Y, total_X, subv_elem = Y_X_subv (
                                        aie_rows,
                                        ifm_dims,
                                        ifm_bits,
                                        ofm_bits
                                    )
    subv_size_input = subv_elem * (ifm_bits // 8)
    subv_size_output = subv_elem * (ofm_bits // 8)
    Y_loop = ceildiv(total_Y, aie_cols)
    if op_type == "Quant":
        qdq_mode = 1                            # 0 - dequant, 1 - quant, 2 - both
    elif op_type == "Dequant":
        qdq_mode = 0                            # 0 - dequant, 1 - quant, 2 - both
    else:
        qdq_mode = 2;                           # 0 - dequant, 1 - quant, 2 - both
    dims = QDQDims (
        Ni=Ni,
        Yi=Yi,
        Xi=Xi,
        Ci=Ci,
        Cip=Cip,
        No=Ni,
        Yo=Yi,
        Xo=Xi,
        Co=Ci,
        Cop=Cip,
        ifm_bits = ifm_bits,
        ofm_bits=ofm_bits,
        fixed_point_bits=fixed_point_bits,
        total_Y=total_Y,
        total_X=total_X,
        CoreqdqPrmSize=CoreqdqPrmSize,
        param_subv_size=config.MAX_CORE_LAYER_PARAM_SIZE,
        wgt_subv_size=CoreqdqPrmSize,
        subv_elem=subv_elem,
        subv_size_input=subv_size_input,
        subv_size_output=subv_size_output,
        Yis=1,
        Yos=1,
        Y_loop=Y_loop,
        op_type=op_type,
        qdq_mode=qdq_mode,
        aie_rows=aie_rows,
        aie_cols=aie_cols,
    )


    tiling_json = {
        'h_in': Yi,
        'w_in': Xi,
        'c_in': Ci,
        'ifm_bits': ifm_bits,
        'host_layer_padding': {
            "act_bits": ifm_bits,
            "ifm": {
                "dims": [Yi, Xi, Ci],
                "value": [0, 0, 0, 0]
            },
            "ofm": {
                "dims": [Yi, Xi, Cip],
                "value": [0, 0, 0, 0]
            }
        }
    }


    tiling_json_filename = os.path.join(CURRDIR, 'tiling.json')
    with open(tiling_json_filename, 'w', encoding="utf-8") as f:
        f.write(json.dumps(tiling_json, sort_keys=True, indent=4))
    return dims

    