from OGOAT.src.Scheduling_Engine.schedules.BufferAllocatorResult import (
    BufferAllocations,
    BufferAllocation,
)
from OGOAT.src.Scheduling_Engine.infra.const import CORE_BANK_SIZE, MEMTILE_SIZE

import sys

"""
Usage: python OGOAT/misc_tools/visualize_buffer_allocation.py <path/to/layer/dir>
"""

WIDTH = 50


def get_flatten_alloc(
    allocs: dict[str, BufferAllocation]
) -> dict[str, tuple[int, int]]:
    """
    Return type:
    key: name of the buffer allocated
    value: (address of the allocation, size of the allocation)
    """
    flatten_alloc = dict()

    for alloc_name in allocs:
        alloc = allocs[alloc_name]

        is_ping_pong = alloc.has_pong()

        i = 0
        for address in alloc.ping.addresses:
            name = alloc_name
            if is_ping_pong:
                name = name + " [ping]"
            if len(alloc.ping.addresses) > 1:
                name = name + f" [{str(i)}]"

            flatten_alloc[name] = (address, alloc.ping.size)
            i += 1

        if not is_ping_pong:
            continue

        i = 0
        for address in alloc.pong.addresses:
            name = alloc_name + " [pong]"
            if len(alloc.pong.addresses) > 1:
                name = name + f" [{str(i)}]"

            flatten_alloc[name] = (address, alloc.pong.size)
            i += 1

    return flatten_alloc


def check_memory_allocation_overlap(allocs: dict[str, tuple[int, int]]):
    """
    Check for overlap between the memory allocations and print a warning when found.
    pre-condition: allocations are sorted in increasing order by their start addresse
    """
    curr_addr = 0
    prev_buff_name = None

    for buff_name in allocs:
        buff_start_addr, buff_size = allocs[buff_name]

        if buff_start_addr < curr_addr:
            assert prev_buff_name is not None
            print(
                f"WARNING: overlap found between buffer '{prev_buff_name}' which ends at 0x{curr_addr:x} and {buff_name} which starts at 0x{buff_start_addr:x}"
            )

        prev_buff_name = buff_name
        curr_addr = buff_start_addr + buff_size


def print_memory_chunk_separation(address: int) -> None:
    global WIDTH
    horizontal_delimiter = "-" * WIDTH
    print(f"{horizontal_delimiter}: {address} = 0x{address:x}")


def print_memory_chunk(name: str, size: int) -> None:
    global WIDTH
    content = f"{name}: size {size} = 0x{size:x}"
    print(f"|{' ' * (WIDTH - 2)}|")
    print(
        f"|{' ' * ((WIDTH - len(content) - 1) // 2 )}{content}{' ' * ((WIDTH - len(content) - 2) // 2 )}|"
    )
    print(f"|{' ' * (WIDTH - 2)}|")


def print_bank_memory_layout(
    allocations: dict[str, tuple[int, int]],
    allocations_name: list[str],
    bank_id,
    bank_size: int,
) -> None:
    """
    Print the memory layout of one specific bank
    """
    print(f"### BANK Id {bank_id} ###")

    bank_start_addr = bank_id * bank_size
    bank_end_addr = bank_start_addr + bank_size
    curr_addr = bank_start_addr

    # print start address of the bank
    print_memory_chunk_separation(curr_addr)

    for alloc_name in allocations_name:
        alloc_start_addr, alloc_size = allocations[alloc_name]

        # The allocation started in a previous bank
        # change the start address and allocation size
        if alloc_start_addr < bank_start_addr:
            alloc_size -= bank_start_addr - alloc_start_addr
            alloc_start_addr = bank_start_addr

        # The allocation end in a following bank
        # Change the allocation size
        if alloc_start_addr + alloc_size > bank_end_addr:
            alloc_size = bank_end_addr - alloc_start_addr

        # The allocation does not start at the current address
        # Print the empty memory chunk. Its size can be postive or negative
        if alloc_start_addr != curr_addr:
            print_memory_chunk("empty memory chunk", alloc_start_addr - curr_addr)
            print_memory_chunk_separation(alloc_start_addr)

        print_memory_chunk(alloc_name, alloc_size)
        print_memory_chunk_separation(alloc_start_addr + alloc_size)

        curr_addr = alloc_start_addr + alloc_size

    # empty memory chunk remaining at the end of the bank
    if curr_addr < bank_end_addr:
        print_memory_chunk("empty memory chunk", bank_end_addr - curr_addr)
        print_memory_chunk_separation(bank_end_addr)


