-- Copyright (c) 2024 Huawei Technologies Co., Ltd.
-- openUBMC is licensed under Mulan PSL v2.
-- You can use this software according to the terms and conditions of the Mulan PSL v2.
-- You may obtain a copy of Mulan PSL v2 at: http://license.coscl.org.cn/MulanPSL2
-- THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
-- EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
-- MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
-- See the Mulan PSL v2 for more details.

local lu = require 'luaunit'
local ncsi_protocol_intf = require 'ncsi_protocol_intf'
local ncsi_utils = require 'ncsi.ncsi_protocol.ncsi_utils'
local ncsi_def = require 'ncsi.ncsi_protocol.ncsi_def'
local ncsi_channel = require 'ncsi.ncsi_protocol.ncsi_channel'
local ncsi_parameter = require 'ncsi.ncsi_protocol.ncsi_parameter'

-- 测试套件
TestNCSIChannel = {}

-- 命令类型常量定义
local ENABLE_CHANNEL = 0x03
local ENABLE_CHANNEL_RSP = 0x83
local DISABLE_CHANNEL = 0x04
local DISABLE_CHANNEL_RSP = 0x84
local ENABLE_CHANNEL_REQ_LEN = 0

-- 模拟发送NCSI命令的函数
local function mock_send_ncsi_cmd(req_data, len, eth_name)
    TestNCSIChannel.last_req_data = req_data
    TestNCSIChannel.last_len = len
    TestNCSIChannel.last_eth_name = eth_name
    return 0
end

-- 初始化函数，在每个测试用例前执行
function TestNCSIChannel:setUp()
    ncsi_parameter.get_instance():init_ncsi_parameter()
    -- 保存原始函数
    self.original_send_ncsi_cmd = ncsi_protocol_intf.send_ncsi_cmd
    self.original_cmd_ctrl = ncsi_utils.ncsi_cmd_ctrl
    self.original_get_ncsi_parameter = ncsi_parameter.get_instance().get_ncsi_parameter

    -- 使用模拟函数替换原始函数
    ncsi_protocol_intf.send_ncsi_cmd = mock_send_ncsi_cmd

    -- 初始化测试变量
    self.last_req_data = nil
    self.last_len = nil
    self.last_eth_name = nil

    -- 模拟NCSI参数
    self.mock_ncsi_parameter = {current_channel = 0, iid = 1, channel_cnt = 4, recv_buf = ''}
    ncsi_parameter.get_instance().get_ncsi_parameter = function() 
        return self.mock_ncsi_parameter
    end

    -- 模拟cmd_ctrl函数以便测试
    ncsi_utils.ncsi_cmd_ctrl = function(_, _, req_packet, eth_name, cmd_process_table)
        -- 确保cmd_process_table存在
        assert(cmd_process_table, "cmd_process_table is nil")

        -- 调用请求处理函数
        if req_packet and req_packet.packet_head and req_packet.packet_head.packet_type and
           cmd_process_table[req_packet.packet_head.packet_type] then
            cmd_process_table[req_packet.packet_head.packet_type](req_packet, eth_name)
        end

        -- 返回成功
        return ncsi_def.NCSI_SUCCESS
    end
end

-- 清理函数，在每个测试用例后执行
function TestNCSIChannel:tearDown()
    -- 恢复原始函数
    ncsi_protocol_intf.send_ncsi_cmd = self.original_send_ncsi_cmd
    ncsi_utils.ncsi_cmd_ctrl = self.original_cmd_ctrl
    ncsi_parameter.get_instance().get_ncsi_parameter = self.original_get_ncsi_parameter
end

