-- Copyright (c) 2024 Huawei Technologies Co., Ltd.
-- openUBMC is licensed under Mulan PSL v2.
-- You can use this software according to the terms and conditions of the Mulan PSL v2.
-- You may obtain a copy of Mulan PSL v2 at: http://license.coscl.org.cn/MulanPSL2
-- THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
-- EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
-- MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
-- See the Mulan PSL v2 for more details.

local lu = require 'luaunit'
local ncsi_def = require 'ncsi.ncsi_protocol.ncsi_def'
local ncsi_channel_init = require 'ncsi.ncsi_protocol.ncsi_channel_init'
local ncsi_utils = require 'ncsi.ncsi_protocol.ncsi_utils'
local ncsi_mac = require 'ncsi.ncsi_protocol.ncsi_mac'
local ncsi_broadcast_filter = require 'ncsi.ncsi_protocol.ncsi_broadcast_filter'
local ncsi_multicast_filter = require 'ncsi.ncsi_protocol.ncsi_multicast_filter'
local ncsi_vlan_mode = require 'ncsi.ncsi_protocol.ncsi_vlan_mode'
local ncsi_parameter = require 'ncsi.ncsi_protocol.ncsi_parameter'
local ncsi_aen = require 'ncsi.ncsi_protocol.ncsi_aen'

-- 测试套件
TestNCSIChannelInit = {}

-- 制造商ID常量
local INTEL_MANUFACTURE_ID = 0x57010000
local BCM_MANUFACTURE_ID = 0x3d110000
local MULEX__MANUFACTURE_ID = 0x6c00000

-- 初始化函数，在每个测试用例前执行
function TestNCSIChannelInit:setUp()
    ncsi_parameter.get_instance():init_ncsi_parameter()

    -- 保存原始函数
    self.original_set_phy_mac_filter = ncsi_mac.ncsi_set_phy_mac_filter
    self.original_enable_brdcast_filter = ncsi_broadcast_filter.ncsi_enable_brdcast_filter
    self.original_enable_multicast_filter = ncsi_multicast_filter.ncsi_enable_multicast_filter
    self.original_disable_vlan_req = ncsi_vlan_mode.ncsi_disable_vlan_req
    self.original_get_ncsi_parameter = ncsi_parameter.get_instance().get_ncsi_parameter
    self.original_aen_enable = ncsi_aen.ncsi_aen_enable

    -- 初始化测试变量
    self.call_count = {
        set_phy_mac_filter = 0,
        enable_brdcast_filter = 0,
        enable_multicast_filter = 0,
        disable_vlan_req = 0,
        aen_enable = 0
    }

    self.function_results = {
        set_phy_mac_filter = ncsi_def.NCSI_SUCCESS,
        enable_brdcast_filter = ncsi_def.NCSI_SUCCESS,
        enable_multicast_filter = ncsi_def.NCSI_SUCCESS,
        disable_vlan_req = ncsi_def.NCSI_SUCCESS,
        aen_enable = ncsi_def.NCSI_SUCCESS
    }

    -- 模拟NCSI参数
    self.mock_ncsi_parameter = {
        manufacture_id = 0,
        multicast_filter_cap = 8,
        iid = 1  -- 添加iid字段，这是NCSI协议需要的包ID
    }
    ncsi_parameter.get_instance().get_ncsi_parameter = function() 
        return self.mock_ncsi_parameter
    end

    -- 模拟各个函数
    ncsi_mac.ncsi_set_phy_mac_filter = function(package_id, channel_id, eth_name, filter_num)
        self.call_count.set_phy_mac_filter = self.call_count.set_phy_mac_filter + 1
        self.last_set_phy_mac_params = {package_id, channel_id, eth_name, filter_num}
        return self.function_results.set_phy_mac_filter
    end

    ncsi_broadcast_filter.ncsi_enable_brdcast_filter = function(package_id, channel_id, eth_name)
        self.call_count.enable_brdcast_filter = self.call_count.enable_brdcast_filter + 1
        self.last_enable_brdcast_params = {package_id, channel_id, eth_name}
        return self.function_results.enable_brdcast_filter
    end

    ncsi_multicast_filter.ncsi_enable_multicast_filter = function(package_id, channel_id, eth_name, filter_cap)
        self.call_count.enable_multicast_filter = self.call_count.enable_multicast_filter + 1
        self.last_enable_multicast_params = {package_id, channel_id, eth_name, filter_cap}
        return self.function_results.enable_multicast_filter
    end

    ncsi_vlan_mode.ncsi_disable_vlan_req = function(package_id, channel_id, eth_name)
        self.call_count.disable_vlan_req = self.call_count.disable_vlan_req + 1
        self.last_disable_vlan_params = {package_id, channel_id, eth_name}
        return self.function_results.disable_vlan_req
    end

    -- 模拟AEN启用函数
    ncsi_aen.ncsi_aen_enable = function(package_id, channel_id, eth_name)
        self.call_count.aen_enable = self.call_count.aen_enable + 1
        self.last_aen_enable_params = {package_id, channel_id, eth_name}
        return self.function_results.aen_enable
    end
