-- Copyright (c) 2024 Huawei Technologies Co., Ltd.
-- openUBMC is licensed under Mulan PSL v2.
-- You can use this software according to the terms and conditions of the Mulan PSL v2.
-- You may obtain a copy of Mulan PSL v2 at:
--         http://license.coscl.org.cn/MulanPSL2
-- THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
-- EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
-- MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
-- See the Mulan PSL v2 for more details.

local log = require 'mc.logging'
local skynet = require 'skynet'
local class = require 'mc.class'
local signal = require 'mc.signal'
local singleton = require 'mc.singleton'

local DEFAULT_TIMEOUT<const> = 1000

---@class c_task: c_basic_class
---@field on_task_start c_basic_signal
---@field on_before_run c_basic_signal
---@field on_after_run c_basic_signal
---@field on_task_exit c_basic_signal
---@field is_exit boolean
---@field is_running boolean
local c_task = class()

function c_task:ctor(name)
    self.name = name
    self.is_exit = false
    self.is_running = false
    self.co = nil
    self.cb = nil
    self.timeout_ms = DEFAULT_TIMEOUT

    self.on_task_start = signal.new()
    self.on_before_run = signal.new()
    self.on_after_run = signal.new()
    self.on_task_exit = signal.new()
end

function c_task:set_timeout_ms(timeout_ms)
    if timeout_ms then
        self.timeout_ms = math.floor(timeout_ms)
    end

    return self
end

---@param cb fun(task: c_task)
function c_task:run_task(cb)
    if not cb then
        return
    end

    self.is_running = true
    self.on_before_run:emit(self)
    local ok, err = pcall(cb, self)
    self.is_running = false
    self.on_after_run:emit(self)
    return ok, err
end

---@param cb fun(task: c_task)
function c_task:loop(cb)
    self.cb = cb
    if self.co then
        return self
    end

    self.co = skynet.fork(function()
        log:notice('task [%s] start', self.name)
        self.on_task_start:emit(self)
        local error_times = 0

        while not self.is_exit do
            local ok, err = self:run_task(self.cb)
            if not ok then
                error_times = error_times + 1
                log:error('task [%s] error: %s', self.name, err)
            end
            if error_times > 10 then
                break
            end
            skynet.sleep(self.timeout_ms / 10)
        end

        log:notice('task [%s] exit', self.name)
    end)

    return self
end

function c_task:stop()
    if self.is_exit then
        return
    end
    skynet.killthread(self.co)
    self.is_exit = true
    self.on_task_exit:emit(self)
end

---@class c_tasks
local c_tasks = class()

function c_tasks:ctor()
    self.tasks = {}
    self.next_tick_cbs = {}
    self.next_tick_co = nil
    self.on_remove_task = signal.new()
end

---@param name string
---@return c_task
function c_tasks:new_task(name)
    local t = self.tasks[name]
    if t then
        t.is_exit = false
        return t
    end

    t = c_task.new(name)
    self.tasks[name] = t

    t.on_task_exit:on(function()
        if not self.tasks[name] then
            return
        end
        self.tasks[name] = nil
        self.on_remove_task:emit(t)
    end)
    return t
end

function c_tasks:process_next_tick()
    if #self.next_tick_cbs == 0 or not self.next_tick_co then
        return
    end

    local next_tick_cbs = self.next_tick_cbs
    self.next_tick_cbs = {}
    for _, v in ipairs(next_tick_cbs) do
        local ok, err = pcall(v)
        if not ok then
            log:error('process_next_tick failed: %s', err)
        end
    end

    self.next_tick_co = nil
    self:start_next_tick()
end

function c_tasks:start_next_tick()
    if #self.next_tick_cbs == 0 or self.next_tick_co then
        return
    end

    self.next_tick_co = skynet.fork(function()
        self:process_next_tick()
    end)
end

function c_tasks:next_tick(cb)
    self.next_tick_cbs[#self.next_tick_cbs + 1] = cb
    self:start_next_tick()
end

function c_tasks:spawn(cb, ...)
    skynet.fork(cb, ...)
end

function c_tasks:get_task(task_name)
    return self.tasks[task_name]
end

function c_tasks.sleep_ms(timeout_ms)
    skynet.sleep(timeout_ms / 10)
end

function c_tasks:match_tasks(name_pattn)
    local tasks = {}
    for k, task in pairs(self.tasks) do
        if string.match(k, name_pattn) then
            tasks[#tasks + 1] = task
        end
    end
    return tasks
end

function c_tasks:stop_all()
    for _, task in pairs(self:match_tasks('.*')) do
        task:stop()
    end
end

c_tasks.c_task = c_task

---@class c_task_singleton
---@field new fun(...): c_tasks
---@field get_instance fun(...): c_tasks
---@field destroy fun()

---@type c_task_singleton
local tasks = singleton(c_tasks)
return tasks

