#include "passes.hpp"
#include "utils.hpp"
#include "subgraph_op.hpp"
#include "op_utils.hpp"

// This function handles the I/O buffer requests from the operator.
// 1. For I/O buffer reqs from the kernel, runner allows an operator to request for
// a buffer size different from what is there in the original model, provided
// the requested buffer is always equal-to/larger-than the size in the model.
// 2. If two operators request for different sizes for a tensor shared by them,
// runner allocates the max of them.
namespace waic_runner {
    using IOBufferInfo = std::pair<size_t, size_t>;
    static void handle_io_tensors(const Metadata::OpInfo& op_info,
        const OpArgMap& req,
        const Metadata& meta,
        std::map<std::string, IOBufferInfo>& io_bufs,
        bool verbose) {
        auto args = get_op_args(op_info);
        auto buf_name = ARRAY_AT(args, req.onnx_arg_idx);
        auto size_in_meta = MAP_AT(meta.tensor_map, buf_name).size_in_bytes;
        auto size_in_op = req.size;
        auto padding_offset = req.padding_offset;
        // size_in_op may be smaller than size_in_meta due to some runtime op
        if (verbose && (size_in_op < size_in_meta + padding_offset)) { //assert
            std::cout <<
                "Size of IO buffer required by op (" << size_in_op << ") is less "
                "than the size in the model (" << size_in_meta << ") with padding (" << padding_offset << ")"
                " for the node: " << op_info.name << std::endl;
            // throw std::runtime_error("Assert for IO buffer!");
        }
        if (verbose && (size_in_op > size_in_meta)) {
            std::cout << "[WARNING] Size of IO buffer required by "
                "op spec (" << size_in_op << ") is higher "
                "than the size in the model (" << size_in_meta << ")"
                " for the node: " << op_info.name << std::endl;
        }

        if (io_bufs.end() == io_bufs.find(buf_name)) {
            io_bufs[buf_name] = std::make_pair(size_in_op, padding_offset);
        }
        size_t size_io = (size_in_op > size_in_meta) ? size_in_op : size_in_meta;
        io_bufs[buf_name].first = (std::max)(io_bufs[buf_name].first, size_io);
        if (io_bufs[buf_name].second != 0 && padding_offset != 0) {
            if (io_bufs[buf_name].second == padding_offset)
                throw std::runtime_error("Different padding offset required for same IO buffer!");
        }
        io_bufs[buf_name].second =
            (std::max)(io_bufs[buf_name].second, padding_offset);
    }

    static void
        update_io_buffers(Metadata& meta,
            const std::map<std::string, IOBufferInfo>& io_bufs) {

        const size_t max_tensor_padding_sz = meta.max_tensor_padding_sz;

        // waic_runner_LOG_TRACE("  Update IO pack buffer sizes and offsets");
        for (auto& [name, tensor_info] : meta.fused_tensors) {
            if (name == "const" || name == "super_instr" || name == "ctrl_pkt" || name == "in_onnx" || name == "out_onnx") {
                continue;
            }
            // this fixes up tensor offset for packed/fused "in", "out", "scratch"
            // tensors "in" and "out" are input/output of subgraph
            // "scratch" tensor is intermediate outputs of subgraph
            size_t tensor_size = max_tensor_padding_sz;
            for (const auto& sub_tensor_name : tensor_info.packed_tensors) {
                auto& tinfo = MAP_AT(meta.tensor_map, sub_tensor_name);
                auto [sub_tensor_size, _sub_tensor_padding] =
                    MAP_AT(io_bufs, sub_tensor_name);
                // For now ignore padding for inputs of later tensors are in scratch pad
                // (space optimization) have first tensor with an offset of
                // max_tensor_padding_sz i.e. will read adjacent tensors in
                // input/scratch BO
                tinfo.offset = tensor_size;
                tinfo.size_in_bytes = sub_tensor_size;
                tensor_size += sub_tensor_size;
                tensor_size = align_to_next(tensor_size, TENSOR_PACK_ALIGNMENT);
            }
            tensor_info.size = tensor_size;
        }
    }

    static void
        update_superkernel_buffers(Metadata& meta,
            const std::vector<size_t>& super_instr_bufs) {

        size_t tensor_size = 0;
        for (size_t i = 0; i < meta.op_list.size(); ++i) {
            const auto& op_info = meta.op_list[i];
            auto op_size = super_instr_bufs[i];
            meta.super_instr_map[op_info.name] = {/*offset*/ tensor_size,
                /*size*/ op_size };
            tensor_size += op_size;
            tensor_size = align_to_next(tensor_size, TENSOR_PACK_ALIGNMENT);
        }
        MAP_AT(meta.fused_tensors, "super_instr").size = tensor_size;
    }

    static void update_ctrl_pkt_buffers(Metadata& meta,
        const std::vector<size_t>& ctrl_pkt_bufs) {
        size_t tensor_size = 0;
        for (size_t i = 0; i < meta.op_list.size(); ++i) {
            const auto& op_info = meta.op_list[i];
            auto op_size = ctrl_pkt_bufs[i];
            meta.ctrl_pkt_map[op_info.name] = {/*offset*/ tensor_size,
                /*size*/ op_size };
            tensor_size += op_size;
            tensor_size = align_to_next(tensor_size, TENSOR_PACK_ALIGNMENT);
        }
        // This entry in the map does not exist by default. Create it and populate
        // with tensor_size.
        meta.fused_tensors["ctrl_pkt"] = { tensor_size, TENSOR_PACK_ALIGNMENT, {} };
    }

