import onnx 
import sys
import os
infra_path = (os.path.dirname(os.path.abspath(__file__))+"/infra/")
sys.path.append(infra_path)

from collections import defaultdict
import numpy as np
import pdb
import argparse
import logging


#From infra dir
import L2_utils as utils

ONNX_DTYPE_SIZES = {
    onnx.TensorProto.FLOAT: 4,
    onnx.TensorProto.DOUBLE: 8,
    onnx.TensorProto.INT32: 4,
    onnx.TensorProto.INT64: 8,
    onnx.TensorProto.UINT8: 1,
    # Add other dtypes as needed
    # Default set to 2
}

class ONNXLivenessAnalyzer: 
    def __init__(self, model_path):
        self.model = onnx.load(model_path)
        self.execution_order = [] # Ordered list of operations
        self.producers = {} #{tensor: op that produces it} 
        self.consumers = defaultdict(set) # {tensor: set of ops that consume it}
        self.liveness_intervals = {} # {tensor: (start_time, end_time)}
        self.tensor_sizes = {} # {tensor: size}

    def parse_model(self): 
        """Extracts computation graph from ONNX model"""
        for idx, node in enumerate(self.model.graph.node): 
            op_name = f"{node.op_type}_{idx}" 
            self.execution_order.append(op_name)
            # Track tensor producers (outputs of this op) 
            for output in node.output: 
                self.producers[output] = op_name
            # Track tensor consumers (inputs to this op) 
            for input_tensor in node.input:
                self.consumers[input_tensor].add(op_name)
            # Compute tensor sizes from initializers (weights, constants, etc.)
            for tensor in self.model.graph.initializer: 
                self.tensor_sizes[tensor.name] = self.get_tensor_size(tensor)

    def get_tensor_size(self, tensor): 
        """NOT TESTED-- akumar22"""
        """Computes the size (in bytes) of a tensor given its shape and data type.""" 
        dtype = tensor.data_type 
        if not hasattr(tensor, "type"): 
            return 0
        shape = [dim.dim_value for dim in tensor.type.tensor_type.shape.dim] 
        
        if None in shape or len(shape) == 0: 
            return 0 # Undefined shape 
        num_elements = np.prod(shape) # Total number of elements 
        element_size = ONNX_DTYPE_SIZES.get(dtype, 2) # Get size per element
        
        return num_elements * element_size # Total size in bytes 

    def compute_memory_and_live_tensors(self):
        """Computes memory usage and number of live tensors at each step.""" 
        max_time = max(end for _, end in self.liveness_intervals.values()) + 1
        memory_usage = [0] * max_time
        live_tensor_count = [0] * max_time
        
        for tensor, (start, end) in self.liveness_intervals.items(): 
            size = self.tensor_sizes.get(tensor, 0) 
            for t in range(start, end + 1): 
                memory_usage[t] += size 
                live_tensor_count[t] += 1 
        
        return memory_usage, live_tensor_count

    def compute_liveness_intervals(self):
        """Determines the liveness interval (first use, last use) of each tensor.""" 
        first_use = {}
        last_use = {} 
        for time, op in enumerate(self.execution_order):
            # Track first & last use of tensors in inputs 
            for tensor in self.consumers:
                if op in self.consumers[tensor]: 
                    if tensor not in first_use: 
                        first_use[tensor] = time 
                    last_use[tensor] = time 
            # Track first & last use of tensors in outputs 
            for tensor in self.producers: 
                if self.producers[tensor] == op: 
                    if tensor not in first_use: 
                        first_use[tensor] = time
                    last_use[tensor] = time 
            # Store intervals
            for tensor in first_use:
                self.liveness_intervals[tensor] = (first_use[tensor], last_use[tensor])

    def get_live_tensors_at(self, time_step):
        """Returns a list of tensors that are live at a given time step."""
        return [tensor for tensor, (start, end) in self.liveness_intervals.items() if start <= time_step <= end]
    
    def print_liveness_info(self):
        """Prints the liveness analysis results."""
        print("\nTensor Liveness Intervals:")
        for tensor, (start, end) in sorted(self.liveness_intervals.items(), key=lambda x: x[1][0]):
            print(f"Tensor: {tensor}, Live from {start} to {end}")

    def plot_liveness_chart(self):
        """Plots a time chart showing tensor liveness intervals."""
        import matplotlib
        matplotlib.use('Agg')
        import matplotlib.pyplot as plt 


        tensors = list(self.liveness_intervals.keys())
        intervals = [self.liveness_intervals[t] for t in tensors]
        
        fig, ax = plt.subplots(figsize=(60, len(tensors) * 0.5))
        y_ticks = np.arange(len(tensors))  # Assign y-axis positions
        
        for i, (start, end) in enumerate(intervals):
            ax.barh(y_ticks[i], end - start + 1, left=start, color="skyblue", edgecolor="black", rasterized=True)
        
        ax.set_yticks(y_ticks)
        ax.set_yticklabels(tensors)
        ax.set_xlabel("Execution Steps")
        ax.set_ylabel("Tensors")
        ax.set_title("Tensor Liveness Analysis")
        ax.grid(axis="x", linestyle="--", alpha=0.5)
        
        # Save the plot as an image file instead of showing it
        plt.savefig('tensor_liveness_analysis.png')
        print(f"Dumping liveness chart: {os.getcwd()}/tensor_liveness_analysis.png")
        plt.close()

def main(args):
    logging.info("Argument summary: %s", args)
    analyzer = ONNXLivenessAnalyzer(args['model_path'])
    analyzer.parse_model()
    analyzer.compute_liveness_intervals()
    analyzer.print_liveness_info()
    analyzer.plot_liveness_chart()

    if args['at_time_stamp'] is not None:
        print(f"\nLive tensors at step {args['at_time_stamp']}: {analyzer.get_live_tensors_at(args['at_time_stamp'])}")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Tensor liveness analysis for ML Workload",
                                        usage='use "%(prog)s --help" for more info',
                                        formatter_class=argparse.RawTextHelpFormatter)
    #required knobs
    parser.add_argument("-mp", "--model_path", required=True, help="path to L1 fused onnx model")
    #optional knobs
    parser.add_argument("--at_time_stamp", default=None, help="view live tensor at specific step")
    #debug/profile knobs
    parser.add_argument('-d','--debug', help="Dump dbg log to 'dbg_log.txt'", action="store_true", default=False)
    parser.add_argument('-df','--debug_file_name', help="Debug log file name", default="dbg_log.txt")
    parser.add_argument('-v','--verbose', choices=['debug', 'info', 'error'], help="Verbosity for debug logs", default='debug')

    args = parser.parse_args()
    
    utils.check_file_type(args.model_path, ".onnx")


    if args.debug:
        filename = args.debug_file_name
        verbose  = utils.DEBUG_VERBOSE.str2enum(args.verbose).value
        print("Saving debug log as :",os.getcwd()+"/"+args.debug_file_name)

        logging.basicConfig(
        filename=filename,
        filemode='w',
        format='[%(asctime)s,%(msecs)d] [%(levelname)s]: %(message)s',
        datefmt='%M:%H:%S',
        level=verbose
        )

    main(vars(args))
