-- Copyright (c) 2024 Huawei Technologies Co., Ltd.
-- openUBMC is licensed under Mulan PSL v2.
-- You can use this software according to the terms and conditions of the Mulan PSL v2.
-- You may obtain a copy of Mulan PSL v2 at: http://license.coscl.org.cn/MulanPSL2
-- THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
-- EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
-- MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
-- See the Mulan PSL v2 for more details.

local lu = require 'luaunit'
local bs = require 'mc.bitstring'
local ncsi_oem_response = require 'ncsi.ncsi_protocol.ncsi_oem_response'
local ncsi_def = require 'ncsi.ncsi_protocol.ncsi_def'
local ncsi_utils = require 'ncsi.ncsi_protocol.ncsi_utils'
local core = require 'network.core'

-- 测试套件
TestNCSIOemResponse = {}

-- OEM命令响应结构定义（与源代码保持一致）
local oem_command_rsp_bs = bs.new([[<<
    rsp_code:16,
    reason_code:16,
    manufacture_id:32,
    cmd_rev:8,
    cmd_id:8,
    sub_cmd:8,
    reserved:8
>>]])

local oem_rsp_bs = bs.new([[<<
    rsp_code:16,
    reason_code:16,
    manufacture_id:32,
    data:22/string,
    fcs:32
>>]])

-- 设置测试环境
function TestNCSIOemResponse:setUp()
    -- 保存原始函数
    self.original_common_respcode_parse = ncsi_utils.common_respcode_parse
    self.original_common_reasoncode_parse = ncsi_utils.common_reasoncode_parse
    self.original_ntohl = core.ntohl

    -- 重置调用计数器
    self.respcode_parse_call_count = 0
    self.reasoncode_parse_call_count = 0
    self.ntohl_call_count = 0
    self.callback_call_count = 0

    -- 模拟 common_respcode_parse 函数
    ncsi_utils.common_respcode_parse = function(rsp_code)
        self.respcode_parse_call_count = self.respcode_parse_call_count + 1
    end

    -- 模拟 common_reasoncode_parse 函数
    ncsi_utils.common_reasoncode_parse = function(reason_code)
        self.reasoncode_parse_call_count = self.reasoncode_parse_call_count + 1
    end

    -- 正确模拟 ntohl 函数 - 网络字节序到主机字节序的转换
    -- 在测试环境中，由于我们使用bitstring库创建的是小端字节序的payload，
    -- 所以ntohl应该直接返回原值（模拟小端主机）
    core.ntohl = function(value)
        self.ntohl_call_count = self.ntohl_call_count + 1
        return value
    end
end

-- 清理测试环境
function TestNCSIOemResponse:tearDown()
    -- 恢复原始函数
    ncsi_utils.common_respcode_parse = self.original_common_respcode_parse
    ncsi_utils.common_reasoncode_parse = self.original_common_reasoncode_parse
    core.ntohl = self.original_ntohl
end

-- 创建模拟响应包
function TestNCSIOemResponse:create_mock_response(rsp_code, reason_code,
    manufacture_id, cmd_rev, cmd_id, sub_cmd, reserved)
    rsp_code = rsp_code or ncsi_def.CMD_COMPLETED
    reason_code = reason_code or 0x0000
    manufacture_id = manufacture_id or ncsi_oem_response.MANUFACTURE_ID_HUAWEI
    cmd_rev = cmd_rev or 0x00
    cmd_id = cmd_id or 0x04
    sub_cmd = sub_cmd or 0x0B
    reserved = reserved or 0x00

    -- 使用与源代码相同的bitstring结构来创建payload
    local payload = oem_command_rsp_bs:pack({
        rsp_code = rsp_code,
        reason_code = reason_code,
        manufacture_id = manufacture_id,
        cmd_rev = cmd_rev,
        cmd_id = cmd_id,
        sub_cmd = sub_cmd,
        reserved = reserved
    })

    return {
        packet_head = {
            payload_len_hi = 0,
            payload_len_lo = #payload
        },
        payload = payload
    }
