from dataclasses import dataclass, field
from enum import Enum
from typing import List
import OGOAT.src.Scheduling_Engine.infra.scheduler_utils as utils

DUMMY_CONST = -1
KERNEL_IDX  = 3

SUPPORTED_OPs   = ["MatMul", "Conv", "PWLA", "Add", "Mul", "LayerNormalization", "GroupNormalization", "LpNormalization", "MHA", "RoPE", "Softmax"]

@dataclass
class BufferGroup:
	"""
	Buffers that have to be grouped on the core tile and on the mem tile.

	On the core tile, the buffers of a group that go to the same bank are
	allocated back to back (but padded to be aligned) within the bank.
	If a buffer group includes buffers of two adjacent banks, the buffers in
	the first bank are allocated at the end of the first bank and the other
	buffers at the beginning of the second benak, so those buffers are
	contiguous in memory.

	If there are pin/pong buffers in the mem tile, the individual ping buffers
	within the group will get allocated back to back (but padded to be aligned)
	in the order in which thier names are specified	in buffer_names. The same
	thing happens for the pong buffers.

	The group name must be unique and not overlap with the name of any tensor
	used for the operator.
	
	An example is q and k for MHA on the memory tile. The memory layout for the
	ping/pong buffers for q and k in the mem tile needs to be like this:
	<q ping><k ping><q pong><k pong>
	"""
	group_name: str
	buffer_names: List[str]

@dataclass
class ExpectedTensors:
	"""
	Expected tensors for a operator.
	This includes input tensors, output tensors and core buffers.
	There is also information about buffers that have to be grouped (if any).
	"""
	input_tensors: List[str]
	output_tensors: List[str]
	core_buffer_names: List[str]
	buffer_groups: List[BufferGroup] = field(default_factory=list)

	# FIXME: This is a patch for the mhahead operator support.
	# The dataflow implementation expect the size of the prm
	# buffer on the memtile to be multiplied by the number of aierows
	extended_memtile_prm_size: bool = False

	@property
	def input_and_output_tensors(self) -> List[str]:
		return self.input_tensors + self.output_tensors

	@property
	def all_tensors(self) -> List[str]:
		return self.input_tensors + self.output_tensors + self.core_buffer_names

