/* Copyright (c) 2024 Huawei Technologies Co., Ltd.
 * openUBMC is licensed under Mulan PSL v2.
 * You can use this software according to the terms and conditions of the Mulan PSL v2.
 * You may obtain a copy of Mulan PSL v2 at:
 *          http://license.coscl.org.cn/MulanPSL2
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
 * EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
 * MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
 * See the Mulan PSL v2 for more details.
 */
#include "l_callbacks.h"

#include <dlfcn.h>

namespace sml {

l_callbacks::l_callbacks(lua_State *L) : m_L(luawrap::main_thread(L))
{
}

l_callbacks::~l_callbacks()
{
}

luawrap::n_ret l_callbacks::new_callbacks(lua_State *L)
{
    luawrap::lua_class<l_callbacks>::new_object(L, L);
    return 1;
}

void l_callbacks::register_to(lua_State *L, luawrap::stack_table &t)
{
    luawrap::restore_stack_top _s(L);

    auto cls = luawrap::lua_class<l_callbacks>(L);
    cls.def("new", c_func_wrap(L, l_callbacks::new_callbacks));
    cls.def("init_handle", c_func_wrap(L, l_callbacks::init_handle));
    cls.def("reply_write_read", c_func_wrap(L, l_callbacks::reply_write_read));
    cls.def("get_wait_tid", c_func_wrap(L, l_callbacks::get_wait_tid));
    cls.def("set_mctp_writeread_cb", &l_callbacks::set_mctp_writeread_cb);

    t.set("callbacks", cls);
}

gint32 l_callbacks::mctp_writeread_cb(guint8 obj_index, guint32 request_length, const guint8 *request,
                                      guint32 *response_length, guint8 *response, guint32 timeout)
{
    auto it = s_params_map.find(obj_index);
    if (it == s_params_map.end()) {
        std::unique_ptr<l_callbacks::params> s_params = std::make_unique<l_callbacks::params>();
        s_params_map.emplace(obj_index, std::move(s_params));
        it = s_params_map.find(obj_index);
        if (it == s_params_map.end()) {
            return RET_ERR;
        }
    }

    auto &s_params = *it->second;
    std::unique_lock<std::mutex> guard(s_params.write_read_mutex);
    if (!s_params.mctp_writeread_cb) {
        debug_log(DLOG_DEBUG, "%s: mctp_writeread_cb is null", __FUNCTION__);
        return RET_ERR;
    }

    std::string_view data(reinterpret_cast<const char *>(request), request_length);
    guint8           wait_response = (response != nullptr && response_length != nullptr && *(response_length) != 0);

    int                             send_result;
    std::optional<std::string_view> rsp_data;
    // s_mctp_writeread_cb是一个全局变量，指代的是lua层的mctp_writeread方法
    std::tie(send_result, rsp_data) = (*s_params.mctp_writeread_cb)(obj_index, data, wait_response, timeout);

    if (send_result != RET_OK) {
        return send_result;
    }

    if (!wait_response) {
        return RET_OK;
    }

    if (!rsp_data) {
        *response_length = 0;
        return RET_OK;
    }

    errno_t securec_rv = memcpy_s(response, *response_length, rsp_data->data(), rsp_data->size());
    if (securec_rv != EOK) {
        debug_log(DLOG_DEBUG, "%s: memcpy_s failed, ret = %d", __FUNCTION__, securec_rv);
        send_result      = RET_ERR;
        *response_length = 0;
    } else {
        *response_length = rsp_data->size();
    }

    return send_result;
}

gint32 l_callbacks::i2c_write_cb(guint8 obj_index, guint8 *pWritebuf, guint8 write_length)
{
    auto it = s_params_map.find(obj_index);
    if (it == s_params_map.end()) {
        debug_log(DLOG_DEBUG, "%s: i2c_write_cb is null, unknow ctrl_id=%d", __FUNCTION__, obj_index);
        return RET_ERR;
    }

    auto &s_params = *it->second;
    return s_params.block_write_read(obj_index, pWritebuf, write_length, nullptr, 0);
}

gint32 l_callbacks::i2c_writeread_cb(guint8 obj_index, guint8 *pWritebuf, guint8 write_length, guint8 *pReadbuf,
                                     guint8 read_length)
{
    auto it = s_params_map.find(obj_index);
    if (it == s_params_map.end()) {
        debug_log(DLOG_DEBUG, "%s: i2c_write_cb is null, unknow ctrl_id=%d", __FUNCTION__, obj_index);
        return RET_ERR;
    }

    auto &s_params = *it->second;
    return s_params.block_write_read(obj_index, pWritebuf, write_length, pReadbuf, read_length);
}

int l_callbacks::params::block_write_read(guint8 obj_index, const guint8 *pWritebuf, guint8 write_length,
                                          guint8 *pReadbuf, guint8 read_length)
{
    std::unique_lock<std::mutex> guard(write_read_mutex);

    // source 为 0 是 skynet_send 要求必须提供 context，不然会崩溃，或上 0xFF0000 可以确保 source 有值
    int source = read_length | (obj_index << 8) | 0xFF0000;
    skynet_send(nullptr, source, handle, msg_tag, 0, pWritebuf, write_length);
    read_buf = pReadbuf;
    read_len = read_length;
    result   = RET_ERR;
    wait_tid = gettid();
    condition.wait(guard);
    wait_tid = 0;

    return result;
}

void l_callbacks::params::reply_write_read(int rsp, std::optional<std::string_view> rsp_data)
{
    std::unique_lock<std::mutex> guard(write_read_mutex);

    result = rsp;
    if (read_buf && read_len > 0 && rsp_data) {
        errno_t securec_rv = memcpy_s(read_buf, read_len, rsp_data->data(), rsp_data->size());
        if (securec_rv != EOK) {
            debug_log(DLOG_INFO, "%s: memcpy_s failed, ret = %d", __FUNCTION__, securec_rv);
            result = RET_ERR;
        }
    }
    condition.notify_one();
}

void l_callbacks::reply_write_read(int ctrl_id, int rsp, std::optional<std::string_view> rsp_data)
{
    auto it = s_params_map.find(ctrl_id);
    if (it != s_params_map.end()) {
        it->second->reply_write_read(rsp, rsp_data);
    }
}

bool l_callbacks::init_handle(int ctrl_id, int handle, int tag)
{
    void *d           = dlopen(nullptr, RTLD_LAZY);
    auto  skynet_send = reinterpret_cast<t_skynet_send>(dlsym(d, "skynet_send"));
    dlclose(d);

    if (!skynet_send) {
        return false;
    }

    std::unique_ptr<l_callbacks::params> s_params = std::make_unique<l_callbacks::params>();
    s_params->handle      = handle;
    s_params->skynet_send = skynet_send;
    s_params->msg_tag     = tag;

    s_params_map.emplace(ctrl_id, std::move(s_params));
    return true;
}

std::map<int, std::unique_ptr<l_callbacks::params>> l_callbacks::s_params_map;

}  // namespace sml
