import sys
import os
infra_path = (os.path.dirname(os.path.abspath(__file__))+"/infra/")
sys.path.append(infra_path)
import logging
import pdb
import numpy as np
import math
from types import MappingProxyType    #To make it read only

import scheduler_utils as utils
import const

class GemmTilingCheck():
	def __init__(self, _data):
		for attr, value in _data.items():
			setattr(self, attr, value)

	def sanity_check(self):
		(ifm_mem_tile_size, ifm_shim_tile_size) = self.calc_buffer_size(*self.tensor_repr("ifm"))
		(wgt_mem_tile_size, wgt_shim_tile_size) = self.calc_buffer_size(*self.tensor_repr("wgt"))
		(ofm_mem_tile_size, ofm_shim_tile_size) = self.calc_buffer_size(*self.tensor_repr("ofm"))
		utils.sanity_check((ifm_mem_tile_size  == self.mem_tile.sizes['ifm']) , f"IFM Memtile tile size {self.mem_tile.sizes['ifm']} from tiler does not match computed value {ifm_mem_tile_size}", "Message")
		utils.sanity_check((ifm_shim_tile_size == self.shim_tile.sizes['ifm']), f"IFM Shim tile size {self.shim_tile.sizes['ifm']} from tiler does not match computed value {ifm_shim_tile_size}", "Message")
		utils.sanity_check((wgt_mem_tile_size  == self.mem_tile.sizes['wgt']) , f"Wgt Memtile tile size {self.mem_tile.sizes['wgt']} from tiler does not match computed value {wgt_mem_tile_size}", "Message")
		utils.sanity_check((wgt_shim_tile_size == self.shim_tile.sizes['wgt']), f"Wgt Shim tile size {self.shim_tile.sizes['wgt']} from tiler does not match computed value {wgt_shim_tile_size}", "Message")
		utils.sanity_check((ofm_mem_tile_size  == self.mem_tile.sizes['ofm']) , f"OFM Memtile tile size {self.mem_tile.sizes['ofm']} from tiler does not match computed value {ofm_mem_tile_size}", "Message")
		utils.sanity_check((ofm_shim_tile_size == self.shim_tile.sizes['ofm']), f"OFM Shim tile size {self.shim_tile.sizes['ofm']} from tiler does not match computed value {ofm_shim_tile_size}", "Message")
	
	def calc_wgt_pin_repr(self):
			subv_cols      = self.wgt.dim[1] // self.core_tile.subV['wgt'][1]
			num_cols       = subv_cols // (self.aie_cols * self.aie_arrays)
			subv_rows      = (self.wgt.dim[0] +  sum(self.padding[0]['pad_ifm_y'])) // self.core_tile.subV['ifm'][1]
			mem_tile_repr  = [num_cols, subv_rows, self.wgt_core_subv_bytes]        #[num_cols, dims.wgt_subv_rows, dims.wgt_subv_bytes]
			shim_tile_repr = [subv_cols, subv_rows, self.wgt_core_subv_bytes]       #[dims.wgt_subv_cols, dims.wgt_subv_rows, dims.wgt_subv_bytes]
			bits           = 8
			return bits, mem_tile_repr, shim_tile_repr

	def calc_wgt_stream_repr(self):
			subv_cols      = self.wgt.dim[1] // self.core_tile.subV['wgt'][1]
			num_cols       = subv_cols // (self.aie_cols * self.aie_arrays)
			subv_rows      = (self.wgt.dim[0] +  sum(self.padding[0]['pad_ifm_y']))  // self.core_tile.subV['ifm'][1]
			mem_tile_repr  = [self.wgt_core_subv_bytes]        #[num_cols, dims.wgt_subv_rows, dims.wgt_subv_bytes]
			#shim_tile_repr = [num_cols, subv_rows, self.wgt_core_subv_bytes]       #[dims.wgt_subv_cols, dims.wgt_subv_rows, dims.wgt_subv_bytes]
			shim_tile_repr = [subv_cols, subv_rows, self.wgt_core_subv_bytes]       #[dims.wgt_subv_cols, dims.wgt_subv_rows, dims.wgt_subv_bytes]
			bits           = 8
			return bits, mem_tile_repr, shim_tile_repr

	def tensor_repr(self, tensor_type):
			mem_tile_repr = None
			shim_tile_repr = None
			bits = None
			if tensor_type == "ifm":
					if self.sch_attr.dataflow_mode['ifm'] == "pin":
							mem_tile_repr  = [self.mem_tile.subV['ifm'][0], self.mem_tile.subV['ifm'][1]]
							shim_tile_repr = [self.ifm.dim[0], self.mem_tile.subV['ifm'][1]]
							bits           = self.ifm_bits
					elif self.sch_attr.dataflow_mode['ifm'] == "stream":
							mem_tile_repr  = [self.mem_tile.subV['ifm'][0], self.mem_tile.subV['ifm'][1]]
							shim_tile_repr = [self.mem_tile.subV['ifm'][0], self.mem_tile.subV['ifm'][1]] # TBD
							bits           = self.ifm_bits
					else:
							utils.sanity_check(False,f"Unsupported ifm data flow mode: {self.sch_attr.dataflow_mode['ifm']}")
			elif tensor_type == "ofm":
					mem_tile_repr  = [self.mem_tile.subV['ofm'][0], self.mem_tile.subV['ofm'][1]]
					shim_tile_repr = [self.ofm.dim[0], self.ofm.dim[1]]
					bits           = self.ofm_bits
			elif tensor_type == "wgt":
				if self.actxact:
					if self.sch_attr.dataflow_mode['wgt'] == "full":
							bits, mem_tile_repr, shim_tile_repr = self.calc_wgt_pin_repr()
					elif self.sch_attr.dataflow_mode['wgt'] == "stream":
							bits, mem_tile_repr, shim_tile_repr = self.calc_wgt_stream_repr()
					else:
							utils.sanity_check(False,f"Unsupported wgt data flow mode: {self.sch_attr.dataflow_mode['wgt']}")
				else:
					if self.sch_attr.dataflow_mode['wgt'] == "full":
							bits, mem_tile_repr, shim_tile_repr = self.calc_wgt_pin_repr()
					elif self.sch_attr.dataflow_mode['wgt'] == "stream":
							bits, mem_tile_repr, shim_tile_repr = self.calc_wgt_stream_repr()
					else:
							utils.sanity_check(False,f"Unsupported wgt data flow mode: {self.sch_attr.dataflow_mode['wgt']}")
			else:
					utils.sanity_check(False, "Invalid tensor type!!. Tensor Type:"+tensor_type)

			assert(bits is not None and mem_tile_repr is not None and shim_tile_repr is not None)
			return bits, mem_tile_repr, shim_tile_repr

	def calc_buffer_size(self, bits, mem_tile_param: list[int], shim_tile_param: list[int]):
			#TODO:sanity_checks
			buffer_element_shim_tile = np.prod(shim_tile_param)
			buffer_element_mem_tile  = np.prod(mem_tile_param)
			mem_tile_size            = (buffer_element_mem_tile * bits) // self.bits_per_byte
			shim_tile_size           = (buffer_element_shim_tile * bits) // self.bits_per_byte
			
			return int(mem_tile_size), int(shim_tile_size)
