-- Copyright (c) 2025 Huawei Technologies Co., Ltd.
-- openUBMC is licensed under Mulan PSL v2.
-- THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
-- EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
-- MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
-- See the Mulan PSL v2 for more details.

local lu = require('luaunit')
local comm_fun = require 'device.class.nic_mgmt.comm_fun'
local comm_defs = require 'device.class.nic_mgmt.comm_defs'
local ctx = require 'mc.context'
local mdb = require 'mc.mdb'
local skynet = require 'skynet'
local log = require 'mc.logging'

local origs = {}
local function save_orig(tbl, key)
    origs[tbl] = origs[tbl] or {}
    origs[tbl][key] = tbl[key]
end
local function restore_all()
    for tbl, keys in pairs(origs) do
        for k, v in pairs(keys) do tbl[k] = v end
    end
end

TestCommFun = {}

function TestCommFun:setUp()
    -- mock comm_defs
    save_orig(comm_defs, 'MACA_SERVICE')
    save_orig(comm_defs, 'MDB_PATH')
    save_orig(comm_defs, 'MDB_INTERFACE')
    save_orig(comm_defs, 'COMMON_PROPERTIES_INTERFACE')
    comm_defs.MACA_SERVICE = 'svc'
    comm_defs.MDB_PATH = '/mdb'
    comm_defs.MDB_INTERFACE = 'mdb_iface'
    comm_defs.COMMON_PROPERTIES_INTERFACE = 'iface_common'
    -- mock ctx
    save_orig(ctx, 'get_context_or_default')
    ctx.get_context_or_default = function() return {} end
    -- mock mdb
    save_orig(mdb, 'register_interface')
    save_orig(mdb, 'get_object')
    mdb.register_interface = function() end
    mdb.get_object = function()
        return {
            get_property = function(self, prop)
                if prop == 'ParentPath' then return 0, '/parent/path' end
                if prop == 'ObjectName' then return 0, 'PCIeNicCard1' end
                return 0, 'val'
            end
        }
    end
    -- mock skynet.queue
    save_orig(skynet, 'queue')
    skynet.queue = function() return function(f) f() end end
    -- mock log
    save_orig(log, 'notice')
    save_orig(log, 'error')
    log.notice = function() end
    log.error = function() end
end

function TestCommFun:tearDown()
    restore_all()
end

function TestCommFun:test_get_all_device_paths()
    local called = {}
    local bus = {
        call = function(self, ...)
            called.args = {...}
            return { '/dev/path1', '/dev/path2' }
        end
    }
    local paths = comm_fun.get_all_device_paths(bus, '/pattern', 1, 'iface')
    lu.assertEquals(paths, { '/dev/path1', '/dev/path2' })
    lu.assertEquals(called.args[2], comm_defs.MDB_PATH)
end

function TestCommFun:test_get_parent_path_normal()
    local bus = {
        call = function() return { { 'iface_common' } } end
    }
    local ret = comm_fun.get_parent_path(bus, '/dev/path')
    lu.assertEquals(ret, '/parent/path')
end

function TestCommFun:test_get_parent_path_no_device_obj()
    local bus = {
        call = function() return {} end
    }
    local log_called = false
    log.notice = function() log_called = true end
    local ret = comm_fun.get_parent_path(bus, '/dev/path')
    lu.assertNil(ret)
    lu.assertTrue(log_called)
end

function TestCommFun:test_get_parent_path_get_property_fail()
    local bus = {
        call = function() return { { 'iface_common' } } end
    }
    mdb.get_object = function() return { get_property = function() return 1, nil end } end
    local log_called = false
    log.error = function() log_called = true end
    local ret = comm_fun.get_parent_path(bus, '/dev/path')
    lu.assertNil(ret)
    lu.assertTrue(log_called)
end

function TestCommFun:test_get_object_name_by_device_path_normal()
    local bus = {
        call = function() return { { 'iface_common' } } end
    }
    local ret = comm_fun.get_object_name_by_device_path(bus, '/dev/path')
    lu.assertEquals(ret, 'PCIeNicCard1')
end

function TestCommFun:test_get_object_name_by_device_path_replace()
    local bus = {
        call = function() return { { 'iface_common' } } end
    }
    local ret = comm_fun.get_object_name_by_device_path(bus, '/dev/path',
        { from = 'PCIeNicCard', to = 'NetworkAdapter' })
    lu.assertEquals(ret, 'NetworkAdapter1')
end