    static void update_op_scratch_buffers(Metadata& meta,
        const std::vector<size_t>& scratch_bufs) {

        size_t max_op_scratch_pad_size = 0;
        meta.scratch_op_set.clear();

        // NOTE: assumption is each op will be sequentially executed
        //       so this scratch pad can be reused
        //       idea is to place at end of intermediate buffers
        for (size_t i = 0; i < meta.op_list.size(); ++i) {
            const auto& op_info = meta.op_list[i];
            auto op_size = scratch_bufs[i];
            if (op_size != 0) {
                meta.scratch_op_set.insert(op_info.name);
                max_op_scratch_pad_size = (std::max)(max_op_scratch_pad_size, op_size);
            }
        }

        // Need to maintain this, since it will be used for BO size
        meta.max_op_scratch_pad_size =
            align_to_next(max_op_scratch_pad_size, TENSOR_PACK_ALIGNMENT);

        auto intermediate_scratch_size = MAP_AT(meta.fused_tensors, "scratch").size;

        MAP_AT(meta.fused_tensors, "scratch").size =
            align_to_next(intermediate_scratch_size, TENSOR_PACK_ALIGNMENT);

    }

    static void update_const_buffers(
        Metadata& meta,
        const std::vector<std::vector<std::pair<size_t, size_t>>>& const_bufs) {
       
        size_t const_tensor_size = 0;
        for (size_t i = 0; i < meta.op_list.size(); ++i) {
            const auto& op_info = meta.op_list[i];
            for (const auto& [xrt_arg_id, buf_size] : const_bufs[i]) {
                meta.const_map[op_info.name] = {/*offset*/ const_tensor_size,
                    /*size*/ buf_size };
                const_tensor_size += buf_size;
                const_tensor_size =
                    align_to_next(const_tensor_size, TENSOR_PACK_ALIGNMENT);
            }
        }
        meta.fused_tensors["const"].size = const_tensor_size;
    }

    // This pass do an initial buffer analysis for all the tensors based on the
    // requirements from the op interface.
    void analyze_buffer_reqs(Metadata& meta, std::string& binfile_path, std::string& prebuilt_bin_dir, const json& tilings_data, bool use_inmem) {
        //waic_runner_LOG_TRACE("Analyzing Buffer Reqs ... START");

        // io bufname -> [size, padding]
        std::map<std::string, IOBufferInfo> io_bufs;
        std::vector<size_t> super_instr_bufs;
        std::vector<size_t> scratch_bufs;
        std::vector<size_t> ctrl_pkt_bin_bufs;

        // [#ops x #consts x (xrt_arg_id, size)]
        std::vector<std::vector<std::pair<size_t, size_t>>> const_bufs;

        size_t max_tensor_padding_sz = 0;

        // Collect All Ops Buffer Reqs
        for (const auto& op_info : meta.op_list) {
            // std::cout << op_info.type << std::endl;
            waic_runner::subgraph_op subop_ =
                waic_runner::subgraph_op<int8_t, int8_t, int8_t>(op_info.type, binfile_path, prebuilt_bin_dir, tilings_data, op_info.attr, use_inmem);
            std::vector<Tensor> tensors;
            auto buf_reqs = subop_.get_buffer_reqs(tensors, tensors, op_info.attr);
            //delete &subop_;
            std::vector<std::pair<size_t, size_t>> consts_in_req;
            size_t super_instr_sz = 0;
            size_t scratch_sz = 0;
            size_t ctrl_pkt_bin_sz = 0;
            for (const auto& req : buf_reqs) {
                if (req.arg_type == OpArgMap::OpArgType::INPUT ||
                    req.arg_type == OpArgMap::OpArgType::OUTPUT) {
                    handle_io_tensors(op_info, req, meta, io_bufs, subop_.get_verbose());
                    max_tensor_padding_sz =
                        (std::max)(max_tensor_padding_sz, req.padding_offset);
                }
                else if (req.arg_type == OpArgMap::OpArgType::CONST_INPUT) {
                    consts_in_req.emplace_back(req.xrt_arg_idx, req.size);
                }
                else if (req.arg_type ==
                    OpArgMap::OpArgType::CONST_KERNEL_PARAM_INPUT) {
                    super_instr_sz = req.size;
                }
                else if (req.arg_type == OpArgMap::OpArgType::SCRATCH_PAD) {
                    scratch_sz = req.size;
                }
                else if (req.arg_type == OpArgMap::OpArgType::CTRL_PKT_BIN) {
                    ctrl_pkt_bin_sz = req.size;
                }
                else {
                    throw std::runtime_error("Unhandled OpArgType in buffer requirements");
                }
            } // for req

            const_bufs.push_back(std::move(consts_in_req));
            super_instr_bufs.push_back(super_instr_sz);
            scratch_bufs.push_back(scratch_sz);
            ctrl_pkt_bin_bufs.push_back(ctrl_pkt_bin_sz);
        } // for op

        meta.max_tensor_padding_sz =
            align_to_next(max_tensor_padding_sz, TENSOR_PACK_ALIGNMENT);

        update_io_buffers(meta, io_bufs);
        update_superkernel_buffers(meta, super_instr_bufs);
        update_ctrl_pkt_buffers(meta, ctrl_pkt_bin_bufs);
        update_const_buffers(meta, const_bufs);
        update_op_scratch_buffers(meta, scratch_bufs);

    }


