import sys
import os
infra_path = (os.path.dirname(os.path.abspath(__file__))+"/infra/")
sys.path.append(infra_path)
import logging
import pdb
import ast
import numpy as np
import math
import json
import enum
import traceback
from dataclasses import dataclass
from types import MappingProxyType    #To make it read only
from typing import Callable, Dict, Iterable, List, Optional, Set

import OGOAT.src.Scheduling_Engine.infra.scheduler_utils as utils
import OGOAT.src.Scheduling_Engine.infra.const as const
import OGOAT.src.Scheduling_Engine.schedules.scheduler as schedule
from OGOAT.src.Scheduling_Engine.schedules import BufferAllocatorResult


@dataclass(init=False)
class TensorDim():
	def __init__ (self, _info):
		self.dim   = _info['shape'] if _info else None
		if self.dim and len(self.dim) == 1:
			self.dim = (self.dim[0], 1) #TODO: Clean up to handle 1 DIM tensor
		self.dtype = _info['datatype'] if _info else None
		self.bytes = _info['bytes'] if _info else None


@dataclass(init=False)
class KernelInfo():
	def __init__(self, _info):
		self.tensor_placement = dict()
		for key, value in _info.items():
			self.tensor_placement[key] = value


@dataclass(init=False)
class TilingInfo:
	def __init__(self, _info, exp_tensors: const.ExpectedTensors):
		tensor_list = exp_tensors.input_and_output_tensors

		# Initialize subV
		self.subV = {key: _info['subvols'][key] for key in tensor_list if key in _info['subvols']}
		for key in self.subV:
			if len(self.subV[key]) == 1:
				self.subV[key] = (self.subV[key][0], 1)

		# Initialize itr
		if 'iters' in _info:
			self.itr = {key: _info['iters'][key] for key in tensor_list if key in _info['iters']}
			for key in self.itr:
				if len(self.itr[key]) == 1 and isinstance(self.itr[key], tuple):
					self.itr[key] = (self.itr[key][0], 1)
		else:
			self.itr = None

		# Initialize sizes
		if 'sizes' in _info:
			self.sizes = {key: _info['sizes'][key] for key in tensor_list if key in _info['sizes']}
		else:
			self.sizes = None

class SchedulingInfo():

	def __init__(self, _info, exp_tensors: const.ExpectedTensors):
		self.dataflow_mode =  {}
		self.ping_pong_enable = {}
		if not _info:
			raise ValueError("Unsupported Op: Input Json file with empty subv info")
		for tensor_name in exp_tensors.input_and_output_tensors:
			self.dataflow_mode[tensor_name] = _info.get(tensor_name)
			self.ping_pong_enable[tensor_name] = _info.get(f'{tensor_name}_ping_pong')

class MemoryAllocator:
	"""
	Allocator for memory blocks based on a base address and maximum size.
	Allocation can happen at beginning of free space and at end of free space.
	Placing of memory blocks is done sequentially, respecting alignment
	requirements.
	"""

	class OutOfMemory(Exception):
		"""
		Out of memory error during allocations.
		"""

	def __init__(self, base_addr: int, max_size: int, default_alignment: int = 64, alloc_name: Optional[str] = None) -> None:
		"""
		Initialize allocator to base address and set the default alignment.
		alloc_name is an optional name for this MemoryAllocator that is used in
		error and warning messages.
		"""
		utils.sanity_check(base_addr >= 0, f"base_addr {base_addr} must be non-negative")
		self._base_addr = base_addr
		self._max_size = max_size
		self._default_alignment = default_alignment
		self._alloc_name = alloc_name

		self._next_free_addr = self._base_addr
		self._end_addr_of_free = self._base_addr + self._max_size

	def _check_out_of_memory(self, condition: bool, buf_name: Optional[str], addr: int, size: int, alignment: int) -> None:
		"""
		Check out of memory.
		condition -- True if enough memory, False if out of memory
		buf_name -- optional name of buffer being allocated, for error message
		addr, size, alignment -- details of allocation that did not fit any more
		"""
		if not condition:
			message = (f"out of memory, addr={addr}, size={size}, alignment={alignment}, "
		               f"remaining={self._end_addr_of_free - self._next_free_addr}, "
		               f"next_free_addr={self._next_free_addr}, end_addr_of_free={self._end_addr_of_free}")
			raise MemoryAllocator.OutOfMemory(self._format_error_message(message, buf_name))

	def _format_error_message(self, message: str, buf_name: Optional[str]) -> str:
		if buf_name is not None:
			message = f"buffer {buf_name}: {message}"
		if self._alloc_name is not None:
			message = f"{self._alloc_name}: {message}"
		return message

	def _sanity_check(self, condition: bool, buf_name: Optional[str], message: str, level: str = "Error") -> None:
		message = self._format_error_message(message, buf_name)
		utils.sanity_check(condition, message, level=level)

	def align(self, addr: int, alignment: int, buf_name: Optional[str]) -> int:
		"""
		Align the passed address to the desired alignment.
		The alignment needs to be a power of two.
		buf_name is an optional name for the memory buffer being allocated
		to be used in error and warning messages.
		"""
		self._sanity_check(alignment > 0 and alignment & (alignment - 1) == 0,
		                   buf_name, f"alignment {alignment} is not a power of two")
		return (addr + (alignment - 1)) & ~(alignment - 1)

	def align_backwards(self, addr: int, alignment: int, buf_name: Optional[str]) -> int:
		"""
		Align the passed address backwards (i.e. going to smaller addresses)
		to the desired alignment.
		The alignment needs to be a power of two.
		buf_name is an optional name for the memory buffer being allocated
		to be used in error and warning messages.
		"""
		self._sanity_check(alignment > 0 and alignment & (alignment - 1) == 0,
		                   buf_name, f"alignment {alignment} is not a power of two")
		return addr & ~(alignment - 1)

	def alloc(self, size: int, alignment: Optional[int] = None, buf_name: Optional[str] = None) -> int:
		"""
		Allocate a new memory block of specified size and alignment (at the
		beginning of the free space).
		If alignment is left out, the default alignment of the allocator is used.
		buf_name is an optional name for the memory buffer being allocated
		to be used in error and warning messages.
		Return the address for the new memory block.
		"""
		self._sanity_check(size >= 0, buf_name, f"size {size} must be non-negative")
		if alignment is None:
			alignment = self._default_alignment
		addr = self.align(self._next_free_addr, alignment, buf_name)
		#self._check_out_of_memory(addr + size <= self._end_addr_of_free, buf_name, addr, size, alignment)
		self._next_free_addr = addr + size
		return addr

	def alloc_at_end(self, size: int, alignment: Optional[int] = None, buf_name: Optional[str] = None) -> int:
		"""
		Allocate a new memory block of specified size and alignment at the
		end of the free space.
		If alignment is left out, the default alignment of the allocator is used.
		buf_name is an optional name for the memory buffer being allocated
		to be used in error and warning messages.
		Return the address for the new memory block.
		"""
		self._sanity_check(size >= 0, buf_name, f"size {size} must be non-negative")
		if alignment is None:
			alignment = self._default_alignment
		addr = self.align_backwards(self._end_addr_of_free - size, alignment, buf_name)
		#self._check_out_of_memory(addr >= self._next_free_addr, buf_name, addr, size, alignment)
		self._end_addr_of_free = addr
		return addr

	def get_total_size(self) -> int:
		"""
		Return the total size of all memory allocated so far.
		"""
		used_at_beginning = self._next_free_addr - self._base_addr
		used_at_end = self._base_addr + self._max_size - self._end_addr_of_free
		return used_at_beginning + used_at_end

