-- Copyright (c) 2024 Huawei Technologies Co., Ltd.
-- openUBMC is licensed under Mulan PSL v2.
-- You can use this software according to the terms and conditions of the Mulan PSL v2.
-- You may obtain a copy of Mulan PSL v2 at: http://license.coscl.org.cn/MulanPSL2
-- THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
-- EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
-- MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
-- See the Mulan PSL v2 for more details.

local lu = require 'luaunit'
local bs = require 'mc.bitstring'
local ncsi_def = require 'ncsi.ncsi_protocol.ncsi_def'
local ncsi_aen = require 'ncsi.ncsi_protocol.ncsi_aen'
local ncsi_protocol_intf = require 'ncsi_protocol_intf'
local ncsi_utils = require 'ncsi.ncsi_protocol.ncsi_utils'
local ncsi_parameter = require 'ncsi.ncsi_protocol.ncsi_parameter'

-- 测试套件
TestNCSIAen = {}

-- 命令类型常量定义
local AEN_ENABLE_RSP = 0x88

local aen_enable_rsp_bs = bs.new([[<<
    rsp_code:16,
    reason_code:16,
    check_sum:32,
    data:22/string,
    fcs:32
>>]])

-- 模拟发送NCSI命令的函数
local function mock_send_ncsi_cmd(req_data, len, eth_name)
    TestNCSIAen.last_req_data = req_data
    TestNCSIAen.last_len = len
    TestNCSIAen.last_eth_name = eth_name
    return ncsi_def.NCSI_SUCCESS
end

-- 初始化函数，在每个测试用例前执行
function TestNCSIAen:setUp()
    ncsi_parameter.get_instance():init_ncsi_parameter()
    -- 保存原始函数
    self.original_send_ncsi_cmd = ncsi_protocol_intf.send_ncsi_cmd
    self.original_cmd_ctrl = ncsi_utils.ncsi_cmd_ctrl
    self.original_get_ncsi_parameter = ncsi_parameter.get_instance().get_ncsi_parameter
    self.original_get_checksum = ncsi_utils.get_checksum
    self.original_get_crc32 = ncsi_utils.get_crc32

    -- 使用模拟函数替换原始函数
    ncsi_protocol_intf.send_ncsi_cmd = mock_send_ncsi_cmd
    
    -- Mock checksum和CRC32计算函数
    ncsi_utils.get_checksum = function(req_packet, len)
        return 0x12345678  -- 返回模拟的校验和
    end
    
    ncsi_utils.get_crc32 = function(req_packet, len)
        return 0x87654321  -- 返回模拟的CRC32
    end

    -- 初始化测试变量
    TestNCSIAen.last_req_data = nil
    TestNCSIAen.last_len = nil
    TestNCSIAen.last_eth_name = nil

    -- 模拟NCSI参数
    self.mock_ncsi_parameter = {
        current_channel = 0,
        iid = 1  -- 添加iid字段，这是NCSI协议需要的包ID
    }
    ncsi_parameter.get_instance().get_ncsi_parameter = function() 
        return self.mock_ncsi_parameter
    end

    -- 模拟cmd_ctrl函数以便测试
    ncsi_utils.ncsi_cmd_ctrl = function(package_id, channel_id, req_packet, eth_name, cmd_process_table)
        -- 确保cmd_process_table存在
        assert(cmd_process_table, "cmd_process_table is nil")

        -- 调用请求处理函数
        if req_packet and req_packet.packet_head and req_packet.packet_head.packet_type and
           cmd_process_table[req_packet.packet_head.packet_type] then
            -- 对于AEN enable请求，需要传递额外的参数
            if req_packet.packet_head.packet_type == 0x08 then  -- AEN_ENABLE_REQ
                cmd_process_table[req_packet.packet_head.packet_type](req_packet, eth_name, 1)  -- enable_flag=1
            else
                cmd_process_table[req_packet.packet_head.packet_type](req_packet, eth_name)
            end
        end

        -- 返回成功
        return ncsi_def.NCSI_SUCCESS
    end
end

