-- Copyright (c) 2024 Huawei Technologies Co., Ltd.
-- openUBMC is licensed under Mulan PSL v2.
-- You can use this software according to the terms and conditions of the Mulan PSL v2.
-- You may obtain a copy of Mulan PSL v2 at: http://license.coscl.org.cn/MulanPSL2
-- THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
-- EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
-- MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
-- See the Mulan PSL v2 for more details.

local lu = require 'luaunit'
local ncsi_protocol_intf = require 'ncsi_protocol_intf'
local ncsi_utils = require 'ncsi.ncsi_protocol.ncsi_utils'
local ncsi_def = require 'ncsi.ncsi_protocol.ncsi_def'
local ncsi_broadcast_filter = require 'ncsi.ncsi_protocol.ncsi_broadcast_filter'
local core = require 'network.core'
local ncsi_parameter = require 'ncsi.ncsi_protocol.ncsi_parameter'

-- 测试套件
TestNCSIBroadcastFilter = {}

-- 命令类型常量定义
local ENABLE_BROADCAST_FILTER = 0x10
local ENABLE_BROADCAST_FILTER_RSP = 0x90

-- 模拟发送NCSI命令的函数
local function mock_send_ncsi_cmd(req_data, len, eth_name)
    TestNCSIBroadcastFilter.last_req_data = req_data
    TestNCSIBroadcastFilter.last_len = len
    TestNCSIBroadcastFilter.last_eth_name = eth_name
    return 0
end

-- 初始化函数，在每个测试用例前执行
function TestNCSIBroadcastFilter:setUp()
    ncsi_parameter.get_instance():init_ncsi_parameter()
    -- 保存原始函数
    self.original_send_ncsi_cmd = ncsi_protocol_intf.send_ncsi_cmd
    self.original_cmd_ctrl = ncsi_utils.ncsi_cmd_ctrl

    -- 使用模拟函数替换原始函数
    ncsi_protocol_intf.send_ncsi_cmd = mock_send_ncsi_cmd

    -- 初始化测试变量
    self.last_req_data = nil
    self.last_len = nil
    self.last_eth_name = nil

    -- 初始化NCSI参数
    local ncsi_para = ncsi_parameter.get_instance():get_ncsi_parameter()
    ncsi_para.current_channel = 0

    -- 模拟cmd_ctrl函数以便测试
    ncsi_utils.ncsi_cmd_ctrl = function(package_id, channel_id, req_packet, eth_name, cmd_process_table)
        -- 确保cmd_process_table存在
        assert(cmd_process_table, "cmd_process_table is nil")

        -- 调用请求处理函数
        if req_packet and req_packet.packet_head and req_packet.packet_head.packet_type and
           cmd_process_table[req_packet.packet_head.packet_type] then
            cmd_process_table[req_packet.packet_head.packet_type](req_packet, eth_name)
        end

        -- 返回成功
        return ncsi_def.NCSI_SUCCESS
    end
end

-- 清理函数，在每个测试用例后执行
function TestNCSIBroadcastFilter:tearDown()
    -- 恢复原始函数
    ncsi_protocol_intf.send_ncsi_cmd = self.original_send_ncsi_cmd
    ncsi_utils.ncsi_cmd_ctrl = self.original_cmd_ctrl
end

-- 测试启用广播过滤功能
function TestNCSIBroadcastFilter:test_enable_broadcast_filter()
    -- 测试参数
    local package_id = 0
    local channel_id = 1
    local eth_name = "eth0"

    -- 执行测试
    local result = ncsi_broadcast_filter.ncsi_enable_brdcast_filter(package_id, channel_id, eth_name)

    -- 验证结果
    lu.assertEquals(result, ncsi_def.NCSI_SUCCESS)
    lu.assertNotNil(self.last_req_data)
    lu.assertEquals(self.last_eth_name, eth_name)

    -- 测试不同的channel_id
    self.last_req_data = nil
    result = ncsi_broadcast_filter.ncsi_enable_brdcast_filter(package_id, 2, eth_name)
    lu.assertEquals(result, ncsi_def.NCSI_SUCCESS)
    lu.assertNotNil(self.last_req_data)

    -- 测试不同的package_id
    self.last_req_data = nil
    result = ncsi_broadcast_filter.ncsi_enable_brdcast_filter(1, channel_id, eth_name)
    lu.assertEquals(result, ncsi_def.NCSI_SUCCESS)
    lu.assertNotNil(self.last_req_data)
