-- Copyright (c) 2024 Huawei Technologies Co., Ltd.
-- openUBMC is licensed under Mulan PSL v2.
-- You can use this software according to the terms and conditions of the Mulan PSL v2.
-- You may obtain a copy of Mulan PSL v2 at: http://license.coscl.org.cn/MulanPSL2
-- THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
-- EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
-- MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
-- See the Mulan PSL v2 for more details.


local lu = require 'luaunit'
local unit = require 'unit_manager.class.unit.unit'
local unit_manager = require 'unit_manager.unit_manager'
local npu_board = require 'unit_manager.class.unit.acu.npu_board'
local cmn = require 'common'
local log = require 'mc.logging'
local c_riser_card = require 'unit_manager.class.unit.ieu.riser_card'
local base_messages = require 'messages.base'

TestUnitManager = {}

function TestUnitManager:test_pcbid_to_pcbver()
    -- PcbID合法
    local version = unit.pcbid_to_pcbver(2)
    lu.assertEquals(version, '.B')
    -- PcbID非法
    version = unit.pcbid_to_pcbver(27)
    lu.assertEquals(version, nil)
end

function TestUnitManager:test_logic_version_id_to_version()
    local version = unit.logic_version_id_to_version(1)
    lu.assertEquals(version, '0.01')
end

function TestUnitManager:test_get_device_name()
    unit_manager.device_names = {}
    local name = unit_manager:get_device_name('0101')
    lu.assertEquals(name, nil)
end

function TestUnitManager:test_log_dump_mcu_error_task()
    local ok, skynet = pcall(require, 'skynet')
    if not skynet then
        return 
    end
    skynet.fork = function(cb)
        if ok then
            cb()
        end
    end
    local mds_obj = {
        ['CollectMCULogFlag'] = 0,
        ['Slot'] = 1,
        property_changed = {
        on = function(self,cb) cb('CollectMCULogFlag', 1) end
        }
    }
    local obj = {
        get_prop=function (self,name) 
            if not self.mds_obj then
                return nil
            elseif self.mds_obj[name] then
                return self.mds_obj[name]
            end
        end,
        update_dump_log=function (self,cb) 
            return cb
        end
    }
    obj.mds_obj=mds_obj
    unit_manager.npu_boards = {}
    table.insert(unit_manager.npu_boards, obj)
    unit_manager.npu_need_collect_count_table = {}
    unit_manager.npu_collecting_count_table = {}
    unit_manager.npu_collecting_count = 0
    unit_manager:log_dump_mcu_error_task()
    unit_manager.npu_collecting_count_table = {[mds_obj['Slot']]=1}
    unit_manager:monitor_collect_mcu_log_flag(obj)
    unit_manager:monitor_need_collect_mcu()
    lu.assertEquals( unit_manager.npu_need_collect_count_table[1], 1)
end

local client = require 'general_hardware.client'
local c_logic_fw = require 'unit_manager.class.logic_fw.fw_init'

function TestUnitManager:test_unregister_cpld_firmware_info()
    local old_func = {}
    old_func.FirmwareInventoryDelete = client.PFirmwareInventoryFirmwareInventoryDelete

    unit_manager.logic_fw = {{position=1, id=1}}
    -- 删除执行成功
    client.PFirmwareInventoryFirmwareInventoryDelete = function(...)
        return true
    end
    unit_manager:unregister_cpld_firmware_info(1)
    lu.assertEquals(unit_manager.logic_fw[1].id, nil)

    -- 删除执行失败
    unit_manager.logic_fw[1].id = 2
    client.PFirmwareInventoryFirmwareInventoryDelete = function(...)
        return false
    end
    unit_manager:unregister_cpld_firmware_info(1)
    lu.assertEquals(unit_manager.logic_fw[1].id, 2)

    client.PFirmwareInventoryFirmwareInventoryDelete = old_func.FirmwareInventoryDelete
end

