'''
This module defines the top-level interface to run all phases of compilation.
External facing functions are documented below.

run_layer_compilation - runs all phases of compilation and saves the final
artifacts to output files. The files created are dma.hpp, graph.hpp,
super.cc, super.hh, which contain the runtime control code, ADF connectivity
graph, and super kernel, respectively.
'''


from typing import List, Dict, Optional, Union, Type
import os

from .types import (
    OverlayShape, CoreInstr, AieTile, DataTransfer,
    DmaConnection, CoreConnection,
    BackEnd, CascDir, DmaPaddingMap,
)
from . import config

from .setup import disable_fast_pm, set_fast_pm, set_multi_uc
from .simulator import run_overlay_deadlock_check
from .optimizer import compile_data_transfers
from .control import generate_layer_control
from .codegen import (
    generate_runtime_control,
    generate_overlay_graph,
    generate_super_kernel,
)
from .print_run_layer_compilation import print_run_layer_compilation_inputs


def run_layer_compilation(
    overlay_shape: OverlayShape,
    kernel_names: Union[List[str], Dict[str, int]],
    kernel_includes: List[str],
    core_instrs: Union[List[Type[CoreInstr]], Dict[AieTile, List[Type[CoreInstr]]]],
    memtile_transfers: List[DataTransfer],
    shim_transfers: List[DataTransfer],
    dma_connections: List[DmaConnection],
    back_end: BackEnd = BackEnd.Adf,
    core_stack_addr: int = 57344,
    param_channel_id: int = 1,
    layer_name: str = 'run_dma_layer_config',
    layer_file: str = 'dma.hpp',
    casc_dir: Optional[CascDir] = None,
    core_connections: List[CoreConnection] = [],
    enable_debug_print: bool = False,
    enable_task_queue_optimization: bool = True,
    dma_padding_map: DmaPaddingMap = DmaPaddingMap(),
):
    config.check_init()
    if config.ENABLE_FAST_PM:
        set_fast_pm()
    else:        
        disable_fast_pm()
    
    set_multi_uc(back_end)
        
    if enable_debug_print:
        print_run_layer_compilation_inputs(
            overlay_shape,
            kernel_names,
            kernel_includes,
            core_instrs,
            memtile_transfers,
            shim_transfers,
            dma_connections,
            back_end,
            core_stack_addr,
            param_channel_id,
            layer_name,
            layer_file,
            casc_dir,
            core_connections,
            )
    
    run_overlay_deadlock_check(
        overlay_shape,
        core_instrs,
        memtile_transfers,
        shim_transfers,
        dma_connections,
        param_channel_id=param_channel_id,
    )

    try:
        data_buffers = compile_data_transfers(
            overlay_shape,
            memtile_transfers,
            shim_transfers,
            enable_task_queue_optimization,
        )
    except Exception as e1:
        print("\033[91m[WARNING] compile_data_transfers failed with task queue optimization enabled.\033[0m")
        print("\033[91m[WARNING] Error: {}\033[0m".format(str(e1)))
        try:
            data_buffers = compile_data_transfers(
                overlay_shape,
                memtile_transfers,
                shim_transfers,
                enable_task_queue_optimization=False,
            )
        except Exception as e2:
            raise

    layer_control = generate_layer_control(
        overlay_shape,
        data_buffers,
        dma_connections,
        enable_task_queue_optimization,
    )
    dma_hpp = generate_runtime_control(
        overlay_shape,
        dma_connections,
        layer_control,
        data_buffers,
        core_instrs,
        kernel_names,
        back_end,
        layer_name,
        dma_padding_map,
    )
    graph_hpp = generate_overlay_graph(
        overlay_shape,
        dma_connections,
        core_stack_addr=core_stack_addr,
        casc_dir=casc_dir,
        core_connections=core_connections,
    )
    super_hh, super_cc = generate_super_kernel(
        kernel_includes,
        kernel_names,
        param_channel_id=param_channel_id
    )
    currdir = os.getcwd()
    dma_hpp_filename = os.path.abspath(os.path.join(currdir, layer_file))
    graph_hpp_filename = os.path.abspath(os.path.join(currdir, 'graph.hpp'))
    super_hh_filename = os.path.abspath(os.path.join(currdir, 'super.hh'))
    super_cc_filename = os.path.abspath(os.path.join(currdir, 'super.cc'))
    with open(dma_hpp_filename, 'w') as f:
        f.write(dma_hpp)
    with open(graph_hpp_filename, 'w') as f:
        f.write(graph_hpp)
    with open(super_hh_filename, 'w') as f:
        f.write(super_hh)
    with open(super_cc_filename, 'w') as f:
        f.write(super_cc)