end

-- 清理函数，在每个测试用例后执行
function TestNCSIChannelInit:tearDown()
    -- 恢复原始函数
    ncsi_mac.ncsi_set_phy_mac_filter = self.original_set_phy_mac_filter
    ncsi_broadcast_filter.ncsi_enable_brdcast_filter = self.original_enable_brdcast_filter
    ncsi_multicast_filter.ncsi_enable_multicast_filter = self.original_enable_multicast_filter
    ncsi_vlan_mode.ncsi_disable_vlan_req = self.original_disable_vlan_req
    ncsi_parameter.get_instance().get_ncsi_parameter = self.original_get_ncsi_parameter
    ncsi_aen.ncsi_aen_enable = self.original_aen_enable
end

-- 测试初始化通道成功场景
function TestNCSIChannelInit:test_initial_channel_success()
    -- 测试参数
    local package_id = 0
    local channel_id = 1
    local eth_name = "eth0"

    -- 执行测试
    ncsi_channel_init.initial_channel(package_id, channel_id, eth_name)

    -- 验证所有函数都被调用
    lu.assertEquals(self.call_count.set_phy_mac_filter, 1)
    lu.assertEquals(self.call_count.enable_brdcast_filter, 1)
    lu.assertEquals(self.call_count.enable_multicast_filter, 1)
    lu.assertEquals(self.call_count.disable_vlan_req, 1)
    lu.assertEquals(self.call_count.aen_enable, 1)

    -- 验证参数传递正确
    lu.assertEquals(self.last_set_phy_mac_params, {package_id, channel_id, eth_name, 1})
    lu.assertEquals(self.last_enable_brdcast_params, {package_id, channel_id, eth_name})
    lu.assertEquals(self.last_enable_multicast_params, {package_id, channel_id, eth_name, 8})
    lu.assertEquals(self.last_disable_vlan_params, {package_id, channel_id, eth_name})
    lu.assertEquals(self.last_aen_enable_params, {package_id, channel_id, eth_name})
end

-- 测试初始化通道但不启用AEN
function TestNCSIChannelInit:test_initial_channel_without_aen()
    -- 测试参数
    local package_id = 0
    local channel_id = 1
    local eth_name = "eth0"

    -- 执行测试
    ncsi_channel_init.initial_channel(package_id, channel_id, eth_name)

    -- 验证所有函数都被调用（包括AEN）
    lu.assertEquals(self.call_count.set_phy_mac_filter, 1)
    lu.assertEquals(self.call_count.enable_brdcast_filter, 1)
    lu.assertEquals(self.call_count.enable_multicast_filter, 1)
    lu.assertEquals(self.call_count.disable_vlan_req, 1)
    lu.assertEquals(self.call_count.aen_enable, 1)  -- AEN现在总是会被调用
