import sys
import os
infra_path = (os.path.dirname(os.path.abspath(__file__))+"/infra/")
sys.path.append(infra_path)
import logging
import pdb
from enum import Enum

import const
from template_base import BaseTemplate, BaseDims
import gen_kernel_param

class ConvDims(BaseDims):
	__slots__ = ['__dict__']
	def __init__(self, data):
		(M, K, N,
		M_subv, K_subv, N_subv,
		aie_rows, aie_cols, aie_arrays,
		act_bits, out_bits, qdq_bytes,
		outer_loop, inner_loop, acc_loop,
		wgt_subv_rows, wgt_subv_cols,
		act_subv_bytes, wgt_subv_bytes, out_subv_bytes,
		sum_subv_bytes, tdm_subv_bytes,
		wgt_bits, bits_per_byte, bias_bits,
		tdm_bits, Mpad, Kpad) = data
		super().__init__(M, K, N,
									 M_subv, K_subv, N_subv,
									 aie_rows, aie_cols, aie_arrays,
									 act_bits, out_bits, qdq_bytes,
									 outer_loop, inner_loop, acc_loop,
									 wgt_subv_rows, wgt_subv_cols,
									 act_subv_bytes, wgt_subv_bytes, out_subv_bytes,
									 sum_subv_bytes, tdm_subv_bytes,
									 wgt_bits, bits_per_byte, bias_bits,
									 tdm_bits,)
		self.Mpad = Mpad
		self.Kpad = Kpad