function TestUnitManager:test_post_reset()
    local REGION_PROP<const> = {
        [0x0] = {
            'GlobalReset', 'GlobalResetLocked', 'Global'
        },
        [0x1] = {
            'ComputingUnitReset', 'ComputingUnitResetLocked', 'ComputingUnit'
        }
    }
    local ok = unit_manager:post_reset({
        set_prop = function()
            return true
        end
    }, REGION_PROP)
    lu.assertEquals(ok, true)

    ok = unit_manager:post_reset({
        set_prop = function()
            error('set fail')
        end
    }, REGION_PROP)
    lu.assertEquals(ok, false)
end

local function construct_ctx()
    local ctx = {}
    ctx.ChanType = 1
    ctx.get_initiator = function()
        return {}
    end

    return ctx
end

function TestUnitManager:test_reset_npu_region()
    local npu_board = {
        set_prop = function()
            return true
        end
    }
    local ok, res = pcall(function()
        unit_manager:reset_npu_region(npu_board, construct_ctx(), 0xff)
    end)
    lu.assertEquals(ok, false)
    lu.assertEquals(res.name, base_messages.PropertyValueNotInListMessage.Name)

    npu_board = {
        set_prop = function()
            error('set fail')
        end
    }
    ok, res = pcall(function()
        unit_manager:reset_npu_region(npu_board, construct_ctx(), 0x0)
    end)
    lu.assertEquals(ok, false)
    lu.assertEquals(res.name, base_messages.InternalErrorMessage.Name)

    npu_board = {
        set_prop = function()
            return true
        end
    }
    ok = pcall(function()
        unit_manager:reset_npu_region(npu_board, construct_ctx(), 0x0)
    end)
    lu.assertEquals(ok, true)

    ok = pcall(function()
        unit_manager:reset_npu_region(npu_board, construct_ctx(), 0x1)
    end)
    lu.assertEquals(ok, true)
end

function TestUnitManager:test_reset_npu_device()
    unit_manager.npu_board_map = {
        ['npu'] = {
            set_prop = function()
                return true
            end,
            get_prop = function()
                return 1
            end
        }
    }

    local ok = unit_manager:reset_npu_device({
        path = 'npu1'
    })
    lu.assertEquals(ok, false)

    unit_manager.npu_board_map = {
        ['npu'] = {
            set_prop = function()
                return true
            end,
            get_prop = function()
                return 2
            end
        }
    }
    ok = unit_manager:reset_npu_device({
        path = 'npu'
    }, construct_ctx(), 0x01)
    lu.assertEquals(ok, nil)

    unit_manager.npu_board_map = {
        ['npu'] = {
            set_prop = function()
                return true
            end,
            get_prop = function()
                return 1
            end
        }
    }
    ok = unit_manager:reset_npu_device({
        path = 'npu'
    }, construct_ctx(), 0x01)
    lu.assertEquals(ok, nil)
    self:get_npu_metric_collection_data()
    self:get_npu_metric_collection_items()
end