end

-- 测试禁用广播过滤功能
function TestNCSIBroadcastFilter:test_disable_broadcast_filter()
    -- 测试参数
    local package_id = 0
    local channel_id = 1
    local eth_name = "eth0"

    -- 执行测试
    local result = ncsi_broadcast_filter.ncsi_disable_brdcast_filter(package_id, channel_id, eth_name)

    -- 验证结果
    lu.assertEquals(result, ncsi_def.NCSI_SUCCESS)
    lu.assertNotNil(self.last_req_data)
    lu.assertEquals(self.last_eth_name, eth_name)

    -- 测试不同的channel_id
    self.last_req_data = nil
    result = ncsi_broadcast_filter.ncsi_disable_brdcast_filter(package_id, 2, eth_name)
    lu.assertEquals(result, ncsi_def.NCSI_SUCCESS)
    lu.assertNotNil(self.last_req_data)

    -- 测试不同的package_id
    self.last_req_data = nil
    result = ncsi_broadcast_filter.ncsi_disable_brdcast_filter(1, channel_id, eth_name)
    lu.assertEquals(result, ncsi_def.NCSI_SUCCESS)
    lu.assertNotNil(self.last_req_data)
end

-- 测试响应处理
function TestNCSIBroadcastFilter:test_response_processing()
    -- 模拟响应包
    local function create_mock_response(rsp_code, reason_code)
        local rsp = {packet_head = {}, payload = ''}
        rsp.packet_head.payload_len_hi = 0
        rsp.packet_head.payload_len_lo = 4

        -- 创建响应payload
        local rsp_data = string.char(
            (rsp_code >> 8) & 0xFF, rsp_code & 0xFF,  -- rsp_code (2字节)
            (reason_code >> 8) & 0xFF, reason_code & 0xFF,  -- reason_code (2字节)
            0, 0, 0, 0  -- check_sum (4字节)
        ) .. string.rep("\0", 22)  -- pad_data (22字节)

        rsp.payload = rsp_data
        return rsp
    end

    -- 测试参数
    local package_id = 0
    local channel_id = 1
    local eth_name = "eth0"

    -- 测试成功响应
    local success_rsp = create_mock_response(ncsi_def.CMD_COMPLETED, 0)
    local original_cmd_ctrl = ncsi_utils.ncsi_cmd_ctrl

    ncsi_utils.ncsi_cmd_ctrl = function(_, _, req_packet, _, cmd_process_table)
        -- 直接调用响应处理函数
        return cmd_process_table[ENABLE_BROADCAST_FILTER_RSP](success_rsp)
    end

    local result = ncsi_broadcast_filter.ncsi_enable_brdcast_filter(package_id, channel_id, eth_name)
    lu.assertEquals(result, ncsi_def.NCSI_SUCCESS)

    -- 测试失败响应
    local fail_rsp = create_mock_response(0x0123, 1)
    ncsi_utils.ncsi_cmd_ctrl = function(_, _, req_packet, _, cmd_process_table)
        return cmd_process_table[ENABLE_BROADCAST_FILTER_RSP](fail_rsp)
    end

    result = ncsi_broadcast_filter.ncsi_enable_brdcast_filter(package_id, channel_id, eth_name)
    lu.assertEquals(result, ncsi_def.NCSI_FAIL)

    -- 恢复原始函数
    ncsi_utils.ncsi_cmd_ctrl = original_cmd_ctrl
end

