"""
This module provides memory configs for L2 memory on AIE4.
"""

# pylint: disable=import-error,no-name-in-module
from graph.allocation_types import (
    MemoryBlock,
    MemoryConfig,
    ParamBlock,
    WeightBlock,
    WeightPingPongBlock,
)
from graph.common import (
    NUM_AIE_ROWS,
    PARAM_SIZE,
    REGION_SIZE,
    SRAM_TOTAL,
    WEIGHT_PING_PONG_SIZE,
    WEIGHT_SIZE,
)

# Memory Config used for ResNet50: This allocates data on middle column of AIE4
# _____________________________________________________________________
# |   |        |        |        |   |          |   |        |         |
# |P0 |W00,W01 |M0      |W10,W11 |P1 |M1        |P2 |W20,W21 |M2       |
# |___|________|________|_______ |___|__________|___|________|_________|
#
# P0, P1, P2 = Layer parameter blocks, each of size 4KB
# W00, W01 = Weight Ping (256KB) and Pong (256KB) block for column-0
# W10, W11 = Weight Ping (256KB) and Pong (256KB) block for column-1
# W20, W21 = Weight Ping (256KB) and Pong (256KB) block for column-2
# M0 = Free memory region on column-0
# M1 = Memory region used for tensor allocation
# M2 = Free memory region on column-2
COL1_MEMORY_CONFIG = MemoryConfig(
    params=(
        ParamBlock(0, PARAM_SIZE),
        ParamBlock(REGION_SIZE, PARAM_SIZE),
        ParamBlock(2 * REGION_SIZE, PARAM_SIZE),
    ),
    weights=(
        WeightPingPongBlock(
            ping=WeightBlock(PARAM_SIZE, WEIGHT_SIZE),
            pong=WeightBlock(PARAM_SIZE + WEIGHT_SIZE, WEIGHT_SIZE),
        ),
        WeightPingPongBlock(
            ping=WeightBlock(REGION_SIZE - WEIGHT_PING_PONG_SIZE, WEIGHT_SIZE),
            pong=WeightBlock(REGION_SIZE - WEIGHT_SIZE, WEIGHT_SIZE),
        ),
        WeightPingPongBlock(
            ping=WeightBlock(2 * REGION_SIZE + PARAM_SIZE, WEIGHT_SIZE),
            pong=WeightBlock(2 * REGION_SIZE + PARAM_SIZE + WEIGHT_SIZE, WEIGHT_SIZE),
        ),
    ),
    memory=[
        MemoryBlock(
            REGION_SIZE + PARAM_SIZE + WEIGHT_SIZE,
            REGION_SIZE - (PARAM_SIZE + WEIGHT_SIZE),
            True,
        )
    ],
)

