'''
This module contains the setup function to initialize all constants
in the config module based on the device generation.
External facing functions are documented below.

set_dev_gen - initializes all constants in the config module based
on the device generation. Any un-used config parameters will be
set to None.
'''

from .types import DevGen, DmaDir, TileType, BackEnd
from . import config


def clear_all_configs():
    config.DEV_GEN = None
    config.NUM_AIE_ROWS = None
    config.NUM_AIE_COLS = None
    config.ENABLE_BUSY_POLL = None
    config.MIN_STEP_SIZE = None
    config.MAX_LOCK_VALUE = None
    config.MAX_TASK_QUEUE_SIZE = None
    config.MAX_NEIGHBOR_ACCESS = None
    config.MAX_REPEAT_COUNT = None
    config.MAX_ITER_WRAP = None
    config.ENABLE_MULTI_UC = None
    config.MAX_CORE_ADDR = None
    config.MAX_CORE_BUFFER_LENGTH = None
    config.MIN_CORE_BUFFER_ALIGNMENT = None
    config.MAX_CORE_STEP = None
    config.MAX_CORE_WRAP = None
    config.MAX_CORE_DIMS = None
    config.MAX_CORE_S2MM_DMA_CHANNEL = None
    config.MAX_CORE_MM2S_DMA_CHANNEL = None
    config.MAX_CORE_LOCK_ID = None
    config.MAX_CORE_LAYER_PARAM_SIZE = None
    config.MAX_CORE_NUM_KERNELS = None
    config.MEMTILE_BASE_ADDR = None
    config.MAX_MEMTILE_ADDR = None
    config.MAX_MEMTILE_BUFFER_LENGTH = None
    config.MAX_MEMTILE_STEP = None
    config.MAX_MEMTILE_WRAP = None
    config.MAX_MEMTILE_DIMS = None
    config.MAX_MEMTILE_S2MM_DMA_CHANNEL = None
    config.MAX_MEMTILE_MM2S_DMA_CHANNEL = None
    config.MAX_MEMTILE_S2MM_NEIGHBOR_CHANNEL = None
    config.MAX_MEMTILE_MM2S_NEIGHBOR_CHANNEL = None
    config.MAX_MEMTILE_LOCK_ID = None
    config.MAX_MEMTILE_PAD_DIMS = None
    config.MAX_MEMTILE_D0_PAD = None
    config.MAX_MEMTILE_D1_PAD = None
    config.MAX_MEMTILE_D2_PAD = None
    config.MEMTILE_ADDR_GRAN = None
    config.MAX_SHIM_ADDR = None
    config.MAX_SHIM_BUFFER_LENGTH = None
    config.MAX_SHIM_STEP = None
    config.MAX_SHIM_WRAP = None
    config.MAX_SHIM_DIMS = None
    config.MAX_SHIM_S2MM_DMA_CHANNEL = None
    config.MAX_SHIM_MM2S_DMA_CHANNEL = None
    config.MAX_SHIM_LOCK_ID = None
    config.SHIM_PARAM_BUFFER_IDX = None
    config.SHIM_CTRL_BUFFER_IDX = None
    config.BD_CONFIG_CTRL_PKT_WORDS = None
    config.LOCK_CONFIG_CTRL_PKT_WORDS = None
    config.TASK_ENQUEUE_CTRL_PKT_WORDS = None
    config.MAX_DDR_BURST_LENGTH = None
    config.IS_MULTI_UC = None
    config.MAX_UC = None
    config.NUM_UC_USED = None
    config.MAX_REMOTE_BARRIER = None

    #
    # AIE-2p Specific Configs
    #

    config.MAX_MEMTILE_BD_LO_ID = None
    config.MAX_MEMTILE_BD_HI_ID = None
    config.MAX_SHIM_BD_ID = None
    config.SHIM_CTRL_MM2S_CHANNEL_ID = None
    config.DATA_TRANSFER_PKT_SPLIT_IDX = None
    config.SHIM_CTRL_PKT_SPLIT_IDX = None
    config.MEMTILE_CTRL_PKT_SPLIT_IDX = None
    config.CORE_CTRL_PKT_SPLIT_IDX = None
    config.NUM_CTRL_PKT_SPLIT = None
    config.SHIM_CTRL_PKT_BD_ID = None

    #
    # AIE-4 Specific Configs
    #

    config.MAX_MEMTILE_PRIVATE_BD_ID = None
    config.MAX_SHIM_LO_S2MM_DMA_CHANNEL = None
    config.MAX_SHIM_LO_MM2S_DMA_CHANNEL = None
    config.MAX_SHIM_LO_BD_ID = None
    config.MAX_SHIM_HI_BD_ID = None
    
    # This map tracks how many channels are available in a specific direction for each tile.
    config.NUM_CHANNEL_LUT = {
        TileType.Memtile: {
            DmaDir.S2MM: None,
            DmaDir.MM2S: None
            },
        TileType.Shim: {
            DmaDir.S2MM: None,
            DmaDir.MM2S: None
            }
        }


