import argparse
import logging
import os
import sys
import json
import subprocess
from pathlib import Path
import onnx
# Add WAIC repo root and relevant subdirectories to Python path
waic_root = Path(__file__).resolve().parents[2]  # WAIC/OGOAT/misc_tools/get_model_perf.py -> WAIC
src_dir = waic_root / "OGOAT" / "src"
l1_fusion_dir = src_dir / "L1_fusion"

sys.path.insert(0, str(waic_root))
sys.path.insert(0, str(src_dir))
sys.path.insert(0, str(l1_fusion_dir))

# Import required OGOAT modules with error handling
try:
    from OGOAT.src.L1_fusion.add_onnx_tensor_shapes import add_model_info
    from OGOAT.src.L1_fusion.parse_onnx_model import ParseOnnxModel
    from OGOAT.src.L1_fusion.L1_utils.model_IR_utils import get_unique_nodes_wrt_shapes_dtypes_attrs
    from OGOAT.src.utils.context import Context, Logger
    from OGOAT.src.L1_fusion.L1_utils.safe_runner import SafeRunner
    from OGOAT.src.L1_fusion.L1_utils.utils import onnxTensor_dtype_to_np_dtype, onnxTensor_np_dtype_to_dtype, onnxTensorProto_to_array
except ImportError as e:
    print(f"Error importing OGOAT modules: {e}")
    print("Please ensure you are running this script from the WAIC repo root and that OGOAT modules are available.")
    sys.exit(1)

def optimize_onnx_model(model_path, output_dir):
    """
    Optimizes an ONNX model by applying a set of optimization passes and saves the optimized model.
    This function loads an ONNX model from the specified path, applies selected optimization passes
    (such as extracting constants to initializers and eliminating unused initializers), and saves
    the optimized model to the given output directory with a modified filename.
    Args:
        model_path (str): Path to the input ONNX model file.
        output_dir (str): Directory where the optimized ONNX model will be saved.
    Returns:
        str: Path to the optimized ONNX model file.
    """
    print("Step 0: Optimizing ONNX model...")

    import onnxoptimizer

    # Load the ONNX model
    original_model = onnx.load(model_path)
    
    passes = ["extract_constant_to_initializer", "eliminate_unused_initializer"]
    optimized_model = onnxoptimizer.optimize(original_model, passes)

    # Save the optimized model to a temporary path for further processing
    base_name = os.path.splitext(os.path.basename(model_path))[0]
    optimized_model_path = os.path.join(output_dir, base_name + "_opt.onnx")
    onnx.save(optimized_model, optimized_model_path)
    return optimized_model_path

def extract_shapes(model_path, load_data, output_dir):
    """
    Extract tensor shapes from ONNX model using add_onnx_tensor_shapes.py
    
    Args:
        model_path (str): Path to the ONNX model
        load_data (int): Whether to load external data (0/1)
        output_dir (str): Output directory for generated files
    
    Returns:
        str: Path to the modified ONNX model with shapes
    """
    print("Step 1: Extracting tensor shapes from ONNX model...")
    
    # Prepare parameters for add_model_info, matching L1_fusion_arg in WAIC.py
    main_params = {
        "model_path": model_path,
        "load_data": str(load_data),
        "output_dir": output_dir,
        # "shape_infer_method": "onnx_shape_infer",
        "input_names": "",
        "input_dims": "",
        # "assign_new_dtypes": "0",
        # "low_precision_act_dtype": "uint16",
        # "high_precision_act_dtype": "uint16",
        # "low_precision_wgt_dtype": "uint8",
        # "high_precision_wgt_dtype": "uint16",
        "in_shape_params": "{}",
        "fixed_input_values": "{}",
        "shape_inference_outputs": 3000,
        "skip_step": [],
        "fast_pm": True,
        "no_dtype_downcast": False,
        "no_dtype_freeze": False,
        "prebuilt_mladf_mha": False,
        "fusion_seq": None,
        "target": None,
        "device": "strix",
        "debug": False,
    }
    
    try:
        add_model_info(main_params)
        
        # Return path to the modified model
        base_name = os.path.splitext(os.path.basename(model_path))[0]
        modified_model_path = os.path.join(output_dir, base_name + "_mod.onnx")
        
        if os.path.exists(modified_model_path):
            print(f"✓ Successfully extracted shapes. Modified model saved to: {modified_model_path}")
            return modified_model_path
        else:
            raise FileNotFoundError(f"Modified model not found at: {modified_model_path}")
            
    except Exception as e:
        print(f"✗ Error extracting shapes: {str(e)}")
        raise


