#!/usr/bin/env python3
"""
Simplified script to extract specific fields from the Model_PSR_v1.1_mod__tilings.json file structure.
Based on the observed JSON structure with layers like Add_0, Conv_0, etc.
"""

import json
import csv
import argparse
import os
import sys
import pandas as pd


def extract_tilings_data_simple(json_file_path, output_excel=None):
    """
    Extract specified fields from tilings JSON file.
    
    Args:
        json_file_path (str): Path to the JSON file
        output_excel (str, optional): Path to output Excel file
    
    Returns:
        list: List of dictionaries with extracted data
    """
    print(f"Reading JSON file: {json_file_path}")
    
    # Load JSON data
    try:
        with open(json_file_path, 'r') as f:
            data = json.load(f)
    except FileNotFoundError:
        print(f"Error: File {json_file_path} not found")
        return None
    except json.JSONDecodeError as e:
        print(f"Error: Invalid JSON format - {e}")
        return None
    
    # Extract data for each layer
    extracted_data = []
    total_cycles = 0
    macs_value = None
    total_macs = 0
    macs_totals = {'int16xint16':0,'int16xint8':0,'int8xint8':0}
    
    for layer_name, layer_data in data.items():
        if not isinstance(layer_data, dict):
            continue
        
        # Initialize row with default values
        row = {
            'layer_name': layer_name,
            'orig_op_type': '',
            'frequency': None,
            'in_act_shape': '',
            'in_wgt_shape': '',
            'out_act_shape': '',
            'min_cycles': None,
            'mac_efficiency': None,
            'estimated_cycles': None,
            'MACs': None,
            'dtype_str': ''
        }
        
        # Extract orig_op_type from layer_info
        if 'layer_info' in layer_data and isinstance(layer_data['layer_info'], dict):
            if 'orig_op_type' in layer_data['layer_info']:
                row['orig_op_type'] = str(layer_data['layer_info']['orig_op_type'])
        
        # If orig_op_type not found, infer from layer name
        if not row['orig_op_type'] and '_' in layer_name:
            row['orig_op_type'] = layer_name.split('_')[0]
        
        # Extract frequency from layer_info
        if 'layer_info' in layer_data and isinstance(layer_data['layer_info'], dict):
            if 'Frequency' in layer_data['layer_info']:
                row['frequency'] = (layer_data['layer_info']['Frequency'])
        
        # Extract shapes from layer_info
        if 'layer_info' in layer_data and isinstance(layer_data['layer_info'], dict):
            layer_info = layer_data['layer_info']
            
            # Extract input activation shape
            if 'in_act_shape' in layer_info:
                row['in_act_shape'] = str(layer_info['in_act_shape'])
            elif 'in_ifmA_shape' in layer_info:
                row['in_act_shape'] = str(layer_info['in_ifmA_shape'])
            elif 'in_ifm_shape' in layer_info:
                row['in_act_shape'] = str(layer_info['in_ifm_shape'])
                
            # Extract input weight shape
            if 'in_wgt_shape' in layer_info:
                row['in_wgt_shape'] = str(layer_info['in_wgt_shape'])
            elif 'in_ifmB_shape' in layer_info:
                row['in_wgt_shape'] = str(layer_info['in_ifmB_shape'])
            
            # Extract output activation shape
            if 'out_act_shape' in layer_info:
                row['out_act_shape'] = str(layer_info['out_act_shape'])
            elif 'out_ofm_shape' in layer_info:
                row['out_act_shape'] = str(layer_info['out_ofm_shape'])
            
            # Extract data types and create dtype_str
            in_datatype = layer_info.get('in_datatype', '') or layer_info.get('in_ifm_datatype', '')
            wgt_datatype = layer_info.get('wgt_datatype', '') or layer_info.get('in_wgt_datatype', '')
            
            # Remove first character if it starts with 'u'
            if in_datatype and in_datatype.startswith('u'):
                in_datatype = in_datatype[1:]
            if wgt_datatype and wgt_datatype.startswith('u'):
                wgt_datatype = wgt_datatype[1:]
            
            # Create dtype_str by concatenating
            if in_datatype and wgt_datatype:
                row['dtype_str'] = in_datatype + 'x' + wgt_datatype
            elif in_datatype:
                row['dtype_str'] = in_datatype
            elif wgt_datatype:
                row['dtype_str'] = wgt_datatype
        
        # Extract min_cycles / min_layer_cycles
        min_cycles_value = None
        for key in ('min_cycles', 'min_layer_cycles'):
            if key in layer_data and layer_data[key] is not None:
                min_cycles_value = layer_data[key]
                break
        if min_cycles_value is None and isinstance(layer_data.get('cycle_counts'), dict):
            cycle_data = layer_data['cycle_counts']
            for key in ('min_cycles', 'min_layer_cycles'):
                if key in cycle_data and cycle_data[key] is not None:
                    min_cycles_value = cycle_data[key]
                    break
        if min_cycles_value is None and isinstance(layer_data.get('layer_macs'), dict):
            mac_data = layer_data['layer_macs']
            for key in ('min_cycles', 'min_layer_cycles'):
                if key in mac_data and mac_data[key] is not None:
                    min_cycles_value = mac_data[key]
                    break
        row['min_cycles'] = min_cycles_value

        # Extract estimated_cycles (layer_cycles or first projected_cycles)
        estimated_cycles = None
        
        # For Quant/Dequant operators, calculate cycles based on output shape
        if row['orig_op_type'] in ['Quant', 'Dequant'] and 'layer_info' in layer_data:
            try:
                layer_info = layer_data['layer_info']
                # Get shape_elements list directly from layer_info
                if 'out_act_shape' in layer_info:
                    shape_elements = layer_info['out_act_shape']
                elif 'out_ofm_shape' in layer_info:
                    shape_elements = layer_info['out_ofm_shape']
                else:
                    shape_elements = None
                
                if shape_elements and isinstance(shape_elements, list):
                    total_elements = 1
                    for dim in shape_elements:
                        if isinstance(dim, (int, float)):
                            total_elements *= dim
                    estimated_cycles = total_elements / 8
            except (ValueError, AttributeError, TypeError):
                # If parsing fails, continue with normal cycle extraction
                pass
        
        # For Conv operators, try to get projected_cycles from performance_metrics first
        if not estimated_cycles and row['orig_op_type'] == 'Conv' and 'performance_metrics' in layer_data:
            perf_metrics = layer_data['performance_metrics']
            if isinstance(perf_metrics, list) and len(perf_metrics) > 0:
                first_metric = perf_metrics[0]
                if isinstance(first_metric, dict) and 'projected_cycles' in first_metric:
                    estimated_cycles = (first_metric['projected_cycles'])
        
        # If not found in performance_metrics, try other locations
        if not estimated_cycles:
            # First try layer_cycles
            if 'layer_cycles' in layer_data:
                estimated_cycles = (layer_data['layer_cycles'])
            
            # If layer_cycles not found, try projected_cycles
            if not estimated_cycles and 'projected_cycles' in layer_data:
                projected = layer_data['projected_cycles']
                if isinstance(projected, list) and len(projected) > 0:
                    # Get the first value from the list
                    estimated_cycles = (projected[0])
                elif isinstance(projected, dict):
                    # Get the first value from the dict
                    first_key = next(iter(projected), None)
                    if first_key is not None:
                        estimated_cycles = (projected[first_key])
                else:
                    estimated_cycles = (projected)
            
            # Also check cycle_counts for layer_cycles or projected_cycles
            if not estimated_cycles and 'cycle_counts' in layer_data and isinstance(layer_data['cycle_counts'], dict):
                cycle_data = layer_data['cycle_counts']
                if 'layer_cycles' in cycle_data:
                    estimated_cycles = (cycle_data['layer_cycles'])
                elif 'projected_cycles' in cycle_data:
                    projected = cycle_data['projected_cycles']
                    if isinstance(projected, list) and len(projected) > 0:
                        estimated_cycles = (projected[0])
                    elif isinstance(projected, dict):
                        first_key = next(iter(projected), None)
                        if first_key is not None:
                            estimated_cycles = (projected[first_key])
                    else:
                        estimated_cycles = (projected)

        row['estimated_cycles'] = estimated_cycles

        # Compute MAC efficiency ratio
        if min_cycles_value is not None and estimated_cycles:
            try:
                row['mac_efficiency'] = min_cycles_value / estimated_cycles
            except Exception:
                row['mac_efficiency'] = None
        
        # Extract MACs from cycle_counts or layer_macs
        macs_value = None
        
        # First try to get MACs from cycle_counts
        if 'cycle_counts' in layer_data and isinstance(layer_data['cycle_counts'], dict):
            cycle_data = layer_data['cycle_counts']
            if 'macs' in cycle_data:
                macs_value = cycle_data['macs']
        
        # If not found in cycle_counts, try layer_macs
        if not macs_value and 'layer_macs' in layer_data:
            layer_macs = layer_data['layer_macs']
            if isinstance(layer_macs, dict):
                # If it's a dict, try to get the first value or look for common keys
                if 'macs' in layer_macs:
                    macs_value = layer_macs['macs']
                elif layer_macs:
                    # Get the first value from the dict
                    first_key = next(iter(layer_macs), None)
                    if first_key is not None:
                        macs_value = layer_macs[first_key]
            else:
                # If it's not a dict, use the value directly
                macs_value = layer_macs
        
        row['MACs'] = macs_value
        
        shape_cycles = estimated_cycles * row['frequency'] if estimated_cycles else 0
        shape_latency = shape_cycles / 1.8e6
        row['estimated_cycles x frequency'] = shape_cycles
        row['estimated_latency x frequency (ms)'] = shape_latency
        if estimated_cycles:
            total_cycles += shape_cycles
        if macs_value and row['frequency']:
            if 'x' in row['dtype_str'] and row['dtype_str'] in macs_totals:
                macs_totals[row['dtype_str']] += row['MACs'] * row['frequency']
            # total_macs += row['MACs'] * row['frequency']

        extracted_data.append(row)

    results1 = {k: None for k in row.keys()}
    results1['estimated_cycles'] = 'Totals'
    results1['MACs'] = total_macs
    results1['estimated_cycles x frequency'] = total_cycles
    results1['estimated_latency x frequency (ms)'] = total_cycles/1.8e6
    extracted_data.append(results1)
    
    # Add macs_totals breakdown
    for dtype_key, macs_total in macs_totals.items():
        # if macs_total > 0:  # Only add non-zero entries
        macs_row = {k: None for k in row.keys()}
        macs_row['estimated_cycles'] = dtype_key
        macs_row['MACs'] = macs_total
        extracted_data.append(macs_row)
    
    # results2 = {k: None for k in row.keys()}
    # results2['estimated_cycles'] = 'Latency (ms)'
    # results2['estimated_cycles x frequency'] = total_cycles/1.8e6
    # extracted_data.append(results2)

    print(macs_totals)

    print(f"Extracted data for {len(extracted_data)} layers")
    
    # Display summary
    print("\nSummary:")
    print(f"Total layers: {len(extracted_data)}")
    
    # Count by operation type
    op_types = {}
    for row in extracted_data:
        op_type = row['orig_op_type']
        if op_type:
            op_types[op_type] = op_types.get(op_type, 0) + 1
    
    print("\nOperation types found:")
    for op_type, count in sorted(op_types.items()):
        print(f"  {op_type}: {count}")
    
    # Count layers with data
    layers_with_shapes = sum(1 for row in extracted_data if row['in_act_shape'] or row['out_act_shape'])
    layers_with_estimated_cycles = sum(1 for row in extracted_data if row['estimated_cycles'])
    
    print(f"\nLayers with shape information: {layers_with_shapes}")
    print(f"Layers with estimated cycles: {layers_with_estimated_cycles}")
    
    # Save to Excel if specified
    if output_excel:
        try:
            # Convert to DataFrame for Excel export
            df = pd.DataFrame(extracted_data)
            
            # Reorder columns for better readability
            column_order = ['layer_name', 'orig_op_type', 'frequency', 'dtype_str', 'in_act_shape', 
                           'in_wgt_shape', 'out_act_shape', 'min_cycles', 'estimated_cycles', 'mac_efficiency', 'MACs',
                           'estimated_cycles x frequency', 'estimated_latency x frequency (ms)']
            df = df[column_order]
            
            # Save to Excel
            df.to_excel(output_excel, index=False, engine='openpyxl')
            print(f"\nData saved to: {output_excel}")
        except ImportError:
            print("Error: pandas and openpyxl are required for Excel output. Falling back to CSV...")
            # Fallback to CSV if pandas/openpyxl not available
            csv_file = output_excel.replace('.xlsx', '.csv').replace('.xls', '.csv')
            fieldnames = ['layer_name', 'orig_op_type', 'frequency', 'dtype_str', 'in_act_shape', 
                         'in_wgt_shape', 'out_act_shape', 'min_cycles', 'estimated_cycles', 'mac_efficiency', 'MACs',
                         'estimated_cycles x frequency']
            
            with open(csv_file, 'w', newline='', encoding='utf-8') as csvfile:
                writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
                writer.writeheader()
                for row in extracted_data:
                    writer.writerow(row)
            print(f"CSV fallback saved to: {csv_file}")
        except Exception as e:
            print(f"Error saving to Excel: {e}")
    
    return extracted_data


