-- Copyright (c) 2024 Huawei Technologies Co., Ltd.
-- openUBMC is licensed under Mulan PSL v2.
-- You can use this software according to the terms and conditions of the Mulan PSL v2.
-- You may obtain a copy of Mulan PSL v2 at: http://license.coscl.org.cn/MulanPSL2
-- THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
-- EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
-- MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
-- See the Mulan PSL v2 for more details.

local lu = require 'luaunit'
local circuit_mgmt = require 'circuit_mgmt'
local skynet = require 'skynet'
local power_mgmt_utils = require 'power_mgmt_utils'
local enums = require 'macros.power_mgmt_enums'
local c_psu_object = require 'device.psu'

TestCircuitMgmt = {}

local function mock_cabinet_software_type()
    local old_get_instance = power_mgmt_utils.get_instance
    power_mgmt_utils.get_instance = function ()
        return {
            get_chassis_type = function ()
                return enums.SOFTWARE_TYPE.CABINET
            end
        }
    end
    return function ()
        power_mgmt_utils.get_instance = old_get_instance
    end
end

local function mock_psu_collection()
    local old_collection = c_psu_object.collection
    c_psu_object.collection = {}
    c_psu_object.collection.fold = function(_, fn)
        local psu1 = {
            MainCircuitVINStatus = 255,
            BackupCircuitVINStatus = 255
        }
        function psu1:get_main_circuit_input_voltage_type()
            return nil, 3
        end
        function psu1:get_backup_circuit_input_voltage_type()
            return nil, 0
        end
        function psu1:get_main_circuit_input_voltage()
            return nil, 220
        end
        function psu1:get_backup_circuit_input_voltage()
            return '', 0
        end

        local psu2 = {
            MainCircuitVINStatus = 10,
            BackupCircuitVINStatus = 20
        }
        function psu2:get_main_circuit_input_voltage_type()
            return nil, 1
        end
        function psu2:get_backup_circuit_input_voltage_type()
            return nil, 2
        end
        function psu2:get_main_circuit_input_voltage()
            return nil, 230
        end
        function psu2:get_backup_circuit_input_voltage()
            return nil, 240
        end

        fn(nil, psu1)
        fn(nil, psu2)
    end
    return function ()
        c_psu_object.collection = old_collection
    end
end

local function mock_sleep()
    local old_sleep = skynet.sleep
    skynet.sleep = function() end
    return function ()
        skynet.sleep = old_sleep
    end
end

local function run_update_circuit_and_capture()
    local mains_called, branch_called
    local old_update_mains = circuit_mgmt.update_mains_circuit_info
    local old_update_branch = circuit_mgmt.update_branch_circuit_info

    circuit_mgmt.update_mains_circuit_info = function(self, obj, a_status, a_type, a_voltage)
        mains_called = {
            obj = obj,
            status = a_status,
            vtype = a_type,
            voltage = a_voltage
        }
    end
    circuit_mgmt.update_branch_circuit_info = function(self, obj, b_status, b_type, b_voltage)
        branch_called = {
            obj = obj,
            status = b_status,
            vtype = b_type,
            voltage = b_voltage
        }
        -- 如果主路信息尚未采集到，先允许循环继续一轮
        if not mains_called then
            return
        end
        -- 主路和支路信息都获取到后中断 while true 循环
        error('stop loop')
    end

    local mains_obj = { CircuitType = 'Mains' }
    local branch_obj = { CircuitType = 'Branch' }
    circuit_mgmt.circuit_objs = {
        mains = mains_obj,
        branch = branch_obj
    }

    local ok = pcall(function ()
        circuit_mgmt:update_circuit_info()
    end)

    circuit_mgmt.update_mains_circuit_info = old_update_mains
    circuit_mgmt.update_branch_circuit_info = old_update_branch

    return ok, mains_obj, branch_obj, mains_called, branch_called
end

