-- Copyright (c) 2024 Huawei Technologies Co., Ltd.
-- openUBMC is licensed under Mulan PSL v2.
-- You can use this software according to the terms and conditions of the Mulan PSL v2.
-- You may obtain a copy of Mulan PSL v2 at: http://license.coscl.org.cn/MulanPSL2
-- THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
-- EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
-- MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
-- See the Mulan PSL v2 for more details.

local lu = require 'luaunit'
local defs = require 'unit_manager.class.logic_fw.comm_defs'
local cmn = require 'common'
local fw_mgmt_client = require 'general_hardware.client'
local chip = require 'unit_manager.class.logic_fw.cpld_chip'
local drivers = require 'unit_manager.class.logic_fw.upgrade.drivers_api'
local process = require 'unit_manager.class.logic_fw.upgrade.process'
local valid = require 'unit_manager.class.logic_fw.upgrade.valid'
local factory = require 'factory'
local utils = require 'mc.utils'
local file_sec = require 'utils.file'
local vos = require 'utils.vos'
local spi_flash = require 'libmgmt_protocol.bios.infrastructure.spi_flash'
local skynet = require 'skynet'
local mtd_api = require 'mtd.drv'
local utils_core = require 'utils.core'

TestProcess = {}

function TestProcess:setupClass()
    original_skynet_sleep = skynet.sleep
    original_cmn_skynet_sleep = cmn.skynet.sleep
    original_UpdateServicePrepareReply = fw_mgmt_client.UpdateServiceUpdateServicePrepareReply
    original_UpdateServiceProcessReply = fw_mgmt_client.UpdateServiceUpdateServiceProcessReply
    original_UpdateServiceFinishReply = fw_mgmt_client.UpdateServiceUpdateServiceFinishReply
    original_secure_tar_unzip = utils.secure_tar_unzip
    original_safe_close_file = utils.safe_close_file
    original_get_file_accessible = vos.get_file_accessible
    original_insmod_driver = spi_flash.insmod_driver
    original_check_device_ready = spi_flash.check_device_ready
    original_rmmod_driver = spi_flash.rmmod_driver
    original_erase = mtd_api.erase
    original_write = mtd_api.write
    original_read = mtd_api.read
    original_stat = utils_core.stat
    original_open_s = file_sec.open_s
    origin_mtd_open = mtd_api.open

    -- 公共函数打桩
    cmn.skynet.sleep = function()
    end

    skynet.sleep = function()
    end

    fw_mgmt_client.UpdateServiceUpdateServicePrepareReply = function(_, _, _, _, version, ret, parameters)
        fw_version_ret = version
        prepare_ret = ret
    end

    fw_mgmt_client.UpdateServiceUpdateServiceProcessReply = function()
    end

    fw_mgmt_client.UpdateServiceUpdateServiceFinishReply = function()
    end

    -- 框架提供的接口执行异常，打桩为os.execute
    utils.secure_tar_unzip = function(file_path, release_path)
        local cmd = string.format('/usr/bin/unzip %s -d %s > /dev/null 2>&1', file_path, release_path)
        os.execute(cmd)
        release_path_ret = release_path
    end

    utils.safe_close_file = function(file_fd, fun)
    end

    vos.get_file_accessible = function(file_path)
        if file_path ~= '/dev/mtd0' then
            return 0
        end
    end

    spi_flash.insmod_driver = function(driver_name, mhz, device_name, client, insmod_cb)
    end

    spi_flash.check_device_ready = function(device_name)
    end

    spi_flash.rmmod_driver = function(client, rmmod_cb, ctx, driver_name)
    end

    mtd_api.erase = function(device_fd, FLASH_START_ADDRESS, erase_length)
        if erase_length == -1 then
            return -1
        end
    end

    mtd_api.write = function(device_fd, FLASH_START_ADDRESS, data, write_length)
    end

    mtd_api.read = function(device_fd, FLASH_START_ADDRESS, write_length)
    end

    mtd_api.open = function()
    end

    utils_core.stat = function(file_path)
        local res = {
            st_size = 1
        }
        return res
    end

    file_sec.open_s = function(file_path, mode)
        if file_path ~= "/dev/shm/upgrade/Firmware1" and 
        file_path ~= "/dev/shm/upgrade/fpga.bin" then
            return 
        end
        file_fd = {
            seek = function(mode, start)
            end,

            read = function(write_length)
            end
        }
        return file_fd
    end
end

