from typing import Optional
from dataclasses import dataclass, field
from dataclass_wizard import JSONWizard
from enum import Enum

import os

from OGOAT.src.Scheduling_Engine.infra.const import BufAllocator_Idx

@dataclass
class MemoryAllocation:
	size: int
	addresses: list[int]

	def __post__init__(self) -> None:
		assert(len(self.addresses) >= 1)


@dataclass
class BufferAllocation:
	name: str
	ping: MemoryAllocation
	pong_optional: Optional[MemoryAllocation] = None

	def has_pong(self) -> bool:
		return self.pong_optional is not None

	@property
	def pong(self) -> MemoryAllocation:
		assert self.pong_optional is not None
		return self.pong_optional


@dataclass
class BufferAllocDebugInfo:
	"""
	Information from running buffer allocator that is useful for debugging
	and understanding what happned, but not part of the actual buffer allocation
	result.

	core_alloc_non_banked -- True if the core buffer allocation respecting the
	                         bank boundaries failed due to out of memory in at
	                         least one bank and the allocations have been
	                         obtained by ignoring the bank boundaries as a
	                         fallback
	"""
	core_alloc_non_banked: bool = False


@dataclass
class BufferAllocations(JSONWizard):
	class AllocType(Enum):
		CORE = 0
		MEM = 1

	core_alloc: dict[str, BufferAllocation] = field(default_factory=dict)
	mem_alloc: dict[str, BufferAllocation] = field(default_factory=dict)
	debug_info: BufferAllocDebugInfo = field(default_factory=BufferAllocDebugInfo)

	@staticmethod
	def get_json_path(dir_path: str) -> str:
		return os.path.join(dir_path, "buffer_allocation.json")

	def dump(self, output_dir: str) -> None:
		assert os.path.exists(output_dir), f"Directory provided does not exist: {output_dir}"
		output_path = BufferAllocations.get_json_path(output_dir)

		with open(output_path, "w") as fd:
			fd.write(self.to_json(indent=4))

	@staticmethod
	def load(input_dir: str):
		assert os.path.exists(input_dir), f"Directory provided does not exist: {input_dir}"

		input_path = BufferAllocations.get_json_path(input_dir)
		assert os.path.exists(input_path), f"Input path does not exist: {input_path}"

		with open(input_path, "r") as fd:
			json_str = fd.read()
		return BufferAllocations.from_json(json_str)

	def get_buffer_allocation_from_type(self, alloc_type: AllocType) -> dict[str, BufferAllocation]:
		match alloc_type:
			case self.AllocType.CORE:
				return self.core_alloc
			case self.AllocType.MEM:
				return self.mem_alloc
		raise ValueError(f"{alloc_type} is not implemented")

	def add_buffer_allocation(self, alloc: BufferAllocation, alloc_type: AllocType) -> None:
		buffer_allocations = self.get_buffer_allocation_from_type(alloc_type)

		# the buffer should not have already been allocated
		assert alloc.name not in buffer_allocations, f"buffer '{alloc.name}' was already inserted"

		buffer_allocations[alloc.name] = alloc

	def add_allocation(self, buffer_name: str, size: int, alloc_type: AllocType, ping_addresses: list[str], pong_addresses: Optional[list[str]] = None) -> None:
		ping_alloc = MemoryAllocation(size, ping_addresses)

		pong_alloc = None
		if pong_addresses is not None:
			pong_alloc = MemoryAllocation(size, pong_addresses)

		buffer_alloc = BufferAllocation(buffer_name, ping_alloc, pong_alloc)
		self.add_buffer_allocation(buffer_alloc, alloc_type)

	def add_core_allocation(self, buffer_name: str, size: int, ping_addresses: list[str], pong_addresses: Optional[list[str]] = None) -> None:
		self.add_allocation(buffer_name, size, BufferAllocations.AllocType.CORE, ping_addresses, pong_addresses)

	def add_mem_allocation(self, buffer_name: str, size: int, ping_addresses: list[str], pong_addresses: Optional[list[str]] = None) -> None:
		self.add_allocation(buffer_name, size, BufferAllocations.AllocType.MEM, ping_addresses, pong_addresses)

	def get_buffer_alloc(self, name: str, alloc_type: AllocType) -> BufferAllocation:
		buffer_allocs = self.get_buffer_allocation_from_type(alloc_type)
		assert name in buffer_allocs, f"buffer '{name}' was never inserted"

		return buffer_allocs[name]

	def get_core_alloc(self, name: str) -> BufferAllocation:
		return self.get_buffer_alloc(name, BufferAllocations.AllocType.CORE)

	def get_mem_alloc(self, name: str) -> BufferAllocation:
		return self.get_buffer_alloc(name, BufferAllocations.AllocType.MEM)


