-- Copyright (c) 2024 Huawei Technologies Co., Ltd.
-- openUBMC is licensed under Mulan PSL v2.
-- You can use this software according to the terms and conditions of the Mulan PSL v2.
-- You may obtain a copy of Mulan PSL v2 at: http://license.coscl.org.cn/MulanPSL2
-- THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
-- EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
-- MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
-- See the Mulan PSL v2 for more details.

local lu = require 'luaunit'
local ncsi_protocol_intf = require 'ncsi_protocol_intf'
local ncsi_utils = require 'ncsi.ncsi_protocol.ncsi_utils'
local ncsi_def = require 'ncsi.ncsi_protocol.ncsi_def'
local ncsi_vlan_filter = require 'ncsi.ncsi_protocol.ncsi_vlan_filter'
local ncsi_parameter = require 'ncsi.ncsi_protocol.ncsi_parameter'

-- 测试套件
TestNCSIVlanFilter = {}

-- 命令类型常量定义
local SET_VLAN_FILTER = 0x0B
local SET_VLAN_FILTER_RSP = 0x8B

-- 模拟发送NCSI命令的函数
local function mock_send_ncsi_cmd(req_data, len, eth_name)
    TestNCSIVlanFilter.last_req_data = req_data
    TestNCSIVlanFilter.last_len = len
    TestNCSIVlanFilter.last_eth_name = eth_name
    return 0
end

-- 初始化函数，在每个测试用例前执行
function TestNCSIVlanFilter:setUp()
    ncsi_parameter.get_instance():init_ncsi_parameter()
    -- 保存原始函数
    self.original_send_ncsi_cmd = ncsi_protocol_intf.send_ncsi_cmd
    self.original_cmd_ctrl = ncsi_utils.ncsi_cmd_ctrl

    -- 使用模拟函数替换原始函数
    ncsi_protocol_intf.send_ncsi_cmd = mock_send_ncsi_cmd

    -- 初始化测试变量
    self.last_req_data = nil
    self.last_len = nil
    self.last_eth_name = nil

    -- 初始化NCSI参数
    local ncsi_para = ncsi_parameter.get_instance():get_ncsi_parameter()
    ncsi_para.current_channel = 0

    -- 模拟cmd_ctrl函数以便测试
    ncsi_utils.ncsi_cmd_ctrl = function(package_id, channel_id, req_packet, eth_name, cmd_process_table)
        -- 确保cmd_process_table存在
        assert(cmd_process_table, "cmd_process_table is nil")

        -- 调用请求处理函数
        if req_packet and req_packet.packet_head and req_packet.packet_head.packet_type and
           cmd_process_table[req_packet.packet_head.packet_type] then
            cmd_process_table[req_packet.packet_head.packet_type](req_packet, eth_name)
        end

        -- 返回成功
        return ncsi_def.NCSI_SUCCESS
    end
end

-- 清理函数，在每个测试用例后执行
function TestNCSIVlanFilter:tearDown()
    -- 恢复原始函数
    ncsi_protocol_intf.send_ncsi_cmd = self.original_send_ncsi_cmd
    ncsi_utils.ncsi_cmd_ctrl = self.original_cmd_ctrl
end