-- 场景：新增电路对象时，按 Id 正常加入缓存
function TestCircuitMgmt:test_object_add_callback_add_new_object()
    circuit_mgmt.circuit_objs = {}
    local obj = {
        Id = 'circuit-1',
        CircuitType = 'Mains'
    }
    circuit_mgmt:object_add_callback('Circuit', obj, 1)
    lu.assertEquals(circuit_mgmt.circuit_objs['circuit-1'], obj)
end

-- 场景：重复添加相同 Id 的电路对象时，保持原有对象不被覆盖
function TestCircuitMgmt:test_object_add_callback_duplicate_id()
    local origin = { Id = 'dup-id', CircuitType = 'Mains' }
    circuit_mgmt.circuit_objs = { ['dup-id'] = origin }
    local new_obj = { Id = 'dup-id', CircuitType = 'Branch' }
    circuit_mgmt:object_add_callback('Circuit', new_obj, 1)
    lu.assertEquals(circuit_mgmt.circuit_objs['dup-id'], origin)
end

-- 场景：A路为直流输入且电压不高于阈值，需要刷新电压状态与额定电压
function TestCircuitMgmt:test_update_mains_dc_voltage_refresh()
    local obj = {
        InputVoltageStatus = 0,
        PhaseWiringType = '',
        NominalVoltage = ''
    }
    circuit_mgmt:update_mains_circuit_info(obj, 1, 1, 260)
    lu.assertEquals(obj.InputVoltageStatus, 1)
    lu.assertEquals(obj.PhaseWiringType, '')
    lu.assertEquals(obj.NominalVoltage, 'DC240V')
end

-- 场景：仅有 A 路状态上报，不包含类型和电压时，仅更新状态，其它字段不变
function TestCircuitMgmt:test_update_mains_only_status()
    local obj = {
        InputVoltageStatus = 0,
        PhaseWiringType = 'Keep',
        NominalVoltage = 'Keep'
    }
    circuit_mgmt:update_mains_circuit_info(obj, 5, nil, nil)
    lu.assertEquals(obj.InputVoltageStatus, 5)
    lu.assertEquals(obj.PhaseWiringType, 'Keep')
    lu.assertEquals(obj.NominalVoltage, 'Keep')
end

-- 场景：A路为直流输入且电压高于阈值，仅刷新状态，不设置额定电压
function TestCircuitMgmt:test_update_mains_dc_voltage_over_threshold()
    local obj = {
        InputVoltageStatus = 0,
        PhaseWiringType = '',
        NominalVoltage = 'DC240V'
    }
    circuit_mgmt:update_mains_circuit_info(obj, 1, 1, 300)
    lu.assertEquals(obj.InputVoltageStatus, 1)
    lu.assertEquals(obj.NominalVoltage, '')
end

-- 场景：A路为交流输入，仅刷新接线方式且不触碰额定电压
function TestCircuitMgmt:test_update_mains_ac_only_refresh_phase()
    local obj = {
        InputVoltageStatus = 2,
        PhaseWiringType = '',
        NominalVoltage = 'DC240V'
    }
    circuit_mgmt:update_mains_circuit_info(obj, nil, 0, 300)
    lu.assertEquals(obj.InputVoltageStatus, 2)
    lu.assertEquals(obj.PhaseWiringType, 'ThreePhase5Wire')
    lu.assertEquals(obj.NominalVoltage, '')
end

-- 场景：B路未获取到任何新信息时保持原值不变
function TestCircuitMgmt:test_update_branch_skip_nil_inputs()
    local obj = {
        InputVoltageStatus = 3,
        PhaseWiringType = 'ThreePhase5Wire',
        NominalVoltage = 'DC240V'
    }
    circuit_mgmt:update_branch_circuit_info(obj, nil, nil, nil)
    lu.assertEquals(obj.InputVoltageStatus, 3)
    lu.assertEquals(obj.PhaseWiringType, 'ThreePhase5Wire')
    lu.assertEquals(obj.NominalVoltage, 'DC240V')
end

