-- Copyright (c) 2024 Huawei Technologies Co., Ltd.
-- openUBMC is licensed under Mulan PSL v2.
-- You can use this software according to the terms and conditions of the Mulan PSL v2.
-- You may obtain a copy of Mulan PSL v2 at: http://license.coscl.org.cn/MulanPSL2
-- THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
-- EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
-- MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
-- See the Mulan PSL v2 for more details.

local lu = require 'luaunit'
local bs = require 'mc.bitstring'
local ncsi_def = require 'ncsi.ncsi_protocol.ncsi_def'
local ncsi_version = require 'ncsi.ncsi_protocol.ncsi_version'
local ncsi_protocol_intf = require 'ncsi_protocol_intf'
local ncsi_utils = require 'ncsi.ncsi_protocol.ncsi_utils'
local ncsi_parameter = require 'ncsi.ncsi_protocol.ncsi_parameter'

-- 测试套件
TestNCSIVersion = {}

-- 命令类型常量定义
local GET_VERSION_ID_RSP = 0x95

local version_rsp_bs = bs.new([[<<
    rsp_code:16,
    reason_code:16,
    major_ver:8,
    minor_ver:8,
    update_ver:8,
    alpha1:8,
    reserved1:3/string,
    alpha2:8,
    name_string:12/string,
    firmware_ver:4/string,
    pci_did:16,
    pci_vid:16,
    pci_ssid:16,
    pci_svid:16,
    manufacturer_id:32,
    check_sum:32,
    fcs:32
>>]])

-- 模拟发送NCSI命令的函数
local function mock_send_ncsi_cmd(req_data, len, eth_name)
    TestNCSIVersion.last_req_data = req_data
    TestNCSIVersion.last_len = len
    TestNCSIVersion.last_eth_name = eth_name
    return ncsi_def.NCSI_SUCCESS
end

-- 初始化函数，在每个测试用例前执行
function TestNCSIVersion:setUp()
    ncsi_parameter.get_instance():init_ncsi_parameter()
    -- 保存原始函数
    self.original_send_ncsi_cmd = ncsi_protocol_intf.send_ncsi_cmd
    self.original_cmd_ctrl = ncsi_utils.ncsi_cmd_ctrl
    self.original_get_ncsi_parameter = ncsi_parameter.get_instance().get_ncsi_parameter

    -- 使用模拟函数替换原始函数
    ncsi_protocol_intf.send_ncsi_cmd = mock_send_ncsi_cmd

    -- 初始化测试变量
    self.last_req_data = nil
    self.last_len = nil
    self.last_eth_name = nil

    -- 模拟NCSI参数
    self.mock_ncsi_parameter = {current_channel = 0, iid = 1, channel_cnt = 4, recv_buf = ''}
    ncsi_parameter.get_instance().get_ncsi_parameter = function() 
        return self.mock_ncsi_parameter
    end

    -- 模拟cmd_ctrl函数以便测试
    ncsi_utils.ncsi_cmd_ctrl = function(_, _, req_packet, eth_name, cmd_process_table)
        -- 确保cmd_process_table存在
        assert(cmd_process_table, "cmd_process_table is nil")

        -- 调用请求处理函数
        if req_packet and req_packet.packet_head and req_packet.packet_head.packet_type and
           cmd_process_table[req_packet.packet_head.packet_type] then
            cmd_process_table[req_packet.packet_head.packet_type](req_packet, eth_name)
        end

        -- 返回成功
        return ncsi_def.NCSI_SUCCESS
    end
end

-- 清理函数，在每个测试用例后执行
function TestNCSIVersion:tearDown()
    -- 恢复原始函数
    ncsi_protocol_intf.send_ncsi_cmd = self.original_send_ncsi_cmd
    ncsi_utils.ncsi_cmd_ctrl = self.original_cmd_ctrl
    ncsi_parameter.get_instance().get_ncsi_parameter = self.original_get_ncsi_parameter
end