def print_memory_layout(
    allocations: dict[str, tuple[int, int]],
    bank_to_alloc: dict[int, list[str]],
    bank_size: int,
) -> None:
    """
    Print the memory layout of each banks
    """
    for bank_id in bank_to_alloc:
        print_bank_memory_layout(
            allocations, bank_to_alloc[bank_id], bank_id, bank_size
        )


def map_allocations_to_bank(
    allocs: dict[str, tuple[int, int]], bank_size: int, bank_nb: int
) -> dict[int, list[str]]:
    """
    Map a list of allocations name to a bank id using the size of the banks and the number of available banks
    """
    bank_to_allocs = dict()
    for bank_id in range(bank_nb):
        bank_start_addr = bank_id * bank_size
        bank_end_addr = bank_start_addr + bank_size

        bank_to_allocs[bank_id] = list()

        for buff_name in allocs:
            buff_start_addr, buff_size = allocs[buff_name]
            buff_end_addr = buff_start_addr + buff_size

            # buffer is not allocated in the curr bank
            if buff_end_addr <= bank_start_addr or buff_start_addr >= bank_end_addr:
                continue

            bank_to_allocs[bank_id].append(buff_name)

    return bank_to_allocs


def map_bank_id_to_allocations(
    allocs: dict[str, tuple[int, int]], bank_size: int, bank_nb: int
) -> dict[str, list[int]]:
    """
    Map a list of bank ids to each allocations using the size of the banks and the number of available banks
    """
    alloc_to_banks = dict()
    for buff_name in allocs:
        alloc_to_banks[buff_name] = list()

        buff_start_addr, buff_size = allocs[buff_name]
        buff_end_addr = buff_start_addr + buff_size

        for bank_id in range(bank_nb):
            bank_start_addr = bank_id * bank_size
            bank_end_addr = bank_start_addr + bank_size

            # buffer is not allocated in the curr bank
            if buff_end_addr <= bank_start_addr or buff_start_addr >= bank_end_addr:
                continue

            alloc_to_banks[buff_name].append(bank_id)

    return alloc_to_banks


def check_bank_overlap(alloc_to_banks: dict[int, list[str]]) -> None:
    """
    Check if one buffer allocation overlap between multiple banks and warn if that's the case
    """
    for allocation in alloc_to_banks:
        banks = alloc_to_banks[allocation]

        if len(banks) <= 1:
            continue

        print(
            f"WARNING: buffer '{allocation}' is overlapping between banks with ids {banks}"
        )


def print_memory_allocation(
    allocations: dict[str, BufferAllocation], bank_size: int, bank_nb: int
) -> None:
    # Flatten the buffer allocations into a dict of tuple
    flatten_alloc = get_flatten_alloc(allocations)

    # Sort the allocation in increasing order by allocation start address
    sorted_alloc = dict(sorted(flatten_alloc.items(), key=lambda item: item[1]))

    # compute a map from bank id to a list of allocation name
    bank_to_alloc = map_allocations_to_bank(sorted_alloc, bank_size, bank_nb)

    # print the memory layout for each banks
    print_memory_layout(sorted_alloc, bank_to_alloc, bank_size)

    # Check for overlap between the memory allocation and warn if found
    check_memory_allocation_overlap(sorted_alloc)

    # compute a map from allocation name to a list of bank ids
    alloc_to_bank = map_bank_id_to_allocations(sorted_alloc, bank_size, bank_nb)

    # check for overlap between many memory banks
    check_bank_overlap(alloc_to_bank)


def main() -> int:
    if len(sys.argv) != 2:
        print(
            "Wrong number of arguments, please provide the layer directory as a single argument"
        )
        return 1

    json_dir_path = sys.argv[1]

    buffer_alloc: BufferAllocations = BufferAllocations.load(json_dir_path)

    print("Core memory layout")
    print_memory_allocation(buffer_alloc.core_alloc, CORE_BANK_SIZE, bank_nb=4)
    if buffer_alloc.debug_info.core_alloc_non_banked:
        print("bank boundaries have been ignored due to out of memory during banked allocation")

    print("")
    print("Memtile memory layout")
    print_memory_allocation(buffer_alloc.mem_alloc, MEMTILE_SIZE, bank_nb=1)

    return 0

if __name__ == "__main__":
    sys.exit(main())