-- 测试启用通道请求报文
function TestNCSIChannel:test_enable_channel_request()
    local package_id = 0
    local channel_id = 1
    local eth_name = "eth0"

    -- 执行测试
    local result = ncsi_channel.ncsi_enable_channel(package_id, channel_id, eth_name)

    -- 验证结果
    lu.assertEquals(result, ncsi_def.NCSI_SUCCESS)
    lu.assertNotNil(self.last_req_data)
    lu.assertEquals(self.last_eth_name, eth_name)

    -- 验证请求报文格式
    local parsed_packet = ncsi_utils.ncsi_packet_bs:unpack(self.last_req_data, true)
    lu.assertNotNil(parsed_packet.packet_head)
    lu.assertEquals(parsed_packet.packet_head.packet_type, ENABLE_CHANNEL)
    lu.assertEquals(parsed_packet.packet_head.channel_id, channel_id)
    lu.assertEquals(parsed_packet.packet_head.package_id, package_id)
    lu.assertEquals(parsed_packet.packet_head.payload_len_hi, (ENABLE_CHANNEL_REQ_LEN >> 8) & 0x0f)
    lu.assertEquals(parsed_packet.packet_head.payload_len_lo, ENABLE_CHANNEL_REQ_LEN & 0xff)
end

-- 测试启用通道功能
function TestNCSIChannel:test_enable_channel()
    local package_id = 0
    local channel_id = 1
    local eth_name = "eth0"

    -- 执行测试
    local result = ncsi_channel.ncsi_enable_channel(package_id, channel_id, eth_name)

    -- 验证结果
    lu.assertEquals(result, ncsi_def.NCSI_SUCCESS)
    lu.assertNotNil(self.last_req_data)
    lu.assertEquals(self.last_eth_name, eth_name)

    -- 测试不同的channel_id
    self.last_req_data = nil
    result = ncsi_channel.ncsi_enable_channel(package_id, 2, eth_name)
    lu.assertEquals(result, ncsi_def.NCSI_SUCCESS)
    lu.assertNotNil(self.last_req_data)
end

-- 测试禁用通道功能
function TestNCSIChannel:test_disable_channel()
    local package_id = 0
    local channel_id = 1
    local eth_name = "eth0"

    -- 执行测试
    local result = ncsi_channel.ncsi_disable_channel(package_id, channel_id, eth_name)

    -- 验证结果
    lu.assertEquals(result, ncsi_def.NCSI_SUCCESS)
    lu.assertNotNil(self.last_req_data)
    lu.assertEquals(self.last_eth_name, eth_name)

    -- 测试不同的channel_id
    self.last_req_data = nil
    result = ncsi_channel.ncsi_disable_channel(package_id, 2, eth_name)
    lu.assertEquals(result, ncsi_def.NCSI_SUCCESS)
    lu.assertNotNil(self.last_req_data)
end

