#include "op_common.hpp"
#include "meta_utils.hpp"
namespace waic_runner {
    static std::map<std::string, tidInfo> create_tensor_id_map(const Metadata& meta) {
        std::map<std::string, tidInfo> tensor_id_map;
        int id = 0;
        for (const auto& [t_name, t_info] : meta.tensor_map) {
            if (t_info.parent_name != "const") {
                tidInfo t = { id++, 0 };
                tensor_id_map[t_name] = t;
            }
        }
        // count number of op inputs for each tensor
        for (const auto& op_info : meta.op_list) {
            for (const auto inarg : op_info.in_args) {
                tensor_id_map[inarg].num_usage += 1;
            }
        }
        return tensor_id_map;
    }

    static void update_meta_scratch_space(
        Metadata& meta, const std::vector<GroupBuff>& gBuffs) {

        for (const auto gB : gBuffs) {
            for (const auto& tname : gB.packed_tensors) {
                auto& tinfo = MAP_AT(meta.tensor_map, tname);
                tinfo.offset = gB.offset;
            }
        }
    }

    void optimize_scratch_buffer(Metadata& meta) {
        auto tensor_id_map = create_tensor_id_map(meta);
        size_t Total_scratch_buff = 0;
        std::vector<tidInfo> LiveNodeList;
        std::vector<int> ReplaceNodeList;
        std::vector<int> FreeNodeList;
        std::vector<int> KeepNodeList;
        std::vector<GroupBuff> gBuffs;

        for (const auto& op_info : meta.op_list) {
            for (const auto inarg : op_info.in_args) {
                auto tInfo = meta.tensor_map[inarg];
                if (tInfo.parent_name.find("scratch") != std::string::npos) {
                    // for each input tensor, if it is in LiveNodeList, it can be replaced once
                    for (const auto tidInfo : LiveNodeList) {
                        if (tensor_id_map[inarg].tid == tidInfo.tid) {
                            ReplaceNodeList.push_back(tidInfo.tid);
                            break;
                        }
                    }
                }
            }
            bool update_flag = 0;
            for (const auto outarg : op_info.out_args) {
                auto tInfo = meta.tensor_map[outarg];
                if (tInfo.parent_name.find("scratch") != std::string::npos) {
                    update_flag = 1;
                    auto out_buff_add = tInfo.size_in_bytes;
                    LiveNodeList.push_back(tensor_id_map[outarg]);
                    if (LiveNodeList.size() > gBuffs.size()) {
                        // add a new gBuffs
                        GroupBuff new_gBuff;
                        new_gBuff.offset = Total_scratch_buff;
                        new_gBuff.size = out_buff_add;
                        new_gBuff.packed_tids.push_back(tensor_id_map[outarg].tid);
                        new_gBuff.packed_tensors.push_back(outarg);
                        gBuffs.push_back(new_gBuff);
                        Total_scratch_buff += out_buff_add;
                    }
                    else {
                        // update gBuffs using LiveNodeList last element to replace the free 
                        if (FreeNodeList.size() > 0) {
                            int tid = FreeNodeList.at(0);
                            // find index of gBuffs
                            size_t idx = 0;
                            bool found = 0;
                            size_t temp_size = 0;
                            for (size_t g = 0; g < gBuffs.size(); g++) {
                                auto tids = gBuffs[g].packed_tids;
                                for (size_t t = 0; t < tids.size(); t++) {
                                    if (tids[t] == tid) {
                                        idx = g;
                                        found = 1;
                                        break;
                                    }
                                }
                                if (found) {
                                    break;
                                }
                                temp_size += gBuffs[g].size;
                            }
                            // update gBuffs;
                            gBuffs[idx].packed_tensors.push_back(outarg);
                            gBuffs[idx].packed_tids.push_back(tensor_id_map[outarg].tid);
                            if (out_buff_add > gBuffs[idx].size) {
                                gBuffs[idx].size = out_buff_add;
                                for (size_t g = idx + 1; g < gBuffs.size(); g++) {
                                    temp_size += gBuffs[g - 1].size;
                                    gBuffs[g].offset = temp_size;
                                }
                                Total_scratch_buff = temp_size + gBuffs[gBuffs.size() - 1].size;
                            }
                        }
                        else {
                            throw std::runtime_error("No free node to use");
                        }
                        // update Free element
                        FreeNodeList.erase(FreeNodeList.begin());
                    }
                }
            }
            if (update_flag) {
                // move replace element to freelist or keeplist and updata livelist
                for (const auto rNode : ReplaceNodeList) {
                    for (size_t j = 0; j < LiveNodeList.size(); j++) {
                        auto lInfo = LiveNodeList.at(j);
                        if (lInfo.tid == rNode) {
                            lInfo.num_usage -= 1; //LiveNode keep the num_usage;
                            LiveNodeList.at(j) = lInfo;
                            if (lInfo.num_usage == 0) {
                                FreeNodeList.push_back(lInfo.tid);
                                // each node can only be in FreeNodeList or KeepNodeList
                                auto iKeep = std::find(KeepNodeList.begin(), KeepNodeList.end(), lInfo.tid);
                                if (iKeep != KeepNodeList.end()) {
                                    KeepNodeList.erase(iKeep);
                                }
                                LiveNodeList.erase(LiveNodeList.begin() + j);
                            }
                            else {
                                // if not found in Keep, then add it 
                                auto iKeep = std::find(KeepNodeList.begin(), KeepNodeList.end(), lInfo.tid);
                                if (iKeep == KeepNodeList.end()) {
                                    KeepNodeList.push_back(lInfo.tid);
                                }
                            }
                            break;
                        }
                    }
                }
                ReplaceNodeList.clear();
            }
        }
        update_meta_scratch_space(meta, gBuffs);
        meta.fused_tensors.at("scratch").size = Total_scratch_buff;
    }
}