def generate_ir(model_path, load_data, output_dir):
    """
    Generate IR representation using parse_onnx_model.py
    
    Args:
        model_path (str): Path to the ONNX model (with shapes)
        load_data (int): Whether to load external data (0/1)
        output_dir (str): Output directory for generated files
    
    Returns:
        str: Path to the generated IR JSON file
    """
    print("Step 2: Generating IR representation...")
    
    # Prepare parameters for ParseOnnxModel, matching L1_fusion_arg in WAIC.py
    main_params = {
        "model_path": model_path,
        "load_data": str(load_data),
        "output_dir": output_dir,
        # "shape_infer_method": "onnx_shape_infer",
        "input_names": "",
        "input_dims": "",
        "assign_new_dtypes": "0",
        "low_precision_act_dtype": "uint16",
        "high_precision_act_dtype": "uint16",
        "low_precision_wgt_dtype": "uint8",
        "high_precision_wgt_dtype": "uint16",
        "skip_step": [],
        "fast_pm": True,
        # "shape_inference_outputs": 3000,
        "no_dtype_downcast": False,
        "no_dtype_freeze": False,
        "prebuilt_mladf_mha": False,
        "fusion_seq": None,
        "target": None,
        "device": "strix",
        "debug": False,
        "in_shape_params": "{}",
        "fixed_input_values": "{}",
    }
    
    try:
        # Create context and logger
        context = Context(output_dir=output_dir, debug=False)
        logger = Logger(name="L1_parsing", context=context)
        logger.info("Start L1 parsing stage")
        
        # Load the ONNX model
        logger.info(f"Loading model from: {model_path}")
        model = onnx.load_model(model_path, load_external_data=load_data)
        logger.info(f"Model loaded successfully")
        
        # Create SafeRunner
        runner = SafeRunner(
            logger=logger,
            output_dir_path=output_dir,
            summary_file_name="parsing_error_summary.txt",
        )
        
        # Create ParseOnnxModel instance and generate IR (now with model parameter)
        parse_onnx_model = ParseOnnxModel(main_params, logger, model, runner)
        model_IR_path = os.path.splitext(model_path)[0] + "_IR.json"
        
        runner.run(parse_onnx_model.parse_fused_model_to_ir, model_IR_path)
        
        if len(runner.errors_occured) != 0:
            logger.info(
            f"Some errors occurred during fusion. Check summary file: {runner.summary_file_path}"
            )
            runner.dump_error_summary()

        # Generate unique nodes IR
        f = open(model_IR_path)
        model_IR = json.load(f)
        get_unique_nodes_wrt_shapes_dtypes_attrs(model_IR_path)
        
        if os.path.exists(model_IR_path):
            print(f"✓ Successfully generated IR. IR file saved to: {model_IR_path}")
            # Return path to the unique nodes IR JSON file
            unique_nodes_ir_path = os.path.splitext(model_IR_path)[0] + "_unique_nodes.json"
            return unique_nodes_ir_path
        else:
            raise FileNotFoundError(f"IR file not found at: {model_IR_path}")
            
    except Exception as e:
        print(f"✗ Error generating IR: {str(e)}")
        raise


def calculate_macs(ir_json_path):
    """
    Calculate MACs using extract_size.py
    
    Args:
        ir_json_path (str): Path to the IR JSON file
    
    Returns:
        str: Path to the generated CSV file with MAC calculations
    """
    print("Step 3: Calculating MACs and generating performance sheet...")
    
    try:
        # Import extract_size main function
        # Determine the directory containing extract_size.py
        extract_size_dir = os.path.dirname(os.path.abspath(__file__))
        sys.path.insert(0, str(extract_size_dir))
        from extract_size import main as extract_size_main
        
        # Prepare arguments for extract_size
        args = {
            "json_path": ir_json_path
        }
        
        # Call extract_size main function
        extract_size_main(args)
        
        # Calculate output CSV path
        excel_path = ir_json_path[:-5] + "_size.xlsx"

        if os.path.exists(excel_path):
            print(f"✓ Successfully calculated MACs. Excel file saved to: {excel_path}")
            return excel_path
        else:
            raise FileNotFoundError(f"Excel file not found at: {excel_path}")

    except Exception as e:
        print(f"✗ Error calculating MACs: {str(e)}")
        raise


def get_dtype(graph, tensor_name):
    # Helper to get dtype from an initializer or value_info
    for vi in graph.value_info:
        if vi.name == tensor_name:
            return onnxTensor_dtype_to_np_dtype(vi.type.tensor_type.elem_type)
    for inp in graph.input:
        if inp.name == tensor_name:
            return onnxTensor_dtype_to_np_dtype(inp.type.tensor_type.elem_type)
    for out in graph.output:
        if out.name == tensor_name:
            return onnxTensor_dtype_to_np_dtype(out.type.tensor_type.elem_type)
    for init in graph.initializer:
        if init.name == tensor_name:
            return onnxTensorProto_to_array(init)[0].dtype
    return None