def main():
    """Main function to handle command line arguments and execute extraction."""
    parser = argparse.ArgumentParser(
        description="Extract specific fields from tilings JSON file",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
Examples:
    python extract_tilings_simple.py --json_path Model_PSR_v1.1_mod__tilings.json
    python extract_tilings_simple.py --json_path Model_PSR_v1.1_mod__tilings.json --output_excel results.xlsx
        """
    )
    
    parser.add_argument(
        "--json_path", "-j",
        required=True,
        help="Path to the tilings JSON file"
    )
    
    parser.add_argument(
        "--output_excel", "-o",
        help="Path to output Excel file (optional)"
    )
    
    parser.add_argument(
        "--show_details", "-d",
        action="store_true",
        help="Show detailed information for each layer"
    )
    
    args = parser.parse_args()
    
    # Validate input file
    if not os.path.exists(args.json_path):
        print(f"Error: JSON file not found: {args.json_path}")
        return 1
    
    # Set default output Excel name if not provided
    if not args.output_excel:
        # Get the directory of the input JSON file
        json_dir = os.path.dirname(os.path.abspath(args.json_path))
        base_name = os.path.splitext(os.path.basename(args.json_path))[0]
        args.output_excel = os.path.join(json_dir, f"{base_name}_extracted.xlsx")
    
    # Extract data
    extracted_data = extract_tilings_data_simple(args.json_path, args.output_excel)
    
    if extracted_data is None:
        return 1
    
    # Show details if requested
    if args.show_details:
        print("\n" + "="*80)
        print("DETAILED LAYER INFORMATION")
        print("="*80)
        
        for row in extracted_data[:10]:  # Show first 10 layers
            print(f"\nLayer: {row['layer_name']}")
            for key, value in row.items():
                if key != 'layer_name' and value:
                    print(f"  {key}: {value}")
        
        if len(extracted_data) > 10:
            print(f"\n... and {len(extracted_data) - 10} more layers")
    
    # Show preview
    print("\n" + "="*80)
    print("PREVIEW (first 5 layers)")
    print("="*80)
    for i, row in enumerate(extracted_data[:5]):
        print(f"{i+1}. {row['layer_name']} ({row['orig_op_type']})")
        for key, value in row.items():
            if key not in ['layer_name', 'orig_op_type'] and value:
                print(f"   {key}: {value}{'...' if len(str(value)) > 100 else ''}")
    
    print(f"\n✓ Extraction completed successfully!")
    print(f"Output saved to: {args.output_excel}")
    
    return 0


if __name__ == "__main__":
    sys.exit(main())
