from typing import Any

import numpy as np
from OGOAT.src.L1_fusion.py_match.nodes_tensors import Node


class DataflowAttributeGenerator:
    """
    A helper class to generate attributes for reshape, resize  ops from input tensor
    """

    def generate_slice_attributes(self, n: Node) -> dict[str, Any]:
        dims = n("data").require_tensor().get_shape()
        rank = len(dims)
        starts_init = n("starts").require_initializer()
        starts = starts_init.get_value()
        ends_init = n("ends").require_initializer()
        ends = ends_init.get_value()
        axes = (
            n("axes").require_initializer().get_value()
            if n("axes").check_initializer()
            else list(range(len(starts)))
        )
        steps = (
            n("steps").require_initializer().get_value()
            if n("steps").check_initializer()
            else [1] * len(starts)
        )

        # adjust negative axes
        axes = [a + rank if a < 0 else a for a in axes]

        starts_adj = []
        ends_adj = []
        for i in range(len(starts)):
            dim = dims[axes[i]]
            step = steps[i]
            # adjust negative starts/ends
            start = starts[i] + dim if starts[i] < 0 else starts[i]
            end = ends[i] + dim if ends[i] < 0 else ends[i]
            # clamp starts and ends to valid range
            if step > 0:
                start = max(0, min(start, dim))
                end = max(0, min(end, dim))
            else:
                start = max(0, min(start, dim - 1))
                end = max(-1, min(end, dim - 1))

            starts_adj.append(start)
            ends_adj.append(end)

        attributes = {
            "starts": np.array(starts_adj, dtype=starts_init.get_dtype()),
            "start": np.array(starts_adj, dtype=starts_init.get_dtype()),
            "ends": np.array(ends_adj, dtype=ends_init.get_dtype()),
            "end": np.array(ends_adj, dtype=ends_init.get_dtype()),
            "axes": np.array(axes, dtype=np.int64),
            "steps": np.array(steps, dtype=np.int64),
        }

        return attributes
