-- Copyright (c) 2024 Huawei Technologies Co., Ltd.
-- openUBMC is licensed under Mulan PSL v2.
-- You can use this software according to the terms and conditions of the Mulan PSL v2.
-- You may obtain a copy of Mulan PSL v2 at: http://license.coscl.org.cn/MulanPSL2
-- THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
-- EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
-- MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
-- See the Mulan PSL v2 for more details.

local lu = require 'luaunit'
local rpc_service_subhealth = require 'rpc_services.rpc_service_subhealth'
local drive_collection = require 'drive.drive_collection'
local common_def = require 'common_def'
local mdb = require 'mc.mdb'
local json = require 'cjson'
local client = require 'storage.client'
local context = require 'mc.context'
local task_service = require 'task_service'

local nvme_serive = require 'nvme.nvme_mi_protocol.nvme_mi_admin_command'
local command_define = require 'nvme.nvme_mi_protocol.nvme_mi_def'

local drive_obj = require 'drive.drive_object'
local sml = require 'sml'

TestIOLatencyNvme = {}

-- 模拟 drive_collection 的行为
local function mock_drive_collection(drives)
    drive_collection.get_instance = function()
        return {
            get_all_drives = function()
                return drives
            end,
            get_drive = function(_, drive_name)
                for _, drive in pairs(drives) do
                    if drive.Name == drive_name then
                        return drive
                    end
                end
                return nil
            end,
            get_nvme_by_drive_name = function(_, drive_obj)
                return {
                    Slot = tonumber(string.match("Disk1", "Disk(.*)"))
                }
            end
        }
    end
end

local pre_fun = task_service.create
local pre_fun1 = drive_collection.get_instance().get_drive
local pre_fun2 = drive_collection.get_instance().get_nvme_by_drive_name
local pre_fun3 = drive_collection.get_instance().get_all_drives
local pre_fun4 = drive_collection.get_instance

-- 测试 collect_drive_diagnose_data 正常流程
function TestIOLatencyNvme:test_collect_drive_diagnose_data_success()
    local ctx = { get_initiator = function() return "test_user" end }
    local obj_path = "/mock/path"
    local data_type = { common_def.SUBHEALTH_STR.VENDOR_IO_LATENCY }

    -- 模拟 mdb.get_object 返回值
    local mock_drive = { Name = "Disk1", Presence = 1, Protocol = 6, identify_pd = false,
        collect_log_from_disk = function()
            return true
        end
    }
    mdb.get_object = function(_, path, _)
        lu.assertEquals(path, obj_path)
        return mock_drive
    end

    -- 模拟 drive_collection
    mock_drive_collection({mock_drive})
    task_service.create = function()
        return 35
    end
    local task_id = rpc_service_subhealth.get_instance():collect_drive_diagnose_data(obj_path, ctx, data_type)
    lu.assertNotNil(task_id)
    task_service.create = pre_fun
    drive_collection.get_instance().get_drive = pre_fun1
    drive_collection.get_instance().get_nvme_by_drive_name = pre_fun2
    drive_collection.get_instance().get_all_drives = pre_fun3
    drive_collection.get_instance = pre_fun4
end

-- 测试 get_drive_diagnose_data 正常流程
function TestIOLatencyNvme:test_get_drive_diagnose_data_success()
    local ctx = { get_initiator = function() return "test_user" end }
    local obj_path = "/mock/path"
    local data_type = { common_def.SUBHEALTH_STR.VENDOR_IO_LATENCY }

    -- 模拟 mdb.get_object 返回值
    local mock_drive = { Name = "Drive1", Presence = 1, io_latency = 123 }
    mdb.get_object = function(_, path, _)
        lu.assertEquals(path, obj_path)
        return mock_drive
    end

    -- 模拟 drive_collection
    mock_drive_collection({mock_drive})

    local result = rpc_service_subhealth.get_instance():get_drive_diagnose_data(obj_path, ctx, data_type)
    lu.assertNotNil(result)
    lu.assertEquals(result['VendorDefinedIOLatency'], 123)
    task_service.create = pre_fun
    drive_collection.get_instance().get_drive = pre_fun1
    drive_collection.get_instance().get_nvme_by_drive_name = pre_fun2
    drive_collection.get_instance().get_all_drives = pre_fun3
    drive_collection.get_instance = pre_fun4