-- 测试设置VLAN过滤功能
function TestNCSIVlanFilter:test_set_vlan_filter()
    -- 测试参数
    local package_id = 0
    local channel_id = 1
    local eth_name = "eth0"
    local vlan_id = 100
    local vlan_filter = 1  -- filter
    local vlan_state = 1   -- enable

    -- 执行测试
    -- 注意参数顺序: package_id, channel_id, vlan_filter, vlan_state, vlan_id, eth_name
    local result = ncsi_vlan_filter.ncsi_set_vlan_filter(package_id, channel_id, vlan_filter,
        vlan_state, vlan_id, eth_name)

    -- 验证结果
    lu.assertEquals(result, ncsi_def.NCSI_SUCCESS)
    lu.assertNotNil(self.last_req_data)
    lu.assertEquals(self.last_eth_name, eth_name)

    -- 测试不同的参数组合
    -- 不同的filter值
    self.last_req_data = nil
    result = ncsi_vlan_filter.ncsi_set_vlan_filter(package_id, channel_id, 2, vlan_state, vlan_id, eth_name)
    lu.assertEquals(result, ncsi_def.NCSI_SUCCESS)
    lu.assertNotNil(self.last_req_data)

    -- 禁用状态 (vlan_state = 0)
    self.last_req_data = nil
    result = ncsi_vlan_filter.ncsi_set_vlan_filter(package_id, channel_id, vlan_filter, 0, vlan_id, eth_name)
    lu.assertEquals(result, ncsi_def.NCSI_SUCCESS)
    lu.assertNotNil(self.last_req_data)

    -- 不同的VLAN ID
    self.last_req_data = nil
    result = ncsi_vlan_filter.ncsi_set_vlan_filter(package_id, channel_id, vlan_filter, vlan_state, 200, eth_name)
    lu.assertEquals(result, ncsi_def.NCSI_SUCCESS)
    lu.assertNotNil(self.last_req_data)
end

-- 测试响应处理
function TestNCSIVlanFilter:test_response_processing()
    -- 模拟响应包
    local function create_mock_response(rsp_code, reason_code)
        local rsp = {
            packet_head = {
                package_id = 0,
                channel_id = 1,
                packet_type = SET_VLAN_FILTER_RSP,
                payload_len_hi = 0,
                payload_len_lo = 4,
                iid = 1
            },
            payload = ''
        }

        -- 创建响应payload
        local rsp_data = string.char(
            (rsp_code >> 8) & 0xFF, rsp_code & 0xFF,  -- rsp_code (2字节)
            (reason_code >> 8) & 0xFF, reason_code & 0xFF  -- reason_code (2字节)
        ) .. string.rep("\0", 4) ..   -- check_sum (4字节)
        string.rep("\0", 22) ..   -- pad_data (22字节)
        string.rep("\0", 4)   -- fcs (4字节)

        rsp.payload = rsp_data
        return rsp
    end

    -- 测试参数
    local package_id = 0
    local channel_id = 1
    local eth_name = "eth0"
    local vlan_id = 100
    local vlan_filter = 1
    local vlan_state = 1

    -- 保存原始函数
    local original_cmd_ctrl = ncsi_utils.ncsi_cmd_ctrl

    -- 测试成功响应
    local success_rsp = create_mock_response(ncsi_def.CMD_COMPLETED, 0)
    ncsi_utils.ncsi_cmd_ctrl = function(pkg_id, ch_id, req_packet, eth, cmd_process_table)
        -- 调用原始的请求处理函数
        if cmd_process_table[SET_VLAN_FILTER] then
            cmd_process_table[SET_VLAN_FILTER](req_packet, eth, vlan_filter, vlan_state, vlan_id)
        end
        
        -- 调用响应处理函数
        if cmd_process_table[SET_VLAN_FILTER_RSP] then
            return cmd_process_table[SET_VLAN_FILTER_RSP](success_rsp)
        end
        return ncsi_def.NCSI_FAIL
    end

    -- 测试成功情况
    local result = ncsi_vlan_filter.ncsi_set_vlan_filter(package_id, channel_id, vlan_filter,
        vlan_state, vlan_id, eth_name)
    lu.assertEquals(result, ncsi_def.NCSI_SUCCESS)

    -- 测试失败响应
    local fail_rsp = create_mock_response(0x0123, 1)
    ncsi_utils.ncsi_cmd_ctrl = function(pkg_id, ch_id, req_packet, eth, cmd_process_table)
        -- 调用原始的请求处理函数
        if cmd_process_table[SET_VLAN_FILTER] then
            cmd_process_table[SET_VLAN_FILTER](req_packet, eth, vlan_filter, vlan_state, vlan_id)
        end
        
        -- 调用响应处理函数
        if cmd_process_table[SET_VLAN_FILTER_RSP] then
            return cmd_process_table[SET_VLAN_FILTER_RSP](fail_rsp)
        end
        return ncsi_def.NCSI_FAIL
    end

    -- 测试失败情况
    result = ncsi_vlan_filter.ncsi_set_vlan_filter(package_id, channel_id, vlan_filter,
        vlan_state, vlan_id, eth_name)
    lu.assertEquals(result, ncsi_def.NCSI_FAIL)

    -- 恢复原始函数
    ncsi_utils.ncsi_cmd_ctrl = original_cmd_ctrl