end

-- 创建模拟响应包
function TestNCSIOemResponse:create_mock_common_response(rsp_code, reason_code,
    manufacture_id)
    rsp_code = rsp_code or ncsi_def.CMD_COMPLETED
    reason_code = reason_code or 0x0000
    manufacture_id = manufacture_id

    -- 使用与源代码相同的bitstring结构来创建payload
    local payload = oem_rsp_bs:pack({
        rsp_code = rsp_code,
        reason_code = reason_code,
        manufacture_id = manufacture_id,
        data = string.rep('\0', 22),
        fcs = 0
    })

    return {
        packet_head = {
            payload_len_hi = 0,
            payload_len_lo = #payload
        },
        payload = payload
    }
end

-- 创建回调函数用于测试
function TestNCSIOemResponse:create_test_callback(expected_cmd_id, expected_sub_cmd)
    return function(rsp)
        self.callback_call_count = self.callback_call_count + 1
        self.last_callback_rsp = rsp
    end
end

-- 测试读取OEM命令响应（华为厂商成功）
function TestNCSIOemResponse:test_read_oem_command_rsp_huawei_success()
    local rsp = self:create_mock_response()
    local callback_table = {
        ncsi_oem_response.create_callback_entry(0x04, 0x0B, self:create_test_callback(0x04, 0x0B))
    }

    local ret = ncsi_oem_response.read_oem_command_rsp(rsp, callback_table)

    lu.assertEquals(ret, ncsi_def.NCSI_SUCCESS)
    lu.assertEquals(self.callback_call_count, 1)
    lu.assertEquals(self.last_callback_rsp, rsp)
    lu.assertTrue(self.ntohl_call_count > 0)
end

-- 测试读取OEM命令响应（华为厂商，无匹配回调）
function TestNCSIOemResponse:test_read_oem_command_rsp_huawei_no_callback()
    local rsp = self:create_mock_response(ncsi_def.CMD_COMPLETED, 0x0000,
        ncsi_oem_response.MANUFACTURE_ID_HUAWEI, 0x00, 0x04, 0x0B, 0x00)

    -- 不提供回调表
    local ret = ncsi_oem_response.read_oem_command_rsp(rsp, nil)

    lu.assertEquals(ret, ncsi_def.NCSI_SUCCESS)
    lu.assertEquals(self.callback_call_count, 0)
end

-- 测试读取OEM命令响应（华为厂商，空回调表）
function TestNCSIOemResponse:test_read_oem_command_rsp_huawei_empty_callback()
    local rsp = self:create_mock_response()
    local callback_table = {}  -- 空回调表

    local ret = ncsi_oem_response.read_oem_command_rsp(rsp, callback_table)

    lu.assertEquals(ret, ncsi_def.NCSI_SUCCESS)
    lu.assertEquals(self.callback_call_count, 0)
end

-- 测试读取OEM命令响应（华为厂商，不匹配的回调）
function TestNCSIOemResponse:test_read_oem_command_rsp_huawei_no_match_callback()
    local rsp = self:create_mock_response(ncsi_def.CMD_COMPLETED, 0x0000,
        ncsi_oem_response.MANUFACTURE_ID_HUAWEI, 0x00, 0x04, 0x0B, 0x00)

    local callback_table = {
        ncsi_oem_response.create_callback_entry(0x05, 0x0C, self:create_test_callback(0x05, 0x0C))
    }

    local ret = ncsi_oem_response.read_oem_command_rsp(rsp, callback_table)

    lu.assertEquals(ret, ncsi_def.NCSI_SUCCESS)
    lu.assertEquals(self.callback_call_count, 0)  -- 没有匹配的回调被调用
end