def set_dev_gen_aie2p():
    config.DEV_GEN = DevGen.Aie2p
    config.NUM_AIE_ROWS = 4
    config.NUM_AIE_COLS = 8
    config.ENABLE_BUSY_POLL = True
    config.MIN_STEP_SIZE = 1
    config.MAX_LOCK_VALUE = 63
    config.MAX_TASK_QUEUE_SIZE = 4
    config.MAX_NEIGHBOR_ACCESS = 1
    config.MAX_REPEAT_COUNT = 2**8
    config.MAX_ITER_WRAP = 2**6
    config.ENABLE_MULTI_UC = False
    config.MAX_CORE_ADDR = 2**16 - 1
    config.MAX_CORE_BUFFER_LENGTH = 2**14 - 1
    config.MIN_CORE_BUFFER_ALIGNMENT = 64
    config.MAX_CORE_STEP = 2**13
    config.MAX_CORE_WRAP = 2**8 - 1
    config.MAX_CORE_DIMS = 3
    config.MAX_CORE_S2MM_DMA_CHANNEL = 1
    config.MAX_CORE_MM2S_DMA_CHANNEL = 1
    config.MAX_CORE_LOCK_ID = 15
    config.MAX_CORE_LAYER_PARAM_SIZE = 1024
    config.MAX_CORE_NUM_KERNELS = 64
    config.MEMTILE_BASE_ADDR = 0
    config.MAX_MEMTILE_ADDR = 2**19 - 1
    config.MAX_MEMTILE_BUFFER_LENGTH = 2**17 - 1
    config.MAX_MEMTILE_STEP = 2**17
    config.MAX_MEMTILE_WRAP = 2**10 - 1
    config.MAX_MEMTILE_DIMS = 4
    config.MAX_MEMTILE_S2MM_DMA_CHANNEL = 5
    config.MAX_MEMTILE_MM2S_DMA_CHANNEL = 5
    config.MAX_MEMTILE_S2MM_NEIGHBOR_CHANNEL = 3
    config.MAX_MEMTILE_MM2S_NEIGHBOR_CHANNEL = 3
    config.MAX_MEMTILE_LOCK_ID = 63
    config.MAX_MEMTILE_PAD_DIMS = 3
    config.MAX_MEMTILE_D0_PAD = 2**6 - 1
    config.MAX_MEMTILE_D1_PAD = 2**5 - 1
    config.MAX_MEMTILE_D2_PAD = 2**4 - 1
    config.MEMTILE_ADDR_GRAN = 2**5
    config.MAX_SHIM_ADDR = 4
    config.MAX_SHIM_BUFFER_LENGTH = 2**32 - 1
    config.MAX_SHIM_STEP = 2**20
    config.MAX_SHIM_WRAP = 2**10 - 1
    config.MAX_SHIM_DIMS = 3
    config.MAX_SHIM_S2MM_DMA_CHANNEL = 1
    config.MAX_SHIM_MM2S_DMA_CHANNEL = 1
    config.MAX_SHIM_LOCK_ID = 15
    config.SHIM_PARAM_BUFFER_IDX = 3
    config.SHIM_CTRL_BUFFER_IDX = 4
    config.BD_CONFIG_CTRL_PKT_WORDS = 12
    config.LOCK_CONFIG_CTRL_PKT_WORDS = 3
    config.TASK_ENQUEUE_CTRL_PKT_WORDS = 3
    config.MAX_DDR_BURST_LENGTH = 32
    config.IS_MULTI_UC = False
    config.MAX_UC = 1
    config.NUM_UC_USED = 1
    config.MAX_REMOTE_BARRIER = 0

    #
    # AIE-2p Specific Configs
    #

    config.MAX_MEMTILE_BD_LO_ID = 23
    config.MAX_MEMTILE_BD_HI_ID = 47
    config.MAX_SHIM_BD_ID = 15
    config.SHIM_CTRL_MM2S_CHANNEL_ID = 0
    config.DATA_TRANSFER_PKT_SPLIT_IDX = 0
    config.SHIM_CTRL_PKT_SPLIT_IDX = 1
    config.MEMTILE_CTRL_PKT_SPLIT_IDX = 2
    config.CORE_CTRL_PKT_SPLIT_IDX = 3
    config.NUM_CTRL_PKT_SPLIT = 4
    config.SHIM_CTRL_PKT_BD_ID = 0

    #FAST PM 
    #ENABLE FAST PM by default
    config.ENABLE_FAST_PM = True
    
    # This map tracks how many channels are available in a specific direction for each tile.
    config.NUM_CHANNEL_LUT = {
        TileType.Memtile: {
            DmaDir.S2MM: (config.MAX_MEMTILE_S2MM_DMA_CHANNEL + 1),
            DmaDir.MM2S: (config.MAX_MEMTILE_MM2S_DMA_CHANNEL + 1)
            },
        TileType.Shim: {
            DmaDir.S2MM: (config.MAX_SHIM_S2MM_DMA_CHANNEL + 1),
            DmaDir.MM2S: (config.MAX_SHIM_MM2S_DMA_CHANNEL + 1)
            }
        }