end

-- 测试 get_drive_diagnose_data 无效 data_type
function TestIOLatencyNvme:test_get_drive_diagnose_data_invalid_data_type()
    local ctx = { get_initiator = function() return "test_user" end }
    local obj_path = "/mock/path"
    local data_type = nil
    local ok = pcall(function()
        rpc_service_subhealth.get_instance():get_drive_diagnose_data(obj_path, ctx, data_type)
    end)
    lu.assertEquals(ok, false)
end

-- 测试 collect_drive_diagnose_data 无效 data_type
function TestIOLatencyNvme:test_collect_drive_diagnose_data_invalid_data_type()
    local ctx = { get_initiator = function() return "test_user" end }
    local obj_path = "/mock/path"
    local data_type = nil
    local ok = pcall(function()
        rpc_service_subhealth.get_instance():collect_drive_diagnose_data(obj_path, ctx, data_type)
    end)
    lu.assertEquals(ok, false)
end

-- 测试 check_support_io_latency_log_page
function TestIOLatencyNvme:test_check_support_io_latency_log_page()
    local mock_data = command_define.io_latency_support_rsp_data:pack({
        status = 1,
        reserve = 2,
        dword0 = 3,
        dword1 = 4,
        dword3 = 5,
        support_flag = 8,
        rest = "7"
    })

    local mock_nvme = {
        nvme_mi_mctp_obj = {
            queue = function(func)
                func()
            end,
            nvme_mi_obj = {
                RawRequest = function()
                    return {
                        value = function() return mock_data end
                    }
                end
            }
        }
    }

    local ok, ret = pcall(function()
        return nvme_serive.check_support_io_latency_log_page(mock_nvme)
    end)
    lu.assertIsTrue(ok)
    lu.assertEquals(ret, 8)
end

-- 测试 get_io_latency_log_version
function TestIOLatencyNvme:test_get_io_latency_log_version()
    local mock_data = command_define.io_latency_version_data:pack({
        status = 1,
        reserve = 2,
        dword0 = 3,
        dword1 = 4,
        dword3 = 5,
        version = 6,
        rest = "7"
    })

    local mock_nvme = {
        nvme_mi_mctp_obj = {
            queue = function(func)
                func()
            end,
            nvme_mi_obj = {
                RawRequest = function()
                    return {
                        value = function() return mock_data end
                    }
                end
            }
        }
    }

    local ok, ret = pcall(function()
        return nvme_serive.get_io_latency_log_version(mock_nvme)
    end)
    lu.assertIsTrue(ok)
    lu.assertEquals(ret, 6)
end

-- 测试 get_io_latency_data
function TestIOLatencyNvme:test_get_io_latency_data()
    local mock_data = command_define.io_latency_raw_data:pack({
        status = 1,
        reserve = 2,
        dword0 = 3,
        dword1 = 4,
        dword3 = 5,
        rest = "7"
    })

    local mock_nvme = {
        nvme_mi_mctp_obj = {
            queue = function(func)
                func()
            end,
            nvme_mi_obj = {
                RawRequest = function()
                    return {
                        value = function() return mock_data end
                    }
                end
            }
        }
    }

    local ok, ret = pcall(function()
        return nvme_serive.get_io_latency_data(mock_nvme)
    end)
    lu.assertIsTrue(ok)
    lu.assertEquals(ret, "7")
end

