import argparse
from functools import reduce
import sys
import logging
import json
from typing import List, Dict, Any
import os
import time
from sys import exit, maxsize
import onnx
from tiler_subgraph_npu import TilerSubgraphNPU

# from overlay_description import Overlay

# Create a Node class to store the information of each node in the reverse graph

class Node:
  def __init__(self, 
               name:str, 
               idx: int,
               parent_to_self: List[str], 
               self_to_child: str, 
               is_child_of:List['Node'], 
               is_parent_of: List['Node'], 
               has_weights: bool,
               is_multi_inp: bool
               ) -> None:
    self.name = name
    self.idx = idx
    self.parent_to_self = parent_to_self
    self.self_to_child = self_to_child
    self.is_child_of = is_child_of
    self.is_parent_of = is_parent_of
    self.has_weights = has_weights
    self.is_multi_inp = is_multi_inp

class Edge:
  def __init__(self,
               name: str,
               parent: str,
               is_marked: bool,
               num_bits: int,
               residency: str,
               is_a_branch: bool,
               size_known: bool,
               size: int,
               src_node_idx: int
               ) -> None:
    self.name = name
    self.parent = parent
    self.is_marked = is_marked
    self.num_bits = num_bits
    self.residency = residency
    self.is_a_branch = is_a_branch
    self.size_known = size_known
    self.size = size
    self.src_node_idx = src_node_idx
    self.sink_node_idx = -1
    self.edge_len = 0

  def calc_edge_len(self):
    self.edge_len = self.sink_node_idx - self.src_node_idx

class Overlay:
  def __init__(self,
               num_of_cols: int,
               num_memtile_row: int) -> None:
    self.multiplier = 0.8
    self.total_l2_ = int((1 << 22) * num_of_cols * num_memtile_row * self.multiplier)
    self.staging_cost_per_tensor = (1 << 19) * num_of_cols

def convert_size_to_bytes(edge_data: Dict[str, Any]) -> Dict[str, Any]:
  for k in edge_data:
    if edge_data[k]["size_known"]:
      edge_data[k]["size"] /= 8192
      val = "%.4f" %edge_data[k]["size"] + " kB"
      edge_data[k]["size"] = val
  return edge_data

def annotate_edges(edge_data: Dict[str, Any], nodes: List[onnx.NodeProto]) -> Dict[str, Any]:
  for node in nodes:
    if node.op_type != "DequantizeLinear":
      # if node.input != []:
      #   for edge in node.input:
      #     if edge_data[edge]["tagged"]:
      #       pass
      #     else:
      #       edge_data[edge]["tagged"] = True
      #       edge_data[edge]["type"] = "act"
      if node.output != []:
        for edge in node.output:
          if edge in edge_data:
            if edge_data[edge]["tagged"]:
              pass
            else:
              edge_data[edge]["tagged"] = True
              edge_data[edge]["type"] = "act"
    else:
      if node.input != []:
        for edge in node.output:
          if edge in edge_data:
            if edge_data[edge]["tagged"]:
              pass
            else:
              edge_data[edge]["tagged"] = True
              edge_data[edge]["type"] = "act"
      else:
        for edge in node.output:
          if edge in edge_data:
            if edge_data[edge]["tagged"]:
              pass
            else:
              edge_data[edge]["tagged"] = True
              edge_data[edge]["type"] = "init"
  return edge_data

def transfer_info(edge_data: Dict[str, Any], node_data: Dict[str, Any]) -> Dict[str, Any]:
  for name in node_data:
    val1 = node_data[name]["in_act_residency"]
    val2 = node_data[name]["out_act_residency"]
    for k in val1:
      v = val1[k]
      val1[k] = [v]
      if k in edge_data:
        val1[k] += [edge_data[k]["size"], edge_data[k]["type"]]

    for k in val2:
      v = val2[k]
      val2[k] = [v]
      if k in edge_data:
        val2[k] += [edge_data[k]["size"], edge_data[k]["type"]]

    node_data[name]["in_act_residency"] = val1
    node_data[name]["out_act_residency"] = val2
  return node_data
  