    void relocate_ctrl_pkt_patch_info(Metadata& meta, std::string& binfile_path, std::string& prebuilt_bin_dir, bool elf_flow, const json& tilings_data, bool use_inmem) {
        auto param_offset = OpArgMap::CONST_KERNEL_PARAM_INPUT; // elf_flow ? OpArgMap::CTRL_PKT_BIN : OpArgMap::CONST_KERNEL_PARAM_INPUT;
        for (size_t i = 0; i < meta.op_list.size(); i++) {
            auto& op_info = meta.op_list.at(i);
            std::vector<size_t> runtime_in_offset = get_runtime_in_offset(op_info.attr);
            std::vector<size_t> runtime_out_offset = get_runtime_out_offset(op_info.attr);

            auto args = get_op_args(op_info);
            size_t out_offset = op_info.in_args.size() + op_info.const_args.size();

            waic_runner::subgraph_op subop_ =
                waic_runner::subgraph_op<int8_t, int8_t, int8_t>(op_info.type, binfile_path, prebuilt_bin_dir, tilings_data, op_info.attr, use_inmem);
            std::vector<Tensor> tensors;
            auto ctrl_pkts_patch_info = subop_.get_ctrl_pkt_patch_info(tensors, tensors, op_info.attr);
            const auto args_map = subop_.get_buffer_reqs(tensors, tensors, op_info.attr);


            // update xrt arg idx based on the args map
            const auto argmap_partition = partition_argmap(args_map);
            for (auto& patch : ctrl_pkts_patch_info) {
                const auto& op_arg = find_op_arg(argmap_partition, patch.xrt_arg_idx, patch.bo_offset);
                if (op_arg.arg_type == OpArgMap::CONST_KERNEL_PARAM_INPUT) {
                    auto tensor_offset = MAP_AT(meta.super_instr_map, op_info.name).offset;
                    auto final_offset = patch.bo_offset + tensor_offset;
                    auto orig_arg_idx = patch.xrt_arg_idx;
                    patch.xrt_arg_idx = OpArgMap::CONST_KERNEL_PARAM_INPUT;
                    patch.bo_offset = final_offset;

                }
                else if (op_arg.arg_type == OpArgMap::CONST_INPUT) {
                    auto tensor_offset = MAP_AT(meta.const_map, op_info.name).offset;
                    auto final_offset = patch.bo_offset + tensor_offset;
                    auto orig_arg_idx = patch.xrt_arg_idx;
                    patch.xrt_arg_idx = OpArgMap::CONST_INPUT;
                    patch.bo_offset = final_offset;

                }
                else if (op_arg.arg_type == OpArgMap::CTRL_PKT_BIN) {
                    auto orig_arg_idx = patch.xrt_arg_idx;
                    patch.xrt_arg_idx = param_offset;
                    patch.bo_offset =
                        patch.bo_offset + meta.ctrl_pkt_map.at(op_info.name).offset;

                }
                else if (op_arg.arg_type == OpArgMap::SCRATCH_PAD) {
                    auto scratch_offset = MAP_AT(meta.fused_tensors, "scratch").size;
                    auto final_offset = patch.bo_offset + scratch_offset;
                    auto orig_arg_idx = patch.xrt_arg_idx;
                    patch.xrt_arg_idx = OpArgMap::SCRATCH_PAD;
                    patch.bo_offset = final_offset;

                }
                else if ((op_arg.arg_type == OpArgMap::INPUT) ||
                    (op_arg.arg_type == OpArgMap::OUTPUT)) {
                    const size_t onnx_argidx = op_arg.onnx_arg_idx;
                    const auto& arg_label = ARRAY_AT(args, onnx_argidx);
                    const auto& tensor = MAP_AT(meta.tensor_map, arg_label);

                    size_t new_argidx = tensor.xrt_arg_idx;
                    size_t block_offset = tensor.offset;
                    if (op_arg.arg_type == OpArgMap::INPUT && runtime_in_offset.size() > 0) {
                        block_offset += runtime_in_offset[onnx_argidx];
                    }                    
                    if (op_arg.arg_type == OpArgMap::OUTPUT && runtime_out_offset.size() > 0) {
                        block_offset += runtime_out_offset[onnx_argidx - out_offset];
                    }                    
                    size_t curr_offset_delta = patch.bo_offset - op_arg.offset;

                    size_t final_offset =
                        block_offset + curr_offset_delta - op_arg.padding_offset;
                    auto orig_arg_idx = patch.xrt_arg_idx;
                    patch.xrt_arg_idx = new_argidx;
                    patch.bo_offset = final_offset;
                }
                else {
                    std::cout << "op type is " << op_info.name << std::endl;
                    throw std::runtime_error("Unknown arg type for op");
                }
            }
            op_info.ctrl_pkt_patch_info = std::move(ctrl_pkts_patch_info);
        }
    }

