import os
import sys
import json
import traceback
import subprocess
import random
import shutil
import argparse
import glob
import yaml
import numpy as np
import pandas as pd
import math

REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__),"../../"))

# Obviate the need of running 'source settings.sh' before running this script.
#
# It also dodge a tricky situation when PYTHONPATH is defined by the settings.sh
# in other work-space. Running this script will silently pick wrong python
# modules defined in other workspace.
#
sys.path.insert(0, REPO_ROOT)
sys.path.insert(1, os.path.join(REPO_ROOT, "dataflow"))

L1_path = f"{REPO_ROOT}/OGOAT/src/L1_fusion/"
tiler_path = f"{REPO_ROOT}/OGOAT/src/Tiler/"
scheduler_path = f"{REPO_ROOT}/OGOAT/src/Scheduling_Engine/"
collateral_path = f"{REPO_ROOT}/OGOAT/Collaterals/"
sys.path.append(tiler_path)
sys.path.append(scheduler_path)
import OGOAT.src.Tiler.run_tiler as Tiler
import OGOAT.src.Scheduling_Engine.main as Scheduler
from OGOAT.src.L1_fusion.kernel_func_list import kernel_func_list

def clear_folder(folder):
    for filename in os.listdir(folder):
        file_path = os.path.join(folder, filename)
        try:
            if os.path.isfile(file_path) or os.path.islink(file_path):
                os.unlink(file_path)
            elif os.path.isdir(file_path):
                shutil.rmtree(file_path)
        except Exception as e:
            print("Failed to delete %s. Reason: %s" % (file_path, e))

def extract_fields(file_name):
    if os.path.exists(file_name):
        with open(file_name, "r") as f:
            data = json.load(f)
        return data
    else:
        return None

def custom_sort_key(string):
    return [int(string.split('_')[0])] + [int(x) for x in string.split('_')[2].split('x')] #sort by layer number

def prod(arr):
    if len(arr) == 0:
        return 0
    result = 1
    for x in arr:
        result = result*x
    return result

def gen_kernel_list(matmul_list, meta_data):
    kernel_list = {}
    kernel_includes_list = []

    for op in matmul_list:
        op_type = op[0]
        if op_type not in meta_data:
            print(f"Kernel metadata for {op_type} not found")
            continue
        kernel_metadata = meta_data[op_type]
        for s in kernel_metadata["kernel_path"]["kernel_list"]:
            kernel_list[s] = kernel_func_list.index(s)
        kernel_includes_list.extend(kernel_metadata["kernel_path"]["kernel_include"])

    all_kernel_list = dict(sorted(kernel_list.items(), key=lambda item: item[1]))
    all_kernel_includes_list = list(set(kernel_includes_list))
    all_kernel_includes_list.remove('super.hh')
    all_kernel_includes_list.sort()
    all_kernel_includes_list.insert(0, 'super.hh')
    kernel_dict = {
        "pm_0": {
            "kernel_list": all_kernel_list,
            "kernel_include": all_kernel_includes_list,
        },
        "group_norm_in_model": False,
        "disable_fast_pm": False,
    }
    return kernel_dict

def create_csv_file(output_dir = "WAIC_Outputs"):
    df = pd.DataFrame(
        columns=[
            "Model name",
            "L1 fusion",
            "Tiler",
            "Op",
            "Op type",
            "Scheduler",
            "Sim",
            "HW",
            "Time",
            "L2 norm",
        ]
    )
    df = pd.concat(
        [df, pd.DataFrame({"Model name": ['matmul_regression']})], ignore_index=True
    )
    df.to_csv(os.path.join(output_dir, "error_status.csv"), index=False)

def add_to_csv(output_dir = "WAIC_Outputs", op=0, op_type='MatMul', status='Pass'):
            df = pd.read_csv(os.path.join(output_dir, "error_status.csv"))
            df = pd.concat(
                [
                    df,
                    pd.DataFrame(
                        {
                            "Op": [str(op)],
                            "Op type": [op_type],
                            "Scheduler": [status],
                        }
                    ),
                ],
                ignore_index=True,
            )
            df.to_csv(os.path.join(output_dir, "error_status.csv"), index=False)