-- 测试获取版本功能
function TestNCSIVersion:test_get_version()
    -- 测试参数
    local package_id = 0
    local channel_id = 1
    local eth_name = "eth0"

    -- 执行测试
    local result = ncsi_version.ncsi_get_version_id(package_id, channel_id, eth_name)

    -- 验证结果
    lu.assertEquals(result, ncsi_def.NCSI_SUCCESS)
    lu.assertNotNil(self.last_req_data)
    lu.assertEquals(self.last_eth_name, eth_name)

    -- 测试不同的channel_id
    self.last_req_data = nil
    result = ncsi_version.ncsi_get_version_id(package_id, 2, eth_name)
    lu.assertEquals(result, ncsi_def.NCSI_SUCCESS)
    lu.assertNotNil(self.last_req_data)
end

-- 测试响应处理
function TestNCSIVersion:test_response_processing()
    -- 模拟响应包
    local function create_mock_response(rsp_code, reason_code)
        local rsp = {packet_head = {}, payload = ''}
        rsp.packet_head.payload_len_hi = 0
        rsp.packet_head.payload_len_lo = 40
        rsp.packet_head.packet_type = GET_VERSION_ID_RSP
        rsp.packet_head.package_id = 0
        rsp.packet_head.channel_id = 1

        -- 创建响应payload
        local rsp_payload = {
            rsp_code = rsp_code,
            reason_code = reason_code,
            major_ver = 0x12,  -- BCD码表示1.2
            minor_ver = 0x34,  -- BCD码表示3.4
            update_ver = 0x56, -- BCD码表示5.6
            alpha1 = string.byte('A'),
            reserved1 = '\0\0\0',
            alpha2 = string.byte('B'),
            name_string = 'TestFirmware',
            firmware_ver = '\x01\x02\x03\x04',
            pci_did = 0x1234,
            pci_vid = 0x5678,
            pci_ssid = 0x9ABC,
            pci_svid = 0xDEF0,
            manufacturer_id = 0x12345678,
            check_sum = 0,
            fcs = 0
        }

        -- 使用bitstring打包数据
        rsp.payload = version_rsp_bs:pack(rsp_payload)
        
        -- 验证打包后的数据
        local unpacked = version_rsp_bs:unpack(rsp.payload, true)
        assert(unpacked.major_ver == 0x12, "major_ver not packed correctly")
        assert(unpacked.minor_ver == 0x34, "minor_ver not packed correctly")
        assert(unpacked.update_ver == 0x56, "update_ver not packed correctly")
        
        return rsp
    end

    -- 测试参数
    local package_id = 0
    local channel_id = 1
    local eth_name = "eth0"

    -- 初始化全局参数
    self.mock_ncsi_parameter = {
        current_channel = 0,
        iid = 1,  -- 添加iid字段，这是NCSI协议需要的包ID
        ncsi_ver = nil,
        firmware_name = nil,
        firmware_ver = nil,
        manufacture_id = nil,
        pcie_device_ids = {}
    }
    ncsi_parameter.get_instance().get_ncsi_parameter = function() 
        return self.mock_ncsi_parameter
    end

    -- 测试成功响应
    local success_rsp = create_mock_response(ncsi_def.CMD_COMPLETED, 0)
    
    local original_cmd_ctrl = ncsi_utils.ncsi_cmd_ctrl
    ncsi_utils.ncsi_cmd_ctrl = function(_, _, _, _, cmd_process_table)
        -- 确保cmd_process_table存在
        assert(cmd_process_table, "cmd_process_table is nil")
        assert(cmd_process_table[GET_VERSION_ID_RSP], "GET_VERSION_ID_RSP handler not found")
        
        -- 调用响应处理函数
        local ret = cmd_process_table[GET_VERSION_ID_RSP](success_rsp)
        
        -- 验证全局参数是否被设置
        assert(self.mock_ncsi_parameter.ncsi_ver ~= nil, "ncsi_ver not set")
        assert(self.mock_ncsi_parameter.firmware_name ~= nil, "firmware_name not set")
        
        return ret
    end

    local result = ncsi_version.ncsi_get_version_id(package_id, channel_id, eth_name)
    lu.assertEquals(result, ncsi_def.NCSI_SUCCESS)

    -- 验证全局参数是否正确设置
    local g_ncsi_parameter = ncsi_parameter.get_instance():get_ncsi_parameter()
    lu.assertNotNil(g_ncsi_parameter.ncsi_ver, "ncsi_ver is nil")
    lu.assertEquals(g_ncsi_parameter.ncsi_ver, '12.34.56AB')
    lu.assertEquals(g_ncsi_parameter.firmware_name, 'TestFirmware')
    lu.assertEquals(g_ncsi_parameter.firmware_ver, '01:02:03:04')
    lu.assertEquals(g_ncsi_parameter.manufacture_id, 0x12345678)
    lu.assertEquals(g_ncsi_parameter.pcie_device_ids.pci_did, 0x3412)  -- 大小端转换
    lu.assertEquals(g_ncsi_parameter.pcie_device_ids.pci_vid, 0x7856)  -- 大小端转换
    lu.assertEquals(g_ncsi_parameter.pcie_device_ids.pci_ssid, 0xBC9A) -- 大小端转换
    lu.assertEquals(g_ncsi_parameter.pcie_device_ids.pci_svid, 0xF0DE) -- 大小端转换

    -- 测试失败响应
    local fail_rsp = create_mock_response(0x0123, 1)
    
    ncsi_utils.ncsi_cmd_ctrl = function(_, _, _, _, cmd_process_table)
        return cmd_process_table[GET_VERSION_ID_RSP](fail_rsp)
    end

    result = ncsi_version.ncsi_get_version_id(package_id, channel_id, eth_name)
    lu.assertEquals(result, ncsi_def.NCSI_FAIL)

    -- 恢复原始函数
    ncsi_utils.ncsi_cmd_ctrl = original_cmd_ctrl