-- 测试边界条件和错误处理
function TestNCSIBroadcastFilter:test_error_handling()
    -- 测试参数
    local package_id = 0
    local channel_id = 1
    local eth_name = "eth0"

    -- 保存并替换ncsi_utils.ncsi_cmd_ctrl函数
    local original_cmd_ctrl = ncsi_utils.ncsi_cmd_ctrl

    -- 模拟命令控制函数返回失败
    ncsi_utils.ncsi_cmd_ctrl = function(_, _, _, _, _)
        return ncsi_def.NCSI_FAIL
    end

    -- 测试错误处理
    local result = ncsi_broadcast_filter.ncsi_enable_brdcast_filter(package_id, channel_id, eth_name)
    lu.assertEquals(result, ncsi_def.NCSI_FAIL)

    -- 测试无效的channel_id
    ncsi_utils.ncsi_cmd_ctrl = function(_, _, req_packet, eth_name, cmd_process_table)
        -- 模拟请求处理成功但响应处理失败
        if cmd_process_table[ENABLE_BROADCAST_FILTER] then
            cmd_process_table[ENABLE_BROADCAST_FILTER](req_packet, eth_name)
        end
        return ncsi_def.NCSI_FAIL
    end

    result = ncsi_broadcast_filter.ncsi_enable_brdcast_filter(package_id, 4, eth_name)
    lu.assertEquals(result, ncsi_def.NCSI_FAIL)

    -- 恢复原始函数
    ncsi_utils.ncsi_cmd_ctrl = original_cmd_ctrl
end

-- 测试命令类型
function TestNCSIBroadcastFilter:test_command_type()
    -- 测试参数
    local package_id = 0
    local channel_id = 1
    local eth_name = "eth0"

    -- 记录请求包
    local captured_req_packet = nil
    local original_cmd_ctrl = ncsi_utils.ncsi_cmd_ctrl

    -- 替换cmd_ctrl以捕获请求包
    ncsi_utils.ncsi_cmd_ctrl = function(_, _, req_packet, _, cmd_process_table)
        captured_req_packet = req_packet
        if cmd_process_table[ENABLE_BROADCAST_FILTER] then
            cmd_process_table[ENABLE_BROADCAST_FILTER](req_packet, eth_name)
        end
        return ncsi_def.NCSI_SUCCESS
    end

    -- 执行测试
    local result = ncsi_broadcast_filter.ncsi_enable_brdcast_filter(package_id, channel_id, eth_name)

    -- 验证结果
    lu.assertEquals(result, ncsi_def.NCSI_SUCCESS)
    lu.assertNotNil(captured_req_packet, "应该捕获到请求包")
    lu.assertNotNil(captured_req_packet.packet_head, "请求包应该有packet_head")
    lu.assertEquals(captured_req_packet.packet_head.packet_type, ENABLE_BROADCAST_FILTER, 
        "请求包的类型应该是ENABLE_BROADCAST_FILTER (0x10)")

    -- 从原始数据中提取并检查命令类型
    local original_send_ncsi_cmd = ncsi_protocol_intf.send_ncsi_cmd
    local captured_payload = nil

    ncsi_protocol_intf.send_ncsi_cmd = function(req_data, len, _)
        captured_payload = req_data
        return 0
    end

    -- 重新执行测试以捕获完整payload
    ncsi_broadcast_filter.ncsi_enable_brdcast_filter(package_id, channel_id, eth_name)

    -- 验证payload中的命令类型字段 (位于NCSI头部的第4个字节)
    lu.assertNotNil(captured_payload, "应该捕获到完整的请求数据")
    local payload_offset = 16  -- 以太网头部长度
    local packet_type = string.byte(captured_payload, payload_offset + 3)
    lu.assertEquals(packet_type, ENABLE_BROADCAST_FILTER, "请求包的类型应该是ENABLE_BROADCAST_FILTER (0x10)")

    -- 验证包头中的channel_id和package_id
    local extracted_channel_id = string.byte(captured_payload, payload_offset + 4)
    local extracted_package_id = string.byte(captured_payload, payload_offset + 5)
    lu.assertEquals(extracted_channel_id, channel_id, "channel_id应该与输入参数匹配")
    lu.assertEquals(extracted_package_id, package_id, "package_id应该与输入参数匹配")

    -- 恢复原始函数
    ncsi_utils.ncsi_cmd_ctrl = original_cmd_ctrl
    ncsi_protocol_intf.send_ncsi_cmd = original_send_ncsi_cmd
end

