import math
import pytest
from buildscripts.common import normalize_shape


def fold_bdcast_identical_dims(shape_a: tuple[int, ...], shape_b: tuple[int, ...], shape_o: tuple[int, ...], fold_left: bool):
    """Fold matching dimensions from either side to reduce broadcast shape rank.
    Should handle the following cases:
        IFM A: (Y, X, C) IFM B: (1, X, C) --fold right-->  IFM A: (1, Y, X*C) IFM B: (1, 1, X*C)
        IFM A: (Y, X, C) IFM B: (Y, X, 1) --fold left-->  IFM A: (1, Y*X, C) IFM B: (1, Y*X, 1)
    Does not fold between broadcast dims.
    """
    axes = range(len(shape_a)) if fold_left else reversed(range(len(shape_a)))
    fold_idx = 0
    for axis in axes:
        if len({shape_a[axis], shape_b[axis], shape_o[axis]}) == 1:
            fold_idx += 1
        else:
            break
    if fold_idx == 0:
        return normalize_shape(shape_a), normalize_shape(shape_b), normalize_shape(shape_o)

    if fold_left:
        def fold_shape(shape):
            return normalize_shape((math.prod(shape[:fold_idx]),) + shape[fold_idx:])
    else:
        def fold_shape(shape):
            return normalize_shape(shape[:-fold_idx] + (math.prod(shape[-fold_idx:]),))

    return fold_shape(shape_a), fold_shape(shape_b), fold_shape(shape_o)


def fold_bdcast_consecutive_dims(shape_a: tuple[int, ...], shape_b: tuple[int, ...], shape_o: tuple[int, ...]):
    """
    Should also fold when there are repeated broadcasted dimensions on the same IFM:
        IFM A: (1, Y, X, C) IFM B: (1, 1, 1, C) --fold left-->  IFM A: (1, 1, Y*X, C) IFM B: (1, 1, 1, C)
        IFM A: (1, Y, X, C) IFM B: (1, Y, 1, 1) --fold right-->  IFM A: (1, 1, Y, X*C) IFM B: (1, 1, Y, 1)
    and vice versa.

    Assumes normalized 4D shapes as input. Checks consecutive dims on N,Y and Y,X and X,C.
    A general solution is too complex and unreadable for now.
    """
    assert len(shape_a) == len(shape_b) == len(shape_o) == 4
    # check each of the six cases, since this logic is more straightforward.
    # N,Y folding (positions 0:2)
    if shape_a[0:2] == (1, 1):
        shape_a = (1, shape_a[0] * shape_a[1], shape_a[2], shape_a[3])
        shape_b = (1, shape_b[0] * shape_b[1], shape_b[2], shape_b[3])
        shape_o = (1, shape_o[0] * shape_o[1], shape_o[2], shape_o[3])
    if shape_b[0:2] == (1, 1):
        shape_a = (1, shape_a[0] * shape_a[1], shape_a[2], shape_a[3])
        shape_b = (1, shape_b[0] * shape_b[1], shape_b[2], shape_b[3])
        shape_o = (1, shape_o[0] * shape_o[1], shape_o[2], shape_o[3])
    # Y,X folding (positions 1:3)
    if shape_a[1:3] == (1, 1):
        shape_a = (1, 1, shape_a[1] * shape_a[2], shape_a[3])
        shape_b = (1, 1, shape_b[1] * shape_b[2], shape_b[3])
        shape_o = (1, 1, shape_o[1] * shape_o[2], shape_o[3])
    if shape_b[1:3] == (1, 1):
        shape_a = (1, 1, shape_a[1] * shape_a[2], shape_a[3])
        shape_b = (1, 1, shape_b[1] * shape_b[2], shape_b[3])
        shape_o = (1, 1, shape_o[1] * shape_o[2], shape_o[3])
    # X,C folding (positions 2:4)
    if shape_a[2:4] == (1, 1):
        shape_a = (1, 1, shape_a[1], shape_a[2] * shape_a[3])
        shape_b = (1, 1, shape_b[1], shape_b[2] * shape_b[3])
        shape_o = (1, 1, shape_o[1], shape_o[2] * shape_o[3])
    if shape_b[2:4] == (1, 1):
        shape_a = (1, 1, shape_a[1], shape_a[2] * shape_a[3])
        shape_b = (1, 1, shape_b[1], shape_b[2] * shape_b[3])
        shape_o = (1, 1, shape_o[1], shape_o[2] * shape_o[3])

    return shape_a, shape_b, shape_o


def fold_bdcast(shape_a: tuple[int, ...], shape_b: tuple[int, ...], shape_o: tuple[int, ...]):
    """Wrapper to fold both identical and consecutive dims.
    """
    shape_a, shape_b, shape_o = fold_bdcast_identical_dims(shape_a, shape_b, shape_o, fold_left=True)
    shape_a, shape_b, shape_o = fold_bdcast_identical_dims(shape_a, shape_b, shape_o, fold_left=False)
    shape_a, shape_b, shape_o = fold_bdcast_consecutive_dims(shape_a, shape_b, shape_o)
    return shape_a, shape_b, shape_o


