import json
import os
import sys
from resize_common import ResizeDims
from dataflow_common import ceildiv
CURRDIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append(CURRDIR)
sys.path.append(os.path.join(CURRDIR, '..'))
sys.path.append(os.path.join(CURRDIR, '..', '..', 'dmacompiler'))
sys.path.append(os.path.join(CURRDIR, '..', '..', 'OGOAT', 'src', 'L1_fusion'))



def memtile_subv_shape (
        Ni: int,
        Yi: int,
        Xi: int,
        Ci: int,
):
    Nis = Ni 
    Yis = 1
    Xis = Xi
    Cis = Ci

    return Nis, Yis, Xis, Cis


def temporal_loop_per_dim(
        Ni: int,
        Yi: int,
        Xi: int,
        Ci: int,
        Nis: int,
        Yis: int,
        Xis: int,
        Cis: int,
        spatial_split: dict[str, int]
):
    N_loop = ceildiv(Ni, Nis * spatial_split['N'])
    Y_loop = ceildiv(Yi, Yis * spatial_split['Y'])
    X_loop = ceildiv(Xi, Xis * spatial_split['X'])
    C_loop = ceildiv(Ci, Cis * spatial_split['C'])

    return N_loop, Y_loop, X_loop, C_loop





def run_tiler(
        aie_rows: int,
        aie_cols: int,
        Ni: int,
        Yi: int,
        Xi: int,
        Ci: int,
        num_interpolations: int,
        ifm_bits: int,
        int_16: int,
        bfloat_16: int,
):
    

    spatial_split = {
        'N' : 1,
        'Y' : 8,
        'X' : 1,
        'C' : 1
    }

    Nis, Yis, Xis, Cis = memtile_subv_shape(Ni, Yi, Xi, Ci)
    N_loop, Y_loop, X_loop, C_loop = temporal_loop_per_dim(Ni, Yi, Xi, Ci, Nis, Yis, Xis, Cis, spatial_split)

    dims = ResizeDims(
            aie_rows,
            aie_cols,
            Ni,
            Yi,
            Xi,
            Ci,
            Nis,
            Yis,
            Xis,
            Cis,
            N_loop,
            Y_loop,
            X_loop,
            C_loop,
            num_interpolations,
            ifm_bits,
            int_16,
            bfloat_16
        )
    
    tiling_json = {
        'h_in': Yi,
        'w_in': Xi,
        'c_in': Ci,
        'h_out': Yi * num_interpolations,
        'w_out': Xi * num_interpolations,
        'c_out': Ci,
        'num_interpolations': num_interpolations,
        'ifm_bits': ifm_bits,
        'int_16_mode': int_16,
        'bfloat_16_mode': bfloat_16,
    }
    tiling_json_filename = os.path.join(CURRDIR, 'tiling.json')
    with open(tiling_json_filename, 'w') as f:
        f.write(json.dumps(tiling_json, sort_keys=True, indent=4))

    return dims