'''This file contains common functions used across the buildtest framework.'''
import os
import re
import csv
import shutil
from contextlib import contextmanager
from enum import Enum
from typing import Optional
from HW_requirements.aie4_test_script import HW_test
from utils.build_utils import get_ml_timeline_log_level

REPO_ROOT = os.environ.get("AIE4_ROOT_DIR")
CERT_SIM = os.path.join(REPO_ROOT, "cert_sim")


@contextmanager
def change_dir(target_dir):
    """
    Context manager to change to the specified directory and then return to the original directory.
    """
    original_dir = os.getcwd()
    try:
        os.chdir(target_dir)
        yield
    finally:
        os.chdir(original_dir)


def extract_simulation_time(sim_log_content: str) -> float:
    """
    Extract simulation time from the log content.
    sim time in nanoseconds
    """
    sim_time_extract_string = '[INFO] : Simulation Finished, Sim result: 0 Total Simulation time '
    sim_time = 0.0
    try:
        time_string = re.search(f'{re.escape(sim_time_extract_string)}(.+?), Wall',
                                sim_log_content)
        if time_string is not None:
            val, unit = time_string.group(1).split()
            sim_time = float(val)
            if unit == 'ps':
                sim_time /= 1000
            elif unit == 'us':
                sim_time *= 1000
            elif unit == 'ms':
                sim_time *= 1_000_000
            elif unit == 's':
                sim_time *= 1_000_000_000
            elif unit == 'ns':
                sim_time = float(sim_time)
    except AttributeError:
        pass
    return sim_time


def process_simulation_results(sim_log: str,
                               shape_index: int,
                               results_list: list,
                               simtime_list: list) -> None:
    """
    Process the simulation results from the AIESimulator.log file.
    """
    if not os.path.exists(sim_log):
        results_list[shape_index] = 'COMPILE FAIL'
        return

    with open(sim_log, 'r', encoding="utf-8", errors='ignore') as log_file:
        sim_log_content = log_file.read()

    if 'DI_PASS' in sim_log_content:
        results_list[shape_index] = 'DI PASS'
    elif 'DI_FAIL' in sim_log_content:
        results_list[shape_index] = 'DI FAIL'
    else:
        results_list[shape_index] = 'SIM INCOMPLETE'

    sim_time = extract_simulation_time(sim_log_content)
    simtime_list[shape_index] = sim_time


def create_hw_package(build_dir: str) -> None:
    """
    Create the hardware package for the specified build directory.
    """
    def _move_selected_files(src_dir: str, dst_dir: str, *, exts=(), names=()):
        for fname in os.listdir(src_dir):
            if (exts and fname.endswith(tuple(exts))) or (fname in names):
                src = os.path.join(src_dir, fname)
                dst = os.path.join(dst_dir, fname)
                os.rename(src, dst)

    if not os.path.exists(build_dir):
        print(f"Build directory {build_dir} does not exist.")
        return

    hw_package = os.path.join(build_dir, 'hw_package')
    asm_dir = os.path.join(build_dir, 'Work_AIE4')
    os.makedirs(hw_package, exist_ok=True)

    metadata_files = {"external_buffer_id.json"}
    if get_ml_timeline_log_level() > 0:
        metadata_files.update({"ml_timeline_metadata.json", "xrt.ini", "aie_trace_config.json"})
    _move_selected_files(build_dir, hw_package, exts=(".bin",), names=metadata_files)

    elf_files = {"config.json", "control.elf"}
    _move_selected_files(asm_dir, hw_package, exts=(".asm",), names=elf_files)


def write_csv(shape_table, results_list, simtime_list, output_path, fieldnames, field_mapping_fn):
    """
    General CSV writer for any op. Takes a mapping function to construct each row.
    """
    csv_data = []
    for idx, shape in enumerate(shape_table):
        row = field_mapping_fn(shape)
        row["result"] = results_list[idx]
        row["sim_time"] = simtime_list[idx]
        csv_data.append(row)

    with open(output_path, 'w', newline='', encoding="utf-8") as csvfile:
        writer = csv.DictWriter(csvfile, fieldnames=fieldnames + ['result', 'sim_time'])
        writer.writeheader()
        writer.writerows(csv_data)


def default_row_mapper(fieldnames):
    """
    Returns a row-mapping function that zips fieldnames to shape elements.
    Raises AssertionError if shape and fieldnames length mismatch.
    """
    def mapper(shape):
        assert len(fieldnames) == len(shape), (
            f"Fieldnames ({len(fieldnames)}) and shape ({len(shape)}) length mismatch.\n"
            f"Fieldnames: {fieldnames}\n"
            f"Shape: {shape}"
        )
        return dict(zip(fieldnames, shape))
    return mapper


class DataflowType(str, Enum):
    """Enumeration for dataflow types."""
    L2 = 'l2'
    L3 = 'l3'


class BuildTarget(str, Enum):
    """Enumeration for build targets."""
    DATAFLOW = 'dataflow'
    SIM = 'sim'
    CERT_SIM = 'cert_sim'
    CERT = 'cert'


def run_hw_validation(out_dir: str,
                      dtype: str = "int16",
                      hw_req_dir: Optional[str] = None,
                      host: str = os.getenv("HW_HOST", "10.228.203.217"),
                      perf_testing: bool = False,
                      golden_io: Optional[str] = None,
                      mode: str = "op",
                      filter_patterns: str | None = None,
                      debug: bool = False) -> None:
    """Invoke HW_test with optional overrides."""
    parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
    hw_req = hw_req_dir or os.path.join(parent_dir, "HW_requirements")
    HW_test(
        out_dir,
        hw_req,
        host=host,
        perf_testing=perf_testing,
        golden_io=golden_io,
        dtype=dtype,
        mode=mode,
        filter_patterns=filter_patterns,
        debug=debug
    )


def clean_output_dir(out_dir: str, clean: bool = False) -> None:
    """
    Remove and recreate the output directory when the clean flag is set.
    """
    if not clean:
        return

    if os.path.isdir(out_dir):
        shutil.rmtree(out_dir)

    os.makedirs(out_dir, exist_ok=True)


class Counter:
    """A counter"""
    def __init__(self):
        self.count = 0

    def __call__(self):
        self.count += 1
        return self.count
