import yaml
import glob
from collections import defaultdict

import os

def merge_dicts(dict1, dict2):
    for key, value in dict2.items():
        if key in dict1:
            if isinstance(dict1[key], dict) and isinstance(value, dict):
                merge_dicts(dict1[key], value)
            elif isinstance(dict1[key], list) and isinstance(value, list):
                dict1[key].extend(value)
            else:
                dict1[key] = [dict1[key], value] if not isinstance(dict1[key], list) else dict1[key] + [value]
        else:
            dict1[key] = value

def get_kernels_for_model(metadata_path, ops_list, print_to_file = False):
    yaml_files = glob.glob(os.path.join(metadata_path + "*kernel_metadata.yaml"))  # This will get all YAML files in the current directory
    merged_data = defaultdict(dict)

    for file_name in yaml_files:
        with open(file_name, 'r') as file:
            data = yaml.safe_load(file)     
            merge_dicts(merged_data, data)

    # print(merged_data["base_kernel_list"])
    # print(merged_data["base_kernel_include"])

    all_kernel_list = merged_data["base_kernel_list"]
    all_kernel_includes_list = merged_data["base_kernel_include"]
    
    for op in ops_list:
        # print(op)
        # print(merged_data[op]["kernel_path"]["kernel_list"])
        # print(merged_data[op]["kernel_path"]["kernel_include"])
        
        all_kernel_list.extend(merged_data[op]["kernel_path"]["kernel_list"])
        all_kernel_includes_list.extend(merged_data[op]["kernel_path"]["kernel_include"])


    all_kernel_set = set(all_kernel_list)
    all_kernel_includes_set = set(all_kernel_includes_list)

    final_dict = defaultdict(dict)
    # not fixed kernel function list
    final_dict["kernel_list"] = list(all_kernel_set)
    final_dict["kernel_include"] = list(all_kernel_includes_set)

    if print_to_file:
        # Save the merged data to a new YAML file
        with open('merged_file.yaml', 'w') as file:
            # import pdb
            # pdb.set_trace()
            yaml.safe_dump(dict(final_dict), file)

    return list(all_kernel_set), list(all_kernel_includes_set)


#For debug purpose
if __name__ == "__main__":
    ops_list = [
		    "Add_qdq_int16xint8xint16",
            "MatMul_qdq_int16xint8xint16",
            "SiLu_qdq_bf16"
    ]

    get_kernels_for_model("../Collaterals/", ops_list, True)
