import csv
from email.mime import image
import math
import pickle
import os
import re
import sys
from statistics import mode
from timeit import default_timer as timer
from collections import OrderedDict
import subprocess
import argparse
import textwrap
import numpy as np
import onnx
import onnxruntime as ort
from ml_dtypes import bfloat16
from sympy import O

REPO_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
l1_fusion_path = os.path.join(REPO_ROOT, "OGOAT", "src", "L1_fusion")
sys.path.append(l1_fusion_path)
sys.path.append(REPO_ROOT)

from tools.extract_channels import extract_ifm_ofm_channels, extract_const_channels
from tools.utils import (
    append_outputs_to_model,
    append_outputs_to_fused_model,
    save_ifm_ofm,
    save_wgt_bias,
    save_wgt_bias_from_pickle
)

np.dtype("bfloat16")
from onnxruntime_extensions import get_library_path

from config import (
    get_fused_model,
    get_ir_json,
    get_node_json,
    get_io_shapes_for_model,
    get_input_for_model,
    get_outs_from_graph,
    get_input_nodes,
    get_output_nodes,
    modify_input_tensors_for_fused_graph,
    results_dir,
)


from onnx import TensorProto, helper, numpy_helper, shape_inference
from onnx.helper import (
    make_attribute,
    make_graph,
    make_model,
    make_node,
    make_tensor_value_info,
)

from onnxruntime_extensions import PyCustomOpDef, PyOp, get_library_path, onnx_op


from datetime import datetime
from sys import exit

def remove_tmp_files():
    if os.path.exists('new_mod_fused.onnx'):
        os.remove('new_mod_fused.onnx')
    if os.path.exists('new_mod_fused.onnx.data'):
        os.remove('new_mod_fused.onnx.data')
        
    if os.path.exists('new_model.onnx'):
        os.remove('new_model.onnx')
    if os.path.exists('new_model.onnx.data'):
        os.remove('new_model.onnx.data')

def run_ortsession_large_model(so, model_path, ir_json, node_json,
                   nodes = "unique", modified = False):
    model = onnx.load_model(model_path, load_external_data=True)
    weights_qdq_dict = []
    ifm_ofm_dict = []
    if modified:
        ifm_ofm_dict = extract_ifm_ofm_channels(ir_json, node_json, all_nodes = nodes)
        weights_qdq_dict = extract_const_channels(ir_json)
        append_outputs_to_fused_model(model, ifm_ofm_dict)
        append_outputs_to_fused_model(model, weights_qdq_dict)
    else:
        append_outputs_to_model(model)
    new_model_path = "new_mod_fused"
    onnx.save_model(model, new_model_path + ".onnx", save_as_external_data=True, all_tensors_to_one_file=True, location=new_model_path + '.onnx.data')
    ort_session = ort.InferenceSession(
        new_model_path + '.onnx', so, providers=["CPUExecutionProvider"]
    )
    outputs = [x.name for x in ort_session.get_outputs()]
    return ort_session, outputs, ifm_ofm_dict, weights_qdq_dict

def run_ortsession(so, model, ir_json, node_json,
                   nodes = "unique", modified = False):
    weights_qdq_dict = []
    ifm_ofm_dict = []
    if modified:
        ifm_ofm_dict = extract_ifm_ofm_channels(ir_json, node_json, all_nodes = nodes)
        weights_qdq_dict = extract_const_channels(ir_json)
        append_outputs_to_fused_model(model, ifm_ofm_dict)
        append_outputs_to_fused_model(model, weights_qdq_dict)
    else:
        append_outputs_to_model(model)
    new_model_path = "new_model"
    onnx.save_model(model, new_model_path + ".onnx", save_as_external_data=True, all_tensors_to_one_file=True, location=new_model_path + '.onnx.data')
    ort_session = ort.InferenceSession(
        new_model_path + '.onnx', so, providers=["CPUExecutionProvider"]
    )
    outputs = [x.name for x in ort_session.get_outputs()]

    return ort_session, outputs, ifm_ofm_dict, weights_qdq_dict