-- 测试响应处理
function TestNCSIChannel:test_response_processing()
    -- 模拟响应包
    local function create_mock_response(rsp_code, reason_code, packet_type)
        local rsp = {packet_head = {}, payload = ''}
        rsp.packet_head.payload_len_hi = 0
        rsp.packet_head.payload_len_lo = 4
        rsp.packet_head.packet_type = packet_type

        -- 创建响应payload
        local rsp_data = string.char(
            (rsp_code >> 8) & 0xFF, rsp_code & 0xFF,  -- rsp_code (2字节)
            (reason_code >> 8) & 0xFF, reason_code & 0xFF,  -- reason_code (2字节)
            0, 0, 0, 0  -- check_sum (4字节)
        ) .. string.rep("\0", 22)  -- pad_data (22字节)

        rsp.payload = rsp_data
        return rsp
    end

    local package_id = 0
    local channel_id = 1
    local eth_name = "eth0"

    -- 测试成功响应
    local success_rsp_enable = create_mock_response(ncsi_def.CMD_COMPLETED, 0, ENABLE_CHANNEL_RSP)
    local success_rsp_disable = create_mock_response(ncsi_def.CMD_COMPLETED, 0, DISABLE_CHANNEL_RSP)
    
    local original_cmd_ctrl = ncsi_utils.ncsi_cmd_ctrl

    -- 测试启用通道的成功响应
    ncsi_utils.ncsi_cmd_ctrl = function(_, _, req_packet, _, cmd_process_table)
        if req_packet.packet_head.packet_type == ENABLE_CHANNEL then
            return cmd_process_table[ENABLE_CHANNEL_RSP](success_rsp_enable)
        else
            return cmd_process_table[DISABLE_CHANNEL_RSP](success_rsp_disable)
        end
    end

    local result = ncsi_channel.ncsi_enable_channel(package_id, channel_id, eth_name)
    lu.assertEquals(result, ncsi_def.NCSI_SUCCESS)

    result = ncsi_channel.ncsi_disable_channel(package_id, channel_id, eth_name)
    lu.assertEquals(result, ncsi_def.NCSI_SUCCESS)

    -- 测试失败响应
    local fail_rsp_enable = create_mock_response(0x0123, 1, ENABLE_CHANNEL_RSP)
    local fail_rsp_disable = create_mock_response(0x0123, 1, DISABLE_CHANNEL_RSP)
    
    ncsi_utils.ncsi_cmd_ctrl = function(_, _, req_packet, _, cmd_process_table)
        if req_packet.packet_head.packet_type == ENABLE_CHANNEL then
            return cmd_process_table[ENABLE_CHANNEL_RSP](fail_rsp_enable)
        else
            return cmd_process_table[DISABLE_CHANNEL_RSP](fail_rsp_disable)
        end
    end

    result = ncsi_channel.ncsi_enable_channel(package_id, channel_id, eth_name)
    lu.assertEquals(result, ncsi_def.NCSI_FAIL)

    result = ncsi_channel.ncsi_disable_channel(package_id, channel_id, eth_name)
    lu.assertEquals(result, ncsi_def.NCSI_FAIL)

    -- 恢复原始函数
    ncsi_utils.ncsi_cmd_ctrl = original_cmd_ctrl
end

-- 测试错误处理
function TestNCSIChannel:test_error_handling()
    -- 模拟cmd_ctrl返回失败
    local original_cmd_ctrl = ncsi_utils.ncsi_cmd_ctrl
    ncsi_utils.ncsi_cmd_ctrl = function(_, _, _, _, _)
        return ncsi_def.NCSI_FAIL
    end

    local package_id = 0
    local channel_id = 1
    local eth_name = "eth0"

    -- 测试启用通道失败
    local result = ncsi_channel.ncsi_enable_channel(package_id, channel_id, eth_name)
    lu.assertEquals(result, ncsi_def.NCSI_FAIL)

    -- 测试禁用通道失败
    result = ncsi_channel.ncsi_disable_channel(package_id, channel_id, eth_name)
    lu.assertEquals(result, ncsi_def.NCSI_FAIL)

    -- 恢复原始函数
    ncsi_utils.ncsi_cmd_ctrl = original_cmd_ctrl
end

-- 测试响应包payload为空的情况
function TestNCSIChannel:test_empty_payload()
    -- 模拟响应包
    local function create_mock_response()
        local rsp = {packet_head = {}, payload = ''}
        rsp.packet_head.payload_len_hi = 0
        rsp.packet_head.payload_len_lo = 4
        rsp.packet_head.packet_type = ENABLE_CHANNEL_RSP

        -- 创建响应payload
        local rsp_payload = '\0'

        rsp.payload = rsp_payload
        return rsp
    end
    -- 测试参数
    local package_id = 0
    local channel_id = 1
    local eth_name = "eth0"

    -- 模拟cmd_ctrl返回空响应
    local original_cmd_ctrl = ncsi_utils.ncsi_cmd_ctrl
    ncsi_utils.ncsi_cmd_ctrl = function(_, _, _, _, cmd_process_table)
        local empty_rsp = create_mock_response()
        return cmd_process_table[ENABLE_CHANNEL_RSP](empty_rsp)
    end

    -- 测试获取状态
    local result = ncsi_channel.ncsi_enable_channel(package_id, channel_id, eth_name)
    lu.assertEquals(result, ncsi_def.NCSI_FAIL, "Should fail when response is nil")

    -- 恢复原始函数
    ncsi_utils.ncsi_cmd_ctrl = original_cmd_ctrl
end