-- 测试响应命令类型
function TestNCSIBroadcastFilter:test_response_command_type()
    -- 测试参数
    local package_id = 0
    local channel_id = 1
    local eth_name = "eth0"

    -- 创建模拟响应包
    local function create_mock_response(rsp_code, reason_code, packet_type)
        local rsp = {
            frame_head = {},
            packet_head = {
                mc_id = 0,
                header_rev = 1,
                reserved = 0,
                packet_type = packet_type,
                channel_id = channel_id,
                package_id = package_id,
                payload_len_hi = 0,
                payload_len_lo = 4
            },
            payload = ''
        }

        -- 创建响应payload
        local rsp_data = string.char(
            (rsp_code >> 8) & 0xFF, rsp_code & 0xFF,
            (reason_code >> 8) & 0xFF, reason_code & 0xFF,
            0, 0, 0, 0
        ) .. string.rep("\0", 22)

        rsp.payload = rsp_data
        return rsp
    end

    -- 模拟处理函数捕获
    local response_handler_called = false
    local captured_response_type = nil

    -- 保存原始函数
    local original_cmd_ctrl = ncsi_utils.ncsi_cmd_ctrl

    -- 模拟cmd_ctrl以捕获响应处理
    ncsi_utils.ncsi_cmd_ctrl = function(pkg_id, ch_id, req_packet, eth, cmd_process_table)
        -- 创建一个响应
        local success_rsp = create_mock_response(ncsi_def.CMD_COMPLETED, 0, ENABLE_BROADCAST_FILTER_RSP)

        -- 检查响应处理函数是否存在
        if cmd_process_table and cmd_process_table[ENABLE_BROADCAST_FILTER_RSP] then
            response_handler_called = true
            captured_response_type = ENABLE_BROADCAST_FILTER_RSP
            -- 直接处理响应，跳过请求发送
            return cmd_process_table[ENABLE_BROADCAST_FILTER_RSP](success_rsp)
        end

        return ncsi_def.NCSI_SUCCESS
    end

    -- 执行测试 - 有效的响应类型 0x90
    local result = ncsi_broadcast_filter.ncsi_enable_brdcast_filter(package_id, channel_id, eth_name)

    -- 验证结果
    lu.assertEquals(result, ncsi_def.NCSI_SUCCESS)
    lu.assertTrue(response_handler_called, "应该调用响应处理函数")
    lu.assertEquals(captured_response_type, ENABLE_BROADCAST_FILTER_RSP,
        "响应命令类型应该是ENABLE_BROADCAST_FILTER_RSP (0x90)")

    -- 模拟cmd_ctrl以返回无效响应类型
    ncsi_utils.ncsi_cmd_ctrl = function(pkg_id, ch_id, req_packet, eth, cmd_process_table)
        -- 从命令处理表中提取响应处理函数
        local rsp_handler = cmd_process_table[ENABLE_BROADCAST_FILTER_RSP]
        lu.assertNotNil(rsp_handler, "响应处理函数应该存在")

        -- 创建一个错误的响应类型
        local invalid_rsp = create_mock_response(ncsi_def.CMD_COMPLETED, 0, 0xFF)  -- 无效的类型

        -- 尝试直接调用响应处理函数
        local ok, result = pcall(function()
            -- 手动设置响应类型为有效值测试处理函数
            invalid_rsp.packet_head.packet_type = ENABLE_BROADCAST_FILTER_RSP
            -- 但设置无效的响应码
            invalid_rsp.payload = string.char(
                0x01, 0x23,  -- 无效响应码 0x0123
                0x00, 0x01,  -- 错误原因码
                0, 0, 0, 0   -- check_sum
            ) .. string.rep("\0", 22)
            return rsp_handler(invalid_rsp)
        end)

        -- 验证结果
        if ok then
            return result  -- 应该是失败状态
        else
            return ncsi_def.NCSI_FAIL
        end
    end

    -- 执行测试
    result = ncsi_broadcast_filter.ncsi_enable_brdcast_filter(package_id, channel_id, eth_name)

    -- 验证处理失败响应的结果
    lu.assertEquals(result, ncsi_def.NCSI_FAIL, "无效响应应该返回失败")

    -- 恢复原始函数
    ncsi_utils.ncsi_cmd_ctrl = original_cmd_ctrl
end