def save_and_get_intermediate_outputs(model, outputs, ort_outs, ifm_ofm_dict, prefix, output_dir):
    curr_dir = os.getcwd()
    ort_outs_dict = OrderedDict(zip(outputs, ort_outs))
    top_folder_name = os.path.join(curr_dir, output_dir, 'DataGen')
    top_folder_name = os.path.join(top_folder_name, 'Activations')
    if not os.path.isdir(top_folder_name):
        os.makedirs(top_folder_name)
    save_ifm_ofm(model, ifm_ofm_dict, ort_outs_dict, top_folder_name, prefix)

def save_and_get_node_constants(outputs, ort_outs, weights_qdq_dict, output_dir):
    curr_dir = os.getcwd()
    ort_outs_dict = OrderedDict(zip(outputs, ort_outs))
    top_folder_name = os.path.join(curr_dir, output_dir, 'DataGen')
    top_folder_name = os.path.join(top_folder_name, 'Consts')
    if not os.path.isdir(top_folder_name):
        os.makedirs(top_folder_name)
    save_wgt_bias(weights_qdq_dict, ort_outs_dict, top_folder_name)

def get_network_l2_norm_msft(ort_outs_mod, msft_out):
    l2 = []
    for i in range(len(msft_out)):
        l2.append(np.linalg.norm(ort_outs_mod[i].astype(np.float32) - 
                                 msft_out[i].astype(np.float32)))
    return l2


def get_network_l2_norm_orig(ort_outs_mod, ort_outs_orig, count = 1):
    l2 = []
    for i in range(count):
        l2.append(np.linalg.norm(ort_outs_mod[i].astype(np.float32) - 
                                 ort_outs_orig[i].astype(np.float32)))
    return l2


def ort_single_input(data_list, args, so, output_dir):
    orig_model_name = args['model_name']
    data_idx = int(args['idx'])  # Data point number
    runall_data = args['all']  # Run against all data points
    output_dir = args['out_dir']
    save_data = True
    if runall_data:
        save_data = False
    model_name = get_fused_model(orig_model_name, output_dir)
    data_folder = os.path.dirname(model_name)
    ir_json = get_ir_json(orig_model_name, output_dir)
    node_json = get_node_json(orig_model_name, output_dir)
    msft_out = data_list[1000:]
    edges = args['edges']
    load_data = args['ld']
    out_flag = args['data_dump']
    outputs = out_flag.split(',')
    save_orig = False
    save_mod = ""
    if "all" in outputs or "ort" in outputs:
        save_orig = True
    if "all" in outputs or ("carf" in outputs and "const" in outputs):
        save_mod = "all"
    elif "carf" in outputs:
        save_mod = "carf"
    elif "const" in outputs:
        save_mod = "const"
    ort_outs_orig = None
    ort_outs_mod  = None
    if save_orig:
        ort_outs_orig = generate_ORT_data(orig_model_name,
                                          node_json, ir_json,
                                          data_list, output_dir, nodes = edges, save = True)
    if save_mod != "":
        ort_outs_mod = generate_CARF_data(model_name, node_json,
                                          ir_json, data_list, load_data, output_dir,
                                          nodes = edges, save = save_mod)

    l2_norm_msft_against_orig = None
    l2_norm_against_msft = None
    l2_norm_against_orig = None
    if ort_outs_orig:
        l2_norm_msft_against_orig = \
                get_network_l2_norm_orig(msft_out,
                                         ort_outs_orig,
                                         len(msft_out))
    if  ort_outs_mod:
        l2_norm_against_msft = \
                get_network_l2_norm_msft(ort_outs_mod, msft_out)
    if ort_outs_orig and ort_outs_mod:
        l2_norm_against_orig = \
                get_network_l2_norm_orig(ort_outs_mod, 
                                         ort_outs_orig, 
                                         len(msft_out))

    return l2_norm_msft_against_orig, l2_norm_against_msft, l2_norm_against_orig


def save_results(partial, result_lines, model_name, output_dir):
    current_datetime = datetime.now()
    model_base = model_name[:-5]
    formatted_datetime = current_datetime.strftime("%Y%m%d_%H%M%S")
    res_filename = f"{model_base}_results_{formatted_datetime}.txt"
    if partial:
        res_filename = f"partial_{res_filename}"
    with open(os.path.join(output_dir, res_filename), "w") as f:
        f.write("\n".join(result_lines))
    print(f"Results saved to {res_filename}")
    print("Total number of files processed: ", len(result_lines))


