import os
import sys
import json
import traceback
import subprocess
import random
import shutil
import argparse
import glob
import yaml
import numpy as np
import pandas as pd
import math

def extract_fields(file_name):
    if os.path.exists(file_name):
        with open(file_name, "r") as f:
            data = json.load(f)
        return data
    else:
        return None

def custom_sort_key(string):
    return [int(string.split('_')[0])] + [int(x) for x in string.split('_')[2].split('x')] #sort by layer number

def extract_tiling(args):
    cwd = os.getcwd()
    test_dir = os.path.join(cwd, args['output_dir'])
    df_dir = [ft.name for ft in os.scandir(test_dir) if ft.is_dir() and bool(os.listdir(ft))]
    df_dir.sort(key=custom_sort_key)
    for op in df_dir:
        op_json_path = os.path.join(test_dir, str(op), f"{op}.json")
        cfg = extract_fields(op_json_path)
        if cfg == None:
            print(f'Layer {op}, Tiler Failed\n')
            continue
        split = cfg['overlay_info']['mode']
        ifm_mode = cfg['scheduling']['ifm']
        wgt_mode = cfg['scheduling']['wgt']
        Tm = cfg['core_tile_params']['iters']['ifm'][1] * cfg['mem_tile_params']['iters']['ifm'][1]
        Tk = cfg['core_tile_params']['iters']['wgt'][1] * cfg['mem_tile_params']['iters']['wgt'][1]
        Tn = cfg['core_tile_params']['iters']['wgt'][2] * cfg['mem_tile_params']['iters']['wgt'][2]
        pad_m = cfg['dma_layer_padding'][0]['ifm']['dims'][1]
        pad_k = cfg['dma_layer_padding'][0]['ifm']['dims'][2]
        pad_n = cfg['dma_layer_padding'][1]['wgt']['dims'][2]
        #if split == 'B4M8K1N1':
        print(f'Layer {op:<20}, split {split:<10}, ifm mode {ifm_mode:<5}, wgt mode {wgt_mode:<5}, Tm {Tm:<2}, Tk {Tk:<2}, Tn {Tn:<2}, pad_m {pad_m:<2}, pad_k {pad_k:<2}, pad_n {pad_n:<2}')

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("-o", "--output_dir", help="Output directory name. Optional Field. Default value = WAIC_Outputs", default='WAIC_Outputs')

    args = parser.parse_args()

    extract_tiling(vars(args))