def get_permutation_order (node, attr, perm):
    for i in range (len(perm)):
        node['attributes'][f'{attr}_{i+1}'] = [int(perm[i])]

def string_to_list (given):
    if given[0] == '[' and given[-1] == ']':
        given = given[1:-1]
    eqv_list = [int(x) for x in given.split(',')]
    return eqv_list

def check_isnan(value):
    try:
        return math.isnan(float(value))
    except (ValueError, TypeError):
        return False

def pm_bin_search(map, kernel_list, id=0):
    for key, val in map.items():
        if key.split('_')[0] == 'pm':
            if set(kernel_list).issubset(set(val['kernel_list'])):
                return int(key.split('_')[-1])
    print(f"Prebuilt xclbin not found for the given kernel list {id}")
    return 0  #default pm_id

def gen_tiler_json(args):
    xlsx_path = args['xlsx_path']
    sheet = args['sheet']
    skip = int(args['skip'])
    stop = int(args['stop'])
    output_dir = args['output_dir']
    fname_log = 'error.log'
    test_dir = os.path.join(REPO_ROOT, output_dir)
    ir_file_name = "test_matmul_IR_unique_nodes.json"
    ir_kernel_name = "test_matmul_IR_kernel_list.json"
    matmul_list = []
    matmul_nRun = []
    yaml_files = glob.glob(os.path.join(collateral_path, "*kernel_metadata.yaml"))  # This will get all YAML files in the current directory
    meta_data = {}

    layer_ids = []
    if len(args['layer']) != 0:
        layer_ids = [int(y) for y in args['layer'].split(',')]

    try:
        with open(os.path.join(collateral_path, 'pm_kernel_map.json'), 'r') as f:
            prebuilt_map = json.load(f)
    except:
        prebuilt_map = {}

    for file_name in yaml_files:
        with open(file_name, 'r') as file:
            meta_data.update(yaml.safe_load(file))

    if skip == 0:
        if sheet=='all': # read all sheets and concat in one table
            dfd = pd.read_excel(xlsx_path,None)
            df = pd.concat(dfd.values(),ignore_index=True)
        else: # read individual sheet only
            df = pd.read_excel(xlsx_path,sheet)

        matmul_dict = df.to_dict(orient='records')
        #new_dict = []
        for node in matmul_dict:
            In_shape = node['Input_shape']
            Out_shape = node['Output_shape']
            Wgt_shape = node['Weight_shape']
            ofm_dtype = node['Op_type'].split('_')[-1].split('x')[-1]
            if In_shape[0] == '[' and In_shape[-1] == ']':
                In_shape = In_shape[1:-1]
            if Out_shape[0] == '[' and Out_shape[-1] == ']':
                Out_shape = Out_shape[1:-1]
            if Wgt_shape[0] == '[' and Wgt_shape[-1] == ']':
                Wgt_shape = Wgt_shape[1:-1]

            in_act_shape = []
            out_act_shape = []
            in_wgt_shape = []
            for i in In_shape.split(','):
                in_act_shape.append(int(i))
            for i in Out_shape.split(','):
                out_act_shape.append(int(i))
            for i in Wgt_shape.split(','):
                in_wgt_shape.append(int(i))
            #NOTE: permutaion support is enabled only for 3D.
            if ((not check_isnan(node.get('PermA',math.nan)) and len(string_to_list(node['PermA'])) == 3) or
                (not check_isnan(node.get('PermB',math.nan)) and len(string_to_list(node['PermB'])) == 3) or
                (not check_isnan(node.get('PermY',math.nan)) and len(string_to_list(node['PermY'])) == 3)):
                #legacy shapes
                permA = [0,1,2] if check_isnan(node.get('PermA',math.nan)) else string_to_list(node['PermA'])
                permB = [0,1,2] if check_isnan(node.get('PermB',math.nan)) else string_to_list(node['PermB'])
                permY = [0,1,2] if check_isnan(node.get('PermY',math.nan)) else string_to_list(node['PermY'])
                InTransposeA = int(permA != [0,1,2]) 
                InTransposeB = int(permB != [0,1,2]) 
                OutTranspose = int(permY != [0,1,2]) 
                Input_size = prod(in_act_shape)
                new_wgt_shape = in_wgt_shape
                if len(new_wgt_shape) == 2:
                    no_batches = [1, 1]
                    K = new_wgt_shape[permB[0]]
                    N = new_wgt_shape[permB[1]]
                    permA = [0] + [x+1 for x in permA]
                    permB = [0] + [x+1 for x in permB]
                    permY = [0] + [x+1 for x in permY]
                elif len(new_wgt_shape) == 3:
                    no_batches = [1, new_wgt_shape[permB[0]]]
                    K = new_wgt_shape[permB[1]]
                    N = new_wgt_shape[permB[2]]
                    permA = [0] + [x+1 for x in permA]
                    permB = [0] + [x+1 for x in permB]
                    permY = [0] + [x+1 for x in permY]
                else:
                    raise Exception("Check the Weight shape")
            else:
                permA = [0,1,2,3] if check_isnan(node.get('PermA',math.nan)) else string_to_list(node['PermA'])
                permB = [0,1,2,3] if check_isnan(node.get('PermB',math.nan)) else string_to_list(node['PermB'])
                permY = [0,1,2,3] if check_isnan(node.get('PermY',math.nan)) else string_to_list(node['PermY'])
                InTransposeA = int(permA != [0,1,2,3]) 
                InTransposeB = int(permB != [0,1,2,3]) 
                OutTranspose = int(permY != [0,1,2,3]) 
                Input_size = prod(in_act_shape)
                new_wgt_shape = in_wgt_shape
                if len(new_wgt_shape) == 2:
                    no_batches = [1 ,1]
                    K = new_wgt_shape[permB[0]]
                    N = new_wgt_shape[permB[1]]
                elif len(new_wgt_shape) == 3:
                    no_batches = [1, new_wgt_shape[permB[0]]]
                    K = new_wgt_shape[permB[1]]
                    N = new_wgt_shape[permB[2]]
                else:
                    wgt_perm = [new_wgt_shape[x] for x in permB]
                    no_batches = wgt_perm[:-2]
                    K = wgt_perm[-2]
                    N = wgt_perm[-1]
            M = int(Input_size/(K*prod(no_batches)))
            in_act_dim = no_batches + [M, K]
            in_wgt_dim = no_batches + [K, N]
            out_dim    = no_batches + [M, N]
            rev_permA = np.argsort(permA)
            rev_permB = np.argsort(permB)

            kernel_dict = gen_kernel_list([[node['Op_type']]], meta_data)
            pm_id = pm_bin_search(prebuilt_map, list(kernel_dict['pm_0']['kernel_list'].keys()), f"{node}")

            node['no_batches'] = prod(no_batches)
            node['M'] = M
            node['K'] = K
            node['N'] = N
            node['in_act_shape']  = [in_act_dim[rev_permA[0]], in_act_dim[rev_permA[1]], in_act_dim[rev_permA[2]], in_act_dim[rev_permA[3]]] 
            node['in_wgt_shape']  = [in_wgt_dim[rev_permB[0]], in_wgt_dim[rev_permB[1]], in_wgt_dim[rev_permB[2]], in_wgt_dim[rev_permB[3]]]
            node['out_act_shape'] = [out_dim[permY[0]], out_dim[permY[1]], out_dim[permY[2]], out_dim[permY[3]]]
            node['outputs'] = "[{'type': 'act', 'dtype': '" + str(ofm_dtype) +"'}]"
            node['attributes'] = {
                    'pm_id'       : pm_id,
                    'num_batches' : prod(no_batches),
                    'InTransposeA': InTransposeA, 
                    'InTransposeB': InTransposeB, 
                    'OutTranspose': OutTranspose,
                    'PermA'       : permA,
                    'PermB'       : permB,
                    'PermY'       : permY,
            }
            node['debug_info'] = {}
            for field in ['midx', 'sched', 'valid_idx', 'pingpong']:
                value = node.get(field, None)
                node['debug_info'][field] = None if check_isnan(value) else value
            #if no_batches==1:
            #    new_dict.append(node)

        
        bytes_for = {"bool": 1/8, "mx9" : 9/8, "bfp16" : 9/8, "float16" : 2, "bfloat16" : 2, "fp32" : 4, "float32" : 4,
                      "double" : 8, "int4" : 1/2, "uint4": 1/2, "int8" : 1, "uint8": 1, "int16": 2, "uint16": 2,  
                      "int32": 4, "uint32": 4, "int64": 8, "uint64": 8,"": 0}
        for node in matmul_dict:
            matmul_list.append((node['Op_type'], node['no_batches'], node['M'], node['K'], node['N'], 
                                node['in_act_shape'], node['in_wgt_shape'], node['out_act_shape'], 
                                node['attributes'], node['Run'], node['outputs'], node['debug_info']))
        
        #matmul_list = list(set(matmul_list))
        #matmul_list.sort(key=lambda x:x[1])

        layer_offset=0
        matmul_list=matmul_list[layer_offset:]


        if os.path.exists(test_dir):
            print(f"Output dir already exist. Deleting it!! {test_dir}")
            #clear_folder(test_dir)
        else:
            os.mkdir(test_dir)

        Tiler_json = {}
        for layer, (op_type, no_batches, M, K, N, in_act_shape, in_wgt_shape, out_act_shape, attr, Run_case, outputs, debug_info) in enumerate(matmul_list):
            if len(layer_ids) != 0 and layer not in layer_ids:
                print(f"skip layer {layer}")
                continue

            if Run_case == False:
                matmul_nRun.append((no_batches, M, K, N))
                continue
            temp = op_type.split('_')[-1]
            in_datatype = temp.split('x')[0]
            wgt_datatype = temp.split('x')[1]
            out_datatype = temp.split('x')[2]
            if 'bias' in op_type:
                in_wgt1_shape = [N]
                wgt1_datatype = in_datatype
            else:
                in_wgt1_shape = []
                wgt1_datatype = "float32"

            Tiler_json[f"{layer}_B{no_batches}_{M}x{K}x{N}"] = {
                "op_type"       : op_type,
                "inputs"        : "[]",
                "outputs"       : outputs,
                "in_act_shape"  : in_act_shape,
                "in_wgt_shape"  : in_wgt_shape,
                "in_wgt1_shape" : in_wgt1_shape,
                "out_act_shape" : out_act_shape,
                "in_datatype"   : in_datatype,
                "wgt_datatype"  : wgt_datatype,
                "wgt1_datatype" : wgt1_datatype,
                "out_datatype"  : out_datatype,
                "in_bytes"      : bytes_for[in_datatype],
                "wgt_bytes"     : bytes_for[wgt_datatype],
                "wgt1_bytes"    : bytes_for[wgt1_datatype],
                "out_bytes"     : bytes_for[out_datatype],
                "qdq_symmetry"  : 0,
                "coeff_shape"   : [N],
                "in_act_residency": "L3",
                "out_act_residency": "L3",
                "attributes"    : {
                    'pm_id'       : [int(attr['pm_id'])],
                    'num_batches' : [int(attr['num_batches'])],
                    'disable_q'   : [0],
                    'InTransposeA': [int(attr['InTransposeA'])],
                    'InTransposeB': [int(attr['InTransposeB'])],
                    'OutTranspose': [int(attr['OutTranspose'])],
                    'permA'       : attr['PermA'],
                    'permB'       : attr['PermB'],
                    'permY'       : attr['PermY'],
                    },
                "debug_info": debug_info
            }
            #get_permutation_order (Tiler_json[f"{layer}_B{no_batches}_{M}x{K}x{N}"], "permA", attr['PermA'])
            #get_permutation_order (Tiler_json[f"{layer}_B{no_batches}_{M}x{K}x{N}"], "permB", attr['PermB'])
            #get_permutation_order (Tiler_json[f"{layer}_B{no_batches}_{M}x{K}x{N}"], "permY", attr['PermY'])
        with open(os.path.join(test_dir,ir_file_name), "w") as fir:
            json.dump(Tiler_json, fir, indent=2)

        #run tiler
        Tiler_arg = {
            "ir_json"  : os.path.join(test_dir, ir_file_name),
            "device"   : "strix",
            "overlay"  : "8x4",
            "tiler_bfm": False,
            "kernel_list": "",
            "build_txn": None,
            "multiprocess": True,
            "output_dir" : False,
            "j"        : 8,
            #"mode_select" : "0,2",
        }
        Tiler.main(Tiler_arg.copy())

    elif skip == 2 and xlsx_path != None:
        if os.path.exists(test_dir):
            print(f"Output dir already exist. Deleting it!! {test_dir}")
            clear_folder(test_dir)
        else:
            os.mkdir(test_dir)

        with open(os.path.join(xlsx_path), "r") as tiler_file:
            Tiler_output_json = json.load(tiler_file)

        Tiler_json = {}
        for layer, opt_tiling in Tiler_output_json.items():
            M = opt_tiling['layer_info']['in_ifm_shape'][0]
            K = opt_tiling['layer_info']['in_ifm_shape'][1]
            N = opt_tiling['layer_info']['out_ofm_shape'][1]
            matmul_list.append(('MatMul_qdq_uint16xuint8xuint16', 1, M, K, N))
            Tiler_json[f"{layer}_{M}x{K}x{N}"] = {
                "op_type"       : "MatMul_qdq_uint16xuint8xuint16",
                "inputs"        : "[]",
                "outputs"       : "[]",
                "in_act_shape"  : [1, M, K],
                "in_wgt_shape"  : [K, N],
                "in_wgt1_shape" : [N],
                "out_act_shape" : [1, M, N],
                "in_datatype"   : "uint16",
                "wgt_datatype"  : "uint8",
                "wgt1_datatype" : "uint16",
                "out_datatype"  : "uint16",
                "in_bytes"      : 2,
                "wgt_bytes"     : 1,
                "wgt1_bytes"    : 2,
                "out_bytes"     : 2,
                "qdq_symmetry"  : 0,
                "coeff_shape"   : [N],
                "in_act_residency": "L3",
                "out_act_residency": "L3",
                "attributes"    : {'num_batches': [no_batches]},
            }
            Tiler_json[f"{layer}_{M}x{K}x{N}"]["attributes"]["pm_id"] = [0]

            layer_name = layer+ f'_{M}x{K}x{N}'
            sub_dir = os.path.join(output_dir, layer_name)
            if not os.path.exists(sub_dir):
                os.makedirs(sub_dir)
            json_out = os.path.join(sub_dir, layer_name + ".json")
            with open(json_out, "w") as f:
                json.dump(opt_tiling, f, indent=2)

        with open(os.path.join(test_dir,ir_file_name), "w") as fir:
            json.dump(Tiler_json, fir, indent=2)

    #create csv file
    create_csv_file(test_dir)

    if skip != 1:
        #create ir kernel list file
        kernel_dict = gen_kernel_list(matmul_list, meta_data)
        with open(os.path.join(test_dir,ir_kernel_name), "w") as kernel_file:
            json.dump(kernel_dict, kernel_file, indent=4)

    #check Tiler Result
    df_dir = [ft.name for ft in os.scandir(test_dir) if ft.is_dir() and bool(os.listdir(ft))]
    
    Tiler_pass = [eval('(' + x.split('_')[1].replace('B','') + ',' + x.split('_')[2].replace('x', ',') + ')') for x in df_dir]
    new_matmul_list = [x[1:5] for x in matmul_list if x[-1]]
    Tiler_fail = list(set(new_matmul_list) - set(Tiler_pass))

    f = open(os.path.join(test_dir, fname_log) , "w")
    f.write(f"new_matmul_list = {new_matmul_list}\n")
    f.write(f"Tiler_pass = {Tiler_pass}\n")
    print(f'Tiler Failed shapes, {len(Tiler_fail)} instances')
    f.write(f'Tiler Failed shapes, {len(Tiler_fail)} instances\n')
    print(f'{Tiler_fail}')
    f.write(f'{Tiler_fail}\n')
    f.close()

    if stop == 1:
        return

    scheduler_pass = []
    dmacompiler_pass = []

    #run scheduler
    df_dir.sort(key=custom_sort_key)
    for op in df_dir:
        f = open(os.path.join(test_dir, fname_log) , "a")
        print(f"Generating dataflow for layer id: {op}")
        check_file_exists = True
        if check_file_exists:
            file_name_list = ['data_flow.py', 'dma.hpp', 'graph.hpp', 'super.cc', 'super.hh']
            for file_name in file_name_list:
                data_flow_path = os.path.join(test_dir, str(op), file_name)
                if os.path.exists(data_flow_path):
                    os.remove(data_flow_path)
        op_json_path = os.path.join(test_dir, str(op), f"{op}.json")
        data = extract_fields(op_json_path)
        if data == None:
            f.write(f'Layer {op}, Tiler Failed\n')
            continue
        op_type = data["layer_info"]["orig_op_type"]
        try:
            #print(f"Running scheduler for {op}:{op_json_path}")
            if args['mode'] == 'python':
                scheduler_arg = {
                        "input_file": op_json_path,
                        "output_dir": os.path.join(test_dir, str(op)),
                        "combine_kernels": False,
                        "fast_pm" : True
                    }
                Scheduler.main(scheduler_arg)
            elif args['mode'] == 'cpp':
                result = subprocess.run(['./WAIC_CPP/build/OGOAT/Release/run_waic', 
                                         '--ir_sched', op_json_path, 
                                         '--output_dir',  os.path.join(test_dir, str(op)), 
                                         '-df', '--verify_txn_on_HW'], capture_output=True, text=True, check=True)
            else:
                raise Exception("mode should be python or cpp")
            add_to_csv(test_dir, op, 'MatMul', 'Pass')
        except:
            add_to_csv(test_dir, op, 'MatMul', 'Fail')
            f.write(f'Layer {op}, Scheduler Failed\n')
            f.write(traceback.format_exc())

        cwd = os.getcwd()
        os.chdir(os.path.join(test_dir, str(op)))
        try:
            if args['mode'] == 'python':
                print(f"Running data flow in {os.getcwd()}")
                result = subprocess.run(['python', 'data_flow.py', '-b', 'TxnHostPatch'], capture_output=True, text=True, check=True)
                f.write(f'Layer {op} PASS\n')
                print(f'Layer {op} PASS\n')
            #print(result.stdout)
        except subprocess.CalledProcessError as e:
            print(f'Layer {op}, dataflow Failed')
            print(e.stderr)
            f.write(f'Layer {op}, dataflow Failed\n')
            f.write(e.stderr)
            print(e.stderr)

        if os.path.exists('data_flow.py'):
            scheduler_pass.append(eval('(' + op.split('_')[1][1:] + ',' + op.split('_')[2].replace('x', ',') + ')'))

        if os.path.exists('dma.hpp'):
            dmacompiler_pass.append(eval('(' + op.split('_')[1][1:] + ',' + op.split('_')[2].replace('x', ',') + ')'))
            
        os.chdir(cwd)
        f.close()

    scheduler_fail = list(set(Tiler_pass) - set(scheduler_pass))
    dmacompiler_fail = list(set(scheduler_pass) - set(dmacompiler_pass))
    f = open(os.path.join(test_dir, fname_log) , "a")
    print(f'Not tested shapse, {matmul_nRun}')
    print(f'Tiler Failed shapes, {len(Tiler_fail)} instances')
    print(f'scheduler fail: {scheduler_fail}, instances {len(scheduler_fail)}')
    print(f'dmacompiler fail: {dmacompiler_fail}, instances {len(dmacompiler_fail)}')
    f.write(f'Not tested shapse, {matmul_nRun}\n')
    f.write(f'Tiler Failed shapes, {len(Tiler_fail)} instances\n')
    f.write(f'scheduler fail: {scheduler_fail}, instances {len(scheduler_fail)})\n')
    f.write(f'dmacompiler fail: {dmacompiler_fail}, instances {len(dmacompiler_fail)}\n')
    f.close()