-- 清理函数，在每个测试用例后执行
function TestNCSIAen:tearDown()
    -- 恢复原始函数
    ncsi_protocol_intf.send_ncsi_cmd = self.original_send_ncsi_cmd
    ncsi_utils.ncsi_cmd_ctrl = self.original_cmd_ctrl
    ncsi_parameter.get_instance().get_ncsi_parameter = self.original_get_ncsi_parameter
    ncsi_utils.get_checksum = self.original_get_checksum
    ncsi_utils.get_crc32 = self.original_get_crc32
    
    -- 清除回调函数，避免测试之间的相互影响
    ncsi_aen.register_channel_init_callback(nil)
end

-- 测试启用AEN功能
function TestNCSIAen:test_aen_enable()
    -- 测试参数
    local package_id = 0
    local channel_id = 1
    local eth_name = "eth0"

    -- 初始化测试状态
    TestNCSIAen.last_req_data = nil
    TestNCSIAen.last_eth_name = nil

    -- 执行测试
    local result = ncsi_aen.ncsi_aen_enable(package_id, channel_id, eth_name)

    -- 验证结果
    lu.assertEquals(result, ncsi_def.NCSI_SUCCESS, "AEN enable should return SUCCESS")
    lu.assertNotNil(TestNCSIAen.last_req_data, "Request data should be set by mock function")
    lu.assertEquals(TestNCSIAen.last_eth_name, eth_name, "Ethernet name should match")
end

-- 测试响应处理
function TestNCSIAen:test_response_processing()
    -- 模拟响应包
    local function create_mock_response(rsp_code, reason_code)
        local rsp = {packet_head = {}, payload = ''}
        rsp.packet_head.payload_len_hi = 0
        rsp.packet_head.payload_len_lo = 4
        rsp.packet_head.packet_type = AEN_ENABLE_RSP
        rsp.packet_head.package_id = 0
        rsp.packet_head.channel_id = 1

        -- 创建响应payload
        local rsp_payload = {
            rsp_code = rsp_code,
            reason_code = reason_code,
            check_sum = 0,
            data = string.rep('\0', 22),
            fcs = 0
        }

        -- 使用bitstring打包数据
        rsp.payload = aen_enable_rsp_bs:pack(rsp_payload)
        return rsp
    end

    -- 测试参数
    local package_id = 0
    local channel_id = 1
    local eth_name = "eth0"

    -- 测试成功响应
    local success_rsp = create_mock_response(ncsi_def.CMD_COMPLETED, 0)

    local original_cmd_ctrl = ncsi_utils.ncsi_cmd_ctrl
    ncsi_utils.ncsi_cmd_ctrl = function(_, _, _, _, cmd_process_table)
        -- 确保cmd_process_table存在
        assert(cmd_process_table, "cmd_process_table is nil")
        assert(cmd_process_table[AEN_ENABLE_RSP], "AEN_ENABLE_RSP handler not found")

        -- 调用响应处理函数
        return cmd_process_table[AEN_ENABLE_RSP](success_rsp)
    end

    local result = ncsi_aen.ncsi_aen_enable(package_id, channel_id, eth_name)
    lu.assertEquals(result, ncsi_def.NCSI_SUCCESS)

    -- 测试失败响应
    local fail_rsp = create_mock_response(0x0123, 1)

    ncsi_utils.ncsi_cmd_ctrl = function(_, _, _, _, cmd_process_table)
        return cmd_process_table[AEN_ENABLE_RSP](fail_rsp)
    end

    result = ncsi_aen.ncsi_aen_enable(package_id, channel_id, eth_name)
    lu.assertEquals(result, ncsi_def.NCSI_FAIL)

    -- 恢复原始函数
    ncsi_utils.ncsi_cmd_ctrl = original_cmd_ctrl
end

-- 测试错误处理
function TestNCSIAen:test_error_handling()
    -- 模拟cmd_ctrl返回失败
    local original_cmd_ctrl = ncsi_utils.ncsi_cmd_ctrl
    ncsi_utils.ncsi_cmd_ctrl = function(_, _, _, _, _)
        return ncsi_def.NCSI_FAIL
    end

    -- 测试参数
    local package_id = 0
    local channel_id = 1
    local eth_name = "eth0"

    -- 测试启用AEN失败
    local result = ncsi_aen.ncsi_aen_enable(package_id, channel_id, eth_name)
    lu.assertEquals(result, ncsi_def.NCSI_FAIL)

    -- 恢复原始函数
    ncsi_utils.ncsi_cmd_ctrl = original_cmd_ctrl