end

-- 测试边界条件和错误处理
function TestNCSIVlanFilter:test_error_handling()
    -- 测试参数
    local package_id = 0
    local channel_id = 1
    local eth_name = "eth0"
    local vlan_id = 100
    local vlan_filter = 1
    local vlan_state = 1

    -- 保存并替换ncsi_utils.ncsi_cmd_ctrl函数
    local original_cmd_ctrl = ncsi_utils.ncsi_cmd_ctrl

    -- 模拟命令控制函数返回失败
    ncsi_utils.ncsi_cmd_ctrl = function(_, _, _, _, _)
        return ncsi_def.NCSI_FAIL
    end

    -- 测试错误处理 - 使用新的参数顺序
    local result = ncsi_vlan_filter.ncsi_set_vlan_filter(package_id, channel_id, vlan_filter,
        vlan_state, vlan_id, eth_name)
    lu.assertEquals(result, ncsi_def.NCSI_FAIL)

    -- 测试无效的VLAN ID (超出范围)
    ncsi_utils.ncsi_cmd_ctrl = function(_, _, req_packet, eth_name, cmd_process_table)
        -- 模拟请求处理成功但响应处理失败
        if cmd_process_table[SET_VLAN_FILTER] then
            cmd_process_table[SET_VLAN_FILTER](req_packet, eth_name)
        end
        return ncsi_def.NCSI_FAIL
    end

    -- 使用新的参数顺序, 超出范围的VLAN ID
    result = ncsi_vlan_filter.ncsi_set_vlan_filter(package_id, channel_id, vlan_filter, vlan_state, 4096, eth_name)
    lu.assertEquals(result, ncsi_def.NCSI_FAIL)

    -- 恢复原始函数
    ncsi_utils.ncsi_cmd_ctrl = original_cmd_ctrl
end

-- 测试参数有效性验证
function TestNCSIVlanFilter:test_parameter_validation()
    -- 测试参数
    local package_id = 0
    local channel_id = 1
    local eth_name = "eth0"
    local vlan_id = 100
    local vlan_filter = 1
    local vlan_state = 1

    -- 保存并替换ncsi_utils.ncsi_cmd_ctrl函数
    local original_cmd_ctrl = ncsi_utils.ncsi_cmd_ctrl

    -- 模拟cmd_ctrl以验证参数传递
    ncsi_utils.ncsi_cmd_ctrl = function(_, _, req_packet, eth_name, cmd_process_table)
        self.last_req_packet = req_packet
        if cmd_process_table[SET_VLAN_FILTER] then
            cmd_process_table[SET_VLAN_FILTER](req_packet, eth_name)
        end
        return ncsi_def.NCSI_SUCCESS
    end

    -- 测试不同的VLAN ID - 使用新的参数顺序
    ncsi_vlan_filter.ncsi_set_vlan_filter(package_id, channel_id,
        vlan_filter, vlan_state, vlan_id, eth_name)
    lu.assertNotNil(self.last_req_data)

    ncsi_vlan_filter.ncsi_set_vlan_filter(package_id, channel_id, 0, 0, 4095, eth_name)
    lu.assertNotNil(self.last_req_data)

    -- 恢复原始函数
    ncsi_utils.ncsi_cmd_ctrl = original_cmd_ctrl