    void generate_pdi_partitions_pass(Metadata& meta, bool eager_mode) {

        std::vector<Partition> partitions;

        if (0 == meta.op_list.size()) {
            meta.partitions = partitions;
            return;
        }

        std::set<std::uint8_t> unique_pdi_ids;

        Partition partition;

        size_t start_op_id = 0;
        auto curr_pdi_id = meta.op_list.at(0).pdi_id;

        partition.pdi_id = curr_pdi_id;
        unique_pdi_ids.insert(curr_pdi_id);

        for (size_t op_id = 1; op_id < meta.op_list.size(); op_id++) {
            curr_pdi_id = meta.op_list.at(op_id).pdi_id;


            if ((partition.pdi_id != curr_pdi_id) || eager_mode) {
                partition.op_range = std::make_pair(start_op_id, op_id);
                partitions.push_back(partition);

                start_op_id = op_id;
                partition.pdi_id = curr_pdi_id;
                unique_pdi_ids.insert(curr_pdi_id);
            }
        }

        partition.op_range = std::make_pair(start_op_id, meta.op_list.size());
        partitions.push_back(partition);

        if (unique_pdi_ids.end() != unique_pdi_ids.find(255)) {
            throw std::runtime_error("Found CONTROL_PDI_ID - this does not belong to any kernel!");
        }

        meta.partitions = std::move(partitions);
    }

 // update NPU tensor offset based on the new src tensors
 static void update_fused_tensor(Metadata& meta, const std::string fused_tensor_name) {

        size_t total_size = 0;
        size_t offset = 0;
        for (auto& fused_tensor : meta.fused_tensors) {
            auto& packed_tensors = fused_tensor.second.packed_tensors;
            if (fused_tensor_name == fused_tensor.first) {
                for (auto& tensor : packed_tensors) {                    
                    auto& d_tensor = MAP_AT(meta.tensor_map, tensor);
                    total_size += d_tensor.size_in_bytes;
                    d_tensor.offset = offset;
                    offset += d_tensor.size_in_bytes;
                }
                fused_tensor.second.size = total_size;
            }
        }
    }

    Metadata remove_identity_ops(const Metadata& meta) {
        Metadata identity_meta = meta;
        auto it = identity_meta.op_list.begin();
        while (it != identity_meta.op_list.end()) {
            if (it->type.find("noop") != std::string::npos ||
                it->type.find("NoOp") != std::string::npos) {
                // std::cout << it->name << std::endl;
                auto args = get_op_args(*it);
                auto out_idx = args.size() - 1;
                // For an Identity op, args[0] is the input, args[1] is the output tensor
                // name.
                auto& input_tensor = MAP_AT(identity_meta.tensor_map, args[0]);
                auto& output_tensor = MAP_AT(identity_meta.tensor_map, args[out_idx]);

                // Decide which tensor is the source and which one should be
                // updated. If the input tensor comes from "in"
                // update output_tensor, otherwise, update input_tensor.
                bool from_input = output_tensor.size_in_bytes <= input_tensor.size_in_bytes;
                {
                    auto& src = from_input ? input_tensor : output_tensor;
                    auto& dest = from_input ? output_tensor : input_tensor;
                    const std::string replace_arg = from_input ? args[out_idx] : args[0];
                    const std::string with_arg = from_input ? args[0] : args[out_idx];

                    // Propagate offset and parent properties.
                    //dest.xrt_arg_idx = src.xrt_arg_idx;
                    //dest.parent_name = src.parent_name;
                    //dest.offset = src.offset;

                    // Update corresponding references in other operations.
                    for (auto& other_op : identity_meta.op_list) {
                        for (auto& arg : other_op.in_args) {
                            if (arg == replace_arg && (dest.parent_name == "scratch")) {
                                arg = with_arg;
                            }
                        }
                        for (auto& arg : other_op.out_args) {
                            if (arg == replace_arg && (dest.parent_name == "scratch")) {
                                arg = with_arg;
                            }
                        }
                    }

                    // Remove the tensor name from packed_tensors in fused_tensors.
                    for (auto& fused_tensor : identity_meta.fused_tensors) {
                        auto& packed_tensors = fused_tensor.second.packed_tensors;
                        // dest is out, src is out, remove dest from npu out
                        if (dest.parent_name == "out" && src.parent_name == "out") {
                            if (dest.parent_name == fused_tensor.first) {
                                dest.ref_idx = src.ref_idx;
                                packed_tensors.erase(std::remove(packed_tensors.begin(),
                                    packed_tensors.end(), replace_arg),
                                    packed_tensors.end());
                            }
                        }
                        if (dest.parent_name == "out" && src.parent_name == "scratch") {
                            if (dest.parent_name == fused_tensor.first) {
                                // remove src from scratch tensor and add it to in/out packed tensors
                                src.offset = dest.offset;
                                src.ref_idx = dest.ref_idx; //only replace
                                src.xrt_arg_idx = dest.xrt_arg_idx;
                                std::replace(packed_tensors.begin(),
                                    packed_tensors.end(), replace_arg, with_arg);
                            }
                            if ("scratch" == fused_tensor.first) {
                                packed_tensors.erase(std::remove(packed_tensors.begin(),
                                    packed_tensors.end(), with_arg),
                                    packed_tensors.end());
                            }
                            continue;
                        }
                        // if dest in scratch, remove it
                        if (dest.parent_name == "scratch") {
                            packed_tensors.erase(std::remove(packed_tensors.begin(),
                                packed_tensors.end(), replace_arg),
                                packed_tensors.end());
                        }
                    }
                    if (dest.parent_name == "out" && src.parent_name == "scratch") {
                        src.parent_name = dest.parent_name;
                        dest.parent_name = "out_onnx";
                    }

                    // Remove the tensor entry which is no longer referenced.                  
                    if (dest.parent_name != "scratch") {
                        // for output tensor
                        dest.additional_offset = 0;
                        ++it;
                        continue;
                    }
                    identity_meta.tensor_map.erase(replace_arg);
                    it = identity_meta.op_list.erase(it);
                }
            }
            else {
                ++it;
            }
        }
        return identity_meta;
    }

