#include "txn_utils.hpp"
#include "ipu_hw_config.hpp"
#include "op_utils.hpp"

// size of txn op header in bytes
// this is the wrapper header around txn format supported by aie-rt
constexpr size_t TXN_OP_SIZE = 8;
constexpr uint32_t TXN_OP_CODE = 0;
namespace waic_runner {
    void txn_util::pass_through(uint8_t** ptr) {
        auto op_hdr = (XAie_OpHdr*)(*ptr);
        switch (op_hdr->Op) {
        case XAIE_IO_WRITE: {
            XAie_Write32Hdr* w_hdr = (XAie_Write32Hdr*)(*ptr);
            *ptr = *ptr + w_hdr->Size;
            break;
        }
        case XAIE_IO_BLOCKWRITE: {
            XAie_BlockWrite32Hdr* bw_header = (XAie_BlockWrite32Hdr*)(*ptr);
            *ptr = *ptr + bw_header->Size;
            break;
        }
        case XAIE_IO_MASKWRITE: {
            XAie_MaskWrite32Hdr* mw_header = (XAie_MaskWrite32Hdr*)(*ptr);
            *ptr = *ptr + mw_header->Size;
            break;
        }
        case XAIE_IO_MASKPOLL:
        case XAIE_IO_MASKPOLL_BUSY: {
            XAie_MaskPoll32Hdr* mp_header = (XAie_MaskPoll32Hdr*)(*ptr);
            *ptr = *ptr + mp_header->Size;
            break;
        }
        case (XAIE_IO_CUSTOM_OP_TCT):
        case (XAIE_IO_CUSTOM_OP_DDR_PATCH):
        case (XAIE_IO_CUSTOM_OP_READ_REGS):
        case (XAIE_IO_CUSTOM_OP_RECORD_TIMER):
        case (XAIE_IO_CUSTOM_OP_MERGE_SYNC): {
            XAie_CustomOpHdr* Hdr = (XAie_CustomOpHdr*)(*ptr);
            *ptr = *ptr + Hdr->Size;
            break;
        }
        case (XAIE_IO_PREEMPT): {
            XAie_PreemptHdr* Hdr = (XAie_PreemptHdr*)(*ptr);
            *ptr = *ptr + sizeof(*Hdr);
            break;
        }
        case (XAIE_IO_LOAD_PM_START): {
            XAie_PmLoadHdr* Hdr = (XAie_PmLoadHdr*)(*ptr);
            *ptr = *ptr + sizeof(*Hdr);
            break;
        }
        default:
            std::cout << std::to_string(op_hdr->Op) << std::endl;
            throw std::runtime_error("Unknown op to pass through");
        }
    }

    txn_util::txn_util(const std::vector<uint8_t>& txn_vec) {
        XAie_TxnHeader* Hdr = (XAie_TxnHeader*)txn_vec.data();
        if (txn_vec.size() != Hdr->TxnSize) {
            throw std::runtime_error(
                "Invalid Transaction Vec : Size of input transaction vector and the "
                "size specified in its header doesn't match.");
        }

        txn.resize(Hdr->TxnSize);
        std::memcpy(txn.data(), txn_vec.data(), Hdr->TxnSize);
        txn_size_ = Hdr->TxnSize;
        num_txn_ops_ = Hdr->NumOps;
    }

    std::vector<uint8_t> txn_util::to_vector() { return txn; }

