#
# Copyright (C) 2023 - 2025 Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: MIT
#

import json
import random
from typing import Any

import numpy as np
import onnxruntime_genai as og  # type: ignore[import-untyped]
import torch
from lm_eval.evaluator import (  # type: ignore[import-not-found]
    eval_logger,
)
from tqdm import tqdm


def get_dtype(dtype_arg: str) -> torch.dtype:
    if dtype_arg == "fp32":
        dtype = torch.float32
    if dtype_arg == "fp16":
        dtype = torch.float16
    if dtype_arg == "bf16":
        dtype = torch.bfloat16
    return dtype


def _adjust_config(
    task_dict: dict[str, Any],
    predict_only: bool = False,
    num_fewshot: int | None = None,
    fewshot_random_seed: int = 1234,
    gen_kwargs: dict[str, Any] | None = None,
) -> dict[str, Any]:
    adjusted_task_dict: dict[str, Any] = {}
    for task_name, task_obj in task_dict.items():
        if isinstance(task_obj, dict):
            adjusted_task_dict = {
                **adjusted_task_dict,
                **{task_name: _adjust_config(task_obj)},
            }

        else:
            if task_obj.get_config("output_type") == "generate_until":
                if gen_kwargs is not None:
                    task_obj.set_config(key="generation_kwargs", value=gen_kwargs, update=True)

            if predict_only:
                eval_logger.info(f"Processing {task_name} in output-only mode. Metrics will not be calculated!")
                # we have to change the class properties post-hoc. This is pretty hacky.
                task_obj.override_metric(metric_name="bypass")

            # override tasks' fewshot values to the provided num_fewshot arg value
            # except if tasks have it set to 0 manually in their configs--then we should never overwrite that
            if num_fewshot is not None:
                if (default_num_fewshot := task_obj.get_config("num_fewshot")) == 0:
                    eval_logger.info(
                        f"num_fewshot has been set to 0 for {task_name} in its config. Manual configuration will be ignored."
                    )
                else:
                    eval_logger.warning(
                        f"Overwriting default num_fewshot of {task_name} from {default_num_fewshot} to {num_fewshot}"
                    )
                    task_obj.set_config(key="num_fewshot", value=num_fewshot)
            else:
                # if num_fewshot not provided, and the task does not define a default one, default to 0
                if (default_num_fewshot := task_obj.get_config("num_fewshot")) is None:
                    task_obj.set_config(key="num_fewshot", value=0)
            # fewshot_random_seed set for tasks, even with a default num_fewshot (e.g. in the YAML file)
            task_obj.set_fewshot_seed(seed=fewshot_random_seed)

            adjusted_task_dict[task_name] = task_obj

    return adjusted_task_dict


def oga_generation(
    case: str,
    tasks: str,
    import_model_dir: str,
    eor: str,
    inputs: list[str],
    model_dir: str,
    filename: str,
    random_seed: int = 0,
    numpy_random_seed: int = 1234,
    torch_random_seed: int = 1234,
    seq_len: int = 512,
    max_seq_len: int = 1024,
) -> list[str]:
    # defined PSU prompt
    PSU_PROMPT = "Please solve following problem and explain it to me. Then give me final answer at the end with a single number preceded by string '#### '. "
    # load the eos_token_id
    with open(str(import_model_dir + "genai_config.json")) as f:
        config = json.load(f)
    eos_token_id = config["model"]["eos_token_id"]

    random.seed(random_seed)
    np.random.seed(numpy_random_seed)
    torch.manual_seed(torch_random_seed)

    def model_load(model_dir: str) -> Any:
        model = og.Model(model_dir)
        return model

    def get_tokenizer(model: Any) -> tuple[Any, Any]:
        tokenizer = og.Tokenizer(model)
        tokenizer_stream = tokenizer.create_stream()
        return tokenizer, tokenizer_stream

    model = model_load(model_dir)
    tokenizer, tokenizer_stream = get_tokenizer(model)
    outputs = []
    with open(filename, "w") as file:
        for i in tqdm(range(len(inputs))):
            if case == "default":
                prompt = inputs[i]
            elif (case == "psu_prompt" or case == "psu_prompt_eos_stop") and tasks == "tinyGSM8k":
                # preprending PSU Prompt
                prompt = PSU_PROMPT + inputs[i]

            input_tokens = tokenizer.encode(prompt)[:seq_len]

            search_options = {}
            params = og.GeneratorParams(model)

            search_options["max_length"] = max_seq_len
            params.set_search_options(**search_options)
            generator = og.Generator(model, params)
            generator.append_tokens(input_tokens)

            num_output_tokens = 0
            tokens = []
            response = ""
            while not generator.is_done():
                generator.generate_next_token()
                new_token = generator.get_next_tokens()[0]

                # early stopping w/eos
                if case == "psu_prompt_eos_stop" and (new_token == eos_token_id).any():
                    print(f"****eos triggered, {new_token}****")
                    break
                tokens.append(new_token)
                response += tokenizer_stream.decode(new_token)
                num_output_tokens += 1
            del generator

            # saving OGA generations
            file.write(response + f"\n{eor}\n")
            file.flush()
            outputs.append(response)
    return outputs