def set_dev_gen_aie4():
    config.DEV_GEN = DevGen.Aie4
    config.NUM_AIE_ROWS = 4
    config.NUM_AIE_COLS = 3
    config.ENABLE_BUSY_POLL = False
    config.MIN_STEP_SIZE = 0
    config.MAX_LOCK_VALUE = 63
    config.MAX_TASK_QUEUE_SIZE = 16
    config.MAX_NEIGHBOR_ACCESS = 7
    config.MAX_REPEAT_COUNT = 2**12
    config.MAX_ITER_WRAP = 2**6
    config.ENABLE_MULTI_UC = True
    config.MAX_CORE_ADDR = 2**17 - 1
    config.MAX_CORE_BUFFER_LENGTH = 2**15 - 1
    config.MIN_CORE_BUFFER_ALIGNMENT = 128  # TODO: confirm with kernel
    config.MAX_CORE_STEP = 2**15 - 1
    config.MAX_CORE_WRAP = 2**9 - 1
    config.MAX_CORE_DIMS = 3
    config.MAX_CORE_S2MM_DMA_CHANNEL = 1
    config.MAX_CORE_MM2S_DMA_CHANNEL = 0
    config.MAX_CORE_LOCK_ID = 15
    config.MAX_CORE_LAYER_PARAM_SIZE = 1024
    config.MAX_CORE_NUM_KERNELS = 64
    config.MEMTILE_BASE_ADDR = 0x0B0_0000
    config.MAX_MEMTILE_ADDR = 3 * 2**20 - 1
    config.MAX_MEMTILE_BUFFER_LENGTH = 2**23 - 1
    config.MAX_MEMTILE_STEP = 2**23 - 1
    config.MAX_MEMTILE_WRAP = 2**12 - 1
    config.MAX_MEMTILE_DIMS = 4
    config.MAX_MEMTILE_S2MM_DMA_CHANNEL = 7
    config.MAX_MEMTILE_MM2S_DMA_CHANNEL = 9
    config.MAX_MEMTILE_S2MM_NEIGHBOR_CHANNEL = 7
    config.MAX_MEMTILE_MM2S_NEIGHBOR_CHANNEL = 9
    config.MAX_MEMTILE_LOCK_ID = 63
    config.MAX_MEMTILE_PAD_DIMS = 3
    config.MAX_MEMTILE_D0_PAD = 2**8 - 1
    config.MAX_MEMTILE_D1_PAD = 2**8 - 1
    config.MAX_MEMTILE_D2_PAD = 2**8 - 1
    config.MEMTILE_ADDR_GRAN = 2**5
    config.MAX_SHIM_ADDR = 4
    config.MAX_SHIM_BUFFER_LENGTH = 2**32 - 1
    config.MAX_SHIM_STEP = 2**22 - 1
    config.MAX_SHIM_WRAP = 2**12 - 1
    config.MAX_SHIM_DIMS = 4
    config.MAX_SHIM_S2MM_DMA_CHANNEL = 1
    config.MAX_SHIM_MM2S_DMA_CHANNEL = 3
    config.MAX_SHIM_LOCK_ID = 31
    config.SHIM_PARAM_BUFFER_IDX = 3
    config.SHIM_CTRL_BUFFER_IDX = 4
    config.BD_CONFIG_CTRL_PKT_WORDS = 12    # TODO: confirm new format
    config.LOCK_CONFIG_CTRL_PKT_WORDS = 3   # TODO: confirm new format
    config.TASK_ENQUEUE_CTRL_PKT_WORDS = 3  # TODO: confirm new format
    config.MAX_DDR_BURST_LENGTH = 8
    config.IS_MULTI_UC = False
    config.MAX_UC = 6
    config.NUM_UC_USED = 3
    config.MAX_REMOTE_BARRIER = 7

    #
    # AIE-4 Specific Configs
    #

    config.MAX_MEMTILE_PRIVATE_BD_ID = 15
    config.MAX_SHIM_LO_S2MM_DMA_CHANNEL = 0
    config.MAX_SHIM_LO_MM2S_DMA_CHANNEL = 1
    config.MAX_SHIM_LO_BD_ID = 15
    config.MAX_SHIM_HI_BD_ID = 31

    #FAST PM should be always false for AIE - 4
    config.ENABLE_FAST_PM = False
    
    # This map tracks how many channels are available in a specific direction for each tile.
    config.NUM_CHANNEL_LUT = {
        TileType.Memtile: {
            DmaDir.S2MM: (config.MAX_MEMTILE_S2MM_DMA_CHANNEL + 1),
            DmaDir.MM2S: (config.MAX_MEMTILE_MM2S_DMA_CHANNEL + 1)
            },
        TileType.Shim: {
            DmaDir.S2MM: (config.MAX_SHIM_S2MM_DMA_CHANNEL + 1),
            DmaDir.MM2S: (config.MAX_SHIM_MM2S_DMA_CHANNEL + 1)
            }
        }