end

-- AEN处理相关常量
local CONFIGURATION_REQUIRED = 0x01
local OEM_NIC_CONFIGURATION_REINIT = 0x80

-- AEN数据包响应的bitstring结构
local aen_packet_rsp_bs = bs.new([[<<
    reserved1:3/string,
    aen_type:8,
    optional_aen_data:32,
    check_sum:32
>>]])

-- 创建模拟AEN数据包的辅助函数（简化版本）
local function create_mock_aen_packet(package_id, channel_id, aen_type, optional_data)
    optional_data = optional_data or 0

    -- 直接创建符合ncsi_aen_packet_proc期望格式的原始数据
    -- 这是一个简化的方法，创建最小可用的AEN数据包
    
    -- 以太网头部 (14字节)
    local eth_header = string.rep('\0', 6) .. string.rep('\0', 6) .. '\x88\xF8'
    
    -- NCSI包头部 (16字节)
    local packet_header = string.char(
        ncsi_def.NCSI_MC_ID,           -- mc_id
        0x01,                          -- header_revision
        0x00,                          -- reserved
        0x00,                          -- iid (AEN_PACKET_IID)
        0xFF,                          -- packet_type (AEN_PACKET_TYPE)
        ((package_id & 0x07) << 5) | (channel_id & 0x1F), -- package_id(3位,高位) + channel_id(5位,低位)
        0x00,                          -- payload_len_hi(4位) + reserved1(4位)
        0x0C,                          -- payload_len_lo (12字节的AEN payload)
        0x00, 0x00, 0x00, 0x00,       -- reserved2
        0x00, 0x00, 0x00, 0x00        -- reserved3
    )
    
    -- AEN payload (12字节)
    local aen_payload = string.rep('\0', 3) ..     -- reserved1 (3字节)
        string.char(aen_type) ..                   -- aen_type (1字节)
        string.char(
            (optional_data >> 24) & 0xFF,         -- optional_aen_data (4字节，大端序)
            (optional_data >> 16) & 0xFF,
            (optional_data >> 8) & 0xFF,
            optional_data & 0xFF
        ) .. 
        string.rep('\0', 4)                       -- check_sum (4字节)
    
    return eth_header .. packet_header .. aen_payload
end

-- 测试AEN数据包处理 - CONFIGURATION_REQUIRED类型
function TestNCSIAen:test_aen_packet_proc_configuration_required()
    -- 清除之前的回调函数
    ncsi_aen.register_channel_init_callback(nil)
    
    -- 初始化测试变量
    local callback_call_count = 0
    local last_callback_params = nil

    -- 注册模拟回调函数
    ncsi_aen.register_channel_init_callback(function(package_id, channel_id, eth_name)
        callback_call_count = callback_call_count + 1
        last_callback_params = {package_id, channel_id, eth_name}
    end)

    -- 测试参数
    local package_id = 1
    local channel_id = 2
    local eth_name = "eth0"
    local aen_type = CONFIGURATION_REQUIRED

    -- 创建模拟AEN数据包
    local data = create_mock_aen_packet(package_id, channel_id, aen_type)

    -- 执行测试
    local success, result = ncsi_aen.ncsi_aen_packet_proc(data, eth_name)

    -- 验证结果
    lu.assertTrue(success)
    lu.assertEquals(result, 0)
    lu.assertEquals(callback_call_count, 1)
    lu.assertNotNil(last_callback_params)
    lu.assertEquals(last_callback_params[1], package_id)
    lu.assertEquals(last_callback_params[2], channel_id)
    lu.assertEquals(last_callback_params[3], eth_name)
end