-- 测试读取OEM命令响应（华为厂商，命令失败）
function TestNCSIOemResponse:test_read_oem_command_rsp_huawei_command_failed()
    local rsp = self:create_mock_response(0x8001, 0x0005)  -- 命令失败
    local callback_table = {
        ncsi_oem_response.create_callback_entry(0x04, 0x0B, self:create_test_callback(0x04, 0x0B))
    }

    local ret = ncsi_oem_response.read_oem_command_rsp(rsp, callback_table)

    lu.assertEquals(ret, ncsi_def.NCSI_FAIL)
    lu.assertEquals(self.callback_call_count, 0)  -- 失败时不调用回调
    lu.assertTrue(self.respcode_parse_call_count > 0)
    lu.assertTrue(self.reasoncode_parse_call_count > 0)
end

-- 测试读取OEM命令响应（非华为厂商成功）
function TestNCSIOemResponse:test_read_oem_command_rsp_other_vendor_success()
    local rsp = self:create_mock_common_response(ncsi_def.CMD_COMPLETED, 0x0000, 0x12345678)  -- 非华为厂商ID

    local ret = ncsi_oem_response.read_oem_command_rsp(rsp, nil)

    lu.assertEquals(ret, ncsi_def.NCSI_SUCCESS)
    lu.assertEquals(self.callback_call_count, 0)  -- 非华为厂商不使用回调
end

-- 测试读取OEM命令响应（非华为厂商失败）
function TestNCSIOemResponse:test_read_oem_command_rsp_other_vendor_failed()
    local rsp = self:create_mock_common_response(0x8001, 0x0005, 0x12345678)  -- 非华为厂商ID，命令失败

    local ret = ncsi_oem_response.read_oem_command_rsp(rsp, nil)

    lu.assertEquals(ret, ncsi_def.NCSI_FAIL)
    lu.assertTrue(self.respcode_parse_call_count > 0)
    lu.assertTrue(self.reasoncode_parse_call_count > 0)
end

-- 测试读取OEM命令响应（rsp为nil）
function TestNCSIOemResponse:test_read_oem_command_rsp_nil_rsp()
    local ret = ncsi_oem_response.read_oem_command_rsp(nil, nil)

    lu.assertEquals(ret, ncsi_def.NCSI_FAIL)
end

-- 测试读取OEM命令响应（payload为nil）
function TestNCSIOemResponse:test_read_oem_command_rsp_nil_payload()
    local rsp = {
        packet_head = {
            payload_len_hi = 0,
            payload_len_lo = 0
        },
        payload = nil
    }

    local ret = ncsi_oem_response.read_oem_command_rsp(rsp, nil)

    lu.assertEquals(ret, ncsi_def.NCSI_FAIL)
end

-- 测试读取OEM命令响应（空payload）
function TestNCSIOemResponse:test_read_oem_command_rsp_empty_payload()
    local rsp = {
        packet_head = {
            payload_len_hi = 0,
            payload_len_lo = 0
        },
        payload = ""
    }

    local ret = ncsi_oem_response.read_oem_command_rsp(rsp, nil)

    -- 这个测试可能会失败，取决于unpack函数的行为
    lu.assertTrue(ret == ncsi_def.NCSI_SUCCESS or ret == ncsi_def.NCSI_FAIL)
end

-- 测试读取OEM命令响应（无效payload格式）
function TestNCSIOemResponse:test_read_oem_command_rsp_invalid_payload()
    local rsp = {
        packet_head = {
            payload_len_hi = 0,
            payload_len_lo = 4
        },
        payload = "abcd"  -- 无效的payload格式
    }

    -- 这个测试可能抛出异常或返回错误
    local success, ret = pcall(function()
        return ncsi_oem_response.read_oem_command_rsp(rsp, nil)
    end)

    if success then
        lu.assertTrue(ret == ncsi_def.NCSI_SUCCESS or ret == ncsi_def.NCSI_FAIL)
    else
        -- 如果抛出异常，也是预期的行为
        lu.assertTrue(true)
    end
end

-- 测试创建回调条目
function TestNCSIOemResponse:test_create_callback_entry()
    local test_func = function() end
    local entry = ncsi_oem_response.create_callback_entry(0x04, 0x0B, test_func)

    lu.assertEquals(entry.cmd_id, 0x04)
    lu.assertEquals(entry.sub_cmd, 0x0B)
    lu.assertEquals(entry.func, test_func)