-- 测试异常分支流程
function TestIOLatencyNvme:test_get_drive_mapped_nvme()
    local mock_drive = { Name = "Disk1"}
    local pre_data = drive_collection.get_instance().nvme_list

    local mock_nvme_list = {}
    table.insert( mock_nvme_list, {Slot = 1, Test = "OK"})
    drive_collection.get_instance().nvme_list = mock_nvme_list
    local ret = drive_collection.get_instance():get_nvme_by_drive_name(mock_drive)
    lu.assertNotNil(ret.Test, "OK")

    mock_nvme_list[1] = {Slot = 2, Test = "NG"}
    ret = drive_collection.get_instance():get_nvme_by_drive_name(mock_drive)
    lu.assertNil(ret)
    drive_collection.get_instance().nvme_list = pre_data
end

-- 测试 collect_drive_from_raid_data 正常流程
function TestIOLatencyNvme:test_collect_drive_from_raid_data_success()
    local ctx = { get_initiator = function() return "test_user" end }
    local obj_path = "/mock/path"
    local data_type = { common_def.SUBHEALTH_STR.VENDOR_IO_LATENCY }

    -- 模拟 mdb.get_object 返回值
    local mock_drive = { Name = "Disk1", Presence = 1, Protocol = 2, RefControllerId = 0,
        Model = 'HWE72P453T8L007N', Revision = '1069', io_latency = '',
        collect_log_from_raid = drive_obj.collect_log_from_raid
    }
    mdb.get_object = function(_, path, _)
        lu.assertEquals(path, obj_path)
        return mock_drive
    end

    -- 模拟 drive_collection
    mock_drive_collection({mock_drive})
    task_service.create = function()
        return 35
    end
    local pre_read = sml.pd_get_read_io_latency_info
    local pre_write = sml.pd_get_write_io_latency_info
    sml.pd_get_read_io_latency_info = function()
        return "Test"
    end
    sml.pd_get_write_io_latency_info = function()
        return "OK"
    end
    local task_id = rpc_service_subhealth.get_instance():collect_drive_diagnose_data(obj_path, ctx, data_type)
    lu.assertNotNil(task_id)
    lu.assertEquals(mock_drive.io_latency, "FFFFFFFFFFFFFFFF546573744F4B")
    task_service.create = pre_fun
    drive_collection.get_instance().get_drive = pre_fun1
    drive_collection.get_instance().get_nvme_by_drive_name = pre_fun2
    drive_collection.get_instance().get_all_drives = pre_fun3
    drive_collection.get_instance = pre_fun4
    sml.pd_get_read_io_latency_info = pre_read
    sml.pd_get_write_io_latency_info = pre_write
end

-- 测试 collect_drive_from_raid_data 正常流程
function TestIOLatencyNvme:test_collect_drive_read_data_fail()
    local ctx = { get_initiator = function() return "test_user" end }
    local obj_path = "/mock/path"
    local data_type = { common_def.SUBHEALTH_STR.VENDOR_IO_LATENCY }

    -- 模拟 mdb.get_object 返回值
    local mock_drive = { Name = "Disk1", Presence = 1, Protocol = 2, RefControllerId = 0,
        Model = 'HWE72P453T8L007N', Revision = '1069', io_latency = '',
        collect_log_from_raid = drive_obj.collect_log_from_raid
    }
    mdb.get_object = function(_, path, _)
        lu.assertEquals(path, obj_path)
        return mock_drive
    end

    -- 模拟 drive_collection
    mock_drive_collection({mock_drive})
    task_service.create = function()
        return 35
    end
    local pre_read = sml.pd_get_read_io_latency_info
    local pre_write = sml.pd_get_write_io_latency_info
    sml.pd_get_read_io_latency_info = function()
        error("mock fail")
    end
    local task_id = rpc_service_subhealth.get_instance():collect_drive_diagnose_data(obj_path, ctx, data_type)
    lu.assertNotNil(task_id)
    lu.assertEquals(mock_drive.io_latency, '')
    task_service.create = pre_fun
    drive_collection.get_instance().get_drive = pre_fun1
    drive_collection.get_instance().get_nvme_by_drive_name = pre_fun2
    drive_collection.get_instance().get_all_drives = pre_fun3
    drive_collection.get_instance = pre_fun4
    sml.pd_get_read_io_latency_info = pre_read
    sml.pd_get_write_io_latency_info = pre_write
