'''
Define unit tests for maxpool dataflow
'''
import unittest
import os
from dmacompiler import (
    memory_tile, BackEnd
)
from utils.utils_common import (
    L2Alloc,
)
from scheduler.maxpool.maxpool import (
    generate_maxpool_mappings,
    compile_maxpool_dataflow
)

RESNET50_MAXPOOL_SHAPES = [
    # Shape: (Yi, Xi, C), (Yo, Xo, C), (kernel_size), (stride), (padding)
    [(112, 112, 64), (56, 56, 64), (3, 3), (2, 2), (0, 0)],
]


class TestMaxpool(unittest.TestCase):
    '''Test conv dataflow mapping and compilation'''

    def test_resnet50_maxpool_shapes(self) -> None:
        '''Compile all possible mappings of all Resnet50 layers'''
        wgt_relative_base_addr = 2 * 2**20  # 2 MiB
        wgt_reserved_size = 64 * 1024 * 4  # 64 KiB * 4 (unicast)
        wgt_ping_addr = wgt_relative_base_addr
        wgt_pong_addr = wgt_relative_base_addr + wgt_reserved_size
        prm_addr = wgt_pong_addr + wgt_reserved_size
        fusion_params = L2Alloc(
            (memory_tile(1), 0 * 2**20),
            (memory_tile(0), 0),
            (memory_tile(1), 1 * 2**20),
            [
                [memory_tile(0), wgt_ping_addr, wgt_pong_addr],
                [memory_tile(1), wgt_ping_addr, wgt_pong_addr],
                [memory_tile(2), wgt_ping_addr, wgt_pong_addr],
            ],
            [
                [memory_tile(0), prm_addr],
                [memory_tile(1), prm_addr],
                [memory_tile(2), prm_addr],
            ],
            True, True, True,
        )
        kernel_names = ['run_maxpool_int8x8']
        kernel_includes = ['super.hh', 'maxpool/maxpool_int8x8_wrapper.cc']
        layer_file_name = 'dma.hpp'
        aie_cols = 3
        aie_rows = 4
        ifm_bits = 8
        ofm_bits = 8
        bits_per_byte = 8
        Y_gran, X_gran, C_gran = (1, 1, 64)
        # NOTE: There is an unrolling factor of 2 for channel dimenion in the kernel
        kernel_Cs_unroll = 1
        Cs_min = C_gran * kernel_Cs_unroll
        for shape in RESNET50_MAXPOOL_SHAPES:
            (Yi, Xi, Ci), (Yo, Xo, Co), (Ky, Kx), (Sy, Sx), (Py, Px) = shape
            assert (Ci == Co), "Input and output channels must be the same for maxpool"
            maxpool_mapped_soln = generate_maxpool_mappings(
                Yi, Xi, Ci, Yo, Xo, Ky, Kx, Sy, Sx, Py, Px,
                Y_gran, X_gran, C_gran, Cs_min,
                aie_cols, aie_rows,
                ifm_bits, ofm_bits, bits_per_byte,
            )

            compile_maxpool_dataflow(maxpool_mapped_soln[0], fusion_params,
                                     kernel_names, kernel_includes, layer_file_name, BackEnd.Adf)

        os.remove('dma.hpp')
        os.remove('graph.hpp')
        os.remove('super.cc')
        os.remove('super.hh')


if __name__ == '__main__':
    unittest.main()