end

-- 测试创建回调条目（nil函数）
function TestNCSIOemResponse:test_create_callback_entry_nil_func()
    local entry = ncsi_oem_response.create_callback_entry(0x04, 0x0B, nil)

    lu.assertEquals(entry.cmd_id, 0x04)
    lu.assertEquals(entry.sub_cmd, 0x0B)
    lu.assertNil(entry.func)
end

-- 测试创建回调条目（边界值）
function TestNCSIOemResponse:test_create_callback_entry_boundary_values()
    local test_func = function() end

    -- 测试边界值
    local test_values = {
        {0, 0},
        {255, 255},
        {-1, -1},
        {999, 999}
    }

    for _, values in ipairs(test_values) do
        local cmd_id, sub_cmd = values[1], values[2]
        local entry = ncsi_oem_response.create_callback_entry(cmd_id, sub_cmd, test_func)

        lu.assertEquals(entry.cmd_id, cmd_id)
        lu.assertEquals(entry.sub_cmd, sub_cmd)
        lu.assertEquals(entry.func, test_func)
    end
end

-- 测试不同的响应码值
function TestNCSIOemResponse:test_different_response_codes()
    local response_codes = {
        ncsi_def.CMD_COMPLETED,
        0x8001,  -- 通用错误
        0x8002,  -- 无效命令
        0x8003,  -- 无效参数
        0x8004,  -- 不支持的命令
        0xFFFF   -- 最大错误码
    }

    for _, rsp_code in ipairs(response_codes) do
        -- 每次测试前重置计数器
        self.respcode_parse_call_count = 0
        self.reasoncode_parse_call_count = 0

        -- 确保使用华为厂商ID
        local rsp = self:create_mock_response(rsp_code, 0x0000, ncsi_oem_response.MANUFACTURE_ID_HUAWEI)
        local ret = ncsi_oem_response.read_oem_command_rsp(rsp, nil)

        if rsp_code == ncsi_def.CMD_COMPLETED then
            lu.assertEquals(ret, ncsi_def.NCSI_SUCCESS)
        else
            lu.assertEquals(ret, ncsi_def.NCSI_FAIL)
            -- 确保错误处理函数被调用
            lu.assertTrue(self.respcode_parse_call_count > 0, "respcode_parse should be called for error response")
        end
    end
end

-- 测试不同的原因码值
function TestNCSIOemResponse:test_different_reason_codes()
    local reason_codes = {0x0000, 0x0001, 0x0005, 0x00FF, 0xFFFF}

    for _, reason_code in ipairs(reason_codes) do
        local rsp = self:create_mock_response(0x8001, reason_code)  -- 使用失败的响应码
        local ret = ncsi_oem_response.read_oem_command_rsp(rsp, nil)

        lu.assertEquals(ret, ncsi_def.NCSI_FAIL)
        lu.assertTrue(self.reasoncode_parse_call_count > 0)
    end
end

-- 测试不同的厂商ID
function TestNCSIOemResponse:test_different_manufacture_ids()
    local manufacture_ids = {
        ncsi_oem_response.MANUFACTURE_ID_HUAWEI,  -- 华为
        0x00000000,  -- 零
        0x12345678,  -- 其他厂商
        0xFFFFFFFF   -- 最大值
    }

    for _, manufacture_id in ipairs(manufacture_ids) do
        local rsp
        if manufacture_id == ncsi_oem_response.MANUFACTURE_ID_HUAWEI then
            rsp = self:create_mock_response(ncsi_def.CMD_COMPLETED, 0x0000, manufacture_id)
        else
            rsp = self:create_mock_common_response(ncsi_def.CMD_COMPLETED, 0x0000, manufacture_id)
        end

        -- 修改ntohl函数返回特定的厂商ID
        core.ntohl = function(value)
            self.ntohl_call_count = self.ntohl_call_count + 1
            return manufacture_id
        end

        local callback_table = {
            ncsi_oem_response.create_callback_entry(0x04, 0x0B, self:create_test_callback(0x04, 0x0B))
        }

        local ret = ncsi_oem_response.read_oem_command_rsp(rsp, callback_table)

        lu.assertEquals(ret, ncsi_def.NCSI_SUCCESS)

        if manufacture_id == ncsi_oem_response.MANUFACTURE_ID_HUAWEI then
            lu.assertTrue(self.callback_call_count > 0)
        end
    end