def find_buffer_group(exp_tensors: const.ExpectedTensors, buffer_name: str) -> Optional[const.BufferGroup]:
	"""
	Find buffer group that contains <buffer_name>. Return this group if
	found. Return None if <buffer_name> is not in any group.
	"""
	for buf_grp in exp_tensors.buffer_groups:
		if buffer_name in buf_grp.buffer_names:
			return buf_grp
	return None

class CoreBufferAllocator:
	"""
	Allocate buffers for cores. Respect banks and sizes described in
	placement_constraints in YAML. Also respect buffer groups described in
	expected tensors in const.py
	"""

	class OrderingReq(enum.Enum):
		"""
		Ordering requirement of a buffer within a bank.
		ANYWHERE: place buffer anywhere in bank (default)
		BEGIN: place buffer at beginning of bank
		END: place buffer at end of bank
		"""
		ANYWHERE = 0
		BEGIN = 1
		END = 2

	class BankNameMapper:
		"""
		Mapping from bank name in format "BANK<id>[.<subbank id>][.{BEGIN|END}]"
		to bank ID, subbank ID, ordering requirement.
		Also suppoerts mapping bank ID back to main bank name "BANK<id>".

		Note: Tiler/utils.py get_bank_name() also parses the same format
		"""

		def __init__(self, max_num_banks: int) -> None:
			"""
			max_num_banks -- number of banks on the core
			"""
			self._max_num_banks = max_num_banks

		def bank_id_to_name(self, bank_id: int) -> str:
			"""
			Return bank name for bank with specified ID.
			"""
			if bank_id < 0 or bank_id >= self._max_num_banks:
				raise ValueError(f"invalid bank ID {bank_id} (max_num_banks={self._max_num_banks})")
			return f"BANK{bank_id}"

		def bank_name_to_id(self, bank_name: str) -> int:
			"""
			Return bank ID for bank name.
			Raise ValueError if conversion of name to ID fails.
			"""
			bank_id, ordering_req = self.bank_name_to_id_and_ordering(bank_name)
			del ordering_req
			return bank_id

		def bank_name_to_id_and_ordering(self, bank_name: str) -> tuple[int, "CoreBufferAllocator.OrderingReq"]:
			"""
			Return bank ID and ordering requirement for bank name.
			Raise ValueError if conversion of name to ID fails.
			"""
			expected_prefix = "BANK"
			if not bank_name.startswith(expected_prefix):
				raise ValueError(f"invalid bank name {bank_name}: does not start with {expected_prefix}")
			parts = bank_name[len(expected_prefix):].split('.')
			ordering_req = CoreBufferAllocator.OrderingReq.ANYWHERE
			if len(parts) >= 2:
				if parts[-1] == "BEGIN":
					del parts[-1]
					ordering_req = CoreBufferAllocator.OrderingReq.BEGIN
				elif parts[-1] == "END":
					del parts[-1]
					ordering_req = CoreBufferAllocator.OrderingReq.END
			try:
				bank_id = int(parts[0])
			except ValueError as exc:
				raise ValueError(f"invalid bank name {bank_name}: {str(exc)}")
			if bank_id < 0 or bank_id >= self._max_num_banks:
				raise ValueError(f"invalid bank ID {bank_id} in bank name{bank_name} (max_num_banks={self._max_num_banks})")
			return bank_id, ordering_req

	class SingleBankNameMapper(BankNameMapper):
		"""
		Modify bank name mapping to map everything to a single bank (BANK0).
		This is used in the fallback mode if banked core buffer allocation
		runs into out-of-memory in one bank.
		The constructor parameter must be actual number of banks on the core,
		because this is the range of numbers occuring in the bank names in
		metadata.
		"""

		def bank_id_to_name(self, bank_id: int) -> str:
			"""
			Return bank name for bank with specified ID.
			"""
			if bank_id != 0:
				raise ValueError(f"invalid bank ID {bank_id} (single bank mode, max_num_banks={self._max_num_banks})")
			return f"BANK{bank_id}"

		def bank_name_to_id_and_ordering(self, bank_name: str) -> tuple[int, "CoreBufferAllocator.OrderingReq"]:
			"""
			Return bank ID and ordering requirement for bank name.
			Raise ValueError if conversion of name to ID fails.
			"""
			bank_id, ordering_req = CoreBufferAllocator.BankNameMapper.bank_name_to_id_and_ordering(self, bank_name)
			# remap to single bank
			#  - only beginning of bank 0 goes to beginning
			if bank_id != 0 and ordering_req == CoreBufferAllocator.OrderingReq.BEGIN:
				ordering_req = CoreBufferAllocator.OrderingReq.ANYWHERE
			#  - only end of bank (self._max_num_of_banks - 1) goes to end
			if bank_id != self._max_num_banks - 1 and ordering_req == CoreBufferAllocator.OrderingReq.END:
				ordering_req = CoreBufferAllocator.OrderingReq.ANYWHERE
			#  - everyhting goes to bank ID 0
			bank_id = 0
			return bank_id, ordering_req

	class PlacementInfo:
		"""
		Pre-computations based on placement.
		"""
		def __init__(self, placement: Dict[str, Dict[str, int]],
			         bank_name_mapper: "CoreBufferAllocator.BankNameMapper") -> None:
			"""
			placement -- placement constraints from YAML file, with formulas for
			             size evaluated to integers already,
			             placement[buffer_name][bank_name] = size
			"""
			self._placement = placement
			self._bank_name_mapper	= bank_name_mapper

			# IDs of banks used
			self._used_bank_ids: Set[int] = set()

			# placement per bank (with bank by ID)
			# self._placement_by_bank[bank_id][buffer_name][bank_name] = size
			self._placement_by_bank: Dict[int, Dict[str, Dict[str, int]]] = {}

			# IDs of banks on which each buffer exists
			# self._banks_by_buffer[buffer_name] = [bank_id, ...]
			self._banks_by_buffer: Dict[str, Set[int]] = {}

			# names of banks/subbanks on which each buffer exists
			# self._bank_names_by_buffer[buffer_name] = [bank_name, ...]
			self._bank_names_by_buffer: Dict[str, Set[set]] = {}

			# pre-compute information
			for buffer_name, sizes_by_bank in self._placement.items():
				for bank_name, size in sizes_by_bank.items():
					bank_id = self._bank_name_mapper.bank_name_to_id(bank_name)
					self._used_bank_ids.add(bank_id)
					self._placement_by_bank.setdefault(bank_id, {}).setdefault(buffer_name, {})[bank_name] = size
					self._banks_by_buffer.setdefault(buffer_name, set()).add(bank_id)
					self._bank_names_by_buffer.setdefault(buffer_name, set()).add(bank_name)

		def get_bank_ids(self, buffer_name: str) -> Set[int]:
			"""
			Return set of IDs of banks on which buffer lives.
			"""
			return self._banks_by_buffer[buffer_name]

		def get_bank_ids_for_list(self, buffer_names: List[str]) -> Set[int]:
			"""
			Return set of IDs of banks on which list of buffers live.
			"""
			return set.union(*(self._banks_by_buffer[buffer_name] for buffer_name in buffer_names))

		def get_bank_names(self, buffer_name: str) -> Set[str]:
			"""
			Return set of bank/subbank names on which a buffer lives.
			"""
			return self._bank_names_by_buffer[buffer_name]

		def get_buffer_names(self) -> List[str]:
			"""
			Get names of all buffers in the placement.
			"""
			return list(self._placement.keys())

		def get_bank_sizes(self, bank_id: int, buffer_name: str) -> Dict[str, int]:
			"""
			Return dictionary of bank name (full one, including subbank) to
			size for a certain buffer in a certain bank (by ID).
			"""
			return self._placement_by_bank[bank_id][buffer_name]

	class BufferOrder:
		"""
		Ordering of core buffers on a bank.
		"""

		def __init__(self, pl_info: "CoreBufferAllocator.PlacementInfo",
		             exp_tensors: const.ExpectedTensors, bank_id: int,
		             bank_name_mapper: "CoreBufferAllocator.BankNameMapper") -> None:
			"""
			Initialize a buffer order instance.
			pl_info -- pre-computed placment information
			exp_tensors -- expected tensors (and buffer groups) for the operator,
			               from const.py
			bank_id -- numerical identifier of the bank
			"""
			self._pl_info = pl_info
			self._exp_tensors = exp_tensors
			self._bank_id = bank_id
			self._bank_name_mapper = bank_name_mapper

			self._bank_name = self._bank_name_mapper.bank_id_to_name(self._bank_id)

			self._grps_proc: Set[str] = set()  # names of buffer groups already processed
			self._buffers_begin: List[str] = []  # names of buffers that need to go to beginning of bank
			self._buffers: List[str] = []  # names of buffers that can go anywhere in bank
			self._buffers_end: List[str] = []  # names of buffers that need to go to end of bank

		def _process_buffer(self, buff_name: str) -> None:
			"""
			Determine ordering of buffer within bank.
			"""
			utils.sanity_check(buff_name in self._exp_tensors.core_buffer_names,
					           f'Invalid buffer name: {buff_name}')

			# check if buffer lives on current bank, nothing to do if not
			banks = self._pl_info.get_bank_ids(buff_name)
			if self._bank_id not in banks:
				return

			# get buffer group that contains buffer, if any
			buff_grp = find_buffer_group(self._exp_tensors, buff_name)
			if buff_grp is not None:
				# buffer is part of group -> switch to buffer group processing
				self._process_buffer_in_group(buff_name, buff_grp)
				return

			# buffer is not part of a group

			# add buffer to begin, anywhere, or to end of bank
			# (a buffer might go to a combination of begin/anywhere/end due to
			# being allocated multiple times ina a bank via subbanks)
			ordering_reqs: set[CoreBufferAllocator.OrderingReq] = set()
			for bank_name in self._pl_info.get_bank_names(buff_name):
				bank_id, ordering_req = self._bank_name_mapper.bank_name_to_id_and_ordering(bank_name)
				if bank_id == self._bank_id:
					ordering_reqs.add(ordering_req)
					continue
			if CoreBufferAllocator.OrderingReq.ANYWHERE in ordering_reqs:
				self._buffers.append(buff_name)
			if CoreBufferAllocator.OrderingReq.BEGIN in ordering_reqs:
				self._buffers_begin.append(buff_name)
			if CoreBufferAllocator.OrderingReq.END in ordering_reqs:
				self._buffers_end.append(buff_name)

		def _process_buffer_in_group(self, buff_name: str, buff_grp: const.BufferGroup) -> None:
			"""
			Determine ordering of buffer that is part of a buffer group within bank.
			"""

			# a buffer being allocated multiple times (in multiple banks or in
			# multiple subbanks) while appearing in a buffer group currently
			# doesn't have a defined meaning, so check for this
			if len(self._pl_info.get_bank_names(buff_name)) != 1:
				utils.sanity_check(False,
				                   f"Buffer group {buff_grp} contains buffer {buff_name} "
				                   "and this buffer is alloacted multiple times. "
								   "This is not defined yet.")

			# check if buffer group is already processed; if so, done
			if buff_grp.group_name in self._grps_proc:
				return
			self._grps_proc.add(buff_grp.group_name)

			# get banks on which group lives
			grp_banks = self._pl_info.get_bank_ids_for_list(buff_grp.buffer_names)

			# check placement constraints within bank due to group using a
			# neighboring bank as well
			uses_prev_bank = self._bank_id - 1 in grp_banks
			uses_next_bank = self._bank_id + 1 in grp_banks
			utils.sanity_check(not uses_prev_bank or not uses_next_bank,
			                   f"Buffer group {buff_grp} (buffer {buff_name}) "
			                   f"cannot spread to prev and next bank of {self._bank_name}")

			# get names of all buffers of the group that exist in the current bank
			buffs_in_bank = [b_name for b_name in buff_grp.buffer_names
			                 if self._bank_id in self._pl_info.get_bank_ids(b_name)]

			# add buffers of this group depending on placement constraints:
			# at beginning / anywhere / at end
			if uses_prev_bank:
				utils.sanity_check(not self._buffers_begin,
				                   f"Buffer group {buff_grp} (buffer {buff_name}) "
				                   f"cannot go to beginning of bank {self._bank_name}, "
				                   f"buffers {self._buffers_begin} are already there")
				self._buffers_begin += buffs_in_bank
			elif uses_next_bank:
				utils.sanity_check(not self._buffers_end,
				                   f"Buffer group {buff_grp} (buffer {buff_name}) "
				                   f"cannot go to end of bank {self._bank_name}, "
				                   f"buffers {self._buffers_end} are already there")
				self._buffers_end += buffs_in_bank
			else:
				self._buffers += buffs_in_bank

		def compute(self) -> None:
			"""
			Compute buffer ordering.
			"""
			for buff_name in self._pl_info.get_buffer_names():
				if buff_name in self._exp_tensors.core_buffer_names:
					self._process_buffer(buff_name)

		def get_buffers(self) -> List[str]:
			"""
			Return list of names buffers that need to be allocated anywhwere in
			bank, but still in the order that is returned here.
			"""
			return self._buffers

		def get_buffers_begin(self) -> List[str]:
			"""
			Return list of names of buffers that need to be allocated at the
			beginning of the bank.
			"""
			return self._buffers_begin

		def get_buffers_end(self) -> List[str]:
			"""
			Return list of names of buffers that need to be allocated at the
			end of the bank.
			"""
			return self._buffers_end


	def __init__(self, placement: Dict[str, Dict[str, int]],
	             exp_tensors: const.ExpectedTensors,
	             treat_as_single_bank: bool = False,
	             max_num_banks: int = const.MAX_NUM_BANKS,
	             bank_size: int = const.CORE_BANK_SIZE) -> None:
		"""
		Initialize CoreBufferAllocator instance.
		placement -- placement constrsints from YAML file, with formulas for
		             size evaluated to integers already,
		             placement[buffer_name][bank_name] = size
		exp_tensors -- expected tensors (and buffer groups) for the operator,
		               from const.py
		treat_as_single_bank -- if True, treat entire core memory as single bank,
		                        used in fallback mode if banked alloc out-of-mem
		"""
		self._placement = placement
		self._exp_tensors = exp_tensors
		self._treat_as_single_bank = treat_as_single_bank
		self._max_num_banks = max_num_banks
		self._bank_size = bank_size

		if self._treat_as_single_bank:
			self._bank_name_mapper = self.SingleBankNameMapper(self._max_num_banks)
			self._bank_size *= self._max_num_banks
			self._max_num_banks = 1
		else:
			self._bank_name_mapper = self.BankNameMapper(self._max_num_banks)
		self._pl_info = self.PlacementInfo(self._placement, self._bank_name_mapper)

		self._buffer_addrs = {f'Core{name.capitalize()}Addr': [] for name in self._placement}
		self._buffer_sizes = {f'Core{name.capitalize()}Size': const.DUMMY_CONST for name in self._placement}

	def _allocate_buffer(self, bank_id: int, ordering_req: OrderingReq,
				         buff_name: str, mem_alloc: MemoryAllocator) -> None:
		"""
		Allocate the buffers with the passed name in the bank with the passed
		bank ID that have the passed ordering requirement.
		Note that there can be multiple buffers with the same name in the same
		bank, for example, if subbanks ("BANK<id>.<subbank no>") are used.
		"""
		for bank_name, buff_size in self._pl_info.get_bank_sizes(bank_id, buff_name).items():
			# grouped buffers don't respect placement requirements from metadata, so skip filtering for those here
			if find_buffer_group(self._exp_tensors, buff_name) is None:
				# non-grouped buffers are just allocated regarding the ordering requirment in metadata
				buff_bank_id, buff_ordering_req = self._bank_name_mapper.bank_name_to_id_and_ordering(bank_name)
				utils.sanity_check(buff_bank_id == bank_id,
								f"buffer's ({buff_name}) bank ID {buff_bank_id} from get_bank_sizes() "
								f"does not match bank ID {bank_id} -> internal error in PlacementInfo")
				if buff_ordering_req != ordering_req:
					# non-grouped buffer and ordering requirement does not match -> filter-out here
					continue
			if buff_size > 0:
				if ordering_req == CoreBufferAllocator.OrderingReq.END:
					buff_addr = mem_alloc.alloc_at_end(buff_size, buf_name=buff_name)
				else:
					buff_addr = mem_alloc.alloc(buff_size, buf_name=buff_name)
				key_str = f'Core{buff_name.capitalize()}'
				self._buffer_addrs[f'{key_str}Addr'].append(buff_addr)
				self._buffer_sizes[f'{key_str}Size'] = buff_size
			else:
				key_str = f'Core{buff_name.capitalize()}'
				self._buffer_addrs[f'{key_str}Addr'].append(None)

	def allocate_buffers(self) -> None:
		"""
		Allocate all needed buffers in all banks.
		"""
		for bank_id in range(self._max_num_banks):
			bank_name = self._bank_name_mapper.bank_id_to_name(bank_id)

			# get order of buffers in this bank
			buf_ord = self.BufferOrder(self._pl_info, self._exp_tensors, bank_id, self._bank_name_mapper)
			buf_ord.compute()

			# allocate buffers in this bank
			if 'mem_alloc' in locals() and mem_alloc._next_free_addr > bank_id * self._bank_size:
				base_addr = mem_alloc._next_free_addr
			else:
				base_addr = bank_id * self._bank_size
			mem_alloc = MemoryAllocator(base_addr,
			                            self._bank_size, default_alignment=64,
			                            alloc_name=bank_name)
			# first, allocate buffers that need to go to begin
			# second, allocate buffers that can go anywhere
			for buff_name in buf_ord.get_buffers_begin():
				self._allocate_buffer(bank_id, CoreBufferAllocator.OrderingReq.BEGIN, buff_name, mem_alloc)
			for buff_name in buf_ord.get_buffers():
				self._allocate_buffer(bank_id, CoreBufferAllocator.OrderingReq.ANYWHERE, buff_name, mem_alloc)
			# process buffers at end in reverse,
			# so last one is actually at end of bank
			for buff_name in reversed(buf_ord.get_buffers_end()):
				self._allocate_buffer(bank_id, CoreBufferAllocator.OrderingReq.END, buff_name, mem_alloc)

	def get_addrs(self) -> Dict[str, List[int]]:
		"""
		Export addresses of buffers ast dict {buffer_name: [addr0, addr1, ...]}.
		"""
		return self._buffer_addrs

	def get_sizes(self) -> Dict[str, int]:
		return self._buffer_sizes