def generate_ORT_data(orig_model_name,
                      node_json,
                      ir_json,
                      data_list, output_dir,
                      nodes = "unique",
                      save = True):
    so = ort.SessionOptions()
    so.register_custom_ops_library(get_library_path())
    so.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
    if data_list is None:
        print("Invalid: No data")
    model_orig = onnx.load(orig_model_name)  # Original model

    print("Run orig model...")
    ort_session_orig, outputs_orig, ifm_ofm_dict_orig, weights_qdq_dict_orig = \
            run_ortsession(so, model_orig, ir_json, node_json, nodes, False)
    input_nodes = get_input_nodes(model_orig)
    ort_outs_orig = get_outs_from_graph(ort_session_orig,
                                        outputs_orig,
                                        data_list,
                                        input_nodes)
    if save:
        var = "ort"
        save_and_get_intermediate_outputs(
            model_orig, outputs_orig, ort_outs_orig, ifm_ofm_dict_orig, var, output_dir
        )
    del ort_session_orig

    return ort_outs_orig


def generate_CARF_data(model_name,
                       node_json,
                       ir_json,
                       data_list,
                       load_data, output_dir,
                       nodes = "unique",
                       save = "all"):
    so = ort.SessionOptions()
    so.register_custom_ops_library(get_library_path())
    so.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
    if data_list is None:
        print("Invalid: No data")
    model_mod = onnx.load(model_name)  # Modified model from config
    print("Run modified model...")
    if load_data:
        ort_session, outputs, ifm_ofm_dict, weights_qdq_dict = \
                run_ortsession_large_model(so, model_name, ir_json, node_json, nodes, True)
    else:
        ort_session, outputs, ifm_ofm_dict, weights_qdq_dict = \
                run_ortsession(so, model_mod, ir_json, node_json, nodes, True)
    input_nodes = get_input_nodes(model_mod)
    data_list_mod = modify_input_tensors_for_fused_graph(model_name, data_list)
    ort_outs_mod = get_outs_from_graph(ort_session,
                                       outputs,
                                       data_list_mod,
                                       input_nodes)
    var = "carf"
    if save == "all" or save == "carf":
        save_and_get_intermediate_outputs(model_mod,
                                          outputs,
                                          ort_outs_mod,
                                          ifm_ofm_dict,
                                          var, output_dir)
    if save == "all" or save == "const":
        save_and_get_node_constants(outputs, ort_outs_mod, weights_qdq_dict, output_dir)
    del ort_session

    return ort_outs_mod

def generate_const_data(orig_model_name, node_json, ir_json, data_list, output_dir):
    so = ort.SessionOptions()
    so.register_custom_ops_library(get_library_path())
    so.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
    if data_list is None:
        print("Invalid: No data")
    model_orig = onnx.load(orig_model_name)  # Original model

    print("Run orig model for const generation ...")
    ort_session_orig, outputs_orig, ifm_ofm_dict_orig, weights_qdq_dict_orig = \
            run_ortsession(so, model_orig, ir_json, node_json, False, False)
    input_nodes = get_input_nodes(model_orig)
    ort_outs_orig = get_outs_from_graph(ort_session_orig,
                                        outputs_orig,
                                        data_list,
                                        input_nodes)
    save_and_get_node_constants(
        outputs_orig, ort_outs_orig, weights_qdq_dict_orig, output_dir)
    del ort_session_orig

    return ort_outs_orig

def convert_int8_to_int4_as_int8(bytes_in):
    """ convert bytes_in, a numpy array of 1d of length N, into 1d numpy array of dtype int8 length 2N
    each byte becomes two elements, the lower and upper halves of the byte
    each 4 bits is interpreted as int4, and stored as int8"""

    #y = np.frombuffer(bytes_in, np.int8)
    y = bytes_in.astype(np.int8)

    # lower 4 bits
    y_lower = y & 0xF
    ind_neg = (16 * (y_lower >= 8)).astype(np.int8) # -16 for all negative values
    y_lower_int8 = y_lower - ind_neg

    # lower 4 bits
    y_upper_int8 = y >> 4 # will sign extend

    # interleave lower and upper
    out = np.empty(len(y_lower_int8)*2, dtype=y_lower_int8.dtype)
    out[0::2] = y_lower_int8
    out[1::2] = y_upper_int8

    return out