def update_pm_id(args):
    output_dir = args['output_dir']
    test_dir = os.path.join(REPO_ROOT, output_dir)
    yaml_files = glob.glob(os.path.join(collateral_path, "*kernel_metadata.yaml"))
    df_dir = [ft.name for ft in os.scandir(test_dir) if ft.is_dir() and bool(os.listdir(ft))]

    try:
        with open(os.path.join(collateral_path, 'pm_kernel_map.json'), 'r') as f:
            prebuilt_map = json.load(f)
    except:
        prebuilt_map = {}

    meta_data = {}
    for file_name in yaml_files:
        with open(file_name, 'r') as file:
            meta_data.update(yaml.safe_load(file))

    #df_dir.sort(key=custom_sort_key)
    for op in df_dir:
        print(f"updating {op} pm id")
        op_json_path = os.path.join(test_dir, str(op), f"{op}.json")
        data = extract_fields(op_json_path)
        if data == None:
            print(f'Layer {op}, Tiler Failed\n')
            continue
        op_type = data.get("layer_info", {}).get("op_type", {})
        kernel_dict = gen_kernel_list([[op_type]], meta_data)
        pm_id = pm_bin_search(prebuilt_map, list(kernel_dict['pm_0']['kernel_list'].keys()), f"{data}")
        data.get("layer_info", {}).get("attributes", {})['pm_id'] = [pm_id]
        with open(op_json_path, "w") as fir:
            json.dump(data, fir, indent=2)

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("-f", "--xlsx_path", help="path to xlsx file. Required Field")
    parser.add_argument("-s", "--sheet", help="sheet name. Optional Field. Default value = None", default='all',
                        choices=['all','int16xint4', 'int16xint4_lut', 'int16xint8', 'int16xint8_lut', 'int16xint16_actxact'])
    parser.add_argument("-skip", "--skip", help = "Skip Tiler if value is not 0. Optional Field. Default value = 0", default='0')
    parser.add_argument("-stop", "--stop", help = "Stop after Tiler if value is 1. Optional Field. Default value = 0", default='0')
    parser.add_argument("-o", "--output_dir", help="Output directory name. Optional Field. Default value = WAIC_Outputs", default='WAIC_Outputs')
    parser.add_argument("-l", "--layer", help="specify comma-seprated layer ids; if not specified all layers will ge generated", default='')
    parser.add_argument("-m", "--mode", help="python or cpp", default='python', choices=['python','cpp'])
    parser.add_argument("-pm", "--update_pm_id", help="update pm id in ir json files based on pm_kernel_map.json in collaterals", 
                        default=False, action='store_true')

    args = parser.parse_args()
    if not args.xlsx_path:
        parser.error("Please pass path/to/xlsx_file")

    if args.update_pm_id:
        update_pm_id(vars(args))
    else:
        gen_tiler_json(vars(args))