end

-- 测试不同的命令ID和子命令组合
function TestNCSIOemResponse:test_different_cmd_combinations()
    local cmd_combinations = {
        {0x04, 0x0B},
        {0x05, 0x0C},
        {0x00, 0x00},
        {0xFF, 0xFF}
    }

    for _, combo in ipairs(cmd_combinations) do
        local cmd_id, sub_cmd = combo[1], combo[2]
        local rsp = self:create_mock_response(ncsi_def.CMD_COMPLETED, 0x0000,
            ncsi_oem_response.MANUFACTURE_ID_HUAWEI, 0x00, cmd_id, sub_cmd, 0x00)

        local callback_called = false
        local callback_table = {
            ncsi_oem_response.create_callback_entry(cmd_id, sub_cmd, function(rsp)
                callback_called = true
            end)
        }

        local ret = ncsi_oem_response.read_oem_command_rsp(rsp, callback_table)

        lu.assertEquals(ret, ncsi_def.NCSI_SUCCESS)
        lu.assertTrue(callback_called)
    end
end

-- 测试payload长度不匹配
function TestNCSIOemResponse:test_payload_length_mismatch()
    -- 创建一个payload长度不足的响应
    local short_payload = string.pack(">HH", 0x0000, 0x0000)  -- 只有4字节
    local rsp = {
        packet_head = {
            payload_len_hi = 0,
            payload_len_lo = #short_payload
        },
        payload = short_payload
    }

    -- 这可能会导致unpack失败
    local success, _ = pcall(function()
        return ncsi_oem_response.read_oem_command_rsp(rsp, nil)
    end)

    -- 测试是否能正确处理payload长度不匹配的情况
    lu.assertTrue(success or not success)
end

-- 测试大量回调条目的性能
function TestNCSIOemResponse:test_large_callback_table_performance()
    -- 重置计数器
    self.callback_call_count = 0

    -- 使用一个不会与循环冲突的cmd_id和sub_cmd
    -- 0xFF = 255，超出了循环的范围(1-100)
    local test_cmd_id = 0xFF
    local test_sub_cmd = 0xFF

    local rsp = self:create_mock_response(ncsi_def.CMD_COMPLETED, 0x0000,
        ncsi_oem_response.MANUFACTURE_ID_HUAWEI, 0x00, test_cmd_id, test_sub_cmd, 0x00)

    -- 创建大量回调条目
    local callback_table = {}
    for i = 1, 100 do
        table.insert(callback_table,
            ncsi_oem_response.create_callback_entry(i, i, function(rsp)
            end))
    end

    -- 添加匹配的回调
    table.insert(callback_table,
        ncsi_oem_response.create_callback_entry(test_cmd_id, test_sub_cmd, function(rsp)
            self.callback_call_count = self.callback_call_count + 1
        end))

    local ret = ncsi_oem_response.read_oem_command_rsp(rsp, callback_table)

    lu.assertEquals(ret, ncsi_def.NCSI_SUCCESS)
    lu.assertTrue(self.callback_call_count > 0, "Matching callback should be called")  -- 确保匹配的回调被调用
    lu.assertEquals(self.callback_call_count, 1, "Only one callback should be called")  -- 确保只有一个回调被调用
end

