import os
import numpy as np
import argparse

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--file_path", required=True)
    parser.add_argument("--shape", help="Shape (513,513,3)", required=True)
    parser.add_argument("--format", help="NHWC", required=False, default="NHWC")
    parser.add_argument("--data_type", help="int8", required=False, default="int8")

    args = parser.parse_args()
    path = args.file_path
    shape = args.shape.split(',')
    input_format = args.format
    data_type = args.data_type
    d_type = np.int8
    if data_type == "int8":
        d_type = np.int8
    elif data_type == "uint16":
        d_type = np.uint16
    elif data_type == "float":
        d_type = np.float32
    else:
        print("Specified data type not supported")
        return

    H = (int)(shape[0])
    W = (int)(shape[1])
    C = (int)(shape[2])

    full_path = os.path.abspath(path)
    if not os.path.isfile(full_path):
        print("Input file does not exist")
        return
    f = open(full_path, 'r')
    tensor = np.fromfile(f, dtype=d_type)
    f.close()
    assert(tensor.shape[0] == H * W * C)
    tsr = np.reshape(tensor, (H, W, C))
    print(tsr.shape)
    if input_format == "NHWC":
        transposed = np.copy(tsr.transpose(2, 0, 1), order='C')
    elif input_format == "NCHW":
        transposed = np.copy(tsr.transpose(1, 2, 0), order='C')
    else:
        print("Unrecognized input format")
        return
    print(transposed.shape)
    output_file = full_path.rpartition('.')[0] + "_transposed." \
            + full_path.rpartition('.')[2]
    transposed.astype(d_type).tofile(output_file)


if __name__ == "__main__":
    main()