end

function TestIOLatencyNvme:test_collect_drive_write_data_fail()
    local ctx = { get_initiator = function() return "test_user" end }
    local obj_path = "/mock/path"
    local data_type = { common_def.SUBHEALTH_STR.VENDOR_IO_LATENCY }

    -- 模拟 mdb.get_object 返回值
    local mock_drive = { Name = "Disk1", Presence = 1, Protocol = 2, RefControllerId = 0,
        Model = 'HWE72P453T8L007N', Revision = '1069', io_latency = '',
        collect_log_from_raid = drive_obj.collect_log_from_raid
    }
    mdb.get_object = function(_, path, _)
        lu.assertEquals(path, obj_path)
        return mock_drive
    end

    -- 模拟 drive_collection
    mock_drive_collection({mock_drive})
    task_service.create = function()
        return 35
    end
    local pre_read = sml.pd_get_read_io_latency_info
    local pre_write = sml.pd_get_write_io_latency_info
    sml.pd_get_read_io_latency_info = function()
        return "Test"
    end
    sml.pd_get_write_io_latency_info = function()
        error("mock fail")
    end
    local task_id = rpc_service_subhealth.get_instance():collect_drive_diagnose_data(obj_path, ctx, data_type)
    lu.assertNotNil(task_id)
    lu.assertEquals(mock_drive.io_latency, '')
    task_service.create = pre_fun
    drive_collection.get_instance().get_drive = pre_fun1
    drive_collection.get_instance().get_nvme_by_drive_name = pre_fun2
    drive_collection.get_instance().get_all_drives = pre_fun3
    drive_collection.get_instance = pre_fun4
    sml.pd_get_read_io_latency_info = pre_read
    sml.pd_get_write_io_latency_info = pre_write
end

-- 测试 collect_drive_from_bma 正常流程
function TestIOLatencyNvme:test_collect_drive_from_bma_success()
    local ctx = { get_initiator = function() return "test_user" end }
    local obj_path = "/mock/path"
    local data_type = { common_def.SUBHEALTH_STR.VENDOR_IO_LATENCY }

    -- 模拟 mdb.get_object 返回值
    local mock_drive = { Name = "Disk1", Presence = 1, RefControllerId = 255, Protocol = 2, identify_pd = false, MediaType = 1,
        Model = 'HWE72P453T8L007N', Revision = '1069', io_latency = '', bma_id = 'PCH_0000:38:05.0_ata1',
        collect_log_from_bma = drive_obj.collect_log_from_bma
    }
    mdb.get_object = function(_, path, _)
        lu.assertEquals(path, obj_path)
        return mock_drive
    end

    -- 模拟 drive_collection
    mock_drive_collection({mock_drive})
    task_service.create = function()
        return 35
    end

    context.get_context = function()
        return { get_initiator = function() return "test_user" end }
    end

    client.PSmsSmsForwardRequest = function(client, ctx, path_params, request_json, retry_times)
        local mock_response = {
            Response = json.encode({
                ResponseStatusCode = 200,
                ResponseBody = json.encode({
                    IOLatency = {
                        ReadRawData = "54657374",
                        WriteRawData = "4F4B"
                    }
                })
            })
        }
        return true, mock_response
    end

    local task_id = rpc_service_subhealth.get_instance():collect_drive_diagnose_data(obj_path, ctx, data_type)
    lu.assertNotNil(task_id)
    lu.assertEquals(mock_drive.io_latency, "FFFFFFFFFFFFFFFF546573744F4B")
    task_service.create = pre_fun
    drive_collection.get_instance().get_drive = pre_fun1
end

