import sys
import os
import glob
import traceback
import json
infra_path = (os.path.dirname(os.path.abspath(__file__))+"/infra/")
schedules_path = (os.path.dirname(os.path.abspath(__file__))+"/schedules/")
code_gen_path = (os.path.dirname(os.path.abspath(__file__))+"/code_gen/")
sys.path.append(infra_path)
sys.path.append(schedules_path)
sys.path.append(code_gen_path)
import argparse
import logging
import pdb
import queue
from dataclasses import dataclass, field
from typing import List

# From infra dir
import scheduler_utils as utils
import custom_dict
#From schedules dir
import OGOAT.src.Scheduling_Engine.schedules.scheduler as sch
import buffer_allocator
import access_pattern


class SchedulingEngine:
    def __init__(self, _stage: sch.Stage):
        self.stage = _stage

    def execute_stages(self, _pipeline_data: sch.SharedResource):
        self.stage.execute(_pipeline_data)


def extract_fields(file_name: str):
    if os.path.exists(file_name):
        with open(file_name, "r") as f:
            data = json.load(f)
        return data
    else:
        return None

def add_stages(fifo):
    buffer_allocation_schedule = SchedulingEngine(
        buffer_allocator.BufferAllocator(artifacts_dict)
    )
    access_pattern_schedule = SchedulingEngine(
        access_pattern.AccessPattern(artifacts_dict)
    )

    fifo.put(buffer_allocation_schedule)
    fifo.put(access_pattern_schedule)
    generate_schedule()


def generate_schedule():
    SharedData = sch.SharedResource()

    while not stage_queue.empty():
        stage = stage_queue.get()
        stage.execute_stages(SharedData)


def init_args_Scheduler(args):
    test_dir = args.output_dir
    combine_kernels = args.combine_kernels.lower() in ["true", "1"]

    kernel_file = (
        [
            f
            for f in glob.glob(os.path.join(test_dir, "*.*"), recursive=True)
            if f.endswith("kernel_list.json")
        ][0]
        if combine_kernels
        else 0
    )

    return test_dir, kernel_file

def run_Scheduler_single_op(args, bypass_scheduler, test_dir, kernel_file, op):
	print(f"Generating dataflow for layer id: {op}")
	op_json_path = os.path.join(test_dir, str(op), f"{op}.json")
	data = extract_fields(op_json_path)
	if data == None:
		return
	op_type = data["layer_info"]["orig_op_type"]
	try:
		print(f"Running scheduler for {op}:{op_json_path}")
		if op_type not in bypass_scheduler:
			scheduler_arg = {
				"input_file": op_json_path,
				"output_dir": os.path.join(test_dir, str(op)),
				"combine_kernels": kernel_file,
				"fast_pm": not args.disable_fast_pm,
				"call_DMAC": args.call_DMAC,
			}
			main(scheduler_arg)

		return [str(op), str(op_type), "pass"]

	except:
		assert not args.assert_on_error, f"Error running scheduler for {op}:"
		print(f"Error running scheduler for {op}:")
		print(traceback.format_exc())
		return [str(op), str(op_type), "fail"]