end

-- 测试错误处理
function TestNCSIVersion:test_error_handling()
    -- 模拟cmd_ctrl返回失败
    local original_cmd_ctrl = ncsi_utils.ncsi_cmd_ctrl
    ncsi_utils.ncsi_cmd_ctrl = function(_, _, _, _, _)
        return ncsi_def.NCSI_FAIL
    end

    -- 测试参数
    local package_id = 0
    local channel_id = 1
    local eth_name = "eth0"

    -- 测试获取版本失败
    local result = ncsi_version.ncsi_get_version_id(package_id, channel_id, eth_name)
    lu.assertEquals(result, ncsi_def.NCSI_FAIL)

    -- 恢复原始函数
    ncsi_utils.ncsi_cmd_ctrl = original_cmd_ctrl
end


-- 测试响应包为空的情况
function TestNCSIVersion:test_empty_response()
    -- 模拟响应包
    local function create_mock_response()
        local rsp = {packet_head = {}, payload = ''}
        rsp.packet_head.payload_len_hi = 0
        rsp.packet_head.payload_len_lo = 40
        rsp.packet_head.packet_type = GET_VERSION_ID_RSP

        -- 创建响应payload
        local rsp_payload = '\0'
        rsp.payload = rsp_payload
        return rsp
    end
    -- 测试参数
    local package_id = 0
    local channel_id = 1
    local eth_name = "eth0"

    -- 模拟cmd_ctrl返回空响应
    local original_cmd_ctrl = ncsi_utils.ncsi_cmd_ctrl
    ncsi_utils.ncsi_cmd_ctrl = function(_, _, _, _, cmd_process_table)
        local empty_rsp = create_mock_response()
        return cmd_process_table[GET_VERSION_ID_RSP](empty_rsp)
    end

    -- 测试获取版本
    local result = ncsi_version.ncsi_get_version_id(package_id, channel_id, eth_name)
    lu.assertEquals(result, ncsi_def.NCSI_FAIL, "Should fail when response is nil")

    -- 恢复原始函数
    ncsi_utils.ncsi_cmd_ctrl = original_cmd_ctrl
end

return TestNCSIVersion