-- Copyright (c) 2024 Huawei Technologies Co., Ltd.
-- openUBMC is licensed under Mulan PSL v2.
-- You can use this software according to the terms and conditions of the Mulan PSL v2.
-- You may obtain a copy of Mulan PSL v2 at: http://license.coscl.org.cn/MulanPSL2
-- THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
-- EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
-- MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
-- See the Mulan PSL v2 for more details.
local lu = require 'luaunit'
local mcu_service = require 'mcu.mcu_service'
local fructrl = require 'mcu.upgrade.fructl_handler'
local cmn = require 'common'
local log = require 'mc.logging'
local utils = require 'mc.utils'
local utils_core = require 'utils.core'
local file_sec = require 'utils.file'
local MCU_ENUMS = require 'mcu.enum.mcu_enums'
local smc_mcu_object = require 'mcu.smc_mcu_object'
local smc_interface = require 'mcu.upgrade.smc_interface'
local unit_manager = require 'unit_manager.unit_manager'
local skynet = require 'skynet'
local vos = require 'utils.vos'
local gen_hw_bus = require 'general_hardware_bus'

TestMcuService = {}
function TestMcuService:test_get_vrd_info()
    mcu_service.bcu_mcu_collection = {
        [1] = {
            sub_component_info_list = {
                vrd = 'test'
            },
            get_vrd_info = function(mcu_obj)
                return mcu_obj.sub_component_info_list
            end
        }
    }
    local ok, vrd_info = pcall(mcu_service.get_vrd_info, mcu_service, '010101')
    lu.assertEquals(ok, true)
    lu.assertEquals(vrd_info.vrd, 'test')
end

function TestMcuService:test_listen_vrd_abnormal()
    local value = {0, 0, 0, 8, 8, 8, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil}
    local i = 0
    local obj = {
        get_id = function()
            return "MCU123"
        end,
        interface = {
            get_system_event = function()
                i = i + 1
                return value[i]
            end,
        },
        get_vrd_dump_info = function(id)
            print(string.format("[Test] get_vrd_dump_info called with mcu_id: %s", id))
        end
    }
    local ok = pcall(mcu_service.listen_vrd_abnormal, mcu_service, obj)
    lu.assertEquals(ok, true)
end

function TestMcuService:test_on_dump_vrd_cb()
    mcu_service.bcu_mcu_collection = {
        {
            get_id = function()
                return "MCU123"
            end,
            get_vrd_dump_info = function(id)
                print(string.format("[Test] get_vrd_dump_info called with mcu_id: %s", id))
            end,
            get_vrd_log = function()
                return {1, 0}
            end
        }
    }
    local ok = pcall(mcu_service.on_dump_vrd_cb, mcu_service, 1, '/vrd_reg_dump')
    lu.assertEquals(ok, true)
end

-- 测试start_update_mcu_version函数的多主机类型早期返回逻辑
function TestMcuService:test_start_update_mcu_version_multihost_early_return()
    -- 保存原始函数
    local original_is_multihost_type = fructrl.is_multihost_type

    -- 模拟多主机类型返回true
    fructrl.is_multihost_type = function(bus)
        return true -- 模拟多主机类型
    end

    -- 创建测试服务实例
    local test_service = {
        bus = 'test_bus',
        mcu_collection = {}
    }

    -- 使用pcall保护测试逻辑，确保函数总是被恢复
    local test_ok, test_err = pcall(function()
        -- 直接调用mcu_service的函数
        local ok, result = pcall(mcu_service.start_update_mcu_version, test_service)
        lu.assertEquals(ok, true)
        lu.assertEquals(result, nil) -- 函数应该立即返回，不执行后续代码
    end)

    -- 恢复原始函数（无论测试是否成功）
    fructrl.is_multihost_type = original_is_multihost_type

    -- 如果测试失败，重新抛出错误
    if not test_ok then
        error(test_err)
    end
end