function TestProcess:teardownClass()
    -- 恢复公共函数打桩
    if original_skynet_sleep then
        skynet.sleep = original_skynet_sleep
    end
    
    if original_cmn_skynet_sleep then
        cmn.skynet.sleep = original_cmn_skynet_sleep
    end

    -- 恢复 fw_mgmt_client 相关函数
    if original_UpdateServicePrepareReply then
        fw_mgmt_client.UpdateServiceUpdateServicePrepareReply = original_UpdateServicePrepareReply
    end
    
    if original_UpdateServiceProcessReply then
        fw_mgmt_client.UpdateServiceUpdateServiceProcessReply = original_UpdateServiceProcessReply
    end
    
    if original_UpdateServiceFinishReply then
        fw_mgmt_client.UpdateServiceUpdateServiceFinishReply = original_UpdateServiceFinishReply
    end

    -- 恢复框架接口
    if original_secure_tar_unzip then
        utils.secure_tar_unzip = original_secure_tar_unzip
    end
    
    if original_safe_close_file then
        utils.safe_close_file = original_safe_close_file
    end

    -- 恢复 vos 相关函数
    if original_get_file_accessible then
        vos.get_file_accessible = original_get_file_accessible
    end

    -- 恢复 spi_flash 相关函数
    if original_insmod_driver then
        spi_flash.insmod_driver = original_insmod_driver
    end
    
    if original_check_device_ready then
        spi_flash.check_device_ready = original_check_device_ready
    end
    
    if original_rmmod_driver then
        spi_flash.rmmod_driver = original_rmmod_driver
    end

    -- 恢复 mtd_api 相关函数
    if original_erase then
        mtd_api.erase = original_erase
    end
    
    if original_write then
        mtd_api.write = original_write
    end

    if original_read then
        mtd_api.read = original_read
    end

    -- 恢复 utils_core 相关函数
    if original_stat then
        utils_core.stat = original_stat
    end

    -- 恢复 file_sec 相关函数
    if original_open_s then
        file_sec.open_s = original_open_s
    end

    if origin_mtd_open then
        mtd_api.open = origin_mtd_open
    end

    -- 清理测试中创建的全局变量
    fw_version_ret = nil
    prepare_ret = nil
    release_path_ret = nil
    file_fd = nil
end

-- 测试提取文件地址
function TestProcess:test_extract_upgrade_file()
    local file_path = "/dev/shm/upgrade/Firmware1"
    local release_path, file_name = process:extract_upgrade_file(file_path)

    lu.assertEquals(release_path, "/dev/shm/upgrade/")
    lu.assertEquals(file_name, "Firmware1")
end

-- 测试cpld选通成功
function TestProcess:test_set_fpga_flash_success()
    local fw = {
        csr = {
            Routes = 0
        }
    }
    local data = 1
    local ok = pcall(function()
        process:set_fpga_flash(fw, data)
    end)
    lu.assertEquals(ok, true)
    lu.assertEquals(fw.csr.Routes, 1)
end

-- 测试cpld选通失败
function TestProcess:test_set_fpga_flash_failed()
    local fw = {
    }
    local data = 1
    local ok = pcall(process.set_fpga_flash, fw, data)
    lu.assertEquals(ok, false)
end

-- 测试安装spi驱动成功
function TestProcess:test_install_spi_driver()
    local fw = {
        csr = {
            Routes = 0
        }
    }
    pcall(function()
        process:install_spi_driver(fw)    
    end)
    lu.assertEquals(fw.csr.Routes, 1)
end

-- 测试安装spi驱动失败
function TestProcess:test_install_spi_driver_failed()
    local ok = pcall(function()
        process:install_spi_driver(nil)
    end)
    lu.assertEquals(ok, false)
end

-- 测试bin文件写入flash成功
function TestProcess:test_write_file_to_flash()
    local file_path = "/dev/shm/upgrade/Firmware1"
    local fw = {
        csr = {
            FlashSizeKiB = 1000
        }
    }
    local ok = pcall(function()
        process:write_file_to_flash(fw, file_path, nil)
    end)
    lu.assertEquals(ok, true)
end

-- 测试bin文件写入falsh失败
function TestProcess:test_write_file_to_flash_failed()
    local file_path = "/dev/shm/upgrade/Firmware1"
    local fw = {
        csr = {
            FlashSizeKiB = 0
        }
    }
    local ok = pcall(function()
        process:write_file_to_flash(fw, file_path, nil)
    end)
    lu.assertEquals(ok, false)
end