def change_tensor_dtype_based_on_qdq(model_path, output_dir):
    """
    Creates a new ONNX model where intermediate tensor dtypes are changed based on Q/DQ nodes.
    - If a tensor is the output of a DequantizeLinear node, its dtype is set to the input dtype of that node.
    - If a tensor is the input to a QuantizeLinear node, its dtype is set to the output dtype of that node.
    Saves the new model as <original_name>_qdq.onnx in output_dir.
    Returns the new model path.
    """
    print("Step 1.1: Changing tensor dtypes based on Q/DQ nodes...")
                
    model = onnx.load(model_path)
    graph = model.graph

    # Find all DequantizeLinear and QuantizeLinear nodes
    dq_nodes = [node for node in graph.node if node.op_type == "DequantizeLinear"]
    q_nodes = [node for node in graph.node if node.op_type == "QuantizeLinear"]

    # Map tensor name to new dtype (np.dtype)
    tensor_new_dtype = {}

    # 1. For each DequantizeLinear, set output tensor dtype to input dtype
    for node in dq_nodes:
        if len(node.input) >= 1 and len(node.output) >= 1:
            inp = node.input[0]
            out = node.output[0]
            inp_dtype = get_dtype(graph, inp)
            if inp_dtype is not None:
                tensor_new_dtype[out] = inp_dtype

    # 2. For each QuantizeLinear, set input tensor dtype to output dtype
    for node in q_nodes:
        if len(node.input) >= 1 and len(node.output) >= 1:
            inp = node.input[0]
            out = node.output[0]
            out_dtype = get_dtype(graph, out)
            if out_dtype is not None:
                tensor_new_dtype[inp] = out_dtype

    # 3. Update value_info, input, output dtypes
    for vi in list(graph.value_info) + list(graph.input) + list(graph.output):
        if vi.name in tensor_new_dtype:
            vi.type.tensor_type.elem_type = onnxTensor_np_dtype_to_dtype(str(tensor_new_dtype[vi.name]))

    # 4. Save new model
    base_name = os.path.splitext(os.path.basename(model_path))[0]
    new_model_path = os.path.join(output_dir, base_name + "_qdq.onnx")
    onnx.save(model, new_model_path)
    print(f"✓ Dtype-changed model saved to: {new_model_path}")
    return new_model_path

def main():
    """Main function to orchestrate the entire workflow"""
    parser = argparse.ArgumentParser(
        description="Extract shapes, generate IR, and calculate MACs for ONNX model",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
Examples:
    python get_model_perf.py --model_path model.onnx --load_data 1
    python get_model_perf.py --model_path /path/to/model.onnx --load_data 0
    python get_model_perf.py --model_path model.onnx --qdq_model 1
    python get_model_perf.py --model_path model.onnx --load_data 1 --qdq_model 1
        """
    )
    
    parser.add_argument(
        "--model_path", "-mp",
        required=True,
        help="Path to the ONNX model file"
    )
    
    parser.add_argument(
        "--load_data", "-ld",
        type=int,
        choices=[0, 1],
        default=0,
        help="Whether to load external data (0=No, 1=Yes). Default: 0"
    )
    
    parser.add_argument(
        "--output_dir", "-out",
        help="Output directory for generated files. Default: same as model directory"
    )
    
    parser.add_argument(
        "--qdq_model", "-qdq",
        type=int,
        choices=[0, 1],
        default=0,
        help="Whether to change tensor dtypes based on Q/DQ nodes (0=No, 1=Yes). Default: 0"
    )
    
    parser.add_argument(
        "-d", "--debug",
        action="store_const",
        dest="loglevel",
        const=logging.DEBUG,
        help="Enable debug logging"
    )
    
    args = parser.parse_args()
    
    # Setup logging
    logging.basicConfig(
        level=args.loglevel or logging.INFO,
        format='%(asctime)s - %(levelname)s - %(message)s'
    )
    
    # Validate model path
    if not os.path.exists(args.model_path):
        print(f"✗ Error: Model file not found: {args.model_path}")
        return 1
    
    # Set output directory
    if args.output_dir:
        output_dir = args.output_dir
    else:
        output_dir = os.path.dirname(os.path.abspath(args.model_path))
    
    # Create output directory if it doesn't exist
    os.makedirs(output_dir, exist_ok=True)
    
    print(f"Processing model: {args.model_path}")
    print(f"Output directory: {output_dir}")
    print(f"Load external data: {'Yes' if args.load_data else 'No'}")
    print("-" * 60)
    
    try:
        # Step 0: Optimize the ONNX model before further processing
        optimized_model_path = optimize_onnx_model(args.model_path, output_dir)

        # Step 1: Extract shapes
        modified_model_path = extract_shapes(optimized_model_path, args.load_data, output_dir)
        
        # Optional step: change tensor dtype based on q/dq nodes
        if args.qdq_model == 1:
            modified_model_path = change_tensor_dtype_based_on_qdq(modified_model_path, output_dir)
        
        # Step 2: Generate IR
        ir_json_path = generate_ir(modified_model_path, args.load_data, output_dir)
        
        # Step 3: Calculate MACs
        csv_path = calculate_macs(ir_json_path)
        
        print(f"\n✓ Pipeline completed successfully!")
        print(f"Final results: {csv_path}")
        
        return 0
        
    except Exception as e:
        print(f"\n✗ Pipeline failed: {str(e)}")
        if args.loglevel == logging.DEBUG:
            import traceback
            traceback.print_exc()
        return 1


if __name__ == "__main__":
    sys.exit(main())