"""
Memory Config which puts parameters and weights at the end of SRAM on entire AIE4
__________________________________________________
|         |   |   |   |        |        |        |
|M        |P0 |P1 |P2 |W00,W01 |W10,W11 |W20,W21 |
|_________|___|___|___|________|________|________|

P0, P1, P2 = Layer parameter blocks, each of size 4KB
W00, W01 = Weight Ping (256KB) and Pong (256KB) block for column-0
W10, W11 = Weight Ping (256KB) and Pong (256KB) block for column-1
W20, W21 = Weight Ping (256KB) and Pong (256KB) block for column-2
M = Memory region used for tensor allocation
"""
WEIGHT_CORE_SIZE = 40 * 1024  # Memory (in bytes) reserved per core for weights
WEIGHT_SIZE = (
    WEIGHT_CORE_SIZE * NUM_AIE_ROWS
)  # Total memory reserved per column for layer parameters
WEIGHT_PING_PONG_SIZE = WEIGHT_SIZE * 2
END_MEMORY_CONFIG = MemoryConfig(
    params=(
        ParamBlock(SRAM_TOTAL - 3 * WEIGHT_PING_PONG_SIZE - 3 * PARAM_SIZE, PARAM_SIZE),
        ParamBlock(SRAM_TOTAL - 3 * WEIGHT_PING_PONG_SIZE - 2 * PARAM_SIZE, PARAM_SIZE),
        ParamBlock(SRAM_TOTAL - 3 * WEIGHT_PING_PONG_SIZE - PARAM_SIZE, PARAM_SIZE),
    ),
    weights=(
        WeightPingPongBlock(
            ping=WeightBlock(SRAM_TOTAL - 3 * WEIGHT_PING_PONG_SIZE, WEIGHT_SIZE),
            pong=WeightBlock(
                SRAM_TOTAL - 3 * WEIGHT_PING_PONG_SIZE + WEIGHT_SIZE, WEIGHT_SIZE
            ),
        ),
        WeightPingPongBlock(
            ping=WeightBlock(SRAM_TOTAL - 2 * WEIGHT_PING_PONG_SIZE, WEIGHT_SIZE),
            pong=WeightBlock(
                SRAM_TOTAL - 2 * WEIGHT_PING_PONG_SIZE + WEIGHT_SIZE, WEIGHT_SIZE
            ),
        ),
        WeightPingPongBlock(
            ping=WeightBlock(SRAM_TOTAL - WEIGHT_PING_PONG_SIZE, WEIGHT_SIZE),
            pong=WeightBlock(
                SRAM_TOTAL - WEIGHT_PING_PONG_SIZE + WEIGHT_SIZE, WEIGHT_SIZE
            ),
        ),
    ),
    memory=[
        MemoryBlock(0, SRAM_TOTAL - 3 * WEIGHT_PING_PONG_SIZE - 3 * PARAM_SIZE, True)
    ],
)
"""
Memory Config which puts parameters and weights at the start and end of SRAM on entire AIE4
___________________________________________________________________
|   |   |        |        |                          |   |        |
|P0 |P1 |W00,W01 |W10,W11 |M                         |P2 |W20,W21 |
|___|___|________|________|__________________________|___|________|

P0, P1, P2 = Layer parameter blocks, each of size 4KB
W00, W01 = Weight Ping (256KB) and Pong (256KB) block for column-0
W10, W11 = Weight Ping (256KB) and Pong (256KB) block for column-1
W20, W21 = Weight Ping (256KB) and Pong (256KB) block for column-2
M = Memory region used for tensor allocation
"""
START_END_MEMORY_CONFIG = MemoryConfig(
    params=(
        ParamBlock(0, PARAM_SIZE),
        ParamBlock(1 * PARAM_SIZE, PARAM_SIZE),
        ParamBlock(SRAM_TOTAL - WEIGHT_PING_PONG_SIZE - PARAM_SIZE, PARAM_SIZE),
    ),
    weights=(
        WeightPingPongBlock(
            ping=WeightBlock(2 * PARAM_SIZE, WEIGHT_SIZE),
            pong=WeightBlock(2 * PARAM_SIZE + WEIGHT_SIZE, WEIGHT_SIZE),
        ),
        WeightPingPongBlock(
            ping=WeightBlock((2 * PARAM_SIZE) + WEIGHT_PING_PONG_SIZE, WEIGHT_SIZE),
            pong=WeightBlock((2 * PARAM_SIZE) + WEIGHT_PING_PONG_SIZE + WEIGHT_SIZE, WEIGHT_SIZE),
        ),
        WeightPingPongBlock(
            ping=WeightBlock(SRAM_TOTAL - WEIGHT_PING_PONG_SIZE, WEIGHT_SIZE),
            pong=WeightBlock(SRAM_TOTAL - WEIGHT_PING_PONG_SIZE + WEIGHT_SIZE, WEIGHT_SIZE),
        ),
    ),
    memory=[
        MemoryBlock(
            (1 * REGION_SIZE),
            SRAM_TOTAL - (1 * REGION_SIZE) - WEIGHT_PING_PONG_SIZE - PARAM_SIZE,
            # (2* PARAM_SIZE) + (2 * WEIGHT_PING_PONG_SIZE),
            # SRAM_TOTAL - (2 * PARAM_SIZE) - (2 * WEIGHT_PING_PONG_SIZE) - WEIGHT_PING_PONG_SIZE - PARAM_SIZE,
            True)
    ],
)
