import dataclasses
import os
from traceback import format_exc
from datetime import datetime
from typing import Any
from OGOAT.src.utils.context import Logger


@dataclasses.dataclass
class ErrorSummary:
    backtrace: str
    timestamp: str

    def __str__(self):
        res = f"Error raised at {self.timestamp}:\n"
        res += self.backtrace
        return res


@dataclasses.dataclass
class SafeRunner:
    logger: Logger

    output_dir_path: str

    summary_file_name: str = "error_summary.txt"

    # list of the ErrorSummary which occured in that function
    errors_occured: list[ErrorSummary] = dataclasses.field(default_factory=list)

    # True if last run was successfull, False otherwise
    has_failed = False

    class SafeRunnerError(Exception):
        pass

    def run(self, main_func, *args, **kwargs) -> Any:
        """
        Run the function passed as parameter. If no exception was raised, return the return value of that function,
        otherwise the exception is caught and stored in the error_occured dictionary and None is returned.
        """
        self.has_failed = False
        try:
            res = main_func(*args, **kwargs)
            return res
        # An error was raised and caught in one of the child function of main_func
        # stop there and let the user handle it how it sees fit
        except SafeRunner.SafeRunnerError:
            return None
        except Exception:
            self.has_failed = True
            backtrace = format_exc()
            self.logger.debug(f"Failed with exception: {backtrace}")
            self.errors_occured.append(
                ErrorSummary(backtrace=backtrace, timestamp=datetime.now())
            )

            return None

    def raise_error(self):
        raise SafeRunner.SafeRunnerError()

    def get_error_summary(self) -> str:
        if len(self.errors_occured) == 0:
            return "No error found"

        res = f"Found {len(self.errors_occured)} errors.\n"
        for error in self.errors_occured:
            res += str(error)
        return res

    @property
    def summary_file_path(self) -> str:
        return os.path.join(self.output_dir_path, self.summary_file_name)

    def dump_error_summary(self) -> None:
        if os.path.exists(self.summary_file_path):
            os.remove(self.summary_file_path)

        with open(self.summary_file_path, "w", encoding="utf-8") as fd:
            fd.write(self.get_error_summary())