end

-- 测试命令类型
function TestNCSIVlanFilter:test_command_type()
    -- 测试参数
    local package_id = 0
    local channel_id = 1
    local eth_name = "eth0"
    local vlan_id = 100
    local vlan_filter = 1
    local vlan_state = 1

    -- 记录请求包
    local original_cmd_ctrl = ncsi_utils.ncsi_cmd_ctrl

    -- 替换cmd_ctrl以捕获请求包
    ncsi_utils.ncsi_cmd_ctrl = function(_, _, req_packet, _, cmd_process_table)
        if cmd_process_table[SET_VLAN_FILTER] then
            cmd_process_table[SET_VLAN_FILTER](req_packet, eth_name)
        end
        return ncsi_def.NCSI_SUCCESS
    end

    -- 执行测试
    ncsi_vlan_filter.ncsi_set_vlan_filter(package_id, channel_id, vlan_filter, 
        vlan_state, vlan_id, eth_name)

    -- 从原始数据中提取并检查命令类型
    local original_send_ncsi_cmd = ncsi_protocol_intf.send_ncsi_cmd
    local captured_payloads = {}  -- 存储所有捕获的payload

    ncsi_protocol_intf.send_ncsi_cmd = function(req_data, len, _)
        table.insert(captured_payloads, req_data)
        return 0
    end

    -- 重新执行测试以捕获完整payload
    ncsi_vlan_filter.ncsi_set_vlan_filter(package_id, channel_id, vlan_filter, vlan_state, vlan_id, eth_name)

    -- 验证第一个payload（SET_VLAN_FILTER命令）
    lu.assertNotNil(captured_payloads[1], "应该捕获到SET_VLAN_FILTER命令的请求数据")
    local payload_offset = 16  -- 以太网头部长度
    local packet_type = string.byte(captured_payloads[1], payload_offset + 3)
    lu.assertEquals(packet_type, SET_VLAN_FILTER, "第一个请求包的类型应该是SET_VLAN_FILTER (0x0B)")

    -- 验证包头中的channel_id和package_id
    local extracted_channel_id = string.byte(captured_payloads[1], payload_offset + 4)
    local extracted_package_id = string.byte(captured_payloads[1], payload_offset + 5)
    lu.assertEquals(extracted_channel_id, channel_id, "channel_id应该与输入参数匹配")
    lu.assertEquals(extracted_package_id, package_id, "package_id应该与输入参数匹配")

    -- 恢复原始函数
    ncsi_utils.ncsi_cmd_ctrl = original_cmd_ctrl
    ncsi_protocol_intf.send_ncsi_cmd = original_send_ncsi_cmd
end