class BufferAllocator(schedule.Stage):
	def __init__(self, artifacts_dict, tiler_pass: bool = False):
			self.tiler_pass = tiler_pass
			#Extract all sub dict from the ctr input
			overlay_params   = artifacts_dict['overlay_info_obj']
			layer_params     = artifacts_dict['layer_info_obj']
			kernel_params    = artifacts_dict['kernel_info_obj']
			schedule_params  = artifacts_dict['scheduling_obj']
			core_tile_params = artifacts_dict['core_tile_params_obj']
			mem_tile_params  = artifacts_dict['mem_tile_params_obj']
			shim_tile_params = artifacts_dict['shim_tile_params_obj']
			padding_params   = artifacts_dict['layer_padding_obj']
			host_padding_params   = artifacts_dict['host_padding_obj']
			dram_params      = artifacts_dict['dram_params_obj']
			program_args     = artifacts_dict['program_arg_obj']

			def create_info_dict(tensor_type, tensor_name, _params):
				return {
						'shape': _params.get_value(f'{tensor_type}_{tensor_name}_shape'),
						'datatype': _params.get_value(f'{tensor_type}_{tensor_name}_datatype'),
						'bytes': _params.get_value(f'{tensor_type}_{tensor_name}_bytes'),
				}

			#NOTE- Bias treated differently
			bias_info        = {
													'shape': layer_params.get_value('in_wgt1_shape'),
													'datatype': layer_params.get_value('in_wgt1_datatype'),
													'bytes': layer_params.get_value('in_wgt1_bytes'),
												}


			#Extract info from tiling param
			overlay_info       = overlay_params.get_dict()
			layer_info         = layer_params.get_dict()
			kernel_info        = kernel_params.get_dict()
			scheduling_info    = schedule_params.get_dict()
			core_info          = core_tile_params.get_dict()
			mem_info           = mem_tile_params.get_dict()
			shim_info          = shim_tile_params.get_dict()
			padding_info       = padding_params.get_dict()
			host_padding_info  = host_padding_params.get_dict()
			dram_params        = dram_params.get_dict()

			#Extract info from tiling param
			buf_placement_info = kernel_params.get_value('placement_constraints')

			# Get expected tensors for operator.
			# First try maximum specific lookup of full op_type, then op_type
			# without data types and then most generic original op type.
			# This method allows to use the orig_op_type in cost.py if all
			# op_types use the same expected tensors and be more specific if
			# there are differences among the expected tensors for different
			# op_types.
			orig_op_type      = layer_params.get_value('orig_op_type')
			op_type           = layer_params.get_value('op_type')
			op_type_wo_dtypes = "_".join(op_type.split("_")[:-1])
			op_t_lst = [op_type, op_type_wo_dtypes, orig_op_type]
			for op_t in op_t_lst:
				exp_tensors = const.expected_tensors.get(op_t)
				if exp_tensors is not None:
					break
			else:
				utils.sanity_check(False, f"Expected tensors not found for {op_t_lst}")

			#Store Node/Kernel Info
			self.program_args  = program_args
			self.orig_op_type  = layer_info['orig_op_type']
			self.op_type       = layer_info['op_type']
			self.op_name       = layer_info['op_type'].split('_')[0]
			self.op_mode       = layer_info.get('qdq_symmetry', 0) #default ASYM
			self.actxact       = 'actxact' in self.op_type
			self.overlay       = overlay_info['overlay']
			self.mode          = overlay_info['mode']
			self.aie_rows      = overlay_info['shape']['row']
			self.aie_cols      = overlay_info['shape']['col']
			self.transposeB      = False
			self.transposeA      = False
			self.permA = self.rev_permA = [0, 1, 2]
			self.permB = self.rev_permB = [0, 1, 2]
			self.permY = [0, 1, 2]
			if overlay_info.get('enabled') is None:
				#NOTE - Temp fix till tiler is fully ready
				self.active_col = self.get_enable_cols(self.aie_rows, self.aie_cols)
			else:
				self.active_col = self.get_enable_cols(overlay_info['enabled']['row'], overlay_info['enabled']['col'])
			self.aie_arrays    = const.AIE_ARRAYS
			self.bits_per_byte = const.BITS_PER_BYTE
			self.ioinfo        = exp_tensors
			self.mem_tile      = TilingInfo(mem_info, self.ioinfo)
			self.core_tile     = TilingInfo(core_info, self.ioinfo)
			self.shim_tile     = TilingInfo(shim_info, self.ioinfo)
			self.sch_attr      = SchedulingInfo(scheduling_info, self.ioinfo)
			self.placement     = kernel_info['placement_constraints']
			self.dram_sizes    = dram_params.get('sizes')
			#self.tdm_dtype     = kernel_params.get_value('tdm_datatype')
			self.core_buffer_placement = KernelInfo(buf_placement_info)
			input_info_dict = {name: create_info_dict('in', name, layer_params) for name in self.ioinfo.input_tensors}
			for tensor_name in self.ioinfo.input_tensors:     #for ifm, wgt,...,q,<input_tensor>
				setattr(self, f'{tensor_name}', TensorDim(input_info_dict.get(f'{tensor_name}')))
			output_info_dict = {name: create_info_dict('out', name, layer_params) for name in self.ioinfo.output_tensors}
			for tensor_name in self.ioinfo.output_tensors:    #for ofm,....,<output_tensor>
				setattr(self, f'{tensor_name}', TensorDim(output_info_dict.get(f'{tensor_name}')))
			self.bias          = TensorDim(bias_info)

			qdq_input = self.ioinfo.input_tensors[0] #TODO- Check on this. Assuming always input0
			if eval(f'self.{qdq_input}.dtype') not in ('bfp16','bfloat'):
				self.qdq_bytes     = list(self.placement.get('qdq').values())[0]
			else:
				self.qdq_bytes     = 0

			self.padding = []
			for idx in range(len(padding_info)):
				for tensor in self.ioinfo.input_and_output_tensors:
					if tensor in padding_info[idx].keys():
						self.padding.append(self.unpack_padding_info(tensor, list(padding_info[idx][tensor]['dims'])))
			self.host_padding = []
			for idx in range(len(host_padding_info)):
				for tensor in self.ioinfo.input_and_output_tensors:
					if tensor in host_padding_info[idx].keys():
						self.host_padding.append(self.unpack_padding_info(tensor, list(host_padding_info[idx][tensor]['dims'])))
			#for tensor in self.ioinfo.input_and_output_tensors:
				#self.padding.append(self.unpack_padding_info(tensor, padding_info.get(tensor)))
			self.param_checker(overlay_info, layer_info, kernel_info, scheduling_info, core_info, mem_info, shim_info, padding_info)

			has_wgt = 'wgt' in self.ioinfo.input_tensors

			if self.orig_op_type in ["QKt_SM"]:
				pass
			elif self.orig_op_type in ["Add", "Mul"]:
				self.actxact       = layer_info.get('actxact')
				self.ifmA_param_type = scheduling_info.get('ifmA_param_type')
				self.ifmB_param_type = scheduling_info.get('ifmB_param_type')
				self.input_shape   = ast.literal_eval(layer_info.get('inputs')) if layer_info.get('inputs') != '' else None
				self.output_shape  = ast.literal_eval(layer_info.get('outputs')) if layer_info.get('outputs') != '' else None
				self.in_ifmA_shape  = layer_info.get('in_ifmA_shape')
				self.in_ifmB_shape  = layer_info.get('in_ifmB_shape')
				self.out_ofm_shape  = layer_info.get('out_ofm_shape')
				self.inner_dim_is_1 = scheduling_info.get('inner_most_dim_is_1', None)

				self.ifmA_scale_factor = scheduling_info.get('ifmA_scale_factor', None)
				self.ifmB_scale_factor = scheduling_info.get('ifmB_scale_factor', None)
				self.ofm_scale_factor  = scheduling_info.get('ofm_scale_factor', None)

				attr               = layer_info.get("attributes", None)
				self.num_batches    = None
				if attr != None:
					self.num_batches   = layer_info.get("attributes").get("num_batches")
				self.multi_ch_batch_bcast = scheduling_info.get('multi_ch_batch_bcast')

				self.ifmA_mode     = scheduling_info['ifmA']
				self.ifmB_mode     = scheduling_info['ifmB']
				self.ofm_mode     = scheduling_info['ofm']

				if self.num_batches != None:
					SIZE_OF_UINT16_T = 2
					single_qdq_struct_size = (SIZE_OF_UINT16_T * 3) # 3 itesm in qdq struct (enable, zp, sc) - all in uint16_t
					qdq_struct_size = (single_qdq_struct_size * self.num_batches[0]) + SIZE_OF_UINT16_T # (numBatch * qdq struct) + nItems - all in uint16_t
					if qdq_struct_size > 64:
						self.qdq_bytes = int(np.ceil(qdq_struct_size / 64.0) * 64)
				
				self.M_subV        = self.core_tile.subV['ifmB'][0]
				#Store tensor bits info
				self.ofm_bits      = round(self.ofm.bytes * self.bits_per_byte)
				self.ifmA_bits     = round(self.ifmA.bytes * self.bits_per_byte)
				self.ifmB_bits     = round(self.ifmB.bytes * self.bits_per_byte)
				
				def calc_final_itr():
					inner = self.mem_tile.itr['ifmB'][0]
					outer = self.mem_tile.itr['ifmA'][0]
					return outer, inner

				self.outer_loop, self.inner_loop  = calc_final_itr()

				self.core_outer_loop, self.core_inner_loop = self.core_tile.itr['ifmA'], self.core_tile.itr['ifmB']
				
				#K Padding Checker
				#TODO: Double check if this is correct with Kyle. Check multiple of 8 on "before"/"after" both OR just total pad?
				utils.sanity_check((self.padding[1]['pad_ifmB_y'][0] % 8 == 0),f"K padding value must be integer multiple of 8")
				utils.sanity_check((self.padding[1]['pad_ifmB_y'][1] % 8 == 0),f"K padding value must be integer multiple of 8")

			elif self.orig_op_type in ["LpNormalization", "LayerNormalization", "Softmax", "GroupNormalization"]:
				pass
			elif self.orig_op_type in ["RoPE"]:
				self.actxact       = 'actxact' in self.op_type
				self.M             = self.ifm.dim[0]
				self.K             = self.ifm.dim[1]
				self.N             = 1
				self.M_subV        = self.core_tile.subV['ifm'][0]
				self.K_subV        = self.core_tile.subV['ifm'][1]
				self.N_subV        = 1
				self.bias_bits     = 0
				self.wgt_bits      = round(self.wgt.bytes * self.bits_per_byte) if has_wgt else 0
				self.ofm_bits      = round(self.ofm.bytes * self.bits_per_byte)
				self.ifm_bits  	   = round(self.ifm.bytes * self.bits_per_byte)
				self.tdm_bits      = 32

				ifm_core_subv_bits = (self.core_tile.subV['ifm'][0] * self.core_tile.subV['ifm'][1] * self.ifm_bits)
				sum_core_subv_bits = (self.core_tile.subV['ifm'][0] * self.tdm_bits)
				wgt_core_subv_bits = 0
				ofm_core_subv_bits = (self.core_tile.subV['ofm'][0] * self.core_tile.subV['ofm'][1] * self.ofm_bits)
				tdm_core_subv_bits = 0

				self.ifm_core_subv_bytes = (ifm_core_subv_bits // self.bits_per_byte)
				self.wgt_core_subv_bytes = (wgt_core_subv_bits // self.bits_per_byte)
				self.ofm_core_subv_bytes = (ofm_core_subv_bits // self.bits_per_byte)
				self.sum_core_subv_bytes = (sum_core_subv_bits // self.bits_per_byte)
				self.tdm_core_subv_bytes = (tdm_core_subv_bits // self.bits_per_byte)

				#TODO: Check with Sourabh and Kyle on this
				def calc_final_itr():
					tm = max(self.mem_tile.itr['ifm'][0], self.core_tile.itr['ifm'][0])
					tk = max(self.mem_tile.itr['ifm'][1], self.core_tile.itr['ifm'][1])
					tn = 0
					return tm, tk, tn

				self.outer_loop, self.acc_loop, self.inner_loop  = calc_final_itr()
				self.wgt_subv_rows = (self.K +sum(self.padding[0]['pad_ifm_y'])) // self.K_subV
				self.wgt_subv_cols = (self.N) // self.N_subV
				#K Padding Checker
				#TODO: Double check if this is correct with Kyle. Check multiple of 8 on "before"/"after" both OR just total pad?
				utils.sanity_check((self.padding[0]['pad_ifm_y'][0] % 8 == 0),f"K padding value must be integer multiple of 8")
				utils.sanity_check((self.padding[0]['pad_ifm_y'][1] % 8 == 0),f"K padding value must be integer multiple of 8")


			elif self.orig_op_type in ["MatMul", "Conv", "PWLA"]:
				self.actxact       = 'actxact' in self.op_type
				#supported_matmul_fusion = ["pwla", "RoPE"]
				if  "MatMul" in self.op_type:
					self.is_pwla_fused  = True if "pwla" in layer_info.get('op_type') else False
					self.is_rope_fused  = True if "RoPE" in layer_info.get('op_type') else False
					self.is_elew_fused  = True if "Add" in layer_info.get('op_type') else False
					attr = layer_info.get('attributes', {})
					self.is_fused_rope_actxact = True if self.is_rope_fused and not attr.get('sin_cos_const', [0])[0] else False
					self.B_split       = int((self.mode).split('B')[1].split('M')[0]) if 'B' in self.mode else 1
					self.B             = int(scheduling_info['Tbatch'])
					self.B_itr         = math.ceil (self.B / self.B_split)
                                        #NOTE: Getting host padded B dims only head per sub-array case else setting it to B.
					self.Bpad_A         = int(sum(self.host_padding[0]['pad_ifm_x']))
					self.Bpad_B         = int(sum(self.host_padding[1]['pad_wgt_x' if not self.actxact else 'pad_ifm_x']))
					self.Bpad_Y         = int(sum(self.host_padding[2]['pad_ofm_x']))
					self.transpose_4d   = [(sum(layer_info['attributes'].get('InTransposeA', [0]))==1 and 
											len(layer_info['in_act_shape_orig'])==4 and 
											(not 1 in layer_info['in_act_shape_orig'])),
										   (sum(layer_info['attributes'].get('InTransposeB', [0]))==1 and 
											len(layer_info['in_wgt_shape_orig'])==4 and 
											(not 1 in layer_info['in_wgt_shape_orig'])),
										   (sum(layer_info['attributes'].get('OutTranspose', [0]))==1 and 
											len(layer_info['out_act_shape_orig'])==4 and 
											(not 1 in layer_info['out_act_shape_orig']))]
					self.info_4d   = {'in_act_shape_orig': layer_info['in_act_shape_orig'],
									  'in_wgt_shape_orig': layer_info['in_wgt_shape_orig'],
									  'out_act_shape_orig': layer_info['out_act_shape_orig'],
									  'permA': layer_info['attributes'].get('permA', [0,1,2,3]),
									  'permB': layer_info['attributes'].get('permB', [0,1,2,3]),
									  'permY': layer_info['attributes'].get('permY', [0,1,2,3]),
									  }
				self.M             = self.ifm.dim[-2]
				self.K             = self.ifm.dim[-1]
				self.K_ifmB        = self.wgt.dim[-2]  if self.orig_op_type in ["MatMul"] else 0 #NOTE - Only make a diff in actxact matmul
				self.N             = self.wgt.dim[-1] if has_wgt else 1
				self.M_subV        = self.core_tile.subV['ifm'][-2]
				self.K_subV        = self.core_tile.subV['ifm'][-1]
				self.N_subV        = 1 if self.core_tile.subV.get('wgt') is None else self.core_tile.subV['wgt'][-1]
				#Store tensor bits info
				self.bias_bits     = round(self.bias.bytes * self.bits_per_byte) if self.bias.bytes else 0
				self.tdm_bits      = 32                                                                              #TODO: Check on this
				self.ifm_bits      = round(self.ifm.bytes * self.bits_per_byte)
				self.wgt_bits      = round(self.wgt.bytes * self.bits_per_byte) if has_wgt else 0
				self.ofm_bits      = round(self.ofm.bytes * self.bits_per_byte)

				if self.orig_op_type == "MatMul":
					self.rev_permA     = layer_info['rev_permA']
					self.rev_permB     = layer_info['rev_permB']
					self.permY         = layer_info['permY']

				self.transposeA = self.rev_permA[2] == 1		# Set when M and K dims are swapped
				self.transposeB = self.rev_permB[2] == 1		# Set when K and N dims are swapped
				self.transposeC = self.permY[2]     == 1		# Set when M and N dims are swapped

				ifm_core_subv_bits = (np.prod(self.core_tile.subV['ifm']) * self.ifm_bits)
				wgt_core_subv_bits = (np.prod(self.core_tile.subV['wgt']) * self.wgt_bits + self.core_tile.subV['wgt'][1] * self.bias_bits) if len(self.core_tile.subV) > 2 else 0
				ofm_core_subv_bits = (np.prod(self.core_tile.subV['ofm']) * self.ofm_bits)
				sum_core_subv_bits = (self.core_tile.subV['ifm'][-2] * self.tdm_bits)
				tdm_core_subv_bits = (self.core_tile.subV['wgt'][-1] * sum_core_subv_bits) if len(self.core_tile.subV) > 2 else 0

				self.ifm_core_subv_bytes = (ifm_core_subv_bits // self.bits_per_byte)
				self.wgt_core_subv_bytes = (wgt_core_subv_bits // self.bits_per_byte)
				self.ofm_core_subv_bytes = (ofm_core_subv_bits // self.bits_per_byte)
				self.sum_core_subv_bytes = (sum_core_subv_bits // self.bits_per_byte)
				self.tdm_core_subv_bytes = (tdm_core_subv_bits // self.bits_per_byte)

				#TODO: Check with Sourabh and Kyle on this
				def calc_final_itr():
					tm = max(self.mem_tile.itr['ifm'][-2] * self.core_tile.itr['ifm'][-2], self.mem_tile.itr['ofm'][-2] * self.core_tile.itr['ofm'][-2])
					tk = max(self.mem_tile.itr['ifm'][-1] * self.core_tile.itr['ifm'][-1], self.mem_tile.itr['wgt'][-2] * self.core_tile.itr['wgt'][-2])
					if (len(self.core_tile.itr) > 2 and len(self.mem_tile.itr) > 2):
						tn = max(self.mem_tile.itr['wgt'][-1] * self.core_tile.itr['wgt'][-1], self.mem_tile.itr['ofm'][-1] * self.core_tile.itr['ofm'][-1])
					else:
						tn = 0
					return tm, tk, tn

				self.outer_loop, self.acc_loop, self.inner_loop  = calc_final_itr() if self.orig_op_type in ["MatMul", "MatAdd"] else (0,0,0)
				self.wgt_subv_rows = (self.K + sum(self.padding[0]['pad_ifm_z' if self.orig_op_type in ["MatMul"] else 'pad_ifm_y'])) // self.K_subV
				self.wgt_subv_cols = (self.N)                           // self.N_subV

				#check input parameters
				#utils.sanity_check((self.M + sum(self.padding[0]['pad_ifm_x'])) % (self.M_subV * self.aie_cols) == 0,f"Unsupported tiling param. M: {self.M}, K_subV: {self.M_subV}", "Message")
				#if self.op_name in ["MatMul", "Conv"]:

					#utils.sanity_check((self.K + sum(self.padding[0]['pad_ifm_y'])) % self.K_subV == 0,f"Unsupported tiling param. K: {self.K}, K_subV: {self.K_subV}")
					#utils.sanity_check((self.N) % (self.N_subV * self.aie_rows * self.aie_arrays) == 0,f"Unsupported tiling param. N: {self.N}, N_subV: {self.N_subV}")

				#utils.sanity_check(sum(self.padding[0]['pad_ifm_x']) < self.M_subV * self.aie_cols, f"Pad value must be less than subvol size")
				#utils.sanity_check(self.padding[0][1][0] < self.K_subV * self.aie_cols, 'Error', f"Pad value must be less than subvol size")
				#utils.sanity_check(self.padding[0][1][1] < self.K_subV * self.aie_cols, 'Error', f"Pad value must be less than subvol size")

				#disable M padding
				#utils.sanity_check((sum(self.padding[0]['pad_ifm_x'])==0) or self.mode=="M1K1N32" or self.mode=="M1K1N16",f"M padding is not supported in MatMul op")

				#K Padding Checker
				#TODO: Double check if this is correct with Kyle. Check multiple of 8 on "before"/"after" both OR just total pad?
				#utils.sanity_check((self.padding[0]['pad_ifm_y'][0] % 8 == 0),f"K padding value must be integer multiple of 8")
				#utils.sanity_check((self.padding[0]['pad_ifm_y'][1] % 8 == 0),f"K padding value must be integer multiple of 8")
			elif self.orig_op_type == "MHA":
				pass
			else:
				utils.sanity_check(False,f"Declare var for the new op: {self.orig_op_type}")

			self._debug_info = BufferAllocatorResult.BufferAllocDebugInfo()

	def unpack_padding_info(self, tensor_name, info):
			padding_dict = {}
			id_to_name_map = { 0: "x", 1: "y", 2: "z", 3: "w"}
			if info is None:
				return
			elif len(info) == 1:
				info = (info[0],(0))
			else:
				info = info #do nothing
			utils.sanity_check(len(info) <= len(id_to_name_map), f"Unsupported dim. Check again. Dim: {len(info)}, Max Dim: {len(id_to_name_map)}")
			for i, padding in enumerate(info):
					field_name = f'pad_{tensor_name}_{id_to_name_map[i]}'
					padding_dict[field_name] = (0, padding)

			return padding_dict

	def get_all_attribute(self):
			return self.__dict__


	def get_enable_cols(self, active_row, active_col):
			utils.sanity_check(self.aie_cols > 0 and self.aie_rows > 0, f"Expected to have non-zero values. max_row: {self.aie_rows}, max_col: {self.aie_cols}")
			utils.sanity_check(active_row == self.aie_rows, f"Expected to have all rows being enabled. active_row: {active_row}, max_rows: {self.aie_rows}")
			utils.sanity_check(active_col > 0 and active_col <= self.aie_cols , f"Active cols out of expected range. col: {active_col}, max_col: {self.aie_cols}")
			col = []
			for col_idx in range(active_col):
				col.append(col_idx)
			return col

	def execute(self, _pipeline_data):
			logging.info("Executing Buffer Allocator stage. Info: %s", self.op_name)
			if self.orig_op_type in const.SUPPORTED_OPs:
					info = self.allocate_buffer() + (self.get_all_attribute(), self._debug_info)
					_pipeline_data.info['BuffAllocator'] = info
			else:
					utils.sanity_check(False,"Unsupported op type. Check again!!. Op: "+self.op_name)

	def allocate_buffer(self):
			#if self.orig_op_type in ["MatMul"]:
			#	import gemm_tiling_check as gemm_checker
			#	size_checker = gemm_checker.GemmTilingCheck(self.get_all_attribute())
			#	size_checker.sanity_check()

			for tensor_name in self.ioinfo.input_and_output_tensors:
				setattr(self, f'{tensor_name}_mem_tile_size', self.mem_tile.sizes[tensor_name])
				setattr(self, f'{tensor_name}_shim_tile_size', self.shim_tile.sizes[tensor_name])
			addr = self.calc_addr()

			return addr

	def calc_addr(self):
		Core_Buff_Addr, CoreSizeDict = self.calc_core_tile_addr()
		self.LUTSize = 0
		if any([True for x in ('CoreLutabSize','CoreLutcdSize') if x in list(CoreSizeDict.keys())]):
			self.LUTSize += CoreSizeDict.get('CoreLutabSize')
			self.LUTSize += CoreSizeDict.get('CoreLutcdSize')
		Mem_Tile_Buff_Addr = self.calc_mem_tile_addr()

		return Core_Buff_Addr, CoreSizeDict, Mem_Tile_Buff_Addr

	def calc_core_tile_addr(self):

		# first, try banked allocation (i.e. normal core buffer allocation)
		self._debug_info.core_alloc_non_banked = False
		try:
			cba = CoreBufferAllocator(self.placement, self.ioinfo)
			cba.allocate_buffers()
		except MemoryAllocator.OutOfMemory as orig_oom:
			traceback.print_exception(orig_oom)
			self._debug_info.core_alloc_non_banked = True
			print("core buffer allocation respecting banks ran out of memory, trying allocation ignoring bank boundaries")
			# banked alloc ran in to out-of-memory, try single-bank fallback
			cba = CoreBufferAllocator(self.placement, self.ioinfo, treat_as_single_bank=True)
			cba.allocate_buffers()
		# either first or second try worked, we have an allocation
		buffer_addrs = cba.get_addrs()
		buffer_sizes = cba.get_sizes()

		for key in buffer_addrs.keys():
			if len(buffer_addrs[key]) < 2:
				buffer_addrs[key].append(None)
		#TODO: ensure pingpong buffer are on different memory banks

		if self.orig_op_type in ["MatMul", "Conv", "PWLA", "MatAdd", "Add", "Mul", "LayerNormalization", "GroupNormalization", "LpNormalization", "Softmax"]:
			self.ifm_core_subv_bytes = buffer_sizes.get('CoreIfmSize', 0)
			self.wgt_core_subv_bytes = buffer_sizes.get('CoreWgtSize', 0)
			self.ofm_core_subv_bytes = buffer_sizes.get('CoreOfmSize', 0)
			self.sum_core_subv_bytes = buffer_sizes.get('CoreIfm_sumSize', 0)
			self.tdm_core_subv_bytes = buffer_sizes.get('CoreTdmSize', 0)

		return buffer_addrs, buffer_sizes

	def calc_mem_tile_addr(self):
		base_addr = 0
		mem_alloc = MemoryAllocator(base_addr, max_size=const.MEMTILE_SIZE, default_alignment=64, alloc_name="memtile")
		MemTileAddrDict= {}

		# some operator like mha need to have bigger memtile allocation
		# for the prm buffer. In that case extend the size by multiplying
		# it by the number of rows in the overlay
		prm_size = const.PARAM_SIZE
		if self.ioinfo.extended_memtile_prm_size:
			prm_size *= self.aie_rows

		MemTileAddrDict['MemtilePrmAddr'] = mem_alloc.alloc(prm_size, buf_name="prm") #NOTE- Expect Param to never have ping pong
		MemTileAddrDict['MemtileLutAddr'] = mem_alloc.alloc(self.LUTSize, buf_name="lut")  #NOTE- Expect LUT to never have ping pong
		MemTileAddrDict['MemtileQdqAddr'] = mem_alloc.alloc(self.qdq_bytes, buf_name="qdq") #NOTE- Expect QDQ to never have ping pong
		if hasattr(self, 'is_rope_fused') and self.is_rope_fused:
			MemTileAddrDict['MemtileRoPEAddr'] = mem_alloc.alloc(self.mem_tile.sizes['ofm']*2, buf_name="rope") #For sin/cos
			MemTileAddrDict['MemtileRoPESize'] = self.mem_tile.sizes['ofm']*2
		if hasattr(self, 'is_elew_fused') and self.is_elew_fused:
			MemTileAddrDict['MemtileifmBAddr'] = mem_alloc.alloc(self.mem_tile.sizes['ofm'], buf_name="ifmB") #For ifmB
			MemTileAddrDict['MemtileifmBSize'] = self.mem_tile.sizes['ofm']

		MemTileAddrDict['MemtilePrmSize'] = prm_size
		MemTileAddrDict['MemtileLutSize'] = self.LUTSize
		MemTileAddrDict['MemtileQdqSize'] = self.qdq_bytes

		# Some buffers need to be grouped (back to back in memory, see docstr of const.BufferGroup class).
		# Find buffer groups required for in/out tensors.
		groups: Dict[str, const.BufferGroup] = {}  # key is group_name
		for tensor in self.ioinfo.input_and_output_tensors:
			buf_grp = find_buffer_group(self.ioinfo, tensor)
			if buf_grp:
				# Tensor is part of a group. It is required to allocate this group.
				groups[buf_grp.group_name] = buf_grp
			else:
				# Use artificial single entry groups for non-grouped in/out tensors
				# in order to use same code below for all buffers.
				groups[tensor] = const.BufferGroup(group_name=tensor, buffer_names=[tensor])

		# Allocate memory for groups, making sure that grouped buffers are back
		# to back in memory inside the ping and pong buffers for the group.
		# For a group with q and k this means memory layout
		# <q ping><k ping><q pong><kpong>
		for group in groups.values():
			# Get wheter to enable ping pong or not.
			enable_ping_pong_list = [self.sch_attr.ping_pong_enable[tensor] for tensor in group.buffer_names]
			enable_ping_pong = any(enable_ping_pong_list)
			# Sanity-check that all buffers in group agree about ping-pong or ping-only.
			utils.sanity_check(enable_ping_pong == all(enable_ping_pong_list),
			                   f"buffer group {group} does not agree about ping/pong or ping-only")
			# Allocate ping buffers of group and pong buffers if enabled
			def alloc_ping_or_pong(suffix:str, enable: bool):
				for tensor in group.buffer_names:
					alloc_size = self.mem_tile.sizes[tensor]
					buf_name = f"{tensor.capitalize()}{suffix}"
					MemTileAddrDict[f'Memtile{buf_name}Addr'] = mem_alloc.alloc(alloc_size, buf_name=buf_name) if enable else None
					MemTileAddrDict[f'Memtile{buf_name}Size'] = alloc_size if enable else None
			alloc_ping_or_pong("Ping", True)
			alloc_ping_or_pong("Pong", enable_ping_pong)

		allocated_size = mem_alloc.get_total_size()
		MemTileAddrDict['MemtileTotal'] = base_addr + allocated_size
		utils.sanity_check((allocated_size <= 512*1024) or self.tiler_pass, f"Total Memtile Utilization must be less than 512K, Current utilization({allocated_size})", "Message")
		return MemTileAddrDict

	def param_checker(self, overlay_info, layer_info, kernel_info, scheduling_info, core_info, mem_info, shim_info, padding_info):
			#check layer_info
			input_tensor_str = f'in_{self.ioinfo.input_tensors[0]}_shape'
			output_tensor_str = f'out_{self.ioinfo.output_tensors[0]}_shape'
			#utils.sanity_check((layer_info[input_tensor_str][1] == layer_info[output_tensor_str][0]), f"M dimension doesn't match in layer info")
			if overlay_info['overlay'] == '4x4':
				utils.sanity_check(self.aie_rows == self.aie_cols, f"Unsupported overlay. aie_row: {self.aie_rows}, aie_col: {self.aie_cols}")
			elif overlay_info['overlay'] == '8x4':
				utils.sanity_check(2*self.aie_rows == self.aie_cols, f"Unsupported overlay. aie_row: {self.aie_rows}, aie_col: {self.aie_cols}")
			else:
				utils.sanity_check(False, f"Invalid overlay. Check again. overlaye: {overlay_info['overlay']}")

