import os
import json
import csv
import numpy
import argparse
import logging
from collections import defaultdict

def prod(arr):
    if not arr or len(arr) == 0:
        return 0
    result = 1
    for x in arr:
        result *= x
    return result

def main(args):
    json_path = args['json_path']

    #load ir json
    with open(json_path) as f:
        mdict = json.load(f)

    new_dict = []
    total_cycles = 0


    macs16x16 = 0
    macs16x8  = 0
    macs8x8   = 0
    for node_id, node_dict in mdict.items():
        temp = {}
        
        op_type = node_dict['op_type']
        print(op_type)
        in_act_shape = node_dict['in_act_shape']
        in_wgt_shape = node_dict['in_wgt_shape']
        out_act_shape = node_dict['out_act_shape']
        temp['Op_type'] = op_type
        temp['Input_shape'] = in_act_shape
        temp['Weight_shape'] = in_wgt_shape
        temp['Output_shape'] = out_act_shape
        temp['Frequency'] =  node_dict['Frequency']
        temp['No_batches'] = None
        temp['M'] = None
        temp['K'] = None
        temp['N'] = None
        temp['MAC'] = None
        MAC=None

        temp['datatype']=node_dict.get('in_datatype')+'x'+node_dict.get('wgt_datatype')

        if temp['datatype']=='uint16xuint16' or temp['datatype']=='int16xint16':
            macs_per_cycle = 128
        elif temp['datatype']=='uint16xuint8' or temp['datatype']=='int16xint8':
            macs_per_cycle = 256
        elif temp['datatype']=='uint8xuint8' or temp['datatype']=='int8xint8':
            macs_per_cycle = 512
        else:
            macs_per_cycle = 32

        input_elems = prod(in_act_shape)
        weight_elems = prod(in_wgt_shape)
        output_elems = prod(out_act_shape)

        in_bytes = node_dict.get('in_bytes', 0)
        weight_bytes = node_dict.get('wgt_bytes', 0)
        out_bytes = node_dict.get('out_bytes', 0)

        temp['comp_efficiency'] = 1
        temp['io_efficiency'] = 0.8

        if "MatMul" in op_type or 'Gemm' in op_type:
            MAC = 0
            no_batches = 1
            i = 0
            while in_wgt_shape[i] == 1 and len(in_wgt_shape)-i>2:
                i = i+1
            new_wgt_shape = in_wgt_shape[i:]
            K = 1
            N = 1
            if len(new_wgt_shape) == 2:
                if node_dict.get('attributes',None):
                    if 'transB' in node_dict['attributes']:
                        if node_dict['attributes']['transB'][0]==1:
                            K = new_wgt_shape[-1]
                            N = new_wgt_shape[-2]
                        else:
                            K = new_wgt_shape[-2]
                            N = new_wgt_shape[-1]
                else:
                    K = new_wgt_shape[-2]
                    N = new_wgt_shape[-1]
            elif len(new_wgt_shape) == 3:
                no_batches = new_wgt_shape[-3]
                K = new_wgt_shape[-2]
                N = new_wgt_shape[-1]
            elif len(new_wgt_shape)==4:
                no_batches = prod(new_wgt_shape[:-2])
                K = new_wgt_shape[-2]
                N = new_wgt_shape[-1]
            else:
                raise Exception("Check the Weight shape")
            M = input_elems/(K*no_batches)
            temp['No_batches'] = no_batches
            temp['M'] = M
            temp['K'] = K
            temp['N'] = N
            temp['MAC'] = no_batches*M*K*N

            temp['ideal_compute_cycles'] = temp['MAC'] / macs_per_cycle / 32
            temp['io_cycles'] = max(in_bytes*input_elems, weight_elems*weight_bytes, output_elems*out_bytes) / 32


            if no_batches> 1:
                temp['comp_efficiency'] = 0.1
            else:
                if temp['datatype'] == 'uint16xuint16' or temp['datatype'] == 'int16xint16':
                    temp['comp_efficiency'] = 0.2
                else:
                    temp['comp_efficiency'] = 0.3

            temp['shape_cycles'] = max(temp['ideal_compute_cycles']/temp['comp_efficiency'], temp['io_cycles']/temp['io_efficiency'])
        elif "Conv" in op_type:
            # assume NCHW, CoCiKxKy
            # C = in_wgt_shape[0]
            MAC = (prod(out_act_shape)*prod(in_wgt_shape[1:]))
            temp['MAC'] = MAC

            temp['ideal_compute_cycles'] = temp['MAC'] / macs_per_cycle / 32
            temp['io_cycles'] = max(in_bytes*input_elems, weight_elems*weight_bytes, output_elems*out_bytes) / 32
            temp['comp_efficiency'] = 0.4

            temp['shape_cycles'] = max(temp['ideal_compute_cycles']/temp['comp_efficiency'], temp['io_cycles']/temp['io_efficiency'])


        elif 'LayerNormalization' in op_type:
            temp['ideal_compute_cycles'] = output_elems/7.5
            temp['io_cycles'] = max(in_bytes*input_elems, weight_elems*weight_bytes, output_elems*out_bytes) / 32
            temp['shape_cycles'] = max(temp['ideal_compute_cycles']/temp['comp_efficiency'], temp['io_cycles']/temp['io_efficiency'])
        elif 'LpNormalization' in op_type:
            temp['ideal_compute_cycles'] = output_elems/9
            temp['io_cycles'] = max(in_bytes*input_elems, weight_elems*weight_bytes, output_elems*out_bytes) / 32
            temp['shape_cycles'] = max(temp['ideal_compute_cycles']/temp['comp_efficiency'], temp['io_cycles']/temp['io_efficiency'])
        elif 'GroupNormalization' in op_type or 'InstanceNormalization' in op_type:
            temp['ideal_compute_cycles'] = output_elems/2.5
            temp['io_cycles'] = max(in_bytes*input_elems, weight_elems*weight_bytes, output_elems*out_bytes) / 32
            temp['shape_cycles'] = max(temp['ideal_compute_cycles']/temp['comp_efficiency'], temp['io_cycles']/temp['io_efficiency'])
        elif 'Softmax' in op_type:
            temp['ideal_compute_cycles'] = output_elems/6
            temp['io_cycles'] = max(in_bytes*input_elems, weight_elems*weight_bytes, output_elems*out_bytes) / 32
            temp['shape_cycles'] = max(temp['ideal_compute_cycles']/temp['comp_efficiency'], temp['io_cycles']/temp['io_efficiency'])
        elif 'Silu' in op_type or 'Gelu' in op_type or 'Sigmoid' in op_type or 'Swish' in op_type or 'Tanh' in op_type:
            temp['ideal_compute_cycles'] = output_elems/9
            temp['io_cycles'] = max(in_bytes*input_elems, weight_elems*weight_bytes, output_elems*out_bytes) / 32
            temp['shape_cycles'] = max(temp['ideal_compute_cycles']/temp['comp_efficiency'], temp['io_cycles']/temp['io_efficiency'])
        elif 'Add' in op_type or 'Sub' in op_type or 'Mul' in op_type or 'Div' in op_type:
            temp['ideal_compute_cycles'] = output_elems/32/32
            temp['io_cycles'] = max(in_bytes*input_elems, weight_elems*weight_bytes, output_elems*out_bytes) / 32
            temp['shape_cycles'] = max(temp['ideal_compute_cycles']/temp['comp_efficiency'], temp['io_cycles']/temp['io_efficiency'])
        else: # Remaining assumed to be fused/optimized away
            temp['ideal_compute_cycles'] = None
            temp['io_cycles'] = None
            temp['shape_cycles'] = None


        if temp['shape_cycles']:
            op_cycles = temp['shape_cycles'] * temp['Frequency']
            temp['shape_cycles x frequency'] = op_cycles
            temp['shape_latency x frequency'] = op_cycles / 1.8e6
            total_cycles += op_cycles

        if temp['MAC']:
            if temp['datatype'] == 'uint16xuint16' or temp['datatype'] == 'int16xint16':
                macs16x16 += temp['MAC'] * temp['Frequency']
            elif temp['datatype'] == 'uint16xuint8' or temp['datatype'] == 'int16xint8':
                macs16x8 += temp['MAC'] * temp['Frequency']
            elif temp['datatype'] == 'uint8xuint8' or temp['datatype'] == 'int8xint8':
                macs8x8 += temp['MAC'] * temp['Frequency']

        new_dict.append(temp)

    results = {k:None for k in temp.keys()}
    results['shape_cycles'] = 'Totals'
    results['shape_cycles x frequency'] = total_cycles
    results['shape_latency x frequency'] = total_cycles / 1.8e6
    new_dict.append(results)
    results2 = {k:None for k in temp.keys()}
    results2['N'] = 'int16xint16'
    results2['MAC'] = macs16x16
    new_dict.append(results2)
    results3 = {k:None for k in temp.keys()}
    results3['N'] = 'int16xint8'
    results3['MAC'] = macs16x8
    new_dict.append(results3)
    results4 = {k:None for k in temp.keys()}
    results4['N'] = 'int8xint8'
    results4['MAC'] = macs8x8
    new_dict.append(results4)
    # Writing to Excel file
    output_filename = json_path[:-5]+"_size.xlsx"
    try:
        import pandas as pd

        # Convert to DataFrame
        df = pd.DataFrame(new_dict)

        # Write to Excel
        df.to_excel(output_filename, index=False)
        print(f"✓ Successfully saved results to Excel: {output_filename}")

    except ImportError:
        # Fallback: use openpyxl directly
        try:
            from openpyxl import Workbook

            wb = Workbook()
            ws = wb.active
            ws.title = "Performance_Data"

            # Write headers
            headers = list(new_dict[0].keys())
            for col_idx, header in enumerate(headers, 1):
                ws.cell(row=1, column=col_idx, value=header)

            # Write data
            for row_idx, row_data in enumerate(new_dict, 2):
                for col_idx, header in enumerate(headers, 1):
                    ws.cell(row=row_idx, column=col_idx, value=row_data[header])

            wb.save(output_filename)
            print(f"✓ Successfully saved results to Excel: {output_filename}")

        except ImportError:
            # Final fallback: keep CSV format
            print("Warning: Excel libraries not available, falling back to CSV format")
            output_filename = json_path[:-5]+"_size.csv"
            with open(output_filename, 'w', newline='') as csvfile:
                writer = csv.DictWriter(csvfile, fieldnames=new_dict[0].keys())
                writer.writeheader()
                writer.writerows(new_dict)
            print(f"✓ Successfully saved results to CSV: {output_filename}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("-d", "--debug", help="Print lots of debugging statements", action="store_const", dest="loglevel", const=logging.DEBUG)
    parser.add_argument("-f", "--json_path", help="path to json file.Required Field")

    args = parser.parse_args()
    if not args.json_path:
        parser.error("Please pass path/to/json_file")
    logging.basicConfig(level=args.loglevel)
    logging.debug("Debug mode is enabled!")

    main(vars(args))