def construct_initializer_dict(model):
    INTIALIZERS = model.graph.initializer
    initializer_dict = {}
    for initializer in INTIALIZERS:
        if initializer.name not in initializer_dict:
            if initializer.data_type == TensorProto.INT4:
                # int4 special handling
                # 2 int4 are packed into 1 byte, unpack 1 byte of 2 int4 into 2 bytes of 2 int8
                if initializer.HasField('raw_data'):
                    bytes_in = np.frombuffer(initializer.raw_data, np.int8)
                else:
                    bytes_in = np.array(initializer.int32_data, np.int8)
                data_1d = convert_int8_to_int4_as_int8(bytes_in)
                initializer_dict[initializer.name] = data_1d.reshape(initializer.dims)
            else:
                initializer_dict[initializer.name] = numpy_helper.to_array(initializer)
    return initializer_dict

def main(args):
    remove_tmp_files()
    model_name = args['model_name']  # Model to test
    data_idx = int(args['idx'])  # Data point number
    data_folder = os.path.dirname(model_name)
    runall_data = args['all']  # Run against all data points
    output_dir = args["out_dir"]
    out_flag = args['data_dump']
    outputs = out_flag.split(',')
    save_mod = ""
    if "wgt" in outputs:
        orig_model_name = args['model_name']
        model_name = get_fused_model(orig_model_name, output_dir)
        data_folder = os.path.dirname(model_name)
        ir_json = get_ir_json(orig_model_name, output_dir)
        node_json = get_node_json(orig_model_name, output_dir)
        model_name = get_fused_model(orig_model_name, output_dir)
        weights_qdq_dict = extract_const_channels(ir_json)
        model_mod = onnx.load(model_name)  # Modified model from config
        ini_dict = construct_initializer_dict(model_mod)
        curr_dir = os.getcwd()