def set_dev_gen(dev_gen: DevGen):
    clear_all_configs()
    if dev_gen == DevGen.Aie2p:
        set_dev_gen_aie2p()
    elif dev_gen == DevGen.Aie4:
        set_dev_gen_aie4()
    else:
        assert False

def set_fast_pm():
    config.check_init()
    if config.DEV_GEN == DevGen.Aie2p:
        config.CORE_CTRL_PKT_SPLIT_IDX = [3, 4, 5, 6]
        config.NUM_CTRL_PKT_SPLIT = 7
        config.DATA_TRANSFER_PKT_ID = 5
        config.SHIM_CTRL_PKT_IDX = 4
        config.MEMTILE_CTRL_PKT_IDX = 6
    else:
        raise RuntimeError(f"Fast PM not supported for device generation - {config.DEV_GEN}")

def set_multi_uc(back_end: BackEnd):
    config.check_init()
    if config.DEV_GEN == DevGen.Aie4 and config.ENABLE_MULTI_UC and back_end == BackEnd.CertAsm:
        config.IS_MULTI_UC = True
    else:
        config.IS_MULTI_UC = False

def disable_fast_pm():
    config.check_init()
    if config.DEV_GEN == DevGen.Aie2p:
        config.CORE_CTRL_PKT_SPLIT_IDX = 3
        config.NUM_CTRL_PKT_SPLIT = 4
        config.DATA_TRANSFER_PKT_ID = 0
        config.SHIM_CTRL_PKT_IDX = 0
        config.MEMTILE_CTRL_PKT_IDX = 0