    void txn_util::patch(const Metadata::OpInfo& op_info,
        const Metadata& meta,
        const std::vector<OpArgMap>& args_map) {

        const auto& tensor_map = meta.tensor_map;
        const auto& super_instr_map = meta.super_instr_map;
        const auto& const_map = meta.const_map;
        const auto intermediate_scratch_size =
            MAP_AT(meta.fused_tensors, "scratch").size;
        const auto max_op_scratch_pad_size = meta.max_op_scratch_pad_size;
        auto args = get_op_args(op_info);
        std::vector<size_t> runtime_in_offset = get_runtime_in_offset(op_info.attr);
        std::vector<size_t> runtime_out_offset = get_runtime_in_offset(op_info.attr);

        size_t out_offset = op_info.in_args.size() + op_info.const_args.size();
        auto total_args_size = args.size();

        const auto argmap_partition = partition_argmap(args_map);

        XAie_TxnHeader* Hdr = (XAie_TxnHeader*)txn.data();
        int num_ops = Hdr->NumOps;
        uint8_t* ptr = txn.data() + sizeof(*Hdr);

        for (int n = 0; n < num_ops; n++) {
            auto op_hdr = (XAie_OpHdr*)ptr;
            switch (op_hdr->Op) {
            case XAIE_IO_CUSTOM_OP_DDR_PATCH: {
                XAie_CustomOpHdr* hdr = (XAie_CustomOpHdr*)(ptr);
                std::uint32_t size = hdr->Size;
                patch_op_t* op = (patch_op_t*)((ptr)+sizeof(*hdr));

                const auto curr_argidx = op->argidx;
                const auto curr_offset = op->argplus;

                // support two additional args for super kernels and initlize const params
                // super kernel params can be sent to NPU - ONNX node may not have this as
                // an input to the op some operators may need to send LUTs to NPU from DDR
                // for functionality. This will not be represented as an input in onnx
                // node. Example kernels - bf16 Silu/Gelu ops.
                if (curr_argidx >= total_args_size + 3) {
                    throw std::runtime_error("curr_argidx() >= # op_args() + 3");
                }

                if (curr_argidx >= args_map.size()) {
                    throw std::runtime_error("curr_argidx() >= args_map size()");
                }

                const auto& op_arg = find_op_arg(
                    argmap_partition, curr_argidx, curr_offset);

                if (op_arg.arg_type == OpArgMap::CONST_KERNEL_PARAM_INPUT) {
                    op->argidx = OpArgMap::CONST_KERNEL_PARAM_INPUT;
                    op->argplus = curr_offset + super_instr_map.at(op_info.name).offset;
                }
                else if (op_arg.arg_type == OpArgMap::CONST_INPUT) {
                    op->argidx = OpArgMap::CONST_INPUT;
                    op->argplus = curr_offset + MAP_AT(const_map, op_info.name).offset;
                }
                else if (op_arg.arg_type == OpArgMap::CTRL_PKT_BIN) {
                    // Ctrl Pkt bin will be packed with super kernel instructions BO
                    auto super_kernel_size = meta.fused_tensors.at("super_instr").size;
                    op->argidx = OpArgMap::CONST_KERNEL_PARAM_INPUT;
                    op->argplus = curr_offset + meta.ctrl_pkt_map.at(op_info.name).offset +
                        super_kernel_size;
                }
                else if (op_arg.arg_type == OpArgMap::SCRATCH_PAD) {
                    op->argidx = OpArgMap::SCRATCH_PAD;
                    if (curr_offset >= max_op_scratch_pad_size) {
                        throw std::runtime_error("curr_offset() >= args_map max_op_scratch_pad_size()");
                    }

                    // Note: Internal scratch pad for each op is shared, since it
                    //       is assumed ops will execute sequentially
                    //       Offset by scratch buffer size since beginning will store
                    //       intermediate outputs, i.e. memory layout will be
                    //       [intermediate_outputs | internal_scratch_pad]
                    op->argplus = curr_offset + intermediate_scratch_size;
                }
                else {
                    const size_t onnx_argidx = op_arg.onnx_arg_idx;
                    const auto& arg_label = ARRAY_AT(args, onnx_argidx);
                    const auto& tensor = MAP_AT(tensor_map, arg_label);

                    size_t new_argidx = tensor.xrt_arg_idx;
                    size_t block_offset = tensor.offset;
                    if (op_arg.arg_type == OpArgMap::INPUT && runtime_in_offset.size() > 0) {
                        block_offset += runtime_in_offset[onnx_argidx];
                    }
                    if (op_arg.arg_type == OpArgMap::OUTPUT && runtime_out_offset.size() > 0) {
                        block_offset += runtime_out_offset[onnx_argidx - out_offset];
                    }
                    size_t curr_offset_delta = curr_offset - op_arg.offset;
                    // tensor.offset tells where data actually is
                    // op_arg.padding_offset is op requirement on whether it needs address
                    // of actual data or beginning of padding
                    size_t final_offset =
                        block_offset + curr_offset_delta - op_arg.padding_offset;

                    op->argidx = new_argidx;
                    op->argplus = final_offset;
                }

                ptr = ptr + size;

            } break;
            default:
                // no modification for other ops
                pass_through(&ptr);
                break;
            }
        }
    }