    Metadata remove_concat_runtime_ops(const Metadata& meta) {
        Metadata identity_meta = meta;
        
        std::map<std::string, std::vector<std::pair<int, int>>> name_offset_map;

        auto it = identity_meta.op_list.begin();
        while (it != identity_meta.op_list.end()) {
            if (it->type.find("concat_runtime") != std::string::npos ||
                it->type.find("Concat_runtime") != std::string::npos) {
                auto args = get_op_args(*it);
                auto out_idx = args.size() - 1;
                auto& output_tensor = MAP_AT(identity_meta.tensor_map, args[out_idx]);

                bool from_output = true;
                if (from_output && (output_tensor.parent_name == "in")) {
                    ++it;
                }
                else {
                    int runtime_out_offset = 0;
		    for (int idx = 0; idx < out_idx; ++idx) {
			// std::cout << "for input " << idx << " offset is " << runtime_out_offset << std::endl;
                        auto& input_tensor = MAP_AT(identity_meta.tensor_map, args[idx]);
                        auto& src = from_output ? output_tensor : input_tensor;
                        auto& dest = from_output ? input_tensor : output_tensor;
                        const std::string replace_arg = from_output ? args[idx] : args[out_idx];
                        const std::string with_arg = from_output ? args[out_idx] : args[idx];
			// std::cout << "replace_arg: " << replace_arg << std::endl;
			// std::cout << "with_arg: " << with_arg << std::endl;
                        // Propagate offset and parent properties.
                        dest.xrt_arg_idx = src.xrt_arg_idx;
                        dest.parent_name = src.parent_name;
                        dest.offset = src.offset;
                        // Update corresponding references in other operations.
                        for (auto& other_op : identity_meta.op_list) {
                            int idx_tensor = 0;
                            for (auto& arg : other_op.in_args) {
                                if (arg == replace_arg) {
                                    arg = with_arg;
                                }
                            }
                            for (auto& arg : other_op.out_args) {
                                if (arg == replace_arg) {
                                    arg = with_arg;
                                    // std::cout << other_op.name << std::endl;
                                    // std::cout << idx_tensor << std::endl;
                                    name_offset_map[other_op.name].push_back(std::make_pair(idx_tensor, runtime_out_offset));
                                }
                                idx_tensor++;
                            }
                        }
			// Calculate offset for next input
                        int row_size = 1;
                        for (int i = 0; i < dest.shape.size() - 1; ++i) {
                                row_size *= dest.shape[i];
                        }
			// Round up innermost dimension to be the next multiple of 8 to match with hw buffer size
			int inner_dim = dest.shape[dest.shape.size() - 1];
			if (inner_dim % 8 == 0) {
				row_size *= inner_dim;
			} else {
				row_size *= (inner_dim / 8 + 1) * 8;
			}
			// If dtype is float for hw it will be bfloat
			std::string dtype = dest.dtype;
			if (dest.dtype == "float") {
                            dtype = "bfloat16";
			}
                        runtime_out_offset += row_size * get_size_of_type(dtype);

                        if (runtime_out_offset % 4) {
                            throw std::runtime_error("runtime out offset is not 32bit aligned!");
                        }

                        // Remove the tensor name from packed_tensors in fused_tensors.
                        for (auto& fused_tensor : identity_meta.fused_tensors) {
                            auto& packed_tensors = fused_tensor.second.packed_tensors;
                            packed_tensors.erase(std::remove(packed_tensors.begin(),
                                packed_tensors.end(), replace_arg),
                                packed_tensors.end());
                        }

                        // Remove the tensor entry which is no longer referenced.
                        identity_meta.tensor_map.erase(replace_arg);
                    }
                    it = identity_meta.op_list.erase(it);
                }
            }
            else {
                ++it;
            }
        }

        for (const auto& pair : name_offset_map) {
            const std::string& key = pair.first;
            std::vector<std::pair<int, int>> value = pair.second;
 
            for (auto& op : identity_meta.op_list) {
                if (op.name == key) {
                    std::vector<int> runtime_out_offset_vec(op.out_args.size(), 0);
                    for (int i = 0; i < value.size(); i++) {
                        runtime_out_offset_vec[value[i].first] = value[i].second;
                    }
                    op.attr["runtime_out_offset"] = runtime_out_offset_vec;
                }
            }
        }
        return identity_meta;
    }