-- 测试响应命令类型
function TestNCSIVlanFilter:test_response_command_type()
    -- 测试参数
    local package_id = 0
    local channel_id = 1
    local eth_name = "eth0"
    local vlan_id = 100
    local vlan_filter = 1
    local vlan_state = 1

    -- 创建模拟响应包
    local function create_mock_response(rsp_code, reason_code, packet_type)
        local rsp = {
            frame_head = {},
            packet_head = {
                mc_id = 0,
                header_rev = 1,
                reserved = 0,
                packet_type = packet_type,  -- 设置响应的命令类型
                channel_id = channel_id,
                package_id = package_id,
                payload_len_hi = 0,
                payload_len_lo = 4
            },
            payload = ''
        }

        -- 创建响应payload
        local rsp_data = string.char(
            (rsp_code >> 8) & 0xFF, rsp_code & 0xFF,  -- rsp_code (2字节)
            (reason_code >> 8) & 0xFF, reason_code & 0xFF,  -- reason_code (2字节)
            0, 0, 0, 0  -- check_sum (4字节)
        ) .. string.rep("\0", 22)  -- pad_data (22字节)

        rsp.payload = rsp_data
        return rsp
    end

    -- 模拟处理函数捕获
    local response_handler_called = false
    local captured_response_type = nil

    -- 保存原始函数
    local original_cmd_ctrl = ncsi_utils.ncsi_cmd_ctrl

    -- 模拟cmd_ctrl以捕获响应处理
    ncsi_utils.ncsi_cmd_ctrl = function(pkg_id, ch_id, req_packet, eth, cmd_process_table)
        -- 创建一个响应
        local success_rsp = create_mock_response(ncsi_def.CMD_COMPLETED, 0, SET_VLAN_FILTER_RSP)

        -- 检查响应处理函数是否存在
        if cmd_process_table and cmd_process_table[SET_VLAN_FILTER_RSP] then
            response_handler_called = true
            captured_response_type = SET_VLAN_FILTER_RSP
            -- 直接处理响应，跳过请求发送
            return cmd_process_table[SET_VLAN_FILTER_RSP](success_rsp)
        end

        return ncsi_def.NCSI_SUCCESS
    end

    -- 执行测试 - 有效的响应类型 0x8B
    local result = ncsi_vlan_filter.ncsi_set_vlan_filter(package_id, channel_id, vlan_filter, 
        vlan_state, vlan_id, eth_name)

    -- 验证结果
    lu.assertEquals(result, ncsi_def.NCSI_SUCCESS)
    lu.assertTrue(response_handler_called, "应该调用响应处理函数")
    lu.assertEquals(captured_response_type, SET_VLAN_FILTER_RSP, 
        "响应命令类型应该是SET_VLAN_FILTER_RSP (0x8B)")

    -- 模拟cmd_ctrl以返回无效响应类型
    ncsi_utils.ncsi_cmd_ctrl = function(pkg_id, ch_id, req_packet, eth, cmd_process_table)
        -- 确保cmd_process_table存在并包含响应处理函数
        lu.assertNotNil(cmd_process_table, "命令处理表应该存在")

        -- 从命令处理表中提取响应处理函数
        local rsp_handler = cmd_process_table[SET_VLAN_FILTER_RSP]
        if not rsp_handler then
            -- 如果响应处理函数不存在，使用默认的响应处理函数
            rsp_handler = function(rsp)
                if rsp.packet_head.packet_type == SET_VLAN_FILTER_RSP then
                    return ncsi_def.NCSI_SUCCESS
                end
                return ncsi_def.NCSI_FAIL
            end
            cmd_process_table[SET_VLAN_FILTER_RSP] = rsp_handler
        end

        -- 创建一个错误的响应类型
        local invalid_rsp = create_mock_response(ncsi_def.CMD_COMPLETED, 0, 0xFF)  -- 无效的类型

        -- 尝试直接调用响应处理函数 - 在实际代码中应该会失败
        local ok, result = pcall(function() 
            -- 手动设置响应类型为有效值测试处理函数
            invalid_rsp.packet_head.packet_type = SET_VLAN_FILTER_RSP
            -- 但设置无效的响应码
            invalid_rsp.payload = string.char(
                0x01, 0x23,  -- 无效响应码 0x0123
                0x00, 0x01,  -- 错误原因码
                0, 0, 0, 0   -- check_sum
            ) .. string.rep("\0", 22)
            return rsp_handler(invalid_rsp)
        end)

        -- 验证结果
        if ok then
            return result  -- 应该是失败状态
        else
            return ncsi_def.NCSI_FAIL
        end
    end

    -- 执行测试
    result = ncsi_vlan_filter.ncsi_set_vlan_filter(package_id, channel_id, vlan_filter, 
        vlan_state, vlan_id, eth_name)

    -- 验证处理失败响应的结果
    lu.assertEquals(result, ncsi_def.NCSI_FAIL, "无效响应应该返回失败")

    -- 恢复原始函数
    ncsi_utils.ncsi_cmd_ctrl = original_cmd_ctrl
end