-- 场景：B路为直流输入且电压不高于阈值，需要刷新电压状态与额定电压
function TestCircuitMgmt:test_update_branch_dc_voltage_refresh()
    local obj = {
        InputVoltageStatus = 0,
        PhaseWiringType = '',
        NominalVoltage = ''
    }
    circuit_mgmt:update_branch_circuit_info(obj, 2, 2, 280)
    lu.assertEquals(obj.InputVoltageStatus, 2)
    lu.assertEquals(obj.PhaseWiringType, '')
    lu.assertEquals(obj.NominalVoltage, 'DC240V')
end

-- 场景：B 路为交流输入且有电压上报时，仅刷新接线方式并清空额定电压
function TestCircuitMgmt:test_update_branch_ac_voltage_over_threshold()
    local obj = {
        InputVoltageStatus = 0,
        PhaseWiringType = '',
        NominalVoltage = 'DC240V'
    }
    circuit_mgmt:update_branch_circuit_info(obj, nil, 0, 310)
    lu.assertEquals(obj.InputVoltageStatus, 0)
    lu.assertEquals(obj.PhaseWiringType, 'ThreePhase5Wire')
    lu.assertEquals(obj.NominalVoltage, '')
end

-- 场景：仅有 B 路状态上报，不包含类型和电压时，仅更新状态，其它字段保持不变
function TestCircuitMgmt:test_update_branch_only_status()
    local obj = {
        InputVoltageStatus = 1,
        PhaseWiringType = 'Keep',
        NominalVoltage = 'Keep'
    }
    circuit_mgmt:update_branch_circuit_info(obj, 7, nil, nil)
    lu.assertEquals(obj.InputVoltageStatus, 7)
    lu.assertEquals(obj.PhaseWiringType, 'Keep')
    lu.assertEquals(obj.NominalVoltage, 'Keep')
end

-- 场景：update_circuit_info 在非机柜类型软件上直接返回，不进入循环
function TestCircuitMgmt:test_update_circuit_info_software_type_not_cabinet()
    local old_get_instance = power_mgmt_utils.get_instance
    power_mgmt_utils.get_instance = function ()
        return {
            get_chassis_type = function ()
                return 0 -- 非 CABINET
            end
        }
    end
    local ok = pcall(function ()
        circuit_mgmt:update_circuit_info()
    end)
    lu.assertTrue(ok)
    power_mgmt_utils.get_instance = old_get_instance
end

-- 场景：软件类型为 CABINET 时，能正确从 PSU 对象聚合 A/B 路信息并下发到 Mains/Branch 电路对象
function TestCircuitMgmt:test_update_circuit_info_collects_and_updates_objs()
    local restore_type = mock_cabinet_software_type()
    local restore_sleep = mock_sleep()
    local restore_collection = mock_psu_collection()

    local ok, mains_obj, branch_obj, mains_called, branch_called =
        run_update_circuit_and_capture()

    lu.assertFalse(ok)
    lu.assertNotNil(mains_called)
    lu.assertEquals(mains_called.obj, mains_obj)
    lu.assertEquals(mains_called.status, 10)
    lu.assertEquals(mains_called.vtype, 1)
    -- A 路电压首次由 psu1 提供为 220，后续不会被覆盖
    lu.assertEquals(mains_called.voltage, 220)

    lu.assertNotNil(branch_called)
    lu.assertEquals(branch_called.obj, branch_obj)
    lu.assertEquals(branch_called.status, 20)
    -- B 路类型第一次由 psu1 提供为 0（交流），不会被后续覆盖
    lu.assertEquals(branch_called.vtype, 0)
    -- Branch 分支中也使用聚合得到的 A 路电压，预期为 220
    lu.assertEquals(branch_called.voltage, 220)

    restore_type()
    restore_sleep()
    restore_collection()
end

-- 场景：init 会启动异步任务（通过 skynet.fork 被调用）
function TestCircuitMgmt:test_init_will_fork_update_task()
    local origin_thread_id = skynet.thread_id
    circuit_mgmt:init()
    lu.assertEquals(skynet.thread_id, origin_thread_id + 1)
end