function TestUnitManager:get_npu_metric_collection_data()
    local unit = unit_manager.get_instance()
    unit.npu_board_map = {}
    local res = unit:get_npu_metric_collection_data({})
    lu.assertEquals(#res, 0)

    unit.npu_board_map = {
        ['npu'] = {}
    }
    res = unit:get_npu_metric_collection_data({
        path = 'npu'
    }, {})
    lu.assertEquals(#res, 0)

    res = unit:get_npu_metric_collection_data({
        path = 'npu'
    }, {
        [1] = 'npuboard.power'
    })
    lu.assertEquals(#res, 0)

    unit.npu_board_map = {
        ['npu'] = {
            get_prop = function()
                return 32768
            end
        }
    }
    res = unit:get_npu_metric_collection_data({
        path = 'npu'
    }, {
        [1] = 'npuboard.powerwatts'
    })
    lu.assertEquals(#res, 0)

    unit.npu_board_map = {
        ['npu'] = {
            get_prop = function()
                return 66
            end
        }
    }
    res = unit:get_npu_metric_collection_data({
        path = 'npu'
    }, {
        [1] = 'npuboard.powerwatts'
    })
    lu.assertEquals(#res, 1)
end

function TestUnitManager:get_npu_metric_collection_items()
    local unit = unit_manager.get_instance()
    unit.npu_board_map = {}
    local res = unit:get_npu_metric_collection_items({})
    lu.assertEquals(res, '')

    unit.npu_board_map = {
        ['npu'] = {
            get_prop = function()
                return 66
            end
        }
    }

    local res = unit:get_npu_metric_collection_items({
        path = 'npu'
    })
    lu.assertEquals(res, 'NpuBoard')
end

-- 测试start_serial_cpld_status_task函数的CPLD自检循环逻辑
function TestUnitManager:test_start_serial_cpld_status_task()
    local original_sleep = cmn.skynet.sleep
    local original_notice = log.notice
    local original_info = log.info

    local sleep_calls = {}
    local log_calls = {}
    local cpld_test_calls = {}

    -- 模拟sleep函数，记录调用并控制循环
    cmn.skynet.sleep = function(ms)
        table.insert(sleep_calls, ms)
        -- 控制循环次数，避免无限循环
        if #sleep_calls >= 4 then
            error("test_exit")
        end
    end

    -- 模拟日志函数
    log.notice = function(self, msg, ...)
        table.insert(log_calls, { type = 'notice', msg = string.format(msg, ...) })
    end
    log.info = function(self, msg, ...)
        table.insert(log_calls, { type = 'info', msg = string.format(msg, ...) })
    end

    -- 创建测试管理器实例
    local test_manager = {
        unit_collection = {
            ['pos1'] = {
                cpld_self_test = function(self)
                    table.insert(cpld_test_calls, { position = 'pos1' })
                end
            },
            ['pos2'] = {
                cpld_self_test = function(self)
                    table.insert(cpld_test_calls, { position = 'pos2' })
                end
            },
            ['pos3'] = {}
        }
    }

    -- 使用pcall保护测试逻辑，确保函数总是被恢复
    local test_ok, test_err = pcall(function()
        -- 实际调用函数
        local _, err = pcall(unit_manager.start_serial_cpld_status_task, test_manager)
        lu.assertTrue(string.find(err, "test_exit") ~= nil)

        -- 验证日志输出
        lu.assertEquals(#log_calls, 1)
        lu.assertEquals(log_calls[1].type, 'notice')
        lu.assertTrue(string.find(log_calls[1].msg, "start serial CPLD status update task") ~= nil)

        -- 验证sleep调用
        lu.assertEquals(#sleep_calls, 4)
        lu.assertEquals(sleep_calls[1], 18000) -- 初始延迟
        lu.assertEquals(sleep_calls[2], 10)    -- unit间隔
        lu.assertEquals(sleep_calls[3], 10)    -- unit间隔
        lu.assertEquals(sleep_calls[4], 6000)  -- 循环延迟

        -- 验证只有有cpld_self_test方法的unit被调用
        lu.assertEquals(#cpld_test_calls, 2)
        local positions = {}
        for _, call in ipairs(cpld_test_calls) do
            positions[call.position] = true
        end
        lu.assertTrue(positions['pos1'])
        lu.assertTrue(positions['pos2'])
    end)

    -- 恢复原始函数（无论测试是否成功）
    cmn.skynet.sleep = original_sleep
    log.notice = original_notice
    log.info = original_info

    -- 如果测试失败，重新抛出错误
    if not test_ok then
        error(test_err)
    end
end

-- 测试start_serial_cpld_status_task函数的错误处理逻辑
function TestUnitManager:test_start_serial_cpld_status_task_error_handling()
    local original_sleep = cmn.skynet.sleep
    local original_info = log.info

    local sleep_calls = {}
    local log_calls = {}

    -- 模拟sleep函数，控制循环
    cmn.skynet.sleep = function(ms)
        table.insert(sleep_calls, ms)
        if #sleep_calls >= 3 then
            error("test_exit")
        end
    end

    -- 模拟日志函数
    log.info = function(self, msg, ...)
        table.insert(log_calls, string.format(msg, ...))
    end

    -- 创建测试管理器实例，使用真实的unit_manager
    local test_manager = {
        unit_collection = {
            ['pos1'] = {
                cpld_self_test = function(self)
                    error("CPLD self test failed") -- 模拟异常
                end
            },
            ['pos2'] = {
                cpld_self_test = function(self)
                end
            }
        }
    }

    -- 使用pcall保护测试逻辑，确保函数总是被恢复
    local test_ok, test_err = pcall(function()
        -- 实际调用函数
        local _, err = pcall(unit_manager.start_serial_cpld_status_task, test_manager)
        lu.assertTrue(string.find(err, "test_exit") ~= nil)

        -- 验证错误日志被记录
        lu.assertEquals(#log_calls, 1)
        lu.assertTrue(string.find(log_calls[1], "CPLD self test failed for position pos1") ~= nil)
        lu.assertTrue(string.find(log_calls[1], "CPLD self test failed") ~= nil)
    end)

    -- 恢复原始函数（无论测试是否成功）
    cmn.skynet.sleep = original_sleep
    log.info = original_info

    -- 如果测试失败，重新抛出错误
    if not test_ok then
        error(test_err)
    end
end

-- 测试start_serial_riser_mcu_status_task函数的Riser MCU状态更新循环逻辑
function TestUnitManager:test_start_serial_riser_mcu_status_task()
    local original_sleep = cmn.skynet.sleep
    local original_notice = log.notice
    local original_info = log.info
    local original_update_mcu_status = c_riser_card.update_mcu_status

    local sleep_calls = {}
    local log_calls = {}
    local update_calls = {}

    -- 模拟sleep函数，记录调用并控制循环
    cmn.skynet.sleep = function(ms)
        table.insert(sleep_calls, ms)
        -- 控制循环次数，避免无限循环
        if #sleep_calls >= 4 then
            error("test_exit")
        end
    end

    -- 模拟日志函数
    log.notice = function(self, msg, ...)
        table.insert(log_calls, { type = 'notice', msg = string.format(msg, ...) })
    end
    log.info = function(self, msg, ...)
        table.insert(log_calls, { type = 'info', msg = string.format(msg, ...) })
    end

    -- 模拟c_riser_card.update_mcu_status函数
    c_riser_card.update_mcu_status = function(riser_obj)
        table.insert(update_calls, { device_name = riser_obj:get_prop('DeviceName') })
    end

    -- 创建测试管理器实例，使用真实的unit_manager
    local test_manager = {
        riser_cards = {
            {
                get_prop = function(self, prop)
                    if prop == 'MCUVersion' then
                        return '1.0.0' -- 非'N/A'，应该被处理
                    elseif prop == 'DeviceName' then
                        return 'RiserCard1'
                    end
                end
            },
            {
                get_prop = function(self, prop)
                    if prop == 'MCUVersion' then
                        return 'N/A' -- 应该被跳过
                    elseif prop == 'DeviceName' then
                        return 'RiserCard2'
                    end
                end
            },
            {
                get_prop = function(self, prop)
                    if prop == 'MCUVersion' then
                        return '2.0.0' -- 非'N/A'，应该被处理
                    elseif prop == 'DeviceName' then
                        return 'RiserCard3'
                    end
                end
            }
        }
    }

    -- 使用pcall保护测试逻辑，确保函数总是被恢复
    local test_ok, test_err = pcall(function()
        -- 实际调用函数
        local _, err = pcall(unit_manager.start_serial_riser_mcu_status_task, test_manager)
        lu.assertTrue(string.find(err, "test_exit") ~= nil)

        -- 验证日志输出
        lu.assertEquals(#log_calls, 1)
        lu.assertEquals(log_calls[1].type, 'notice')
        lu.assertTrue(string.find(log_calls[1].msg, "start serial Riser MCU status update task") ~= nil)

        -- 验证sleep调用
        lu.assertEquals(#sleep_calls, 4)
        lu.assertEquals(sleep_calls[1], 12000) -- 初始延迟
        lu.assertEquals(sleep_calls[2], 5)     -- riser间隔
        lu.assertEquals(sleep_calls[3], 5)     -- riser间隔
        lu.assertEquals(sleep_calls[4], 500)   -- 循环延迟

        -- 验证只有MCUVersion不为'N/A'的riser卡被处理
        lu.assertEquals(#update_calls, 2)
        local device_names = {}
        for _, call in ipairs(update_calls) do
            device_names[call.device_name] = true
        end
        lu.assertTrue(device_names['RiserCard1'])
        lu.assertTrue(device_names['RiserCard3'])
        lu.assertNil(device_names['RiserCard2']) -- 确保RiserCard2被跳过
    end)

    -- 恢复原始函数（无论测试是否成功）
    cmn.skynet.sleep = original_sleep
    log.notice = original_notice
    log.info = original_info
    c_riser_card.update_mcu_status = original_update_mcu_status

    -- 如果测试失败，重新抛出错误
    if not test_ok then
        error(test_err)
    end
end

-- 测试start_serial_riser_mcu_status_task函数的错误处理逻辑
function TestUnitManager:test_start_serial_riser_mcu_status_task_error_handling()
    local original_sleep = cmn.skynet.sleep
    local original_info = log.info
    local original_update_mcu_status = c_riser_card.update_mcu_status

    local sleep_calls = {}
    local log_calls = {}

    -- 模拟sleep函数，控制循环
    cmn.skynet.sleep = function(ms)
        table.insert(sleep_calls, ms)
        if #sleep_calls >= 3 then
            error("test_exit")
        end
    end

    -- 模拟日志函数
    log.info = function(self, msg, ...)
        table.insert(log_calls, string.format(msg, ...))
    end

    -- 模拟c_riser_card.update_mcu_status函数抛出异常
    c_riser_card.update_mcu_status = function(riser_obj)
        error("Riser MCU communication failed")
    end

    -- 创建测试管理器实例，使用真实的unit_manager
    local test_manager = {
        riser_cards = {
            {
                get_prop = function(self, prop)
                    if prop == 'MCUVersion' then
                        return '1.0.0'
                    elseif prop == 'DeviceName' then
                        return 'TestRiserCard'
                    end
                end
            }
        }
    }

    -- 使用pcall保护测试逻辑，确保函数总是被恢复
    local test_ok, test_err = pcall(function()
        -- 实际调用函数
        local _, err = pcall(unit_manager.start_serial_riser_mcu_status_task, test_manager)
        lu.assertTrue(string.find(err, "test_exit") ~= nil)

        -- 验证错误日志被记录
        lu.assertEquals(#log_calls, 1)
        lu.assertTrue(string.find(log_calls[1], "Riser MCU status update failed for TestRiserCard") ~= nil)
        lu.assertTrue(string.find(log_calls[1], "Riser MCU communication failed") ~= nil)
    end)

    -- 恢复原始函数（无论测试是否成功）
    cmn.skynet.sleep = original_sleep
    log.info = original_info
    c_riser_card.update_mcu_status = original_update_mcu_status

    -- 如果测试失败，重新抛出错误
    if not test_ok then
        error(test_err)
    end
end

function TestUnitManager:test_on_add_object_complete()
    local temp_func = client.PFirmwareInventoryFirmwareInventoryAdd
    local add_object_flag = false
    unit_manager.cpu_board_position = {}
    unit_manager.exp_board_position = {
        [1] = 'TestPosition2'
    }
    unit_manager.logic_fw = {
        [1] = {
            position = 'TestPosition2',
            csr = {
                Name = 'EXU_CPLD_Test'
            },
            get_fw_version = function()
                return '0.00'
            end,
            get_fw_location = function()
                return 'TestPosition2'
            end
        }
    }
    unit_manager.npu_boards = {}
    client.PFirmwareInventoryFirmwareInventoryAdd = function()
        add_object_flag = true
        return true
    end

    unit_manager:on_add_object_complete('TestPosition2')
    lu.assertEquals(add_object_flag, true)

    client.PFirmwareInventoryFirmwareInventoryAdd = temp_func
end