-- 测试bin文件写入falsh失败
function TestProcess:test_write_file_to_flash_failed_case_fd()
    local file_path = "/dev/shm/upgrade/Firmware2"
    local fw = {
        csr = {
            FlashSizeKiB = 1000
        }
    }
    local ok = pcall(function()
        process:write_file_to_flash(fw, file_path, nil)
    end)
    lu.assertEquals(ok, false)
end

-- 测试bin文件写入falsh失败
function TestProcess:test_write_file_to_flash_failed_case_mtd()
    local file_path = "/dev/shm/upgrade/Firmware1"
    local fw = {
        csr = {
            FlashSizeKiB = -1
        }
    }
    local ok = pcall(function() 
        process:write_file_to_flash(fw, file_path, nil)
    end)
    lu.assertEquals(ok, false)
end

--测试按照flash方式升级fpga
function TestProcess:test_upgrade_fpga_by_flash_fail()
    local file_path = "/dev/shm/upgrade/Firmware1"
    local fpga_signal = {
        fw_list = {
            {
                system_id = 1,
                csr = {
                    Name = 'FPGA',
                    UpgradeType = 'SFC',
                    ValidAction = 0,
                    uid = 111,
                    FlashSizeKiB = 1000
                }
            }
        },
        upg_cfg_list = {
            {
                name = "Firmware1",
                uid = 111,
                check_fw_uid_exist = function(fw)
                    return 0
                end
            }
        }
    }

    local ok = pcall(function()
        process:upgrade_fpga_by_flash(fpga_signal, file_path)
    end)
    lu.assertEquals(ok, true)
end

--测试UID匹配固件失败
function TestProcess:test_find_fpga_fw_with_uid_no_cfg()
    local fpga_signal = {
        upg_cfg_list = {
            {
                name = "Firmware2"
            }
        }
    }

    local fw_name = "Firmware1"
    local ok = pcall(function() 
        process:find_fpga_fw_with_uid(fpga_signal, fw_name)
    end)
    lu.assertEquals(ok, false)
end

--测试UID匹配固件
function TestProcess:test_find_fpga_fw_with_uid()
    local fpga_signal = {
        fw_list = {
            {
                system_id = 1,
                csr = {
                    Name = 'FPGA'
                }
            }
        },
        upg_cfg_list = {
            {
                name = "Firmware1",
                uid = 111,
                check_fw_uid_exist = function(fw)
                    return 0
                end
            }
        }
    }
    local fw_name = "Firmware1"
    fpga_fw = process:find_fpga_fw_with_uid(fpga_signal, fw_name)
    lu.assertEquals(fpga_fw[1].csr.Name, 'FPGA')
end

--测试UID匹配固件
function TestProcess:test_find_fpga_fw_with_uid_no_complist()
    local fpga_signal = {
        fw_list = {
        },
        upg_cfg_list = {
            {
                name = "Firmware1",
                uid = 111,
                check_fw_uid_exist = function(fw)
                    return 0
                end
            }
        }
    }
    local fw_name = "Firmware1"
    local ok = pcall(function()
        process:find_fpga_fw_with_uid(fpga_signal, fw_name)
    end)
    lu.assertEquals(ok, false)
end

-- 测试upgrade_component_cpld函数，当没有匹配的组件时返回MATCH_FAIL (覆盖line 455)
function TestProcess:test_upgrade_component_cpld_no_match_component()
    local db = {}
    local system_id = 1
    local fw_list = {
        {
            system_id = 1,
            uid = 100,
            name = 'test_fw',
            csr = {
                Name = 'CPLD'
            },
            switch_to_firmware_route = function() end,
            switch_to_default_route = function() end,
            update_chip_lock = function() return true, 0 end,
            update_chip_unlock = function() return true, 0 end,
            chip_info = {
                SetBypassMode = function() end
            }
        }
    }
    local cfg_list = {
        {
            name = 'Firmware1',
            uid = 200,  -- 不匹配的UID
            check_fw_uid_exist = function(fw)
                return false  -- 返回false，导致没有匹配的组件
            end,
            file_type = 'vme',
            chip_num = 1
        }
    }
    local file_path = '/dev/shm/upgrade/Firmware1'
    local upgrade_list = {1}
    local hot_upgrade = false

    local ret, is_need_valid = process:upgrade_component_cpld(db, system_id, fw_list, cfg_list, file_path, upgrade_list, hot_upgrade)
    lu.assertEquals(ret, defs.RET.MATCH_FAIL)
    lu.assertEquals(is_need_valid, false)
end