-- 测试 collect_drive_from_bma 异常流程
function TestIOLatencyNvme:test_collect_drive_from_bma_fail()
    local ctx = { get_initiator = function() return "test_user" end }
    local obj_path = "/mock/path"
    local data_type = { common_def.SUBHEALTH_STR.VENDOR_IO_LATENCY }

    -- 模拟 mdb.get_object 返回值
    local mock_drive = { Name = "Disk1", Presence = 1, RefControllerId = 255, Protocol = 2, identify_pd = false, MediaType = 1,
        Model = 'HWE72P453T8L007N', Revision = '1069', io_latency = '', bma_id = 'PCH_0000:38:05.0_ata1',
        collect_log_from_bma = drive_obj.collect_log_from_bma
    }
    mdb.get_object = function(_, path, _)
        lu.assertEquals(path, obj_path)
        return mock_drive
    end

    -- 模拟 drive_collection
    mock_drive_collection({mock_drive})
    task_service.create = function()
        return 35
    end

    context.get_context = function()
        return { get_initiator = function() return "test_user" end }
    end

    client.PSmsSmsForwardRequest = function(client, ctx, path_params, request_json, retry_times)
        local mock_response = {
            Response = json.encode({
                ResponseStatusCode = 200,
                ResponseBody = json.encode({
                    IOLatency = nil
                })
            })
        }
        return true, mock_response
    end

    local task_id = rpc_service_subhealth.get_instance():collect_drive_diagnose_data(obj_path, ctx, data_type)
    lu.assertNotNil(task_id)
    lu.assertEquals(mock_drive.io_latency, "")

    client.PSmsSmsForwardRequest = function(client, ctx, path_params, request_json, retry_times)
        local mock_response = {
            Response = json.encode({
                ResponseStatusCode = 404,
                ResponseBody = {
                    IOLatency = {}
                }
            })
        }
        return true, mock_response
    end

    local task_id = rpc_service_subhealth.get_instance():collect_drive_diagnose_data(obj_path, ctx, data_type)
    lu.assertNotNil(task_id)
    lu.assertEquals(mock_drive.io_latency, "")

    client.PSmsSmsForwardRequest = function(client, ctx, path_params, request_json, retry_times)
        local mock_response = {
            Response = json.encode({
                ResponseStatusCode = 404,
                ResponseBody = nil
            })
        }
        return true, mock_response
    end

    local task_id = rpc_service_subhealth.get_instance():collect_drive_diagnose_data(obj_path, ctx, data_type)
    lu.assertNotNil(task_id)
    lu.assertEquals(mock_drive.io_latency, "")

    client.PSmsSmsForwardRequest = function(client, ctx, path_params, request_json, retry_times)
        local mock_response = {
            Response = {}
        }
        return false, mock_response
    end

    local task_id = rpc_service_subhealth.get_instance():collect_drive_diagnose_data(obj_path, ctx, data_type)
    lu.assertNotNil(task_id)
    lu.assertEquals(mock_drive.io_latency, "")

    mock_drive.bma_id = ''
    local task_id = rpc_service_subhealth.get_instance():collect_drive_diagnose_data(obj_path, ctx, data_type)
    lu.assertNotNil(task_id)
    lu.assertEquals(mock_drive.io_latency, "")
    mock_drive.bma_id = 'PCH_0000:38:05.0_ata1'

    mock_drive.Model = 'HWE52P453T8L007N'
    local task_id = rpc_service_subhealth.get_instance():collect_drive_diagnose_data(obj_path, ctx, data_type)
    lu.assertNotNil(task_id)
    lu.assertEquals(mock_drive.io_latency, "")
    mock_drive.Model = 'HWE72P453T8L007N'

    mock_drive.Model = 'SAMSUNG MZ7LH5T9HMJA-00005'
    local task_id = rpc_service_subhealth.get_instance():collect_drive_diagnose_data(obj_path, ctx, data_type)
    lu.assertNotNil(task_id)
    lu.assertEquals(mock_drive.io_latency, "")
    mock_drive.Model = 'HWE72P453T8L007N'

    mock_drive.Revision = '1059'
    local task_id = rpc_service_subhealth.get_instance():collect_drive_diagnose_data(obj_path, ctx, data_type)
    lu.assertNotNil(task_id)
    lu.assertEquals(mock_drive.io_latency, "")
    mock_drive.Revision = '1069'

    mock_drive.Protocol = 3
    local task_id = rpc_service_subhealth.get_instance():collect_drive_diagnose_data(obj_path, ctx, data_type)
    lu.assertNotNil(task_id)
    lu.assertEquals(mock_drive.io_latency, "")
    mock_drive.Protocol = 2

    task_service.create = pre_fun
    drive_collection.get_instance().get_drive = pre_fun1