def pad_to_32_bits(shape: tuple[int, ...], n_bytes: int) -> tuple[int, ...]:
    """Increments largest dimensions greater than 1 until total bits is multiple of 32.
    """

    total_bits = math.prod(shape) * n_bytes * 8
    we_do_not_need_to_pad = total_bits % 32 == 0
    if we_do_not_need_to_pad:
        return shape
    largest_dim_greater_than_1 = -1
    for i in range(len(shape)):
        if shape[i] > 1:
            largest_dim_greater_than_1 = i
            break
    assert largest_dim_greater_than_1 != -1, "We shouldn't be padding to 32-bits if all dimensions are equal to 1."
    shape = list(shape)
    iterations = 0
    while total_bits % 32 != 0:
        shape[largest_dim_greater_than_1] += 1
        total_bits = math.prod(shape) * n_bytes * 8
        iterations += 1
        if iterations > 2:
            raise RuntimeError("Could not pad to 32-bits after 4 iterations, something is wrong.")
    return tuple(shape)


@pytest.mark.tiler
def test_bdcast_fold_identical_dims_right_fold():
    assert fold_bdcast_identical_dims((4, 8, 16), (1, 8, 16), (4, 8, 16), fold_left=False) == (
        (1, 1, 4, 128),
        (1, 1, 1, 128),
        (1, 1, 4, 128),
    )


@pytest.mark.tiler
def test_bdcast_fold_identical_dims_left_fold():
    assert fold_bdcast_identical_dims((4, 8, 16), (4, 8, 1), (4, 8, 16), fold_left=True) == (
        (1, 1, 32, 16),
        (1, 1, 32, 1),
        (1, 1, 32, 16),
    )


@pytest.mark.tiler
def test_bdcast_fold_identical_dims_right_no_fold():
    assert fold_bdcast_identical_dims((4, 8, 16), (1, 1, 16), (4, 8, 16), fold_left=False) == (
        (1, 4, 8, 16),
        (1, 1, 1, 16),
        (1, 4, 8, 16),
    )


@pytest.mark.tiler
def test_bdcast_fold_identical_dims_left_no_fold():
    assert fold_bdcast_identical_dims((4, 8, 16), (4, 1, 1), (4, 8, 16), fold_left=True) == (
        (1, 4, 8, 16),
        (1, 4, 1, 1),
        (1, 4, 8, 16),
    )


@pytest.mark.tiler
def test_bdcast_fold_consecutive_dims_left_fold():
    assert fold_bdcast_consecutive_dims((1, 4, 8, 16), (1, 1, 1, 16), (1, 4, 8, 16)) == (
        (1, 1, 32, 16),
        (1, 1, 1, 16),
        (1, 1, 32, 16),
    )


@pytest.mark.tiler
def test_bdcast_fold_consecutive_dims_right_fold():
    assert fold_bdcast_consecutive_dims((1, 4, 8, 16), (1, 4, 1, 1), (1, 4, 8, 16)) == (
        (1, 1, 4, 128),
        (1, 1, 4, 1),
        (1, 1, 4, 128),
    )


@pytest.mark.tiler
def test_bdcast_fold_consecutive_dims_no_fold():
    assert fold_bdcast_consecutive_dims((1, 4, 8, 16), (1, 1, 8, 1), (1, 4, 8, 16)) == (
        (1, 4, 8, 16),
        (1, 1, 8, 1),
        (1, 4, 8, 16),
    )


@pytest.mark.tiler
def test_bdcast_fold_consecutive_dims_no_fold_2():
    assert fold_bdcast_consecutive_dims((1, 4, 8, 16), (1, 4, 1, 16), (1, 4, 8, 16)) == (
        (1, 4, 8, 16),
        (1, 4, 1, 16),
        (1, 4, 8, 16),
    )


@pytest.mark.tiler
def test_bdcast_multiple_folds():
    """Put all the gotcha's here."""
    assert fold_bdcast(
        (1, 12, 77, 77),
        (1, 1, 77, 77),
        (1, 12, 77, 77)
    ) == (
        (1, 1, 12, 5929),
        (1, 1, 1, 5929),
        (1, 1, 12, 5929)
    )
    assert fold_bdcast(
        (1, 64, 64, 320),
        (1, 1, 1, 320),
        (1, 64, 64, 320)
    ) == (
        (1, 1, 4096, 320),
        (1, 1, 1, 320),
        (1, 1, 4096, 320)
    )
    assert fold_bdcast(
        (1, 12, 512, 1),
        (1, 1, 1, 1),
        (1, 12, 512, 1)
    ) == (
        (1, 1, 1, 6144),
        (1, 1, 1, 1),
        (1, 1, 1, 6144)
    )
    # laion shape
    assert fold_bdcast(
        (10, 8, 77, 128),
        (10, 1, 77, 128),
        (10, 8, 77, 128),
    ) == (
        (1, 10, 8, 9856),
        (1, 10, 1, 9856),
        (1, 10, 8, 9856)
    )


@pytest.mark.tiler
def test_pad_to_32_bits_no_pad():
    assert pad_to_32_bits((4, 8, 16), 2) == (4, 8, 16)  # 4*8*16*16=8192 bits, already multiple of 32


@pytest.mark.tiler
def test_pad_to_32_bits_pad():
    assert pad_to_32_bits((1, 3, 5, 7), 2) == (1, 4, 5, 7)  # 3*5*7*16=1680 bits, needs padding to 32-bits
