import onnx
import json
import argparse

class TilerSubgraphNPU:
    def __init__(self, onnx_model_path: str, json_file_path: str):
        self.onnx_model_path = onnx_model_path
        self.json_file_path = json_file_path
        self.model = None
        self.config = None
        self.npu_node_name_set = None
        self.cpu_node_name_set = None
        self.cpu_edges = None

        self._load_model()
        self._load_config()
        self._demarcate()
        self._populate_edge_names()

    def _load_model(self):
        try:
            self.model = onnx.load(self.onnx_model_path)
        except Exception as e:
            raise ValueError(f"Failed to load ONNX model from {self.onnx_model_path}: {e}")

    def _load_config(self):
        try:
            with open(self.json_file_path, 'r') as f:
                self.config = json.load(f)
        except Exception as e:
            raise ValueError(f"Failed to load JSON file from {self.json_file_path}: {e}")
        
    def _populate_node_name(self):
        node_name_set = set()
        for node in self.model.graph.node:
            node_name_set.add(node.name)
        return node_name_set
    
    def _populate_npu_node_name(self):
        node_name_list = []
        for node_t in self.config:
            node_name_list += self.config[node_t]["layer_info"]["nodenames"]
        return node_name_list
    
    def _demarcate(self):
        self.cpu_node_name_set = self._populate_node_name()
        npu_node_name_list = self._populate_npu_node_name()
        self.npu_node_name_set = set(npu_node_name_list)
        n = len(self.cpu_node_name_set)
        
        for node in npu_node_name_list:
            assert node in self.cpu_node_name_set, f"Node {node} not found in ONNX model node names."
            self.cpu_node_name_set.remove(node)
        
        assert n == (len(npu_node_name_list) + len(self.cpu_node_name_set)), "Demarcation assertion failed."
        
    def _populate_edge_names(self):
        self.cpu_edges = set()
        for node in self.model.graph.node:
            if not self.npu_or_cpu(node.name):
                for input_edge in node.input:
                    self.cpu_edges.add(input_edge)
                for output_edge in node.output:
                    self.cpu_edges.add(output_edge)
    
    def npu_or_cpu(self, node_name: str) -> bool:
        if node_name in self.npu_node_name_set:
            return True
        if node_name in self.cpu_node_name_set:
            return False
        assert False, f"{node_name} not found in either NPU or CPU node sets"
        
    def assignL3(self, edge_name: str) -> bool:
        if edge_name in self.cpu_edges:
            return True
        return False
        
if __name__ == "__main__":

    parser = argparse.ArgumentParser(description="Tiler Subgraph NPU")
    parser.add_argument("--json_file_path", type=str, required=True, help="Path to the JSON configuration file")
    parser.add_argument("--onnx_file_path", type=str, required=True, help="Path to the ONNX model file")
    args = parser.parse_args()

    json_file_path = args.json_file_path
    onnx_file_path = args.onnx_file_path
    tiler = TilerSubgraphNPU(onnx_file_path, json_file_path)
    
    while True:
        user_input = input("Enter a node name (press Enter to exit): ")
        if user_input == "":
            break
        if tiler.npu_or_cpu(user_input):
            print(f"{user_input} is an NPU node.")
        else:
            print(f"{user_input} is a CPU node.")