end


-- 额外测试：覆盖 drive_object:collect_log_from_disk 的异常分支
function TestIOLatencyNvme:test_collect_log_from_disk_no_mctp()
    local mock_drive = { Name = "DiskX", io_latency = '' }
    local mock_nvme = {} -- 不包含 nvme_mi_mctp_obj

    local ok, ret = pcall(function()
        return drive_obj.collect_log_from_disk(mock_drive, mock_nvme)
    end)
    lu.assertIsTrue(ok)
    lu.assertEquals(ret, false)
end

function TestIOLatencyNvme:test_collect_log_from_disk_support_flag_fail()
    local mock_drive = { Name = "DiskX", io_latency = '' }
    local mock_nvme = {
        nvme_mi_mctp_obj = { nvme_mi_obj = {} },
    }

    local pre_check = nvme_serive.check_support_io_latency_log_page
    nvme_serive.check_support_io_latency_log_page = function(_)
        return 0 -- 不支持
    end

    local ok, ret = pcall(function()
        return drive_obj.collect_log_from_disk(mock_drive, mock_nvme)
    end)
    lu.assertIsTrue(ok)
    lu.assertEquals(ret, false)

    nvme_serive.check_support_io_latency_log_page = pre_check
end

function TestIOLatencyNvme:test_collect_log_from_disk_version_invalid()
    local mock_drive = { Name = "DiskX", io_latency = '' }
    local mock_nvme = {
        nvme_mi_mctp_obj = { nvme_mi_obj = {} },
    }

    local pre_check = nvme_serive.check_support_io_latency_log_page
    local pre_version = nvme_serive.get_io_latency_log_version
    nvme_serive.check_support_io_latency_log_page = function(_) return 1 end
    nvme_serive.get_io_latency_log_version = function(_) return 0xffffffff end

    local ok, ret = pcall(function()
        return drive_obj.collect_log_from_disk(mock_drive, mock_nvme)
    end)
    lu.assertIsTrue(ok)
    lu.assertEquals(ret, false)

    nvme_serive.check_support_io_latency_log_page = pre_check
    nvme_serive.get_io_latency_log_version = pre_version
end

-- 覆盖 to_hex 接受非字符串时返回 nil 的分支（通过 collect_log_from_disk 间接触发）
function TestIOLatencyNvme:test_collect_log_from_disk_raw_data_nonstring()
    local mock_drive = { Name = "DiskX", io_latency = '' }
    local mock_nvme = {
        nvme_mi_mctp_obj = { nvme_mi_obj = {} },
    }

    local pre_check = nvme_serive.check_support_io_latency_log_page
    local pre_version = nvme_serive.get_io_latency_log_version
    local pre_data = nvme_serive.get_io_latency_data
    nvme_serive.check_support_io_latency_log_page = function(_) return 1 end
    nvme_serive.get_io_latency_log_version = function(_) return 1 end
    nvme_serive.get_io_latency_data = function(_, _)
        return 12345 -- 非字符串
    end

    local ok, _ = pcall(function()
        return drive_obj.collect_log_from_disk(mock_drive, mock_nvme)
    end)
    lu.assertIsTrue(ok)
    -- to_hex 对非字符串返回 nil，因此 io_latency 应被设置为 nil
    lu.assertEquals(mock_drive.io_latency, nil)

    nvme_serive.check_support_io_latency_log_page = pre_check
    nvme_serive.get_io_latency_log_version = pre_version
    nvme_serive.get_io_latency_data = pre_data
end


return TestIOLatencyNvme