import os
import sys
import json
import traceback
import subprocess
import random
import shutil
import numpy as np

REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__),"../../"))
tiler_path = f"{REPO_ROOT}/OGOAT/src/Tiler/"
scheduler_path = f"{REPO_ROOT}/OGOAT/src/Scheduling_Engine/"
sys.path.append(tiler_path)
sys.path.append(scheduler_path)
import OGOAT.src.Tiler.run_tiler as Tiler
import OGOAT.src.Scheduling_Engine.main as Scheduler

def clear_folder(folder):
    for filename in os.listdir(folder):
        file_path = os.path.join(folder, filename)
        try:
            if os.path.isfile(file_path) or os.path.islink(file_path):
                os.unlink(file_path)
            elif os.path.isdir(file_path):
                shutil.rmtree(file_path)
        except Exception as e:
            print("Failed to delete %s. Reason: %s" % (file_path, e))

def extract_fields(file_name):
    if os.path.exists(file_name):
        with open(file_name, "r") as f:
            data = json.load(f)
        return data
    else:
        return None

def custom_sort_key(string):
    return int(string.split('_')[0]) #sort by layer number


def gen_tiler_json():

    fname_log = 'error.log'
    test_dir = os.path.join(REPO_ROOT, "WAIC_Outputs")
    ir_file_name = "test_matmul_IR_unique_nodes.json"

    if os.path.exists(test_dir):
        print(f"Output dir already exist. Deleting it!! {test_dir}")
        clear_folder(test_dir)
    else:
        os.mkdir(test_dir)

    matmul_list = [(1, 50, 768, 3072), (10, 77, 512, 2048), (1, 128, 768, 3072), (1, 197, 768, 3072)]
    matmul_list = list(set(matmul_list))
    matmul_list.sort(key=lambda x:x[0])

    f = open(os.path.join(test_dir, fname_log) , "w")
    layer = 0

    Tiler_json = {}
    for (A, B, C, D) in matmul_list:
        Tiler_json[f"{layer}_{A}x{B}x{C}x{D}"] = {
            "op_type"       : "MatMul_qdq_biasgelu_uint16xuint8xuint16",
            "inputs"        : "",
            "outputs"       : "",
            "in_act_shape"  : [A*B, C],
            "in_wgt_shape"  : [C, D],
            "in_wgt1_shape" : [D],
            "out_act_shape" : [A, B, D],
            "in_datatype"   : "uint16",
            "wgt_datatype"  : "uint8",
            "wgt1_datatype" : "int32",
            "out_datatype"  : "uint16",
            "in_bytes"      : 2,
            "wgt_bytes"     : 1,
            "wgt1_bytes"    : 4,
            "out_bytes"     : 2,
            "qdq_symmetry"  : "None",
            "coeff_shape"   : [D],
            "in_act_residency": "L3",
            "out_act_residency": "L3",
        }
        layer += 1

    matmul_list = [(1, 128, 768, 3072)]
    matmul_list = list(set(matmul_list))
    matmul_list.sort(key=lambda x:x[0])

    for (A, B, C, D) in matmul_list:
        Tiler_json[f"{layer}_{A}x{B}x{C}x{D}"] = {
            "op_type"       : "MatMul_qdq_biasgelu_uint16xuint8xuint16",
            "inputs"        : "",
            "outputs"       : "",
            "in_act_shape"  : [A, B, C],
            "in_wgt_shape"  : [C, D],
            "in_wgt1_shape" : [D],
            "out_act_shape" : [A, B, D],
            "in_datatype"   : "uint16",
            "wgt_datatype"  : "uint8",
            "wgt1_datatype" : "uint16",
            "out_datatype"  : "uint16",
            "in_bytes"      : 2,
            "wgt_bytes"     : 1,
            "wgt1_bytes"    : 2,
            "out_bytes"     : 2,
            "qdq_symmetry"  : "None",
            "coeff_shape"   : [D],
            "in_act_residency": "L3",
            "out_act_residency": "L3",
        }
        layer += 1
        
    with open(os.path.join(test_dir,ir_file_name), "w") as fir:
        json.dump(Tiler_json, fir, indent=2)

    #run tiler
    Tiler_arg = {
        "ir_json"  : os.path.join(test_dir, ir_file_name),
        "device"   : "strix",
        "overlay"  : "8x4",
        "tiler_bfm": False,
        #"mode_select" : "0,2",
    }
    Tiler.main(Tiler_arg.copy())
    
    #check Tiler Result
    df_dir = [ft.name for ft in os.scandir(test_dir) if ft.is_dir()]
    
    Tiler_pass = [eval('(' + x.split('_')[1].replace('x', ',') + ')') for x in df_dir]
    Tiler_fail = list(set(matmul_list) - set(Tiler_pass))
    print(f'Tiler Failed shapes')
    f.write(f'Tiler Failed shapes\n')
    print(f'{Tiler_fail}')
    f.write(f'{Tiler_fail}\n')

    #run scheduler
    df_dir.sort(key=custom_sort_key)
    for op in df_dir:
        print(f"Generating dataflow for layer id: {op}")
        check_file_exists = False
        op_json_path = os.path.join(test_dir, str(op), f"{op}.json")
        data = extract_fields(op_json_path)
        if data == None:
            f.write(f'Layer {op}, Tiler Failed\n')
            continue
        op_type = data["layer_info"]["orig_op_type"]
        try:
            #print(f"Running scheduler for {op}:{op_json_path}")
            scheduler_arg = {
                    "input_file": op_json_path,
                    "output_dir": os.path.join(test_dir, str(op)),
                    "combine_kernels": 0,
                }
            Scheduler.main(scheduler_arg)
        except:
            f.write(f'Layer {op}, Scheduler Failed\n')
            f.write(traceback.format_exc())

        cwd = os.getcwd()
        os.chdir(os.path.join(test_dir, str(op)))
        try:
            print(f"Running data flow in {os.getcwd()}")
            result = subprocess.run(['python', 'data_flow.py', '-b', 'Adf'], capture_output=True, text=True, check=True)
            f.write(f'Layer {op} PASS\n')
            print(f'Layer {op} PASS\n')
            #print(result.stdout)
        except subprocess.CalledProcessError as e:
            f.write(f'Layer {op}, dataflow Failed\n')
            f.write(e.stderr)
            print(e.stderr)

        os.chdir(cwd)
    f.close()

if __name__ == "__main__":
    gen_tiler_json()