class M4N4(BaseTemplate):
	def __init__(self, _data):
		super().__init__()
		self.data = _data.info.get('BuffAllocator')
		self.helper_func = self.helper_func()
		#logging.info("Generating data flow. %s",self.data.info.get('BuffAllocator'))
		buff = tuple(self.data)
		tuple_var = (buff[2]['M'], buff[2]['K'], buff[2]['N'],
								buff[2]['M_subV'], buff[2]['K_subV'], buff[2]['N_subV'],
								buff[2]['aie_rows'], buff[2]['aie_cols'], buff[2]['aie_arrays'],
								buff[2]['ifm_bits'], buff[2]['ofm_bits'], buff[2]['qdq_bytes'],
								buff[2]['outer_loop'], buff[2]['inner_loop'], buff[2]['acc_loop'],
								buff[2]['wgt_subv_rows'], buff[2]['wgt_subv_cols'],
								buff[2]['ifm_core_subv_bytes'], buff[2]['wgt_core_subv_bytes'], 
								buff[2]['ofm_core_subv_bytes'],
								buff[2]['sum_core_subv_bytes'], buff[2]['tdm_core_subv_bytes'],
								buff[2]['wgt_bits'], const.BITS_PER_BYTE, buff[2]['bias_bits'],
								buff[2]['tdm_bits'],0 ,0)
		dims = ConvDims(tuple_var)

	def gen_buff_addr(self):
		buff_alloc_info = self.data.info.get('BuffAllocator')
		buff_prm  = buff_alloc_info[const.BufAllocator_Idx.BUFF_ALLOC_PARAM_IDX.value]
		mem_tile_addr = buff_alloc_info[const.BufAllocator_Idx.MEM_TILE_ADDR_IDX.value]
		core_tile_addr = buff_alloc_info[const.BufAllocator_Idx.CORE_TILE_ADDR_IDX.value]
		return f"""
		#ParamSize = {const.PARAM_SIZE}
		#MemtileIfmSize = {buff_prm['ifm_mem_tile_size']}
		#MemtileWgtSize = {buff_prm['wgt_mem_tile_size']}
		#MemtileOfmSize = {buff_prm['ifm_mem_tile_size']}
		#CoreIfmSize = {buff_prm['ifm_core_tile_size']}
		#CoreWgtSize = {buff_prm['wgt_core_tile_size']}
		#CoreOfmSize = {buff_prm['ofm_core_tile_size']}
		#CoreBankSize = {const.CORE_BANK_SIZE}


		MemtilePrmPingAddr = {mem_tile_addr['MemtilePrmPingAddr']}
		MemtileIfmPingAddr = {mem_tile_addr['MemtileIfmPingAddr']}
		MemtileWgtPingAddr = {mem_tile_addr['MemtileWgtPingAddr']}
		MemtileWgtPongAddr = {mem_tile_addr['MemtileWgtPongAddr']}
		MemtileOfmPingAddr = {mem_tile_addr['MemtileOfmPingAddr']}

		#TODO: REVISIT after perf review to see if we need new buff allocation policy
		CoreIfmPingAddr = {core_tile_addr['CoreIfmPingAddr']}
		CoreWgtPingAddr = {core_tile_addr['CoreWgtPingAddr']}
		CoreOfmPingAddr = {core_tile_addr['CoreOfmPingAddr']}
		CoreTdmPingAddr = {core_tile_addr['CoreTdmPingAddr']}
		CoreIfmPongAddr = {core_tile_addr['CoreIfmPongAddr']}
		CoreWgtPongAddr = {core_tile_addr['CoreWgtPongAddr']}
		CoreTdmPongAddr = {core_tile_addr['CoreTdmPongAddr']}
		CoreSumPingAddr = {core_tile_addr['CoreSumPingAddr']}
		CoreStackAddr = {core_tile_addr['CoreStackAddr']}

		#TODO: Enable this in buff allocator
		#assert CoreTdmPingAddr + CoreTdmSize < CoreIfmPongAddr
		#assert CoreSumPingAddr + CoreSumSize < CoreStackAddr
		"""

		def gen_core_instr(self, dims):
			core_tile_addr = buff_alloc_info[const.BufAllocator_Idx.CORE_TILE_ADDR_IDX.value]
			if dims.acc_loop == 1:
				return f"""
				#TODO: Replace VAR with it's corresponding value
					core_instrs = [
						ConfigBuffer(DmaChannel(DmaDir.S2MM, 1), CoreIfmPingAddr, CoreIfmPongAddr, CoreIfmSize),
						ConfigBuffer(DmaChannel(DmaDir.S2MM, 0), CoreWgtPingAddr, CoreWgtPongAddr, CoreWgtSize),
						ConfigBuffer(DmaChannel(DmaDir.MM2S, 0), CoreOfmPingAddr, None, CoreOfmSize),
						Loop(dims.outer_loop, [
								Loop(dims.inner_loop, [
										AcqBuffer(DmaChannel(DmaDir.MM2S, 0)),
										AcqBuffer(DmaChannel(DmaDir.S2MM, 1)),
										AcqBuffer(DmaChannel(DmaDir.S2MM, 0)),
										CallKernel(
												'run_conv_2d',
												kernel_params=a16w8_conv_params(
														dims, 1, 1, CoreTdmPingAddr, CoreTdmPongAddr, CoreSumPingAddr)
										),
										RelBuffer(DmaChannel(DmaDir.S2MM, 1)),
										RelBuffer(DmaChannel(DmaDir.S2MM, 0)),
										RelBuffer(DmaChannel(DmaDir.MM2S, 0)),
								]),
						]),
				]
				"""
			else:
				return f"""
				#TODO: Replace VAR with it's corresponding value
				core_instrs = [
						Loop(dims.outer_loop, [
								ConfigBuffer(DmaChannel(DmaDir.S2MM, 1), CoreIfmPingAddr, CoreIfmPongAddr, CoreIfmSize),
								ConfigBuffer(DmaChannel(DmaDir.S2MM, 0), CoreWgtPingAddr, CoreWgtPongAddr, CoreWgtSize),
								ConfigBuffer(DmaChannel(DmaDir.MM2S, 0), CoreOfmPingAddr, None, CoreOfmSize),
								Loop(dims.inner_loop, [
										AcqBuffer(DmaChannel(DmaDir.S2MM, 1)),
										AcqBuffer(DmaChannel(DmaDir.S2MM, 0)),
										CallKernel(
												'run_conv_2d',
												kernel_params=a16w8_conv_params(
														dims, 1, 0, CoreTdmPingAddr, CoreTdmPongAddr, CoreSumPingAddr)
										),
										RelBuffer(DmaChannel(DmaDir.S2MM, 1)),
										RelBuffer(DmaChannel(DmaDir.S2MM, 0)),
										Loop(dims.acc_loop - 2, [
												AcqBuffer(DmaChannel(DmaDir.S2MM, 1)),
												AcqBuffer(DmaChannel(DmaDir.S2MM, 0)),
												CallKernel(
														'run_conv_2d',
														kernel_params=a16w8_conv_params(
																dims, 0, 0, CoreTdmPingAddr, CoreTdmPongAddr, CoreSumPingAddr)
												),
												RelBuffer(DmaChannel(DmaDir.S2MM, 1)),
												RelBuffer(DmaChannel(DmaDir.S2MM, 0)),
										]),
										AcqBuffer(DmaChannel(DmaDir.MM2S, 0)),
										AcqBuffer(DmaChannel(DmaDir.S2MM, 1)),
										AcqBuffer(DmaChannel(DmaDir.S2MM, 0)),
										CallKernel(
												'run_conv_2d',
												kernel_params=a16w8_conv_params(
														dims, 0, 1, CoreTdmPingAddr, CoreTdmPongAddr, CoreSumPingAddr)
										),
										RelBuffer(DmaChannel(DmaDir.S2MM, 1)),
										RelBuffer(DmaChannel(DmaDir.S2MM, 0)),
										RelBuffer(DmaChannel(DmaDir.MM2S, 0)),
								]),
						]),
				]
			"""

		def gen_memtile_instr(self):
			gen_memtile_transfers = f"""
				memtile_transfers = [
				DataTransfer(
						[1] + [0] * (dims.outer_loop - 1),
						AieTile(TileType.Memtile, col), [MemtilePrmPingAddr], ParamSize,
						[access_linear_buffer(memtile_dma(col, DmaDir.S2MM, 1), ParamSize)],
						[access_linear_buffer(memtile_dma(col, DmaDir.MM2S, 1 + row), ParamSize)
						 for row in range(dims.aie_rows)]
				) for col in range(dims.aie_cols)
		] + [
				DataTransfer(
						pack_ifm_repeat(dims),
						AieTile(TileType.Memtile, col), [MemtileIfmPingAddr], MemtileIfmSize,
						[pack_ifm_header_body_footer(
								dims,
								memtile_dma(col, DmaDir.S2MM, 1),
								ifm_memtile_memory(dims, col),
								(ifm_memtile_s2mm(dims),) * 3,
								IfmBits)],
						[pack_ifm_header_body_footer(
								dims,
								memtile_dma(col, DmaDir.MM2S, 1 + row),
								ifm_memtile_memory(dims, col),
								mm2s_fmt,
								IfmBits,
								enable_padding=True) for row in range(dims.aie_rows)
																		 for _ in range(ifm_chain_length(dims))
																		 for mm2s_fmt in ifm_memtile_mm2s(dims, col)],
						reuse_ratio=(dims.inner_loop // ifm_chain_length(dims)),
						sync_strategy=SyncStrategy.Parallel_1_to_N,
				) for col in range(dims.aie_cols)
		] + [
				DataTransfer(
						[dims.acc_loop * dims.inner_loop] * dims.outer_loop,
						AieTile(TileType.Memtile, col), [MemtileWgtPingAddr, MemtileWgtPongAddr], MemtileWgtSize,
						[generate_transfer_params(
								memtile_dma(col, DmaDir.S2MM, 0),
								wgt_memtile_memory(CoreWgtSize),
								wgt_memtile_s2mm(),
								WgtBits)],
						[generate_transfer_params(
								memtile_dma(col, DmaDir.MM2S, 0),
								wgt_memtile_memory(CoreWgtSize),
								wgt_memtile_mm2s(),
								WgtBits)],
						sync_strategy=SyncStrategy.Parallel_1_to_N,
				) for col in range(dims.aie_cols)
		] + [
				DataTransfer(
						[dims.inner_loop] * dims.outer_loop,
						AieTile(TileType.Memtile, col), [MemtileOfmPingAddr], MemtileOfmSize,
						[generate_transfer_params(
								memtile_dma(col, DmaDir.S2MM, 2 + row),
								ofm_memtile_memory(dims),
								ofm_memtile_s2mm(dims, row),
								OfmBits) for row in range(dims.aie_rows)],
						[generate_transfer_params(
								memtile_dma(col, DmaDir.MM2S, 5),
								ofm_memtile_memory(dims),
								ofm_memtile_mm2s(dims),
								OfmBits)],
						sync_strategy=SyncStrategy.Parallel_N_to_1,
				) for col in range(dims.aie_cols)
		]
			"""
			return gen_memtile_transfers

		def gen_shimtile_instr(self):
			gen_shim_transfers = f"""
			shim_transfers = [
				DataTransfer(
						[1] + [0] * (dims.outer_loop - 1),
						AieTile(TileType.Shim, col), [ShimPrmBufferIdx], ParamSize,
						[],
						[access_linear_buffer(shim_dma(col, DmaDir.MM2S, 1), ParamSize)]
				) for col in range(dims.aie_cols)
			] + [
					generate_shim_data_transfer(
							[1] + [0] * (dims.outer_loop - 1),
							shim_dma(col, DmaDir.MM2S, 1),
							ShimIfmBufferIdx,
							ifm_shim_memory(dims),
							fmt,
							IfmBits,
					) for col in range(dims.aie_cols) for fmt in ifm_shim_mm2s(dims, col)
			] + [
					generate_shim_data_transfer(
							[dims.outer_loop] + [0] * (dims.outer_loop - 1),
							shim_dma(col, DmaDir.MM2S, 0),
							ShimWgtBufferIdx,
							wgt_shim_memory(dims, CoreWgtSize),
							wgt_shim_mm2s(dims, col),
							WgtBits,
					) for col in range(dims.aie_cols)
			] + [
					generate_shim_data_transfer(
							[1] + [0] * (dims.outer_loop - 1),
							shim_dma(col, DmaDir.S2MM, 0),
							ShimOfmBufferIdx,
							ofm_shim_memory(dims),
							ofm_shim_s2mm(dims, col),
							OfmBits,
					) for col in range(dims.aie_cols)
			]
		"""
			return gen_shim_transfers

		def gen_helper_func(self, enable_pack_shim, enable_pack_ifm, dims):
				code = ""
				code += self.gen_pack_shim_transfer_with_iter_step() if enable_pack_shim else code
				if enable_pack_ifm:
						code += self.gen_GemmDims(dims)
						code += self.gen_pack_TransferParams('ConvDims', self.gen_ifm_repeat_counts('ConvDims'))
				return code