expected_tensors = {
	# The key of this map can be the full op_type,
	# op_type with the data types suffix removed, or the orig_op_type.
	# The lookup happens in this order and stops on the first match.
	"MatMul": ExpectedTensors(
		input_tensors=['ifm', 'wgt'],
		output_tensors=['ofm'],
		core_buffer_names=['ifm', 'wgt', 'ofm', 'ifm_sum', 'qdq', 'tdm', 'stack', 'Qout', 'DQout', 'Bufc0', 'scratch', 'WgtUnpack'],
	),
	"Conv": ExpectedTensors(
		input_tensors=['ifm', 'wgt'],
		output_tensors=['ofm'],
		core_buffer_names=['ifm', 'wgt', 'ofm', 'ifm_sum', 'qdq', 'tdm', 'stack', 'Qout', 'DQout'],
	),
	"PWLA": ExpectedTensors(
		input_tensors=['ifm'],
		output_tensors=['ofm'],
		core_buffer_names=['ifm', 'ofm', 'LUTab', 'LUTcd', 'stack', 'Qout', 'DQout', 'qdq', 'tdm'],
	),
	"Add": ExpectedTensors(
		input_tensors=['ifmA', 'ifmB'],
		output_tensors=['ofm'],
		core_buffer_names=['ifmA', 'ifmB', 'wgt', 'ofm', 'ifm_sum', 'qdq', 'tdm', 'stack', 'Qout', 'DQout'],
		extended_memtile_prm_size=True
	),
	"RoPE": ExpectedTensors(
		input_tensors=['ifm', 'sin', 'cos'],
		output_tensors=['ofm'],
		core_buffer_names=['ifm', 'sin', 'cos', 'ofm', 'qdq', 'tdm1', 'tdm2', 'stack', 'Qout', 'DQout'],
	),
	"Mul": ExpectedTensors(
		input_tensors=['ifmA', 'ifmB'],
		output_tensors=['ofm'],
		core_buffer_names=['ifmA', 'ifmB', 'wgt', 'ofm', 'ifm_sum', 'qdq', 'tdm', 'stack', 'Qout', 'DQout'],
		extended_memtile_prm_size=True
	),
	"LayerNormalization": ExpectedTensors(
		input_tensors=['ifm'],
		output_tensors=['ofm'],
		core_buffer_names=['ifm', 'ofm', 'qdq', 'stack'],
	),
	"GroupNormalization": ExpectedTensors(
		input_tensors=['ifm'],
		output_tensors=[ 'ofm'],
		core_buffer_names=['ifm', 'ofm', 'qdq', 'stack'],
	),
	"LpNormalization": ExpectedTensors(
		input_tensors=['ifm'],
		output_tensors=['ofm'],
		core_buffer_names=['ifm', 'ofm', 'qdq', 'stack'],
	),
	"Softmax": ExpectedTensors(
		input_tensors=['ifm'],
		output_tensors=['ofm'],
		core_buffer_names=['ifm', 'ofm', 'qdq', 'stack'],
	),
	"MHA_2p1_qdq": ExpectedTensors(
		input_tensors=["q", "k"],
		output_tensors=["ofm"],
		core_buffer_names=["q", "k", "tdm", "ofm", "qdq", "act1_sum", "act2_sum", "c0", "scratch", "stack"],
		# buffer_groups=[BufferGroup(group_name="qk", buffer_names=["q", "k"])],
		# Extend the memtile prm buffer size by * aierows_nb
		extended_memtile_prm_size=True
	),
	"MHA_2p1_bias_qdq": ExpectedTensors(
		input_tensors=["q", "k"],
		output_tensors=["ofm"],
		core_buffer_names=["q", "k", "m", "tdm", "ofm", "qdq", "act1_sum", "act2_sum", "c0", "scratch", "stack"],
		# buffer_groups=[BufferGroup(group_name="qk", buffer_names=["q", "k"])],
		# Extend the memtile prm buffer size by * aierows_nb
		extended_memtile_prm_size=True
	),
	"MHA_3p0_qdq": ExpectedTensors(
		input_tensors=["q", "k", "v"],
		output_tensors=["ofm"],
		core_buffer_names=["q", "k", "tdm", "ofm", "qdq", "act1_sum", "act2_sum", "v", "c0_k", "c0_v", "scratch", "stack"],
		buffer_groups=[BufferGroup(group_name="kv", buffer_names=["k", "v"])],
		# Extend the memtile prm buffer size by * aierows_nb
		extended_memtile_prm_size=True
	),
	"MHA_3p0_1col_qdq": ExpectedTensors(
		input_tensors=["q", "k", "v"],
		output_tensors=["ofm"],
		core_buffer_names=["q", "k", "tdm", "qdq", "act1_sum", "act2_sum", "c0_k", "c0_v", "scratch", "stack"],
		# buffer_groups=[BufferGroup(group_name="kv", buffer_names=["k", "v"])],
		# Extend the memtile prm buffer size by * aierows_nb
		extended_memtile_prm_size=True
	),
	"MHA_3p0_1col_bias_qdq": ExpectedTensors(
		# FIXME: add m input when dataflow uses the buffer allocation result and extend the tiler to add memtile/shim/core subv/iters for m
		input_tensors=["q", "k", "v"],
		output_tensors=["ofm"],
		core_buffer_names=["q", "k", "tdm", "m", "qdq", "act1_sum", "act2_sum", "c0_k", "c0_v", "scratch", "stack"],
		# buffer_groups=[BufferGroup(group_name="kv", buffer_names=["k", "v"])],
		# Extend the memtile prm buffer size by * aierows_nb
		extended_memtile_prm_size=True
	),
}


CORE_BANK_SIZE  = 16384   #PULL FROM HW META DATA
MEMTILE_SIZE    = 512*1024
PARAM_SIZE      = 1024 
BITS_PER_BYTE   = 8
MAX_NUM_BANKS   = 4


class BANK_MAPPING(Enum):
	BANK0 = 1
	BANK1 = 2
	BANK2 = 3
	BANK3 = 4
	
	@classmethod
	def str2enum(enum_class, string_val):
		if string_val in enum_class.__members__:
			return enum_class[string_val]
		else:
			utils.sanity_check(False,"String not found. Str: "+str(string_val))

#TODO: Pick this up from overlay obj
#AIE_ROW    = 4
#AIE_COL    = 4
AIE_ARRAYS = 1

#TODO: Put this info in tiling engine

class Kernel_OPCODE(Enum):
	OPCODE_GEMM_TDM = 'run_a16w8_gemm_tdm'
	OPCODE_GEMM_QDQ = 'run_a16w8_gemm_qdq'

class BufAllocator_Idx(Enum):
    CORE_TILE_ADDR_IDX   = 0
    CORE_TILE_SIZE_IDX   = 1
    MEM_TILE_ADDR_IDX    = 2
    BUFF_ALLOC_PARAM_IDX = 3
    DEBUG_INFO           = 4