    Metadata remove_split_runtime_ops(const Metadata& meta) {
        Metadata identity_meta = meta;
        auto it = identity_meta.op_list.begin();
        while (it != identity_meta.op_list.end()) {
            // std::cout << it->name << std::endl;
            if (it->type.find("split_runtime") != std::string::npos || it->type.find("Split_runtime") != std::string::npos) {
                std::cout << "SKIPPING SPLIT" << it->name << std::endl;
                auto args = get_op_args(*it);
                auto out_idx = args.size() - 1;
                // name.
                int starting_offset;
                auto& input_tensor = MAP_AT(identity_meta.tensor_map, args[out_idx]);
                starting_offset = input_tensor.offset;
                for (auto i = 1; i < args.size(); ++i) {
                    auto& output_tensor = MAP_AT(identity_meta.tensor_map, args[i]);
                    output_tensor.offset = starting_offset;
                    starting_offset += output_tensor.size_in_bytes;
                }
            }
            ++it;
        }
        return identity_meta;
    }

    Metadata remove_slice_runtime_ops(const Metadata& meta) {
        Metadata identity_meta = meta;
        
        std::map<std::string, std::vector<std::pair<int, int>>> name_offset_map;

        auto it = identity_meta.op_list.begin();
        while (it != identity_meta.op_list.end()) {
            if (it->type.find("slice_runtime") != std::string::npos ||
                it->type.find("Slice_runtime") != std::string::npos) {
                auto args = get_op_args(*it);
                auto out_idx = args.size() - 1;
                auto& input_tensor = MAP_AT(identity_meta.tensor_map, args[0]);
                auto& output_tensor = MAP_AT(identity_meta.tensor_map, args[out_idx]);
                // Get start attribute
                int start = 0;
                if (it->attr.count("start") &&
                    it->attr.at("start").type() == typeid(std::vector<int>)) {
                    const auto& start_vector = std::any_cast<const std::vector<int> &>(it->attr.at("start"));
                    if (start_vector.size() != 1) {
                        std::cout << "Warning: slice_runtime op start attribute has more then one value, will take first one" << std::endl;
                    }
                    start = start_vector[0];
                } else {
                    std::cout << "Warning: There is no start attr for slice_runtime op, set it to 0" << std::endl;
                    start = 0;
                }

                // Get axes attribute
                int axes = 0;
                if (it->attr.count("axes") &&
                    it->attr.at("axes").type() == typeid(std::vector<int>)) {
                    const auto& axes_vector = std::any_cast<const std::vector<int> &>(it->attr.at("axes"));
                    if (axes_vector.size() != 1) {
                        std::cout << "Warning: slice_runtime op axes attribute has more then one value, will take first one" << std::endl;
                    }
                    axes = axes_vector[0];
                } else {
                    std::cout << "Warning: There is no axes attr for slice_runtime op, set it to 0" << std::endl;
                    axes = 0;
                }

                bool from_input = output_tensor.size_in_bytes <= input_tensor.size_in_bytes;
                {
                    auto& src = from_input ? input_tensor : output_tensor;
                    auto& dest = from_input ? output_tensor : input_tensor;
                    const std::string replace_arg = from_input ? args[out_idx] : args[0];
                    const std::string with_arg = from_input ? args[0] : args[out_idx];
                    // Propagate offset and parent properties.
                    //dest.xrt_arg_idx = src.xrt_arg_idx;
                    //dest.parent_name = src.parent_name;

                    // Calculate buffer size for [axes+1:-1] dimensions of input shape
                    int row_size = 1;
                    for (int i = (axes + 1); i < src.shape.size() - 1; ++i) {
                            row_size *= src.shape[i];
                    }
                    // Round up innermost dimension to be the next multiple of 8 to match with hw buffer size
                    int inner_dim = src.shape[src.shape.size() - 1];
                    if (inner_dim % 8 == 0) {
                        row_size *= inner_dim;
                    } else {
                        row_size *= (inner_dim / 8 + 1) * 8;
                    }
                    // If dtype is float for hw it will be bfloat
                    std::string dtype = src.dtype;
                    if (src.dtype == "float") {
                        dtype = "bfloat16";
                    }
                    int runtime_in_offset = start * row_size * get_size_of_type(dtype);
                    if (runtime_in_offset % 4) {
                        throw std::runtime_error("runtime in offset is not 32bit aligned!");
                    }
                    // std::cout << "Slice runtime info" << std::endl;
                    // std::cout << row_size << std::endl;
                    // std::cout << start << std::endl;
                    // std::cout << runtime_in_offset << std::endl;



                    //dest.offset = src.offset;
                    // Update corresponding references in other operations.
                    for (auto& other_op : identity_meta.op_list) {
                        int idx_tensor = 0;
                        for (auto& arg : other_op.in_args) {
                            if (arg == replace_arg && (dest.parent_name == "scratch")) {
                                arg = with_arg;
                                // std::cout << other_op.name << std::endl;
                                // std::cout << idx_tensor << std::endl;
                                name_offset_map[other_op.name].push_back(std::make_pair(idx_tensor, runtime_in_offset));
                            }
                            idx_tensor++;
                        }
                        for (auto& arg : other_op.out_args) {
                            if (arg == replace_arg && (dest.parent_name == "scratch")) {
                                arg = with_arg;
                            }
                        }
                    }

                    // Remove the tensor name from packed_tensors in fused_tensors.
                    for (auto& fused_tensor : identity_meta.fused_tensors) {
                        auto& packed_tensors = fused_tensor.second.packed_tensors;
                        // dest is out, src is out, remove it from npu out
                        if (dest.parent_name == "out" && src.parent_name == "out") {
                            if (dest.parent_name == fused_tensor.first) {
                                dest.ref_idx = src.ref_idx;
                                packed_tensors.erase(std::remove(packed_tensors.begin(),
                                    packed_tensors.end(), replace_arg),
                                    packed_tensors.end());
                            }
                        }
                        if (dest.parent_name == "out" && src.parent_name == "scratch") {
                            if (dest.parent_name == fused_tensor.first) {
                                // remove src from scratch tensor and add it to in/out packed tensors
                                src.offset = dest.offset;
                                src.ref_idx = dest.ref_idx;
                                src.xrt_arg_idx = dest.xrt_arg_idx;
                                std::replace(packed_tensors.begin(),
                                    packed_tensors.end(), replace_arg, with_arg);
                            }
                            if ("scratch" == fused_tensor.first) {
                                packed_tensors.erase(std::remove(packed_tensors.begin(),
                                    packed_tensors.end(), with_arg),
                                    packed_tensors.end());
                            }
                            continue;
                        }
                        // if dest in scratch, remove it
                        if (dest.parent_name == "scratch") {
                            packed_tensors.erase(std::remove(packed_tensors.begin(),
                                packed_tensors.end(), replace_arg),
                                packed_tensors.end());
                        }
                    }
                    if (dest.parent_name == "out" && src.parent_name == "scratch") {
                        src.parent_name = dest.parent_name;
                        dest.parent_name = "out_onnx";
                        // update the fused tensor and some tensor offset due to src tensor size is larger than dest tensor.
                        update_fused_tensor(identity_meta, "out");
                    }

                    // Remove the tensor entry which is no longer referenced.                  
                    if (dest.parent_name != "scratch") {
                        // for output tensor
                        dest.additional_offset = runtime_in_offset;
                        ++it;
                        continue;
                    }
                    identity_meta.tensor_map.erase(replace_arg);
                    it = identity_meta.op_list.erase(it);
                }
            }
            else {
                ++it;
            }
        }

        for (const auto& pair : name_offset_map) {
            const std::string& key = pair.first;
            std::vector<std::pair<int, int>> value = pair.second;
 
            for (auto& op : identity_meta.op_list) {
                if (op.name == key) {
                    std::vector<int> runtime_in_offset_vec(op.in_args.size(), 0);
                    for (int i = 0; i < value.size(); i++) {
                        runtime_in_offset_vec[value[i].first] = value[i].second;
                    }
                    op.attr["runtime_in_offset"] = runtime_in_offset_vec;
                }
            }
        }
        return identity_meta;
    }

