import os
import sys
import json
import shutil
import logging
from typing import List, Optional
import numpy as np
import copy

from bilinear_resize_common import (
    bilinear_resize_input_subvol_dims,
    bilinear_resize_wgt_subvol_dims,
    CoordinateTransfromationMode,
    BilinearResizeShape,
)

from dataflow_common import ceildiv, iceil, overlay_stack_addr


bilinear_resize_split_modes = {
    # NOTE: {Yo_split, Xo_split, C_split}
    'Y1X1C32': (1, 1, 32),
    'Y4X8C1': (4, 8, 1),
    'Y8X4C1': (8, 4, 1), 
}


def map_bilinear_resize(
    ShapeDims: BilinearResizeShape,
    verbose: bool = False,
) -> None:
    """
    Function to map the bilinear resize operation.
    Give the input and output tensor dimensions for a 
    given spatial split find the optimal core subvolumes
    and make sure the mapping is feasible.
    """
    ifm_L1_size = 0
    ofm_L1_size = 0
    mem_align = 64  # Memory alignment for L1 cache
    if ((ShapeDims.Yo < (ShapeDims.Y_gran * ShapeDims.Y_split)) or
        (ShapeDims.Xo < (ShapeDims.X_gran * ShapeDims.X_split)) or
        (ShapeDims.Co < (ShapeDims.C_gran * ShapeDims.C_split))):
        print(
            f"Skipping mapping for spatial split: "
            f"{ShapeDims.Y_split, ShapeDims.X_split, ShapeDims.C_split} as it exceeds limits."
        ) if verbose else None
        ShapeDims.feasible = False
        return
    Yos = ceildiv(ShapeDims.Yo, ShapeDims.Y_split)
    Xos = ceildiv(ShapeDims.Xo, ShapeDims.X_split)
    Cos = iceil(ceildiv(ShapeDims.Co, ShapeDims.C_split), ShapeDims.C_gran)
    Cis = Cos
    wgt_subvol_dims = bilinear_resize_wgt_subvol_dims(output_subvol_dims=(Yos, Xos))
    Y_temporal_iters = 1
    X_temporal_iters = 1
    C_temporal_iters = 1
    ((Yis, Xis),
    (Yis_step, Xis_step),
    (Yis_offset, Xis_offset)) = bilinear_resize_input_subvol_dims(
        ShapeDims.mode,
        (ShapeDims.Yo, ShapeDims.Xo),
        (ShapeDims.Yi, ShapeDims.Xi),
        (Yos, Xos),
        split=(ceildiv(ShapeDims.Yo, Yos), ceildiv(ShapeDims.Xo, Xos)),
        verbose=verbose,
    )
    ShapeDims.Y_loop = Y_temporal_iters
    ShapeDims.X_loop = X_temporal_iters
    ShapeDims.C_loop = C_temporal_iters
    ShapeDims.Yos = Yos
    ShapeDims.Xos = Xos
    ShapeDims.Cos = Cos
    ShapeDims.Cis = Cis
    ShapeDims.Yis = Yis
    ShapeDims.Xis = Xis
    ShapeDims.Yis_step = Yis_step
    ShapeDims.Xis_step = Xis_step
    ShapeDims.Yis_offset = Yis_offset
    ShapeDims.Xis_offset = Xis_offset
    ShapeDims.wgt_subvol_dims = wgt_subvol_dims
    ifm_l1_size = iceil((Yis * Xis * Cis * ShapeDims.act_bits) // 8, mem_align)
    ofm_l1_size = iceil((Yos * Xos * Cos * ShapeDims.act_bits) // 8, mem_align)
    wgt_l1_size = iceil((wgt_subvol_dims * ShapeDims.act_bits) // 8, mem_align)
    qdq_l1_size = ShapeDims.qdq_bytes
    available_l1_size = overlay_stack_addr()
    occupied_l1_size = ((ifm_l1_size * ShapeDims.ifm_l1_no_buff)
                        + (ofm_l1_size * ShapeDims.ofm_l1_no_buff)
                        + (wgt_l1_size * ShapeDims.ifm_l1_no_buff)
                        + qdq_l1_size)
    while occupied_l1_size > available_l1_size:
        # NOTE: first choice is to split channel dimension
        if (Cos > ShapeDims.C_gran) and (Cos % ShapeDims.C_gran== 0):
            Cos = iceil(Cos // 2, ShapeDims.C_gran)
            C_temporal_iters = ceildiv(ShapeDims.Co, (Cos * ShapeDims.C_split)) 
        # NOTE Second choice is to split Y dimension
        elif (Yos > ShapeDims.Y_gran) and (Yos % ShapeDims.Y_gran== 0):
            Yos = Yos // 2
            Y_temporal_iters = ceildiv(ShapeDims.Yo, (Yos * ShapeDims.Y_split)) 
        # NOTE: Third choice is to split X dimension
        elif (Xos > ShapeDims.X_gran) and (Xos % ShapeDims.X_gran== 0):
            Xos = Xos // 2
            X_temporal_iters = ceildiv(ShapeDims.Xo, (Xos * ShapeDims.X_split)) 
        # NOTE: Fourth choice is to disable ping-pong
        else:
            if ShapeDims.ifm_l1_no_buff == 1 and ShapeDims.ofm_l1_no_buff == 1:
                print(
                    f"Unable to split further. "
                    f"Occupied L1 size: {occupied_l1_size}, "
                    f"Available L1 size: {available_l1_size}"
                ) if verbose else None
                ShapeDims.feasible = False
                return
            if ShapeDims.ifm_l1_no_buff > 1:
                ShapeDims.ifm_l1_no_buff = 1 
            elif ShapeDims.ofm_l1_no_buff > 1:
                ShapeDims.ofm_l1_no_buff = 1 
        # Recalculate the sizes
        Cis = Cos
        wgt_subvol_dims = bilinear_resize_wgt_subvol_dims(output_subvol_dims=(Yos, Xos))
        ((Yis, Xis),
        (Yis_step, Xis_step),
        (Yis_offset, Xis_offset)) = bilinear_resize_input_subvol_dims(
            ShapeDims.mode,
            (ShapeDims.Yo, ShapeDims.Xo),
            (ShapeDims.Yi, ShapeDims.Xi),
            (Yos, Xos),
            split=(ceildiv(ShapeDims.Yo, Yos), ceildiv(ShapeDims.Xo, Xos)),
            # verbose=verbose,
        )
        ifm_l1_size = iceil((Yis * Xis * Cis * ShapeDims.act_bits) // 8, mem_align)
        ofm_l1_size = iceil((Yos * Xos * Cos * ShapeDims.act_bits) // 8, mem_align)
        wgt_l1_size = iceil((wgt_subvol_dims * ShapeDims.act_bits) // 8, mem_align)
        occupied_l1_size = ((ifm_l1_size * ShapeDims.ifm_l1_no_buff)
                    + (ofm_l1_size * ShapeDims.ofm_l1_no_buff)
                    + (wgt_l1_size * ShapeDims.ifm_l1_no_buff)
                    + qdq_l1_size)
        ShapeDims.Y_loop = Y_temporal_iters
        ShapeDims.X_loop = X_temporal_iters
        ShapeDims.C_loop = C_temporal_iters
        ShapeDims.Yos = Yos
        ShapeDims.Xos = Xos
        ShapeDims.Cos = Cos
        ShapeDims.Cis = Cis
        ShapeDims.Yis = Yis
        ShapeDims.Xis = Xis
        ShapeDims.wgt_subvol_dims = wgt_subvol_dims
        ShapeDims.Yis_step = Yis_step
        ShapeDims.Xis_step = Xis_step
        ShapeDims.Yis_offset = Yis_offset
        ShapeDims.Xis_offset = Xis_offset
    print(f"Yos: {Yos}, Xos: {Xos}, Cos: {Cos}") if verbose else None
    print(f"Yis: {Yis}, Xis: {Xis}, Cis: {Cis}") if verbose else None
    print(f"Yis_step: {Yis_step}, Xis_step: {Xis_step}") if verbose else None
    print(f"Yis_offset: {Yis_offset}, Xis_offset: {Xis_offset}") if verbose else None
    print(f"wgt_subvol_dims: {wgt_subvol_dims}") if verbose else None
    print(f"Y_temporal_iters: {Y_temporal_iters}") if verbose else None
    print(f"X_temporal_iters: {X_temporal_iters}") if verbose else None
    print(f"C_temporal_iters: {C_temporal_iters}") if verbose else None


def tiling_ranker(mapped_solns: dict, verbose: False) -> dict:
    '''
    Filter out the non feasible solutions and rank the feasible solutions
    first pass of cost function:
        Calculate the over compute for each solution and sort the solutions least to most
    second pass of cost function:
        sort the solutions based on the number of iterations (Y_loop, X_loop, C_loop)
    '''
    feasible_solns = {key: value for key, value in mapped_solns.items() if value.feasible}
    overcompute_list = []
    for key in feasible_solns.keys():
        feasible_solns[key].Y_overcompute = (feasible_solns[key].Yos
                                           * feasible_solns[key].Y_split
                                           * feasible_solns[key].Y_loop)
        feasible_solns[key].X_overcompute = (feasible_solns[key].Xos
                                           * feasible_solns[key].X_split
                                           * feasible_solns[key].X_loop)
        feasible_solns[key].C_overcompute = (feasible_solns[key].Cos
                                           * feasible_solns[key].C_split
                                           * feasible_solns[key].C_loop)
        overcompute_macs = (feasible_solns[key].Y_overcompute
                            * feasible_solns[key].X_overcompute
                            * feasible_solns[key].C_overcompute)
        overcompute_list.append((key, overcompute_macs))
        print(f"key {key} overcompute: {overcompute_macs}") if verbose else None
    overcompute_list.sort(key=lambda x: x[1])
    # Create a new ranked dictionary based on the sorted overcompute_list
    overcompute_ranked_solns = {item[0]: feasible_solns[item[0]] for item in overcompute_list}
    ranked_solns = dict(
        sorted(
            overcompute_ranked_solns.items(),
            key=lambda item: item[1].Y_loop * item[1].X_loop * item[1].C_loop
        )
    )
    return ranked_solns


def bilinear_resize_tiler(
    ShapeDims: BilinearResizeShape,
    verbose: bool = False,
) -> list:
    """
    Function to perform tiling for bilinear resize operation.
    """
    tiler_solution = {}
    for key in bilinear_resize_split_modes.keys():
        print(f"Mapping for {key}") if verbose else None
        dims = copy.deepcopy(ShapeDims)
        dims.Y_split, dims.X_split, dims.C_split  = bilinear_resize_split_modes[key]
        tiler_solution[key] = dims
        map_bilinear_resize(tiler_solution[key], verbose=verbose)
    ranked_soln = tiling_ranker(tiler_solution, verbose=verbose)
    final_soln = []
    print("Ranked solutions: ") if verbose else None
    for key, value in ranked_soln.items():
        final_soln.append(value)
        print(f"{key} {value}") if verbose else None
    return final_soln 