end

-- 测试Intel制造商ID的多播过滤器处理
function TestNCSIChannelInit:test_intel_manufacturer_multicast_filter()
    -- 设置Intel制造商ID
    self.mock_ncsi_parameter.manufacture_id = INTEL_MANUFACTURE_ID

    -- 测试参数
    local package_id = 0
    local channel_id = 1
    local eth_name = "eth0"

    -- 执行测试
    ncsi_channel_init.initial_channel(package_id, channel_id, eth_name)

    -- 验证Intel网卡的多播过滤器参数为nil
    lu.assertEquals(self.last_enable_multicast_params, {package_id, channel_id, eth_name, nil})
end

-- 测试BCM制造商ID的多播过滤器处理
function TestNCSIChannelInit:test_bcm_manufacturer_multicast_filter()
    -- 设置BCM制造商ID
    self.mock_ncsi_parameter.manufacture_id = BCM_MANUFACTURE_ID

    -- 测试参数
    local package_id = 0
    local channel_id = 1
    local eth_name = "eth0"

    -- 执行测试
    ncsi_channel_init.initial_channel(package_id, channel_id, eth_name)

    -- 验证BCM网卡的多播过滤器参数为nil
    lu.assertEquals(self.last_enable_multicast_params, {package_id, channel_id, eth_name, nil})
end

-- 测试MULEX制造商ID的多播过滤器处理
function TestNCSIChannelInit:test_mulex_manufacturer_multicast_filter()
    -- 设置MULEX制造商ID
    self.mock_ncsi_parameter.manufacture_id = MULEX__MANUFACTURE_ID

    -- 测试参数
    local package_id = 0
    local channel_id = 1
    local eth_name = "eth0"

    -- 执行测试
    ncsi_channel_init.initial_channel(package_id, channel_id, eth_name)

    -- 验证MULEX网卡的多播过滤器参数为nil
    lu.assertEquals(self.last_enable_multicast_params, {package_id, channel_id, eth_name, nil})
end

-- 测试其他制造商ID的多播过滤器处理
function TestNCSIChannelInit:test_other_manufacturer_multicast_filter()
    -- 设置其他制造商ID
    self.mock_ncsi_parameter.manufacture_id = 0x12345678

    -- 测试参数
    local package_id = 0
    local channel_id = 1
    local eth_name = "eth0"

    -- 执行测试
    ncsi_channel_init.initial_channel(package_id, channel_id, eth_name)

    -- 验证其他网卡使用原始的多播过滤器能力
    lu.assertEquals(self.last_enable_multicast_params, {package_id, channel_id, eth_name, 8})
end

-- 测试MAC过滤器设置失败
function TestNCSIChannelInit:test_mac_filter_failure()
    -- 设置MAC过滤器失败
    self.function_results.set_phy_mac_filter = ncsi_def.NCSI_FAIL

    -- 测试参数
    local package_id = 0
    local channel_id = 1
    local eth_name = "eth0"

    -- 执行测试
    ncsi_channel_init.initial_channel(package_id, channel_id, eth_name)

    -- 验证MAC过滤器被调用但后续函数未被调用
    lu.assertEquals(self.call_count.set_phy_mac_filter, 1)
    lu.assertEquals(self.call_count.enable_brdcast_filter, 0)
    lu.assertEquals(self.call_count.enable_multicast_filter, 0)
    lu.assertEquals(self.call_count.disable_vlan_req, 0)
    lu.assertEquals(self.call_count.aen_enable, 0)
end