    Metadata remove_gather_runtime_ops(const Metadata& meta) {
        Metadata identity_meta = meta;
        
        std::map<std::string, std::vector<std::pair<int, int>>> name_offset_map;

        auto it = identity_meta.op_list.begin();
        while (it != identity_meta.op_list.end()) {
            if (it->type.find("gather_runtime") != std::string::npos ||
                it->type.find("Gather_runtime") != std::string::npos) {
               // std::cout << "====" << it->name << std::endl;
                // std::cout << it->name << std::endl;
                auto args = get_op_args(*it);
                assert(args.size() == 3);
                auto out_idx = args.size() - 1;
                // For Gather_runtime op, args[0] is the input, args[1] is the index,
                 // args[2] is the output tensor name.
                auto& input_tensor = MAP_AT(identity_meta.tensor_map, args[0]);
                auto& output_tensor = MAP_AT(identity_meta.tensor_map, args[out_idx]);
                // std::cout << "input_tensor: " << input_tensor.parent_name << std::endl;
                // std::cout << "output_tensor: " << output_tensor.parent_name << std::endl;
		        // Get indices attribute
		        int indices = 0;
                if (it->attr.count("indices") &&
                    it->attr.at("indices").type() == typeid(std::vector<int>)) {
                    const auto& indices_vector = std::any_cast<const std::vector<int> &>(it->attr.at("indices"));
		            if (indices_vector.size() != 1) {
                        std::cout << "Warning: gather_runtime op indices attribute has more then one value, will take first one" << std::endl;
		            }
		            indices = indices_vector[0];
		        } else {
                    std::cout << "Warning: There is no indices attr for gather_runtime op, set it to 0" << std::endl;
		            indices = 0;
		        }

		        // Get axis attribute
		        int axis = 0;
                if (it->attr.count("axis") &&
                    it->attr.at("axis").type() == typeid(std::vector<int>)) {
                    const auto& axis_vector = std::any_cast<const std::vector<int> &>(it->attr.at("axis"));
		            if (axis_vector.size() != 1) {
                        std::cout << "Warning: gather_runtime op axis attribute has more then one value, will take first one" << std::endl;
		            }
		            axis = axis_vector[0];
		        } else {
                    std::cout << "Warning: There is no axis attr for gather_runtime op, set it to 0" << std::endl;
		            axis = 0;
		        }

                // 
                // update output_tensor for gather.                
                bool from_input = output_tensor.size_in_bytes < input_tensor.size_in_bytes;
                 {
                    auto& src = from_input ? input_tensor : output_tensor;
                    auto& dest = from_input ? output_tensor : input_tensor;
                    const std::string replace_arg = from_input ? args[out_idx] : args[0];
                    const std::string with_arg = from_input ? args[0] : args[out_idx];
                    // std::cout << "replace: " << replace_arg << std::endl;
                    // std::cout << "with: " << with_arg << std::endl;
                    // Propagate offset and parent properties.
                    //dest.xrt_arg_idx = src.xrt_arg_idx;
                    //dest.parent_name = src.parent_name;

                    // Calculate buffer size for [axis+1:-1] dimensions of input shape
                    int row_size = 1;
                    for (int i = (axis + 1); i < src.shape.size() - 1; ++i) {
                        row_size *= src.shape[i];
                    }
                    // Round up innermost dimension to be the next multiple of 8 to match with hw buffer size
                    int inner_dim = src.shape[src.shape.size() - 1];
                    if (inner_dim % 8 == 0) {
                        row_size *= inner_dim;
                    } else {
                        row_size *= (inner_dim / 8 + 1) * 8;
                    }
                    // If dtype is float for hw it will be bfloat
                    std::string dtype = src.dtype;
                    if (src.dtype == "float") {
                        dtype = "bfloat16";
                    }
		            int runtime_in_offset = indices * row_size * get_size_of_type(dtype);
                    if (runtime_in_offset % 4) {
                        throw std::runtime_error("runtime in offset is not 32bit aligned!");
                    }
                    // std::cout << runtime_in_offset << std::endl;


                   // dest.offset = src.offset;
                    // Update corresponding references in other operations.
                    for (auto& other_op : identity_meta.op_list) {
                        int idx_tensor = 0;
                        for (auto& arg : other_op.in_args) {
                            if (arg == replace_arg && (dest.parent_name == "scratch")) {
                                arg = with_arg;                                
                                // std::cout << other_op.name << std::endl;
                                // std::cout << idx_tensor << std::endl;
                                name_offset_map[other_op.name].push_back(std::make_pair(idx_tensor, runtime_in_offset));
                            }
                            idx_tensor++;
                        }
                        for (auto& arg : other_op.out_args) {
                            if (arg == replace_arg && (dest.parent_name == "scratch")) {
                                arg = with_arg;
                            }
                        }
                    }

                    // Remove the tensor name from packed_tensors in fused_tensors.  
                    
                    for (auto& fused_tensor : identity_meta.fused_tensors) {
                        auto& packed_tensors = fused_tensor.second.packed_tensors;
                        // dest is out, src is out, remove it from npu out
                        if (dest.parent_name == "out" && src.parent_name == "out") {
                            if (dest.parent_name == fused_tensor.first) {
                                dest.ref_idx = src.ref_idx;
                                packed_tensors.erase(std::remove(packed_tensors.begin(),
                                    packed_tensors.end(), replace_arg),
                                    packed_tensors.end());
                            }
                        }
                        if (dest.parent_name == "out" && src.parent_name == "scratch") {
                            if (dest.parent_name == fused_tensor.first) {
                                // remove src from scratch tensor and add it to in/out packed tensors
                                src.offset = dest.offset;
                                src.ref_idx = dest.ref_idx;
                                src.xrt_arg_idx = dest.xrt_arg_idx;
                                std::replace(packed_tensors.begin(),
                                    packed_tensors.end(), replace_arg, with_arg);
                            }
                            if ("scratch" == fused_tensor.first) {
                                packed_tensors.erase(std::remove(packed_tensors.begin(),
                                    packed_tensors.end(), with_arg),
                                    packed_tensors.end());
                            }
                            continue;
                        }
                        // if dest in scratch, remove it
                        if (dest.parent_name == "scratch") {
                            packed_tensors.erase(std::remove(packed_tensors.begin(),
                                packed_tensors.end(), replace_arg),
                                packed_tensors.end());
                        }
                    }
                    if (dest.parent_name == "out" && src.parent_name == "scratch") {
                        src.parent_name = dest.parent_name;
                        dest.parent_name = "out_onnx";
                        // update the fused tensor and some tensor offset due to src tensor size is larger than dest tensor.
                        update_fused_tensor(identity_meta, "out");
                    }

                    // Remove the tensor entry which is no longer referenced.                  
                   if (dest.parent_name != "scratch") {
                       // for output tensor
                       dest.additional_offset = runtime_in_offset;
                       ++it;
                       continue;
                   }
                   identity_meta.tensor_map.erase(replace_arg);
                   it = identity_meta.op_list.erase(it);
                   
                }
            }
            else {
                ++it;
            }
        }

        for (const auto& pair : name_offset_map) {
            const std::string& key = pair.first;
            std::vector<std::pair<int, int>> value = pair.second;
 
            for (auto& op : identity_meta.op_list) {
                if (op.name == key) {
                    std::vector<int> runtime_in_offset_vec(op.in_args.size(), 0);
                    for (int i = 0; i < value.size(); i++) {
                        runtime_in_offset_vec[value[i].first] = value[i].second;
                    }
                    op.attr["runtime_in_offset"] = runtime_in_offset_vec;
                }
            }
        }
        return identity_meta;
    }
}