def main(args):
    logging.info("Argument summary: %s", args)
    #print("Argument summary: ", args)

    global artifacts_dict
    test_scenario = utils.load_file(args["input_file"])

    mem_tile_params_obj = custom_dict.ReadOnlyDict(test_scenario["mem_tile_params"])
    core_tile_params_obj = custom_dict.ReadOnlyDict(test_scenario["core_tile_params"])
    shim_tile_params_obj = custom_dict.ReadOnlyDict(test_scenario["shim_tile_params"])
    scheduling_obj = custom_dict.ReadOnlyDict(test_scenario["scheduling"])
    layer_info_obj = custom_dict.ReadOnlyDict(test_scenario["layer_info"])
    overlay_info_obj = custom_dict.ReadOnlyDict(test_scenario["overlay_info"])
    kernel_info_obj = custom_dict.ReadOnlyDict(test_scenario["kernel_info"])
    layer_padding_obj = custom_dict.ReadOnlyDict(test_scenario["dma_layer_padding"])
    host_padding_obj = custom_dict.ReadOnlyDict(test_scenario["host_layer_padding"])
    dram_params_obj = custom_dict.ReadOnlyDict(test_scenario.get("dram_params", {}))

    # skip conv processing
    if layer_info_obj.get_value("orig_op_type") == "Conv":
        print("Skipping Scheduler for Conv Op")
        return

    artifacts_dict = {
        "kernel_info_obj": kernel_info_obj,
        "overlay_info_obj": overlay_info_obj,
        "layer_info_obj": layer_info_obj,
        "scheduling_obj": scheduling_obj,
        "mem_tile_params_obj": mem_tile_params_obj,
        "core_tile_params_obj": core_tile_params_obj,
        "layer_padding_obj": layer_padding_obj,
        "host_padding_obj": host_padding_obj,
        "shim_tile_params_obj": shim_tile_params_obj,
        "dram_params_obj": dram_params_obj,
        "program_arg_obj": args,
    }

    global stage_queue
    stage_queue = queue.Queue()
    add_stages(stage_queue)


@dataclass
class SchedulerConfig:
    bypass_scheduler: List[str] = field(default_factory=lambda: [
        "Conv",
        "Concat",
        "Transpose",
        "Slice",
        "Slice_qdq",
        "Slice_neg",
        "Resize",
        "DepthToSpace",
        "Quant",
        "Dequant",
        "BilinearResize",
        "MHA",
		"MaxPool",
    ])


if __name__ == "__main__":
	parser = argparse.ArgumentParser(description="Generate compiler scheduler for MLOps",
																	 usage='use "%(prog)s --help" for more info', 
																	 formatter_class=argparse.RawTextHelpFormatter)
	
	#required knobs
	parser.add_argument('-test','--input_file',required=True, help="High level summary of the ML OPs and the scenario to run")
	#debug/profile knobs
	parser.add_argument('-ck','--combine_kernels', help="Use combine kernel file", default=0)
	parser.add_argument('-d','--debug', help="Dump dbg log to 'dbg_log.txt'", action="store_true", default=False)
	parser.add_argument('-df','--debug_file_name', help="Debug log file name", default="dbg_log.txt")
	parser.add_argument('-v','--verbose', choices=['debug', 'info', 'error'], help="Verbosity for debug logs", default='debug')
	parser.add_argument('-p','--profile', help="Profile auto scheduler", action="store_true", default=False)
	parser.add_argument('-pf','--profile_graph_name', help="Profile graph file name", default="dbg_call_graph.png")
	parser.add_argument('-o','--output_dir', help="path to output directory", default="./")
	parser.add_argument('-fast_pm','--fast_pm', help="fast_pm control", default=True)
	parser.add_argument('--call_DMAC', help="Call DMA Compiler instead of dumping .py files", action="store_true", default=False)


	args = parser.parse_args()
	
	utils.check_file_type(args.input_file, ".json")

	if args.debug:
		filename = args.debug_file_name
		verbose  = utils.DEBUG_VERBOSE.str2enum(args.verbose).value
		print("Saving debug log as :",os.getcwd()+"/"+args.debug_file_name)

		logging.basicConfig(
		filename=filename,
		filemode='w',
		format='[%(asctime)s,%(msecs)d] [%(levelname)s]: %(message)s',
		datefmt='%M:%H:%S',
		level=verbose
		)

	if args.profile:
		#Run with profiling
		
		from pycallgraph2 import PyCallGraph
		from pycallgraph2.output import GraphvizOutput
		from pycallgraph2 import Config
		from pycallgraph2 import GlobbingFilter
		
		print("Saving profile graph as as :",os.getcwd()+"/"+args.profile_graph_name)
		config = Config()
		config.trace_filter = GlobbingFilter(exclude=[
			'pycallgraph.*',
			'custom_dict.*',
		])
		graphviz = GraphvizOutput(output_file = args.profile_graph_name)
		with PyCallGraph(output=graphviz, config=config):
			main(vars(args))
	
	else:
		#Run with out profiling
		main(vars(args))