-- 测试并发调用模拟
function TestNCSIOemResponse:test_concurrent_calls_simulation()
    -- 重置计数器
    self.callback_call_count = 0

    -- 模拟多次快速调用
    for i = 1, 10 do
        local cmd_id = (i % 5) + 1  -- 1-5的循环
        local sub_cmd = (i % 5) + 1

        local rsp = self:create_mock_response(ncsi_def.CMD_COMPLETED, 0x0000,
            ncsi_oem_response.MANUFACTURE_ID_HUAWEI, 0x00, cmd_id, sub_cmd, 0x00)

        local callback_table = {
            ncsi_oem_response.create_callback_entry(cmd_id, sub_cmd, function(rsp)
                self.callback_call_count = self.callback_call_count + 1
            end)
        }

        local ret = ncsi_oem_response.read_oem_command_rsp(rsp, callback_table)
        lu.assertEquals(ret, ncsi_def.NCSI_SUCCESS)
    end

    lu.assertTrue(self.callback_call_count > 0, "At least one callback should be called")  -- 确保至少有一个回调被调用
    lu.assertEquals(self.callback_call_count, 10)  -- 确保所有10个回调都被调用
end

-- 测试边界和极端值
function TestNCSIOemResponse:test_extreme_values()
    -- 测试极端的数值
    local extreme_values = {
        {math.huge, 0},      -- 无穷大
        {0, math.huge},      -- 无穷大
        {-1, -1},            -- 负数
        {2^31, 2^31}         -- 大整数
    }

    for _, values in ipairs(extreme_values) do
        local cmd_id, sub_cmd = values[1], values[2]
        if cmd_id == cmd_id and sub_cmd == sub_cmd then  -- 过滤NaN
            local success, _ = pcall(function()
                return ncsi_oem_response.create_callback_entry(cmd_id, sub_cmd, function() end)
            end)
            lu.assertTrue(success or not success)
        end
    end
end

-- 测试多个回调条目
function TestNCSIOemResponse:test_multiple_callback_entries()
    -- 重置计数器
    self.callback_call_count = 0

    local rsp = self:create_mock_response(ncsi_def.CMD_COMPLETED, 0x0000,
        ncsi_oem_response.MANUFACTURE_ID_HUAWEI, 0x00, 0x04, 0x0B, 0x00)

    local callback1_called = false
    local callback2_called = false

    local callback_table = {
        ncsi_oem_response.create_callback_entry(0x04, 0x0B, function(rsp)
            callback1_called = true
            self.callback_call_count = self.callback_call_count + 1
        end),
        ncsi_oem_response.create_callback_entry(0x05, 0x0C, function(rsp)
            callback2_called = true
            self.callback_call_count = self.callback_call_count + 1
        end)
    }

    local ret = ncsi_oem_response.read_oem_command_rsp(rsp, callback_table)

    lu.assertEquals(ret, ncsi_def.NCSI_SUCCESS)
    lu.assertTrue(callback1_called, "Matching callback should be called")  -- 匹配的回调应该被调用
    lu.assertFalse(callback2_called)  -- 不匹配的回调不应该被调用
    lu.assertTrue(self.callback_call_count > 0, "Callback count should be incremented")
end

-- 测试回调函数异常处理
function TestNCSIOemResponse:test_callback_function_exception()
    local rsp = self:create_mock_response()

    local callback_table = {
        ncsi_oem_response.create_callback_entry(0x04, 0x0B, function(rsp)
            error("Callback function error")
        end)
    }

    -- 测试回调函数抛出异常时的处理
    local success, _ = pcall(function()
        return ncsi_oem_response.read_oem_command_rsp(rsp, callback_table)
    end)

    -- 根据实现，可能会传播异常或捕获异常
    lu.assertTrue(success or not success)
end

function TestNCSIOemResponse:test_empty_payload()
    -- 模拟响应包
    local function create_mock_response()
        local rsp = {packet_head = {}, payload = ''}
        -- 创建响应payload
        local rsp_payload = '\0'

        rsp.payload = rsp_payload
        return rsp
    end
    local empty_rsp = create_mock_response()

    ncsi_oem_response.read_oem_command_rsp(empty_rsp, _)
end

-- 运行所有测试
return TestNCSIOemResponse