import random
from typing import List


'''
Define a compile flag to be utilized in the testbench
'''
def directive(ident: str, val: int, hw_run: bool) -> str:
    if hw_run:
        return f"-D{ident}={val}"
    else:
        return f'--Xpreproc="-D{ident}={val}"'
    

'''
Write the indices to a binary file
'''
def write_bin_file(idxs: List[int], filename: str):
    with open(filename, 'wb') as file:
        for i in range(len(idxs)):
            bytes = idxs[i].to_bytes(4, byteorder='little')
            file.write(bytes)

'''
Initialize indices to some random values in the correct range of 
0 to the dimension being gathered (in this case hardcoded to the
outer dimension of a 2D tensor)
'''
def get_random_idxs(input_shape: List[int], num_idxs: int, axis: int):
    max_value = input_shape[axis] - 1
    idxs = []
    for i in range(num_idxs):
        idxs.append(random.randint(0, max_value))
    return idxs

'''
Bookkeeping for the Gather operation
Includes important information of:
    - The shape of the array in use (aie_cols, aie_rows)
    - The input and output shape of the tensor
    - The axis being gathered
    - The size of each element in the input/wgt/output tensor
    - Whether qdq is involved
    - The size of the layer params per core
'''
class Dims:
    def __init__(
        self,
        input_shape: List[int],
        output_shape: List[int],
        is_qdq: bool,
        axis: int,
        param_subv_size: int,
        aie_cols = 8,
        aie_rows = 4,
        input_bits = 16,
        wgt_bits = 64,
        output_bits = 16,
    ):
        self.input_shape = input_shape
        self.output_shape = output_shape
        self.is_qdq = is_qdq
        self.axis = axis
        self.aie_cols = aie_cols
        self.aie_rows = aie_rows
        self.input_bits = input_bits
        self.wgt_bits = wgt_bits
        self.output_bits = output_bits
        self.param_subv_size = param_subv_size