-- 测试start_update_mcu_version函数的MCU版本更新循环逻辑
function TestMcuService:test_start_update_mcu_version_update_loop()
    -- 保存原始函数
    local original_is_multihost_type = fructrl.is_multihost_type
    local original_sleep = cmn.skynet.sleep
    local original_notice = log.notice
    local original_info = log.info

    local sleep_calls = {}
    local log_calls = {}
    local update_calls = {}

    -- 模拟非多主机类型
    fructrl.is_multihost_type = function(bus)
        return false -- 模拟非多主机类型
    end

    -- 模拟sleep函数，记录调用并控制循环
    cmn.skynet.sleep = function(ms)
        table.insert(sleep_calls, ms)
        -- 控制循环次数，避免无限循环
        if #sleep_calls >= 4 then
            error("test_exit") -- 用异常退出循环
        end
    end

    -- 模拟日志函数
    log.notice = function(self, msg, ...)
        table.insert(log_calls, { type = 'notice', msg = string.format(msg, ...) })
    end
    log.info = function(self, msg, ...)
        table.insert(log_calls, { type = 'info', msg = string.format(msg, ...) })
    end

    -- 创建测试服务实例
    local test_service = {
        bus = 'test_bus',
        mcu_collection = {
            ['pos1'] = {
                update_version_by_mcu = function(self, bus)
                    table.insert(update_calls, { position = 'pos1', bus = bus })
                end
            },
            ['pos2'] = {
                update_version_by_mcu = function(self, bus)
                    table.insert(update_calls, { position = 'pos2', bus = bus })
                end
            }
        }
    }

    -- 使用pcall保护测试逻辑，确保函数总是被恢复
    local test_ok, test_err = pcall(function()
        -- 直接调用mcu_service的函数
        local _, err = pcall(mcu_service.start_update_mcu_version, test_service)
        lu.assertTrue(string.find(err, "test_exit") ~= nil) -- 预期因为sleep次数限制而退出

        -- 验证sleep调用
        lu.assertEquals(#sleep_calls, 4)
        lu.assertEquals(sleep_calls[1], 12000) -- 初始延迟
        lu.assertEquals(sleep_calls[2], 20)    -- MCU间隔
        lu.assertEquals(sleep_calls[3], 20)    -- MCU间隔
        lu.assertEquals(sleep_calls[4], 6000)  -- 循环延迟

        -- 验证MCU更新调用
        lu.assertEquals(#update_calls, 2)
        local positions = {}
        for _, call in ipairs(update_calls) do
            positions[call.position] = true
            lu.assertEquals(call.bus, 'test_bus') -- 验证bus参数
        end
        lu.assertTrue(positions['pos1'])          -- 确保pos1被调用
        lu.assertTrue(positions['pos2'])          -- 确保pos2被调用

        -- 验证日志调用
        lu.assertEquals(#log_calls, 1)
        lu.assertEquals(log_calls[1].type, 'notice')
        lu.assertEquals(log_calls[1].msg, 'start serial task for update mcu version')
    end)

    -- 恢复原始函数（无论测试是否成功）
    fructrl.is_multihost_type = original_is_multihost_type
    cmn.skynet.sleep = original_sleep
    log.notice = original_notice
    log.info = original_info

    -- 如果测试失败，重新抛出错误
    if not test_ok then
        error(test_err)
    end
end

-- 测试start_update_mcu_version函数的错误处理逻辑
function TestMcuService:test_start_update_mcu_version_error_handling()
    -- 保存原始函数
    local original_is_multihost_type = fructrl.is_multihost_type
    local original_sleep = cmn.skynet.sleep
    local original_info = log.info

    local sleep_calls = {}
    local log_calls = {}

    -- 模拟非多主机类型
    fructrl.is_multihost_type = function(bus)
        return false
    end

    -- 模拟sleep函数，控制循环
    cmn.skynet.sleep = function(ms)
        table.insert(sleep_calls, ms)
        if #sleep_calls >= 3 then
            error("test_exit")
        end
    end

    -- 模拟日志函数
    log.info = function(self, msg, ...)
        table.insert(log_calls, string.format(msg, ...))
    end

    -- 创建测试服务实例
    local test_service = {
        bus = 'test_bus',
        mcu_collection = {
            ['pos1'] = {
                update_version_by_mcu = function(self, bus)
                    error("MCU communication failed") -- 模拟异常
                end
            },
            ['pos2'] = {
                update_version_by_mcu = function(self, bus)
                end
            }
        }
    }

    -- 使用pcall保护测试逻辑
    local test_ok, test_err = pcall(function()
        -- 直接调用mcu_service的函数
        local _, err = pcall(mcu_service.start_update_mcu_version, test_service)
        lu.assertTrue(string.find(err, "test_exit") ~= nil)

        -- 验证错误日志被记录
        lu.assertEquals(#log_calls, 1)
        lu.assertTrue(string.find(log_calls[1], "Update MCU version failed for position pos1") ~= nil)
        lu.assertTrue(string.find(log_calls[1], "MCU communication failed") ~= nil)
    end)

    -- 恢复原始函数（无论测试是否成功）
    fructrl.is_multihost_type = original_is_multihost_type
    cmn.skynet.sleep = original_sleep
    log.info = original_info

    -- 如果测试失败，重新抛出错误
    if not test_ok then
        error(test_err)
    end
end

function TestMcuService:test_on_dump_mcu_log_cb()
    local type_list = string.char(3, 0, 0, 0)
    local type_info = string.char(1, 5, 1)
    local log_data = string.char(100, 100, 100, 100, 100)
    fructrl.get_power_status = function() return 'ON' end
    file_sec.check_realpath_before_open_s = function ()
        return MCU_ENUMS.RET_OK
    end
    local blkread = smc_interface.blkread
    local blkwrite = smc_interface.blkwrite
    smc_interface.blkread = function (smc_interface, cmd, len) 
        if cmd == 0x1C100 then
            return true, type_list
        end
        if cmd == 0x1C501 then
            return true, type_info
        end
        if cmd == 0x1C901 then
            return true, log_data
        end
    end

    smc_interface.blkwrite = function(smc_interface, cmd, data)
        if cmd == 0x1CC00 then
            return true
        end
    end
    local app = {
        bus = 'test_bus'
    }
    gen_hw_bus.new(app)
    unit_manager.get_instance().device_names = {'test'}
    local obj = smc_mcu_object.new({['BoardType'] = 'BCU', ['UID'] = '1825'}, '0101', MCU_ENUMS.SMC_CHANNEL, false, {})
    local old_loop = skynet.fork_loop
    skynet.fork_loop = function(arg1, cb) cb() end
    obj.collect_mcu_dump_info_task(obj)
    local hex_data = string.format('=======================MCU_1825_LOG Time:%s=================\n',
        os.date("%Y%m%d%H%M%S", os.time()))
    utils_core.mkdir_with_parents('/test', utils.S_IRWXU | utils.S_IRGRP | utils.S_IXGRP)
    mcu_service:on_dump_mcu_log_cb(1, '/test')
    for i = 1, #log_data do
        hex_data = hex_data .. string.format('0x%02x ', string.byte(log_data, i))
    end
    hex_data = hex_data .. '\n'
    local reg_info = ""
    local fp = io.open('/test/pll_reg_dump', 'r')
    if not fp then
        print('cannot open file /test/pll_reg_dump')
    else
        reg_info = fp:read('a')
        lu.assertEquals(reg_info, hex_data)
        fp:close()
    end    
    skynet.fork_loop = old_loop
    smc_interface.blkread = blkread
    smc_interface.blkwrite = blkwrite
    utils.remove_file(MCU_ENUMS.MCU_LOG_PATH)
    utils.remove_file('/test')
end

function TestMcuService:test_get_loop_time()
    local old_tick_get = vos.vos_tick_get
    local time = 3000
    vos.vos_tick_get = function() 
        return time + 3000 
    end
    local loop_time
    smc_mcu_object.mcu_log_type_record = {}
    for i = 1, 3 do 
        loop_time = smc_mcu_object:get_loop_time()
    end
    lu.assertEquals(loop_time, 360000)

    vos.vos_tick_get = function() 
        return time + 900000 
    end
    smc_mcu_object.mcu_log_type_record = {}
    for i = 1, 3 do 
        loop_time = smc_mcu_object:get_loop_time()
    end
    lu.assertEquals(loop_time, 360000)

    vos.vos_tick_get = old_tick_get
end

function TestMcuService:test_save_log_file()
    local old_move_file_s = file_sec.move_file_s
    file_sec.move_file_s = function(src, dest) 
        error('test') 
        return 0
    end
    local ret, _ = pcall(smc_mcu_object.save_log_file, smc_mcu_object, '/test', 'test_log')
    file_sec.move_file_s =  old_move_file_s
    lu.assertEquals(ret, true)
end