    std::vector<uint8_t>
        txn_util::fuse_txns(const std::vector<std::vector<uint8_t>>& txns) {
        if (txns.empty()) {
            throw std::runtime_error("No transactions to fuse");
        }

        std::vector<uint8_t> fused_txn;

        uint32_t NumOps = uint32_t(0);
        uint32_t TxnSize = uint32_t(sizeof(XAie_TxnHeader));

        // First go through all txn and figure out size to pre-allocate
        // this is to avoid unnecessary vector re-allocation
        for (size_t i = 0; i < txns.size(); ++i) {
            const auto& txn = ARRAY_AT(txns, i);
            const XAie_TxnHeader* txn_hdr = (const XAie_TxnHeader*)txn.data();
            NumOps += txn_hdr->NumOps;

            if (txn_hdr->TxnSize < sizeof(XAie_TxnHeader)) {
                throw std::runtime_error("Size of fused_transaction is smaller than its header");
            }
            uint32_t instr_size = txn_hdr->TxnSize - uint32_t(sizeof(XAie_TxnHeader));
            TxnSize += instr_size;
        }

        fused_txn.reserve(TxnSize);

        // First txn - copy over header too
        const auto& txn1 = ARRAY_AT(txns, 0);
        const XAie_TxnHeader* txn1_hdr = (const XAie_TxnHeader*)txn1.data();
        fused_txn.insert(fused_txn.end(), txn1.data(),
            txn1.data() + txn1_hdr->TxnSize);

        // Rest of txns
        for (size_t i = 1; i < txns.size(); ++i) {
            const auto& txn = ARRAY_AT(txns, i);
            const XAie_TxnHeader* txn_hdr = (const XAie_TxnHeader*)txn.data();
            const uint8_t* instr_ptr = txn.data() + sizeof(XAie_TxnHeader);
            // skip copying over the header for the rest of txns
            size_t instr_size = txn_hdr->TxnSize - sizeof(XAie_TxnHeader);
            fused_txn.insert(fused_txn.end(), instr_ptr, instr_ptr + instr_size);
        }

        // Update the header
        XAie_TxnHeader* txn_vec_hdr = (XAie_TxnHeader*)(fused_txn.data());
        txn_vec_hdr->NumOps = NumOps;
        txn_vec_hdr->TxnSize = TxnSize;
        if (fused_txn.size() != TxnSize) {
            throw std::runtime_error("Size of fused_transaction {} doesn't match the size "
                "in its header {}");
        }
        return fused_txn;
    }

    transaction_op::transaction_op(const std::vector<uint8_t>& txn) {
        txn_op_.resize(txn.size() + TXN_OP_SIZE);

        XAie_TxnHeader* hdr = (XAie_TxnHeader*)txn.data();

        uint32_t* ptr = (uint32_t*)txn_op_.data();
        // set op code
        *ptr = TXN_OP_CODE;
        ptr++;
        *ptr = hdr->TxnSize + TXN_OP_SIZE;

        memcpy(txn_op_.data() + TXN_OP_SIZE, txn.data(), txn.size());
    }

    size_t transaction_op::get_txn_instr_size() {
        uint32_t* ptr = (uint32_t*)txn_op_.data();
        return *(++ptr);
    }

    std::vector<uint8_t> transaction_op::get_txn_op() { return txn_op_; }

    size_t transaction_op::getInstrBufSize(const std::string& txn_str) {
        return TXN_OP_SIZE + txn_str.size();
    }
}


