-- Copyright (c) 2024 Huawei Technologies Co., Ltd.
-- openUBMC is licensed under Mulan PSL v2.
-- You can use this software according to the terms and conditions of the Mulan PSL v2.
-- You may obtain a copy of Mulan PSL v2 at: http://license.coscl.org.cn/MulanPSL2
-- THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
-- EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
-- MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
-- See the Mulan PSL v2 for more details.

local lu = require 'luaunit'
local bs = require 'mc.bitstring'
local ncsi_def = require 'ncsi.ncsi_protocol.ncsi_def'
local ncsi_capabilities = require 'ncsi.ncsi_protocol.ncsi_capabilities'
local ncsi_protocol_intf = require 'ncsi_protocol_intf'
local ncsi_utils = require 'ncsi.ncsi_protocol.ncsi_utils'
local ncsi_parameter = require 'ncsi.ncsi_protocol.ncsi_parameter'

-- 测试套件
TestNCSICapabilities = {}

-- 命令类型常量定义
local GET_CAPABILITIES_RSP = 0x96

local capabilities_rsp_bs = bs.new([[<<
    rsp_code:16,
    reason_code:16,
    capality_flag:32,
    broadcast_filter_cap:32,
    multicast_filter_cap:32,
    buffer_cap:32,
    aen_support:32,
    vlan_filter_cnt:8,
    mix_filter_cnt:8,
    mul_filter_cnt:8,
    unicast_filter_cnt:8,
    reserved:2/string,
    vlan_mode:8,
    channel_cnt:8,
    check_sum:32,
    fcs:32
>>]])

-- 模拟发送NCSI命令的函数
local function mock_send_ncsi_cmd(req_data, len, eth_name)
    TestNCSICapabilities.last_req_data = req_data
    TestNCSICapabilities.last_len = len
    TestNCSICapabilities.last_eth_name = eth_name
    return ncsi_def.NCSI_SUCCESS
end

-- 初始化函数，在每个测试用例前执行
function TestNCSICapabilities:setUp()
    ncsi_parameter.get_instance():init_ncsi_parameter()
    -- 保存原始函数
    self.original_send_ncsi_cmd = ncsi_protocol_intf.send_ncsi_cmd
    self.original_cmd_ctrl = ncsi_utils.ncsi_cmd_ctrl
    self.original_get_ncsi_parameter = ncsi_parameter.get_instance().get_ncsi_parameter

    -- 使用模拟函数替换原始函数
    ncsi_protocol_intf.send_ncsi_cmd = mock_send_ncsi_cmd

    -- 初始化测试变量
    self.last_req_data = nil
    self.last_len = nil
    self.last_eth_name = nil

    -- 模拟NCSI参数
    self.mock_ncsi_parameter = {
        current_channel = 0,
        channel_cap = {},
        multicast_filter_cap = 0,
        channel_cnt = 0,
        iid = 1  -- 添加iid字段，这是NCSI协议需要的包ID
    }

    -- 初始化channel_cap数组
    for i = 0, 1 do  -- NCSI_PACKAGE_MAX_ID = 1
        self.mock_ncsi_parameter.channel_cap[i] = {}
        for j = 0, 3 do  -- NCSI_CHANNEL_MAX_ID = 3
            self.mock_ncsi_parameter.channel_cap[i][j] = {
                capbility_flag = {},
                broadcast_filter_capality = {},
                multicast_filter_capality = 0,
                buffer_capality = 0,
                aen_support = 0,
                vlan_filter_cnt = 0,
                mix_filter_cnt = 0,
                mul_filter_cnt = 0,
                unicast_filter_cnt = 0,
                reserved = {0, 0},
                vlan_mode = 0,
                channel_cnt = 0
            }
        end
    end
    
    ncsi_parameter.get_instance().get_ncsi_parameter = function() 
        return self.mock_ncsi_parameter
    end

    -- 模拟cmd_ctrl函数以便测试
    ncsi_utils.ncsi_cmd_ctrl = function(_, _, req_packet, eth_name, cmd_process_table)
        -- 确保cmd_process_table存在
        assert(cmd_process_table, "cmd_process_table is nil")

        -- 调用请求处理函数
        if req_packet and req_packet.packet_head and req_packet.packet_head.packet_type and
           cmd_process_table[req_packet.packet_head.packet_type] then
            cmd_process_table[req_packet.packet_head.packet_type](req_packet, eth_name)
        end

        -- 返回成功
        return ncsi_def.NCSI_SUCCESS
    end
end

-- 清理函数，在每个测试用例后执行
function TestNCSICapabilities:tearDown()
    -- 恢复原始函数
    ncsi_protocol_intf.send_ncsi_cmd = self.original_send_ncsi_cmd
    ncsi_utils.ncsi_cmd_ctrl = self.original_cmd_ctrl
    ncsi_parameter.get_instance().get_ncsi_parameter = self.original_get_ncsi_parameter
end

-- 测试获取能力集功能
function TestNCSICapabilities:test_get_capabilities()
    -- 测试参数
    local package_id = 0
    local channel_id = 1
    local eth_name = "eth0"

    -- 执行测试
    local result = ncsi_capabilities.ncsi_get_capabilities(package_id, channel_id, eth_name)

    -- 验证结果
    lu.assertEquals(result, ncsi_def.NCSI_SUCCESS)
    lu.assertNotNil(self.last_req_data)
    lu.assertEquals(self.last_eth_name, eth_name)

    -- 测试不同的channel_id
    self.last_req_data = nil
    result = ncsi_capabilities.ncsi_get_capabilities(package_id, 2, eth_name)
    lu.assertEquals(result, ncsi_def.NCSI_SUCCESS)
    lu.assertNotNil(self.last_req_data)