-- 测试AEN数据包处理 - OEM_NIC_CONFIGURATION_REINIT类型
function TestNCSIAen:test_aen_packet_proc_oem_reinit()
    -- 测试参数
    local package_id = 1
    local channel_id = 2
    local eth_name = "eth1"
    local aen_type = OEM_NIC_CONFIGURATION_REINIT

    -- 创建模拟AEN数据包
    local data = create_mock_aen_packet(package_id, channel_id, aen_type)

    -- 执行测试
    local success, result = ncsi_aen.ncsi_aen_packet_proc(data, eth_name)

    -- 验证结果
    lu.assertTrue(success)
    lu.assertEquals(result, 1)
end

-- 测试AEN数据包处理 - 未知AEN类型
function TestNCSIAen:test_aen_packet_proc_unknown_type()
    -- 测试参数
    local package_id = 0
    local channel_id = 1
    local eth_name = "eth2"
    local aen_type = 0x99  -- 未知类型

    -- 创建模拟AEN数据包
    local data = create_mock_aen_packet(package_id, channel_id, aen_type)

    -- 执行测试
    local success, result = ncsi_aen.ncsi_aen_packet_proc(data, eth_name)

    -- 验证结果
    lu.assertTrue(success)
    lu.assertEquals(result, 0)
end

-- 测试AEN数据包处理 - 边界条件：package_id和channel_id为最大值
function TestNCSIAen:test_aen_packet_proc_boundary_conditions()
    -- 清除之前的回调函数
    ncsi_aen.register_channel_init_callback(nil)
    
    -- 初始化测试变量
    local last_callback_params = nil

    -- 注册模拟回调函数
    ncsi_aen.register_channel_init_callback(function(package_id, channel_id, eth_name)
        last_callback_params = {package_id, channel_id, eth_name}
    end)

    -- 测试参数：使用边界值
    local package_id = 7    -- package_id最大值（3位：0-7）
    local channel_id = 31   -- channel_id最大值（5位：0-31）
    local eth_name = "eth15"
    local aen_type = CONFIGURATION_REQUIRED

    -- 创建模拟AEN数据包
    local data = create_mock_aen_packet(package_id, channel_id, aen_type)

    -- 执行测试
    local success, result = ncsi_aen.ncsi_aen_packet_proc(data, eth_name)

    -- 验证结果
    lu.assertTrue(success)
    lu.assertEquals(result, 0)
    lu.assertNotNil(last_callback_params)
    lu.assertEquals(last_callback_params[1], package_id)
    lu.assertEquals(last_callback_params[2], channel_id)
    lu.assertEquals(last_callback_params[3], eth_name)
end

-- 测试AEN数据包处理 - 空网卡名称
function TestNCSIAen:test_aen_packet_proc_empty_eth_name()
    -- 清除之前的回调函数
    ncsi_aen.register_channel_init_callback(nil)
    
    -- 初始化测试变量
    local last_callback_params = nil

    -- 注册模拟回调函数
    ncsi_aen.register_channel_init_callback(function(package_id, channel_id, eth_name)
        last_callback_params = {package_id, channel_id, eth_name}
    end)

    -- 测试参数
    local package_id = 0
    local channel_id = 1
    local eth_name = ""  -- 空网卡名称
    local aen_type = CONFIGURATION_REQUIRED

    -- 创建模拟AEN数据包
    local data = create_mock_aen_packet(package_id, channel_id, aen_type)

    -- 执行测试
    local success, result = ncsi_aen.ncsi_aen_packet_proc(data, eth_name)

    -- 验证结果
    lu.assertTrue(success)
    lu.assertEquals(result, 0)
    lu.assertNotNil(last_callback_params)
    lu.assertEquals(last_callback_params[3], "")
end