def edge_length_marker(file_path: str, onnx_path: str, test_name:str = "", threshold_level:int = 5) -> str:

  logging.info(f"Edge Length Threshold Level at {threshold_level}")
  # Reading the JSON IR file
  with open(file_path, "r") as file:
    node_data = json.load(file)

  bits_for = {
    "int8" : 8, 
    "mx9" : 9, 
    "fp32" : 32, 
    "bfloat16" : 16, 
    "bfp16" : 9, 
    "uint16": 16, 
    "int16": 16, 
    "debug": 1, 
    "uint8": 8, 
    "float32" : 32, 
    "int64" : 64,
    "bool" : 1,
    "int32" : 32
    }
  ol = Overlay(8, 1)
  json_path = onnx_path.split("nhwc_fused")[0] + "_tilings.json"
  tiler = TilerSubgraphNPU(onnx_path, json_path)
  
  # Determining where output file should be written
  logging.debug("Output IR will be present inside subfolder titled \'output\'")
  file_name = os.path.basename(file_path)
  if logging.getLogger().getEffectiveLevel() == logging.DEBUG:
    if not os.path.exists("." + os.path.sep + "output"):
      os.makedirs("output")
    out_path_n = "output" + os.path.sep + file_name[:-5] + "_nodes.json"
    out_path_e = "output" + os.path.sep + file_name[:-5] + "_edges.json"
  else:
    out_path_n = file_name[:-5] + "_nodes.json"
    out_path_e = file_name[:-5] + "_edges.json"

  logging.info("Calculating total number of nodes in graph")
  total_number_of_nodes = len(node_data.keys())
  logging.info("Total number of nodes in graph is %d" % total_number_of_nodes)

  # Determining the global inputs and outputs of this graph
  if os.path.exists(onnx_path + '_data'):
    model = onnx.load_model(onnx_path, load_external_data=onnx_path + '_data')
  else:
    model = onnx.load_model(onnx_path)
  name_to_node_map = dict()
  name_to_edge_map = dict()
  edge_data = dict()
  total_number_of_edges = 0
  global_input_outputs = set()
  for _inp in model.graph.input:
    global_input_outputs.add(_inp.name)
    name_to_edge_map[_inp.name] = Edge(_inp.name, "", True, -1, "L3", False, False, 0, -1)
    name_to_edge_map[_inp.name].sink_node_idx = maxsize - 1
    edge_data[_inp.name] = {
      "edge_name": _inp.name,
      "parent_node_name": "",
      "act_num_bits": -1,
      "residency": "L3",
      "is_a_branch": False,
      "size_known": False,
      "size": 0,
      "global_edge": True,
      "NPU_edge": False,
      "tagged": True,
      "type": "act"
    }
  for _out in model.graph.output:
    global_input_outputs.add(_out.name)
    total_number_of_edges += 1

  all_inp_act_edge_set = set()
  # multi_inp_node_set = {"Add", "MatMul", "Mul"}
  no_inp_edge_set = {"Constant"}
  out_edge_set = set()
  is_multi_inp_node = False
  branch_counter = 0

  for idx, k in enumerate(node_data):
    out_act_signal_name = node_data[k]["out_act_signal_name"]
    enum_inp_edge = 0
    ctr = 0
    err_e = -1
    is_fixed = False
    has_weights = False
    for e in node_data[k]["inputs"]:
      if e["name"] == "":
        if err_e >= 0:
          logging.basicConfig(level=logging.ERROR)
          logging.error("Node %s has multiple inp edges that are empty strings in the source IR" %k)
          logging.error("Program terminating")
          exit(0)
        err_e = ctr
        curr_log_level = logging.root.level
        logging.basicConfig(level=logging.CRITICAL)
        logging.critical("Node %s has an input edge that is an empty string in the source IR" %k)
        logging.critical("Empty string edge will be ignored")
        logging.basicConfig(level=curr_log_level)
        continue
      if e["name"] not in out_edge_set and e["name"] not in global_input_outputs:
        has_weights = True
      elif e["name"] in out_edge_set:
        name_to_edge_map[e["name"]].sink_node_idx = idx
        enum_inp_edge += 1
      else:
        enum_inp_edge += 1
      ctr += 1
    if err_e >= 0:
      del node_data[k]["inputs"][err_e]
    is_multi_inp_node = True if (enum_inp_edge > 1) else False
        

    if (type(out_act_signal_name) == list):
      logging.basicConfig(level=logging.ERROR)
      logging.error("Multiple output edges originating from node")
      logging.error("Violates assumption that there is one unique edge from every node")
      logging.error("Check node %s inside input parser" %k)
      logging.error("Program terminating")
      exit(0)
    if (not is_multi_inp_node) and (node_data[k]["op_type"] not in no_inp_edge_set):
      in_act_signal_name = [node_data[k]["in_act_signal_name"]]
      if (not (type(in_act_signal_name) == list and len(in_act_signal_name) == 1)):
        logging.basicConfig(level=logging.ERROR)
        logging.error("Input JSON parsed wrong")
        logging.error("Violates assumption that all nodes will have a unique in_act_signal_name")
        logging.error("Check node %s inside input parser" %k)
        logging.error("Program terminating")
        exit(0)
    else:
      in_act_signal_name = node_data[k]["inputs"]
      in_act_signal_name = [x["name"] for x in in_act_signal_name]
      if (len(in_act_signal_name) == 1):
        logging.basicConfig(level=logging.ERROR)
        logging.error("Input JSON parsed wrong")
        logging.error("%s node cannot have single input" %node_data[k]["inputs"])
        logging.error("Check node %s inside input parser" %k)
        logging.error("Program terminating")
        exit(0)

    if ("activation_residency" in node_data[k]):
      del node_data[k]["activation_residency"]

    dont_cares = ["onstant", "cale", "ero_point", "onst"]
    name_to_node_map[k] = Node(k, idx, in_act_signal_name, out_act_signal_name, [], [], has_weights, is_multi_inp_node)
    if node_data[k]["op_type"] not in no_inp_edge_set:
      for edges in in_act_signal_name:
        total_number_of_edges += 1
        if edges not in all_inp_act_edge_set:
          all_inp_act_edge_set.add(edges)
        else:
          if all(term not in edges for term in dont_cares):
            name_to_edge_map[edges].is_a_branch = True
            edge_data[edges]["is_a_branch"] = True
          branch_counter += 1
    try:
      size = reduce(lambda a, b: a * b, node_data[k]["out_act_shape"])
    except:
      size = 0
      logging.info("Node %s has empty output activation shape. Assuming output activation has 0 size" %k)
    if not tiler.assignL3(out_act_signal_name):
      is_fixed = True if (size * bits_for[node_data[k]["out_datatype"]]) > ol.total_l2_ else False
      name_to_edge_map[out_act_signal_name] = Edge(out_act_signal_name, k, is_fixed, bits_for[node_data[k]["out_datatype"]], "L3", False, True, size  * bits_for[node_data[k]["out_datatype"]], idx)
      edge_data[out_act_signal_name] = {
        "edge_name": out_act_signal_name,
        "parent_node_name": k,
        "act_num_bits": bits_for[node_data[k]["out_datatype"]],
        "residency": "L3",
        "is_a_branch": False,
        "size_known": True,
        "size": size * bits_for[node_data[k]["out_datatype"]],
        "global_edge": False,
        "NPU_edge": True,
        "tagged": False,
        "type": ""
      }
    else:
      is_fixed = True
      name_to_edge_map[out_act_signal_name] = Edge(out_act_signal_name, k, is_fixed, bits_for[node_data[k]["out_datatype"]], "L3", False, True, size  * bits_for[node_data[k]["out_datatype"]], idx)
      edge_data[out_act_signal_name] = {
        "edge_name": out_act_signal_name,
        "parent_node_name": k,
        "act_num_bits": bits_for[node_data[k]["out_datatype"]],
        "residency": "L3",
        "is_a_branch": False,
        "size_known": True,
        "size": size * bits_for[node_data[k]["out_datatype"]],
        "global_edge": False,
        "NPU_edge": False,
        "tagged": True,
        "type": ""
      }
    if out_act_signal_name in global_input_outputs:
      edge_data[out_act_signal_name]["global_edge"] = True
    out_edge_set.add(out_act_signal_name)  
    node_data[k]["has_weights"] = has_weights
    node_data[k]["is_multi_inp_node"] = is_multi_inp_node
    node_data[k]["idx"] = idx

  for layer in name_to_node_map:
    if node_data[layer]["children_names"] != []:
      for name in node_data[layer]["children_names"]:
        name_to_node_map[layer].is_parent_of += [name_to_node_map[name]]
        name_to_node_map[name].is_child_of += [name_to_node_map[layer]]

  for e in name_to_edge_map:
    name_to_edge_map[e].calc_edge_len()
    e_len = name_to_edge_map[e].edge_len
    name_to_edge_map[e].edge_len = e_len if (e_len > 0) else maxsize
    edge_data[e]["edge_len"] = e_len if (e_len > 0) else maxsize
    name_to_edge_map[e].is_marked = True if (name_to_edge_map[e].edge_len > threshold_level) else False

  logging.info("Calculating total number of edges in graph")
  logging.info("Total number of edges in graph is %d" % total_number_of_edges)
  logging.info("Total number of unique edges in graph is %d" % len(edge_data.keys()))
  logging.info("Total number of branches in graph is %d" % branch_counter)

  if logging.getLogger().getEffectiveLevel() <= logging.WARNING:
    print("*" * 100)

  if logging.getLogger().getEffectiveLevel() == logging.WARNING:
    if test_name not in name_to_edge_map:
      logging.basicConfig(level=logging.ERROR)
      logging.error("Edge provided doesn't exist in parsed output json.")
      logging.error("Program exiting")
      exit(0)
    else:
      err = name_to_edge_map[test_name]
      logging.warning("Printing information on %s" %err.name)
      logging.warning("Parent Node %s" %err.parent)
      logging.warning("Already fixed - %d" %err.is_marked)
      logging.warning("Encoded with %d bits" %err.num_bits)
      logging.warning("Current Residency: %s" %err.residency)
      logging.warning("Is a branch - %d" %err.is_a_branch)
      logging.warning("Size known - %d" %err.size_known)
      logging.warning("Size %d kB" %(err.size >> 13))
      logging.warning("Edge length - %d" %err.edge_len)
      exit(0)

  for k in node_data:
    total_staging_cost = 0 + node_data[k]["has_weights"]
    current_edges = node_data[k]["inputs"] + node_data[k]["outputs"]
    current_edges = [x["name"] for x in current_edges]
    sep = len(node_data[k]["inputs"])
    movable = []
    total_reassign_possible = 0
    current_available_l2 = ol.total_l2_
    mask = [(e in name_to_edge_map) for e in current_edges]
    for idx in range(len(mask)):
      if not mask[idx]:
        continue
      else:
        if name_to_edge_map[current_edges[idx]].is_marked:
          if name_to_edge_map[current_edges[idx]].residency == "L3":
            total_staging_cost += 1
          else:
            current_available_l2 -= name_to_edge_map[current_edges[idx]].size
        else:
          movable += [current_edges[idx]]
          total_reassign_possible += name_to_edge_map[current_edges[idx]].size

    total_staging_cost *= ol.staging_cost_per_tensor
    current_available_l2 -= total_staging_cost

    logging.info("For node %s" %k)
    if len(movable) > 0:
      logging.debug("Edges at this node that can be re-assigned:")
      for edges in movable:
        logging.debug("\t\t%s" %edges)
    if ((total_reassign_possible != 0) and (current_available_l2 > total_reassign_possible)):
      logging.debug("%d edge residency adjusted for this node" %len(movable))
      for edges in movable:
        name_to_edge_map[edges].is_marked = True
        name_to_edge_map[edges].residency = "L2"
        edge_data[edges]["residency"] = "L2"
        current_available_l2 -= name_to_edge_map[edges].size
        total_reassign_possible -= name_to_edge_map[edges].size

    logging.info("L2 memory available %d kB" %(current_available_l2 >> 13))
    logging.info("Total reassign Possible %d kB\n" %(total_reassign_possible >> 13))

    in_act_residency = dict()
    out_act_residency = dict()
    for idx in range(len(mask)):
      if mask[idx]:
        if idx < sep:
          in_act_residency[current_edges[idx]] = name_to_edge_map[current_edges[idx]].residency
        else:
          out_act_residency[current_edges[idx]] = name_to_edge_map[current_edges[idx]].residency
    node_data[k]["in_act_residency"] = in_act_residency
    node_data[k]["out_act_residency"] = out_act_residency

  if (logging.getLogger().getEffectiveLevel() <= logging.INFO):
    print("*" * 100)

  edge_data = convert_size_to_bytes(edge_data)
  edge_data = annotate_edges(edge_data, model.graph.node)
  node_data = transfer_info(edge_data, node_data)

  with open(out_path_n, "w") as file:
    json.dump(node_data, file, indent=4)
  with open(out_path_e, "w") as file:
    json.dump(edge_data, file, indent=4)
  
  return (out_path_n, out_path_e)