end

-- 测试响应处理
function TestNCSICapabilities:test_response_processing()
    -- 模拟响应包
    local function create_mock_response(rsp_code, reason_code)
        local rsp = {packet_head = {}, payload = ''}
        rsp.packet_head.payload_len_hi = 0
        rsp.packet_head.payload_len_lo = 32
        rsp.packet_head.packet_type = GET_CAPABILITIES_RSP
        rsp.packet_head.package_id = 0
        rsp.packet_head.channel_id = 1

        -- 创建响应payload
        local rsp_payload = {
            rsp_code = rsp_code,
            reason_code = reason_code,
            capality_flag = 0x78563412,
            broadcast_filter_cap = 0x21436587,
            multicast_filter_cap = 0x44332211,
            buffer_cap = 0x11223344,
            aen_support = 0xDDCCBBAA,
            vlan_filter_cnt = 16,
            mix_filter_cnt = 8,
            mul_filter_cnt = 4,
            unicast_filter_cnt = 2,
            reserved = '\0\0',
            vlan_mode = 1,
            channel_cnt = 4,
            check_sum = 0,
            fcs = 0
        }

        rsp.payload = capabilities_rsp_bs:pack(rsp_payload)
        return rsp
    end

    -- 测试参数
    local package_id = 0
    local channel_id = 1
    local eth_name = "eth0"

    -- 测试成功响应
    local success_rsp = create_mock_response(ncsi_def.CMD_COMPLETED, 0)
    
    local original_cmd_ctrl = ncsi_utils.ncsi_cmd_ctrl
    ncsi_utils.ncsi_cmd_ctrl = function(_, _, _, _, cmd_process_table)
        return cmd_process_table[GET_CAPABILITIES_RSP](success_rsp)
    end

    local result = ncsi_capabilities.ncsi_get_capabilities(package_id, channel_id, eth_name)
    lu.assertEquals(result, ncsi_def.NCSI_SUCCESS)

    -- 验证全局参数是否正确设置
    local g_ncsi_parameter = ncsi_parameter.get_instance():get_ncsi_parameter()
    lu.assertEquals(g_ncsi_parameter.multicast_filter_cap, 0x44332211)
    lu.assertEquals(g_ncsi_parameter.channel_cnt, 4)

    -- 测试失败响应
    local fail_rsp = create_mock_response(0x0123, 1)
    
    ncsi_utils.ncsi_cmd_ctrl = function(_, _, _, _, cmd_process_table)
        return cmd_process_table[GET_CAPABILITIES_RSP](fail_rsp)
    end

    result = ncsi_capabilities.ncsi_get_capabilities(package_id, channel_id, eth_name)
    lu.assertEquals(result, ncsi_def.NCSI_FAIL)

    -- 恢复原始函数
    ncsi_utils.ncsi_cmd_ctrl = original_cmd_ctrl
end

-- 测试错误处理
function TestNCSICapabilities:test_error_handling()
    -- 模拟cmd_ctrl返回失败
    local original_cmd_ctrl = ncsi_utils.ncsi_cmd_ctrl
    ncsi_utils.ncsi_cmd_ctrl = function(_, _, _, _, _)
        return ncsi_def.NCSI_FAIL
    end

    -- 测试参数
    local package_id = 0
    local channel_id = 1
    local eth_name = "eth0"

    -- 测试获取能力集失败
    local result = ncsi_capabilities.ncsi_get_capabilities(package_id, channel_id, eth_name)
    lu.assertEquals(result, ncsi_def.NCSI_FAIL)

    -- 恢复原始函数
    ncsi_utils.ncsi_cmd_ctrl = original_cmd_ctrl
end

-- 测试响应包为空的情况
function TestNCSICapabilities:test_empty_response()
    -- 模拟响应包
    local function create_mock_response()
        local rsp = {packet_head = {}, payload = ''}
        rsp.packet_head.payload_len_hi = 0
        rsp.packet_head.payload_len_lo = 32
        rsp.packet_head.packet_type = GET_CAPABILITIES_RSP
        rsp.packet_head.package_id = 0
        rsp.packet_head.channel_id = 1

        -- 创建响应payload
        local rsp_payload = '\0'

        rsp.payload = rsp_payload
        return rsp
    end

    -- 测试参数
    local package_id = 0
    local channel_id = 1
    local eth_name = "eth0"

    -- 模拟cmd_ctrl返回空响应
    local original_cmd_ctrl = ncsi_utils.ncsi_cmd_ctrl
    ncsi_utils.ncsi_cmd_ctrl = function(_, _, _, _, cmd_process_table)
        local empty_rsp = create_mock_response()
        return cmd_process_table[GET_CAPABILITIES_RSP](empty_rsp)
    end

    -- 测试获取能力集
    local result = ncsi_capabilities.ncsi_get_capabilities(package_id, channel_id, eth_name)
    lu.assertEquals(result, ncsi_def.NCSI_FAIL, "Should fail when response is nil")

    -- 恢复原始函数
    ncsi_utils.ncsi_cmd_ctrl = original_cmd_ctrl
end

return TestNCSICapabilities