-- 测试AEN数据包处理 - 多种AEN类型连续处理
function TestNCSIAen:test_aen_packet_proc_multiple_types()
    -- 清除之前的回调函数
    ncsi_aen.register_channel_init_callback(nil)
    
    -- 初始化测试变量
    local callback_call_count = 0

    -- 注册模拟回调函数
    ncsi_aen.register_channel_init_callback(function(package_id, channel_id, eth_name)
        callback_call_count = callback_call_count + 1
    end)

    -- 测试参数
    local package_id = 0
    local channel_id = 1
    local eth_name = "eth0"

    -- 测试1：CONFIGURATION_REQUIRED类型
    local data1 = create_mock_aen_packet(package_id, channel_id, CONFIGURATION_REQUIRED)
    local success1, result1 = ncsi_aen.ncsi_aen_packet_proc(data1, eth_name)

    -- 测试2：OEM_NIC_CONFIGURATION_REINIT类型
    local data2 = create_mock_aen_packet(package_id, channel_id, OEM_NIC_CONFIGURATION_REINIT)
    local success2, result2 = ncsi_aen.ncsi_aen_packet_proc(data2, eth_name)

    -- 测试3：未知类型
    local data3 = create_mock_aen_packet(package_id, channel_id, 0x99)
    local success3, result3 = ncsi_aen.ncsi_aen_packet_proc(data3, eth_name)

    -- 验证结果
    lu.assertTrue(success1)
    lu.assertEquals(result1, 0)
    lu.assertTrue(success2)
    lu.assertEquals(result2, 1)
    lu.assertTrue(success3)
    lu.assertEquals(result3, 0)

    lu.assertEquals(callback_call_count, 1)  -- 只有CONFIGURATION_REQUIRED会调用
end

-- 测试AEN数据包处理 - 带可选数据的包
function TestNCSIAen:test_aen_packet_proc_with_optional_data()
    -- 测试参数
    local package_id = 1
    local channel_id = 2
    local eth_name = "eth0"
    local aen_type = OEM_NIC_CONFIGURATION_REINIT
    local optional_data = 0x12345678

    -- 创建带可选数据的模拟AEN数据包
    local data = create_mock_aen_packet(package_id, channel_id, aen_type, optional_data)

    -- 执行测试
    local success, result = ncsi_aen.ncsi_aen_packet_proc(data, eth_name)

    -- 验证结果
    lu.assertTrue(success)
    lu.assertEquals(result, 1)
end

-- 测试AEN数据包处理 - 无效数据包类型
function TestNCSIAen:test_aen_packet_proc_invalid_packet()
    -- 创建一个无效的AEN数据包（错误的packet_type）
    -- 以太网头部 (14字节)
    local eth_header = string.rep('\0', 6) .. string.rep('\0', 6) .. '\x88\xF8'
    
    -- NCSI包头部 (16字节) - 使用错误的packet_type
    local packet_header = string.char(
        ncsi_def.NCSI_MC_ID,           -- mc_id
        0x01,                          -- header_revision
        0x00,                          -- reserved
        0x00,                          -- iid
        0x80,                          -- packet_type (错误的值，不是0xFF)
        (1 << 3) | (0 & 0x07),        -- channel_id=1, package_id=0
        0x00,                          -- payload_len_hi + reserved1
        0x0C,                          -- payload_len_lo
        0x00, 0x00, 0x00, 0x00,       -- reserved2
        0x00, 0x00, 0x00, 0x00        -- reserved3
    )
    
    -- AEN payload (12字节)
    local aen_payload = string.rep('\0', 3) ..     -- reserved1 (3字节)
        string.char(CONFIGURATION_REQUIRED) ..     -- aen_type (1字节)
        string.rep('\0', 4) ..                     -- optional_aen_data (4字节)
        string.rep('\0', 4)                        -- check_sum (4字节)
    
    local invalid_data = eth_header .. packet_header .. aen_payload

    -- 执行测试
    local success, result = ncsi_aen.ncsi_aen_packet_proc(invalid_data, "eth0")

    -- 验证结果：应该返回false表示不是有效的AEN包
    lu.assertFalse(success)
    lu.assertEquals(result, 0)
end

-- 测试响应包为空的情况
function TestNCSIAen:test_empty_response()
    -- 测试参数
    local eth_name = "eth0"
    local invalid_data = string.rep('\0', 10)  -- 只提供10字节的数据，远小于所需长度

    -- 测试AEN处理
    local success, result = ncsi_aen.ncsi_aen_packet_proc(invalid_data, eth_name)
    lu.assertFalse(success)
    lu.assertEquals(result, 0)
end

return TestNCSIAen