-- 测试广播过滤器设置失败
function TestNCSIChannelInit:test_broadcast_filter_failure()
    -- 设置广播过滤器失败
    self.function_results.enable_brdcast_filter = ncsi_def.NCSI_FAIL

    -- 测试参数
    local package_id = 0
    local channel_id = 1
    local eth_name = "eth0"

    -- 执行测试
    ncsi_channel_init.initial_channel(package_id, channel_id, eth_name)

    -- 验证MAC和广播过滤器被调用但后续函数未被调用
    lu.assertEquals(self.call_count.set_phy_mac_filter, 1)
    lu.assertEquals(self.call_count.enable_brdcast_filter, 1)
    lu.assertEquals(self.call_count.enable_multicast_filter, 0)
    lu.assertEquals(self.call_count.disable_vlan_req, 0)
    lu.assertEquals(self.call_count.aen_enable, 0)
end

-- 测试多播过滤器设置失败
function TestNCSIChannelInit:test_multicast_filter_failure()
    -- 设置多播过滤器失败
    self.function_results.enable_multicast_filter = ncsi_def.NCSI_FAIL

    -- 测试参数
    local package_id = 0
    local channel_id = 1
    local eth_name = "eth0"

    -- 执行测试
    ncsi_channel_init.initial_channel(package_id, channel_id, eth_name)

    -- 验证前三个函数被调用但后续函数未被调用
    lu.assertEquals(self.call_count.set_phy_mac_filter, 1)
    lu.assertEquals(self.call_count.enable_brdcast_filter, 1)
    lu.assertEquals(self.call_count.enable_multicast_filter, 1)
    lu.assertEquals(self.call_count.disable_vlan_req, 0)
    lu.assertEquals(self.call_count.aen_enable, 0)
end

-- 测试VLAN请求禁用失败
function TestNCSIChannelInit:test_vlan_disable_failure()
    -- 设置VLAN禁用失败
    self.function_results.disable_vlan_req = ncsi_def.NCSI_FAIL

    -- 测试参数
    local package_id = 0
    local channel_id = 1
    local eth_name = "eth0"

    -- 执行测试
    ncsi_channel_init.initial_channel(package_id, channel_id, eth_name)

    -- 验证前四个函数被调用但AEN未被调用
    lu.assertEquals(self.call_count.set_phy_mac_filter, 1)
    lu.assertEquals(self.call_count.enable_brdcast_filter, 1)
    lu.assertEquals(self.call_count.enable_multicast_filter, 1)
    lu.assertEquals(self.call_count.disable_vlan_req, 1)
    lu.assertEquals(self.call_count.aen_enable, 0)
end

-- 测试边界条件：package_id和channel_id为最大值
function TestNCSIChannelInit:test_boundary_conditions()
    -- 测试参数：使用边界值
    local package_id = 255
    local channel_id = 31
    local eth_name = "eth15"

    -- 执行测试
    ncsi_channel_init.initial_channel(package_id, channel_id, eth_name)

    -- 验证参数传递正确
    lu.assertEquals(self.last_set_phy_mac_params, {package_id, channel_id, eth_name, 1})
    lu.assertEquals(self.last_enable_brdcast_params, {package_id, channel_id, eth_name})
    lu.assertEquals(self.last_enable_multicast_params, {package_id, channel_id, eth_name, 8})
    lu.assertEquals(self.last_disable_vlan_params, {package_id, channel_id, eth_name})
    lu.assertEquals(self.last_aen_enable_params, {package_id, channel_id, eth_name})
end

-- 测试空字符串网卡名
function TestNCSIChannelInit:test_empty_eth_name()
    -- 测试参数
    local package_id = 0
    local channel_id = 1
    local eth_name = ""

    -- 执行测试
    ncsi_channel_init.initial_channel(package_id, channel_id, eth_name)

    -- 验证空字符串也能正确传递
    lu.assertEquals(self.last_set_phy_mac_params[3], "")
    lu.assertEquals(self.last_enable_brdcast_params[3], "")
    lu.assertEquals(self.last_enable_multicast_params[3], "")
    lu.assertEquals(self.last_disable_vlan_params[3], "")
    lu.assertEquals(self.last_aen_enable_params[3], "")
end

return TestNCSIChannelInit