#        ort_outs_dict = OrderedDict(zip(outputs, ort_outs))
        top_folder_name = os.path.join(curr_dir, output_dir, 'DataGen')
        top_folder_name = os.path.join(top_folder_name, 'Consts')
        if not os.path.isdir(top_folder_name):
            os.makedirs(top_folder_name)
        save_wgt_bias_from_pickle(weights_qdq_dict, ini_dict, top_folder_name)
        print("Saved constants without running fused model")
    else:
        # OGOAT custom ops
        from custom_ops.conv import (
                Conv_qdq_bias_uint16xuint8xuint16,
                Conv_qdq_uint16xuint8xuint16,
                Conv_qdq_biasleakyrelu_uint16xuint8xuint16,
                Conv_qdq_leakyrelu_uint16xuint8xuint16,
                Conv_qdq_relu_uint16xuint8xuint16
        )
        from custom_ops.norm import (
                GroupNormalization_qdq_uint16xuint16xuint16,
                LayerNormalization_qdq_uint16xuint8xuint16
        )
        from custom_ops.silu import Silu_qdq_uint16xuint16
        from custom_ops.sigmoid import Sigmoid_qdq_uint16xuint16
        from custom_ops.swish import Swish_qdq_uint16xuint16
        from custom_ops.relu import Relu_qdq_uint16xuint16
        from custom_ops.tanh import Tanh_qdq_uint16xuint16
        from custom_ops.gemm import Gemm_qdq_WBias_uint16xuint8xuint16
        from custom_ops.gemm import MatMul_qdq_Unsqueeze_WBias_uint16xuint8xuint16
        from custom_ops.gemm import MatMul_qdq_biasleakyrelu_uint16xuint8xuint16
        from custom_ops.gemm import MatMul_qdq_bias_uint16xint8xuint16
        from custom_ops.gemm import MatMul_qdq_biasgelu_uint16xint8xuint16
        from custom_ops.add import Add_qdq_EleWise_uint16xuint16xuint16
        from custom_ops.add import Add_qdq_EleWise_uint8xuint8xuint8
        from custom_ops.add import Add_qdq_BroadCast_uint16xuint16xuint16
        from custom_ops.add import Add_qdq_BroadCast_uint8xuint8xuint8
        from custom_ops.sub import Sub_qdq_uint16xuint16xuint16
        from custom_ops.div import Div_qdq_uint16xuint16xuint16
        from custom_ops.mul import Mul_qdq_EleWise_uint16xuint16xuint16
        from custom_ops.mul import Mul_qdq_BroadCast_uint16xuint16xuint16
        from custom_ops.matmul import MatMul_qdq_uint16xuint8xuint16
        from custom_ops.matmul import MatMul_qdq_uint16xuint16xuint16
        from custom_ops.matmul_act import MatMul_qdq_actxact_uint16xuint16xuint16
        from custom_ops.matmul_bias import MatMul_qdq_bias_uint16xuint8xuint16
        from custom_ops.matmul_bias import MatMul_qdq_bias_Transpose_uint16xuint8xuint16
        from custom_ops.matmul_bias_gelu import MatMul_qdq_biasgelu_uint16xuint8xuint16
        from custom_ops.matmul_bias_swish import MatMul_qdq_biasswish_uint16xuint8xuint16
        from custom_ops.matmul_bias_tanh import MatMul_qdq_biastanh_uint16xuint8xuint16
        from custom_ops.matmul_bias_sigmoid import MatMul_qdq_biassigmoid_uint16xuint8xuint16
        from custom_ops.matmul_bias_silu import MatMul_qdq_biassilu_uint16xuint8xuint16
        from custom_ops.matmul_bias_elu import MatMul_qdq_biaselu_uint16xuint8xuint16
        from custom_ops.matmul_bias_relu import MatMul_qdq_biasrelu_uint16xuint8xuint16
        from custom_ops.mha import (
                MHA_2p1_qdq_uint16xuint16xuint16,
                MHA_3p0_1col_qdq_uint16xuint16xuint16,
                MHA_2p1_bias_qdq_uint16xuint16xuint16,
                MHA_3p0_1col_bias_qdq_uint16xuint16xuint16,
        )
        from custom_ops.conv_to_matmul import (
                MatMul_qdq_uint16xint4xuint16,
                MatMul_qdq_uint16xint8xuint16,
        )
        from custom_ops.conv_slice_silu_to_matmul import (
                MatMul_qdq_slice_silu_uint16xint4xuint16,
                MatMul_qdq_slice_uint16xint4xuint16
        )
        from custom_ops.globalaveragepool import GlobalAveragePool_qdq_uint16xuint16
        from custom_ops.softmax import Softmax_qdq_uint16xuint16
        from custom_ops.gelu import Gelu_qdq_uint16xuint16
        from custom_ops.ggelu import GGelu_qdq_uint16xuint16
        from custom_ops.gather import Gather_qdq_uint16xuint16
        from custom_ops.nni import Resize_qdq_uint16xuint16
        #from custom_ops.concat import Concat20_qdq_uint16
        #from custom_ops.concat import Concat10_qdq_uint16
        #from custom_ops.concat import Concat5_qdq_uint16
        #from custom_ops.concat import Concat4_qdq_uint16
        #from custom_ops.concat import Concat2_qdq_uint16
        from custom_ops.slice import Slice_qdq_uint16xuint16
        #from custom_ops.split import Split_qdq_uint16
        from custom_ops.averagepool import AveragePool_qdq_uint16xuint16
        from custom_ops.depthtospace import DepthToSpace_qdq_uint16xuint16
        from custom_ops.psu_lp_norm import LpNormalization_qdq_uint16xuint16xuint16
        from custom_ops.lp_norm import LpNormalization_qdq_uint16xuint16
        #from custom_ops.transpose import Transpose_noop
        #from custom_ops.transpose import Transpose_qdq_uint16xuint16
        from custom_ops.flatten import Flatten_noop
        from custom_ops.reshape import Reshape_noop
        from custom_ops.reshape import Reshape_qdq_uint16xuint16
        from custom_ops.gather_elements import GatherElements_qdq_uint16xuint16
        from custom_ops.quant_linear import QuantizeLinear_qdq
        from custom_ops.dequant_linear import DequantizeLinear_qdq
        from custom_ops.squeeze import Squeeze_noop
        from custom_ops.unsqueeze import Unsqueeze_noop
        from custom_ops.neg import Neg_qdq_uint16xuint16
        from custom_ops.rope import (
                RoPE_actxact_qdq_uint16xuint16,
                RoPE_qdq_uint16xuint16
        )
        from custom_ops.reducesum import ReduceSum_noop
        from custom_ops.pwla import PWLA_qdq_uint16xuint16
        from custom_ops.pwla import MatMul_qdq_biaspwla_uint16xuint8xuint16

        so = ort.SessionOptions()
        so.register_custom_ops_library(get_library_path())
        so.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
        if not runall_data:
            data_list, file_name_list = get_input_for_model(model_name, data_idx)
            l2_norm_msft_against_orig, l2_norm_against_msft, \
                    l2_norm_against_orig = ort_single_input(data_list, args, so, output_dir)
            print(
                "Orig Model L2 Norm with MSFT is:",
                l2_norm_msft_against_orig,
            )
            print(
                "Fused Model L2 Norm with MSFT is:",
                l2_norm_against_msft,
            )
            print(
                "Fused Model L2 Norm with Orig Model is:",
                l2_norm_against_orig,
            )
        else:
            print("Running against all data points for Model", model_name)
            result_lines = []
            num_data_points = len(os.listdir(os.path.join(data_folder, "msft_output")))
            sum_l2_norm_against_msft = 0.0
            count = 0
            for idx in range(num_data_points):
                try:
                    data_list, file_name_list = get_input_for_model(model_name, idx)
                    l2_norm_msft_against_orig, l2_norm_against_msft, \
                            l2_norm_against_orig = ort_single_input(data_list, args, so, output_dir)
                    dpl = " ".join(file_name_list)
                    result_line = f"Data point: {dpl, idx}"
                    print(
                        "Orig Model L2 Norm with MSFT is:",
                        l2_norm_msft_against_orig,
                    )
                    print(
                        "Fused Model L2 Norm with MSFT is:",
                        l2_norm_against_msft,
                    )
                    sum_l2_norm_against_msft += l2_norm_against_msft
                    count += 1
                    print(
                        "Average L2 Norm against MSFT so far: ",
                        sum_l2_norm_against_msft / count,
                    )
                    result_line = result_line.join(f", L2 Norm against MSFT output: {l2_norm_against_msft}")
                    print(
                        "Fused Model L2 Norm with Orig Model is:",
                        l2_norm_against_orig,
                    )
                    result_line = result_line.join(f", L2 Norm against Orig Model output: {l2_norm_against_orig}")
                    result_lines.append(result_line)
                except Exception as e:
                    print(
                        f"Script Failed. An unexpected error occurred: {e} \nPartial results saved to txt file"
                    )
                    save_results(True, result_lines, model_name, output_dir)
                    exit()
                except KeyboardInterrupt:
                    print("Script Cancelled. Partial results saved to txt file")
                    save_results(True, result_lines, model_name, output_dir)
                    sys.exit(0)

            save_results(False, result_lines, model_name, output_dir)
            remove_tmp_files()


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="run_ort.py, run the original and fused graph and dump the specified data",
                                  usage='use "%(prog)s --help" for more info',
                                  formatter_class=argparse.RawTextHelpFormatter)
    parser.add_argument("--model_name", required=True)
    parser.add_argument(
        "--ld", type=int, help="load data for large models", default=0, required=False
    )
    parser.add_argument(
        "--idx", type=int, help="against datapoint number", default=0, required=False
    )
    parser.add_argument(
        "--edges",
        type=str,
        help=textwrap.dedent('''\
        edges to dump
        all - extract all ifm/ofm-s from fused graph
        unique - extract ifm/ofm from unique nodes
        fused - extract ifm/ofm from fused nodes only
        fused_unique - extract ifm/ofm from unique
                       fused nodes only
        <text file path> - path to a text file with
                           specified channel names
                           to be extracted'''),
        default="unique", required=False
    )
    parser.add_argument(
        "--all",
        type=bool,
        help="run against all datapoints",
        default=False,
        required=False,
    )
    parser.add_argument(
        "--data_dump",
        type=str,
        help=textwrap.dedent('''\
        data to be dumped,
        all   - dump all specified edges and all consts
        ort   - dump all spcified edges from original model
        carf  - dump all spcified edges from fused model
        const - dump all const data
        multiple options can be specified separated by ',' comma'''),
        default="all",
        required=False
    )
    parser.add_argument(
        "--out_dir",
        type=str,
        help=textwrap.dedent('''\
        output directory, default is from config.py'''),
        default=results_dir,
        required=False
    )
    args = parser.parse_args()

    if not args.model_name:
        parser.error("Please pass model path with --model_name flags.")

    main(vars(args))