function TestCommFun:test_get_object_name_by_device_path_no_device_obj()
    local bus = {
        call = function() return {} end
    }
    local log_called = false
    log.notice = function() log_called = true end
    local ret = comm_fun.get_object_name_by_device_path(bus, '/dev/path')
    lu.assertNil(ret)
    lu.assertTrue(log_called)
end

function TestCommFun:test_get_object_name_by_device_path_get_property_fail()
    local bus = {
        call = function() return { { 'iface_common' } } end
    }
    mdb.get_object = function() return { get_property = function() return 1, nil end } end
    local log_called = false
    log.notice = function() log_called = true end
    local ret = comm_fun.get_object_name_by_device_path(bus, '/dev/path')
    lu.assertNil(ret)
    lu.assertTrue(log_called)
end

function TestCommFun:test_get_object_identifier_by_device_path_normal()
    local bus = {
        call = function() return { { 'iface_common' } } end
    }
    mdb.get_object = function()
        return {
            get_property = function(self, prop)
                if prop == 'ObjectIdentifier' then
                    return 0, { 1, 'mgr', 'chassis', 'pos' }
                end
                return 0, 'val'
            end
        }
    end
    local ret = comm_fun.get_object_identifier_by_device_path(bus, '/dev/path')
    lu.assertEquals(ret, { 1, 'mgr', 'chassis', 'pos' })
end

function TestCommFun:test_get_object_identifier_by_device_path_no_device_obj()
    local bus = {
        call = function() return {} end
    }
    local log_called = false
    log.notice = function() log_called = true end
    local ret = comm_fun.get_object_identifier_by_device_path(bus, '/dev/path')
    lu.assertNil(ret)
    lu.assertTrue(log_called)
end

function TestCommFun:test_get_object_identifier_by_device_path_get_property_fail()
    local bus = {
        call = function() return { { 'iface_common' } } end
    }
    mdb.get_object = function()
        return {
            get_property = function(self, prop)
                if prop == 'ObjectIdentifier' then
                    return 1, nil
                end
                return 0, 'val'
            end
        }
    end
    local log_called = false
    log.notice = function() log_called = true end
    local ret = comm_fun.get_object_identifier_by_device_path(bus, '/dev/path')
    lu.assertNil(ret)
    lu.assertTrue(log_called)
end

function TestCommFun:test_get_position_by_object_name_normal()
    local ret = comm_fun.get_position_by_object_name('NetworkAdapter_1_2')
    lu.assertEquals(ret, '2')
end

function TestCommFun:test_get_position_by_object_name_single_segment()
    local ret = comm_fun.get_position_by_object_name('NetworkAdapter')
    lu.assertEquals(ret, 'NetworkAdapter')
end

function TestCommFun:test_get_position_by_object_name_with_trailing_underscore()
    -- string.gmatch 使用 \"[^_]+\"，末尾下划线不会生成空段，结果仍为最后一个非空段
    local ret = comm_fun.get_position_by_object_name('NetworkAdapter_3_')
    lu.assertEquals(ret, '3')
end

function TestCommFun:test_set_interface_add_signal()
    local fake_match_rule = {
        signal = function() return { with_path_namespace = function() return 'sig' end } end
    }
    local fake_interface_add = { name = 'n', interface = 'i' }
    package.loaded['sd_bus.org_freedesktop_dbus'] = {
        ObjMgrInterfacesAdded = fake_interface_add,
        MatchRule = fake_match_rule
    }
    local bus = {
        match = function(self, sig, cb)
            local msg = {
                read = function() return '/dev/path', { iface = true } end
            }
            cb(msg)
            return 123
        end
    }
    local sig_slot = {}
    local cb_called = false
    comm_fun.set_interface_add_signal(bus, sig_slot, '/pattern', 'iface', function(path)
        cb_called = path
    end)
    lu.assertEquals(sig_slot[1], 123)
    lu.assertEquals(cb_called, '/dev/path')
end

function TestCommFun:test_set_interface_del_signal()
    local bus = {
        match = function(self, sig, cb)
            local msg = {
                read = function() return '/dev/path', { 'iface' } end
            }
            cb(msg)
            return 456
        end
    }
    local sig_slot = {}
    local cb_called = nil
    comm_fun.set_interface_del_signal(bus, sig_slot, '/pattern', 'iface', function(path)
        cb_called = path
    end)
    lu.assertEquals(sig_slot[1], 456)
    lu.assertEquals(cb_called, '/dev/path')
end