def main(main_params):
  if main_params["json_path"] is None:
    model_path = str(main_params["model_path"])
    json_path = model_path[:-5] + "_IR.json"
  elif main_params["model_path"] is None:
    json_path = str(main_params["json_path"])
    model_path = json_path[:-8] + ".onnx"
  else:
    json_path = str(main_params["json_path"])
    model_path = str(main_params["model_path"])
  test_edge = ""
  threshold_level = 8
  if "test_edge" in main_params:
    test_edge = str(main_params["test_edge"])
  if main_params["threshold_level"] is not None:
    threshold_level = int(main_params["threshold_level"])
  start = time.time()
  out_paths = edge_length_marker(json_path, model_path, test_edge, threshold_level)
  print("Edge List populated in %f seconds" %(time.time() - start))
  print("Output File Path:")
  print("\t" + os.getcwd() + os.path.sep + out_paths[0])
  print("\t" + os.getcwd() + os.path.sep + out_paths[1])

if __name__ == "__main__":
  parser = argparse.ArgumentParser()
  parser.add_argument("-d", "--debug", help="Print lots of debugging statements", action="store_const", dest="loglevel", const=logging.DEBUG)
  parser.add_argument("-v", "--verbose", help="Print verbose info statements", action="store_const", dest="loglevel", const=logging.INFO)
  parser.add_argument("-t", "--test_edge", help="In-case information on a particular edge is needed")
  parser.add_argument("-p", "--json_path", help="path to json input file. Required Field")
  parser.add_argument("-m", "--model_path", help="path to onnx input model. Required Field")
  parser.add_argument("-q", "--threshold_level", help="Threshold edge-length beyond which edge marked L3")
  args = parser.parse_args()
  if (not args.threshold_level):
    print("Explicit threshold length not provided. Currently set to default value of 5. Please provide threshold length with -q/--threshold_level option")
  if (not args.json_path) and (not args.model_path):
    parser.error("Please pass path/to/json/file using -p or --json_path and pass path/to/onnx/model using -m or --model_path flags.\npython3 edge_parser.py --help\n\t\t\tfor further info.")
  if args.test_edge:
    args.loglevel = logging.WARNING
  if args.loglevel is None:
    args.loglevel = logging.ERROR
  logging.basicConfig(level=args.loglevel)

  # Remove all handlers associated with the root logger.
  for handler in logging.root.handlers[:]:
    logging.root.removeHandler(handler)

  # Create a handler for sys.stdout
  stdout_handler = logging.StreamHandler(sys.stdout)
  stdout_handler.setLevel(args.loglevel)
  
  # Create a formatter and set it for the handler
  formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
  stdout_handler.setFormatter(formatter)

  # Add the handler to the logger
  logging.getLogger().addHandler(stdout_handler)
  
  logging.debug("Debug mode is enabled!")
  if logging.getLogger().getEffectiveLevel() == logging.INFO:
    logging.info("Verbose mode is enabled!")
  if logging.getLogger().getEffectiveLevel() == logging.WARNING:
    logging.warning("Investigation mode is enabled")

  main_params = vars(args)
  main(main_params)