def sanitize_address_list(addresses: list[Optional[int]]) -> list[int]:
	"""
	Sanitize the address list as sometimes the buffer allocator add a None at
	the end. I do not know if any part of the code uses that marker so I'm
	keeping it for now but remove it for the creation of the buffer allocation
	dataclass.
	This function creates a new list and should not change the original.
	"""
	return [addr for addr in addresses if addr is not None]


# FIXME: this conversion function is a patch that I introduce in order
# to start cleaning how the data is convey from the buffer allocation
# to another stage (access_pattern or dataflow i.e).
# The end goal is to directly create a BufferAllocations object in
# the buffer allocator and remove the need for the conversion to happen.
def get_buffer_allocations(_pipeline_data) -> BufferAllocations:
	"""
	Convert the pipeline_data object into a BufferAllocations dataclass
	to facilate the access, serialization, deserialization of the data
	"""
	# Create the buffer allocations
	buffer_alloc = BufferAllocations()

	CoreBuffsAddr = _pipeline_data.info["BuffAllocator"][BufAllocator_Idx.CORE_TILE_ADDR_IDX.value]
	CoreBuffsSize = _pipeline_data.info["BuffAllocator"][BufAllocator_Idx.CORE_TILE_SIZE_IDX.value]

	io_info = _pipeline_data.info["BuffAllocator"][BufAllocator_Idx.BUFF_ALLOC_PARAM_IDX.value]["ioinfo"]

	# get the list of all buffers in the operator
	buffer_names = io_info.core_buffer_names
	for buffer_name in buffer_names:
		if "Core" + buffer_name.capitalize() + "Addr" not in CoreBuffsAddr:
			continue

		core_buffer_addr = sanitize_address_list(CoreBuffsAddr["Core" + buffer_name.capitalize() + "Addr"])
		core_buffer_size = CoreBuffsSize["Core" + buffer_name.capitalize() + "Size"]

		core_alloc = MemoryAllocation(size=core_buffer_size, addresses=core_buffer_addr)

		buffer_alloc.add_buffer_allocation(
			BufferAllocation(buffer_name, core_alloc),
			BufferAllocations.AllocType.CORE
		)


	inout_buffer_names = io_info.input_tensors + io_info.output_tensors + ["prm", "lut", "qdq"]
	MemBuffsInfo = _pipeline_data.info["BuffAllocator"][BufAllocator_Idx.MEM_TILE_ADDR_IDX.value]
	for buffer_name in inout_buffer_names:
		buffer_name = buffer_name.capitalize()
		if "Memtile" + buffer_name + "Addr" in MemBuffsInfo:
			mem_buffer_addr = MemBuffsInfo["Memtile" + buffer_name + "Addr"]
			mem_buffer_size = MemBuffsInfo["Memtile" + buffer_name + "Size"]
			mem_alloc = MemoryAllocation(mem_buffer_size, [mem_buffer_addr])

			buffer_alloc.add_buffer_allocation(
				BufferAllocation(buffer_name.lower(), mem_alloc),
				BufferAllocations.AllocType.MEM
			)
			continue

		mem_buffer_ping_addr = MemBuffsInfo["Memtile" + buffer_name + "PingAddr"]
		mem_buffer_ping_size = MemBuffsInfo["Memtile" + buffer_name + "PingSize"]
		mem_ping_alloc = MemoryAllocation(mem_buffer_ping_size, [mem_buffer_ping_addr])

		mem_buffer_pong_addr = MemBuffsInfo["Memtile" + buffer_name + "PongAddr"]
		mem_buffer_pong_size = MemBuffsInfo["Memtile" + buffer_name + "PongSize"]
		mem_pong_alloc = MemoryAllocation(mem_buffer_pong_size, [mem_buffer_pong_addr])

		buffer_alloc.add_buffer_allocation(
			BufferAllocation(buffer_name.lower(), mem_ping_alloc, mem_pong_alloc),
			BufferAllocations.AllocType.MEM
		)

	buffer_alloc.debug_info = _pipeline_data.info["BuffAllocator"][BufAllocator_Idx.DEBUG_INFO.value]

	return buffer_alloc
