-- Copyright (c) 2024 Huawei Technologies Co., Ltd.
-- openUBMC is licensed under Mulan PSL v2.
-- You can use this software according to the terms and conditions of the Mulan PSL v2.
-- You may obtain a copy of Mulan PSL v2 at:
--         http://license.coscl.org.cn/MulanPSL2
-- THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
-- EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
-- MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
-- See the Mulan PSL v2 for more details.

local c_tasks = require 'tasks'
local utils = require 'mc.utils'

---@class c_hook_tasks: c_tasks
---@field run_until fun(self: c_hook_tasks, cb: fun(): boolean)
---@field run_all_task fun(self: c_hook_tasks)
---@field run_task fun(self: c_hook_tasks, task_or_name: any)

---@return c_hook_tasks
local function hook()
    ---@type any
    local my_tasks = c_tasks.get_instance()

    local c_task = my_tasks.c_task

    c_task.loop = function(self, cb)
        self.loop_cb = cb
        return self
    end

    c_task.run = function(self)
        if self.is_exit then
            self.on_task_exit:emit(self)
            return
        end

        if not self.is_fist_run then
            self.is_fist_run = true
            self.on_task_start:emit(self)
        end
        if self.is_running then
            return
        end

        local ok, err = self:run_task(self.loop_cb)
        if not ok then
            error(err)
        end

        if self.is_exit then
            self.on_task_exit:emit(self)
        end
    end

    function my_tasks:run_all_task()
        local all = self:match_tasks('.*')
        for _, t in ipairs(all) do
            t:run()
        end
        self:process_next_tick()
    end

    local function check_timeout(name, st, timout_sec)
        local ct = utils.time()
        local diff_sec = ct.tv_sec - st.tv_sec
        if diff_sec > timout_sec then
            error(string.format('tasks.run_until %s timeout %s sec', name, diff_sec))
        end
    end

    function my_tasks:run_until(cb)
        local st = utils.time()
        while true do
            local all = self:match_tasks('.*')
            for _, t in ipairs(all) do
                t:run()
                if cb() then
                    return
                end

                -- 测试时任务最多允许运行 5 秒
                -- 假如业务代码中真的有 sleep 动作，需要打桩去掉防止这里超时
                check_timeout(t.name, st, 5)
            end
            self:process_next_tick()
        end
    end

    function my_tasks:run_task(name_or_task)
        if getmetatable(name_or_task) == c_task then
            name_or_task:run()
        else
            for _, t in ipairs(self:match_tasks(name_or_task)) do
                t:run()
            end
        end
        self:process_next_tick()
    end

    return my_tasks
end

return {hook = hook, unhook = c_tasks.destroy}
