-- Copyright (c) 2025 Huawei Technologies Co., Ltd.
-- openUBMC is licensed under Mulan PSL v2.
-- THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
-- EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
-- MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
-- See the Mulan PSL v2 for more details.

local lu = require('luaunit')
local om_mgmt = require 'device.class.nic_mgmt.om_mgmt'
local comm_fun = require 'device.class.nic_mgmt.comm_fun'
local comm_defs = require 'device.class.nic_mgmt.comm_defs'
local ctx = require 'mc.context'
local mdb = require 'mc.mdb'
local c_factory = require 'mc.orm.factory'
local c_optical_module = require 'device.class.optical_module'
local port_mgmt = require 'device.class.nic_mgmt.port_mgmt'
local om_init = require 'device.class.nic_mgmt.om.om_init'
local log = require 'mc.logging'

local origs = {}
local function save_orig(tbl, key)
    origs[tbl] = origs[tbl] or {}
    origs[tbl][key] = tbl[key]
end
local function restore_all()
    for tbl, keys in pairs(origs) do
        for k, v in pairs(keys) do tbl[k] = v end
    end
end

TestOmMgmt = {}

function TestOmMgmt:setUp()
    -- mock comm_defs
    save_orig(comm_defs, 'CARD_DEVICE_PATH_PATTERN')
    save_orig(comm_defs, 'OPTICAL_MODULE_DEVICE_INTERFACE')
    save_orig(comm_defs, 'MACA_SERVICE')
    save_orig(comm_defs, 'MDB_PATH')
    save_orig(comm_defs, 'MDB_INTERFACE')
    comm_defs.CARD_DEVICE_PATH_PATTERN = '/mock/pattern'
    comm_defs.OPTICAL_MODULE_DEVICE_INTERFACE = 'iface_om'
    comm_defs.MACA_SERVICE = 'svc'
    comm_defs.MDB_PATH = '/mdb'
    comm_defs.MDB_INTERFACE = 'mdb_iface'
    -- mock ctx
    save_orig(ctx, 'get_context_or_default')
    ctx.get_context_or_default = function() return {} end
    -- mock mdb
    save_orig(mdb, 'register_interface')
    save_orig(mdb, 'get_object')
    mdb.register_interface = function() end
    mdb.get_object = function()
        return {
            get_property = function(self, prop)
                return 0, 'val'
            end
        }
    end
    -- mock c_optical_module
    save_orig(c_optical_module, 'create_mdb_object')
    save_orig(c_optical_module, 'collection')
    c_optical_module.create_mdb_object = function(tbl) return tbl end
    c_optical_module.collection = {
        find = function(self, q)
            return { children = {}, PortID = q and q.PortID or 123,
                NetworkAdapterId = q and q.NetworkAdapterId or 'nid',
                parent = { ObjectName = 'obj' }, ObjectName = 'portobj' }
        end
    }
    -- mock c_factory
    save_orig(c_factory, 'get_instance')
    c_factory.get_instance = function() return { create_object = function() end } end
    -- mock port_mgmt.get_instance
    save_orig(port_mgmt, 'get_instance')
    port_mgmt.get_instance = function() return {
        get_orm_obj_by_device_path = function(_, path)
            if path == '/parent/path' then return { PortID = 123, NetworkAdapterId = 'nid',
                parent = { ObjectName = 'obj' }, ObjectName = 'portobj', FirmwareVersion = 'fw', children = {} } end
        end,
        register_children = function() end
    } end
    -- mock comm_fun
    save_orig(comm_fun, 'get_parent_path')
    save_orig(comm_fun, 'get_object_name_by_device_path')
    save_orig(comm_fun, 'get_all_device_paths')
    save_orig(comm_fun, 'set_interface_add_signal')
    comm_fun.get_parent_path = function(bus, optical_device_path)
        if optical_device_path == '/fail' then return nil end
        return '/parent/path'
    end
    comm_fun.get_object_name_by_device_path = function(bus, optical_device_path, rep)
        if optical_device_path == '/fail_obj' then return nil end
        return 'OpticalModule1'
    end
    comm_fun.get_all_device_paths = function() return { '/om1', '/om2' } end
    comm_fun.set_interface_add_signal = function(bus, sig_slot, path_pattern, interface_name, cb)
        sig_slot[#sig_slot + 1] = 123
        -- mock bus:match
        if not bus.match then
            bus.match = function(self, sig, cb)
                return 123
            end
        end
        cb('/mock/path')
    end
    -- mock om_init
    save_orig(om_init, 'init')
    om_init.init = function() end
    -- mock log
    save_orig(log, 'notice')
    save_orig(log, 'error')
    log.notice = function() end
    log.error = function() end
end

function TestOmMgmt:tearDown()
    restore_all()
end

function TestOmMgmt:test_ctor_and_init()
    local bus = {}
    local mgr = om_mgmt
    mgr.bus = bus
    mgr.objects = {}
    mgr.sig_slot = {}
    lu.assertIs(mgr.bus, bus)
    lu.assertIsTable(mgr.objects)
    lu.assertIsTable(mgr.sig_slot)
end

function TestOmMgmt:test_init_obj_normal()
    local bus = {}
    local mgr = om_mgmt
    mgr.bus = bus
    mgr.objects = {}
    mgr.sig_slot = {}
    mgr:init_obj('/om1')
    lu.assertIsTable(mgr.objects['/om1'])
end

function TestOmMgmt:test_init_obj_already_exists()
    local bus = {}
    local mgr = om_mgmt
    mgr.bus = bus
    mgr.objects = { ['/om1'] = { id = 1 } }
    mgr.sig_slot = {}
    mgr:init_obj('/om1')
    lu.assertEquals(mgr.objects['/om1'].id, 1)
end

function TestOmMgmt:test_init_obj_no_parent_path()
    local bus = {}
    local mgr = om_mgmt
    mgr.bus = bus
    mgr.objects = {}
    mgr.sig_slot = {}
    mgr:init_obj('/fail')
    lu.assertNil(mgr.objects['/fail'])
end

function TestOmMgmt:test_init_obj_no_parent_orm()
    local bus = {}
    local mgr = om_mgmt
    mgr.bus = bus
    mgr.objects = {}
    mgr.sig_slot = {}
    -- mock port_mgmt.get_instance返回无效对象
    port_mgmt.get_instance = function() return {
        get_orm_obj_by_device_path = function(_, path) return nil end,
        register_children = function() end
    } end
    mgr:init_obj('/om1')
    lu.assertNil(mgr.objects['/om1'])
end

function TestOmMgmt:test_init_obj_no_object_name()
    local bus = {}
    local mgr = om_mgmt
    mgr.bus = bus
    mgr.objects = {}
    mgr.sig_slot = {}
    mgr:init_obj('/fail_obj')
    lu.assertNil(mgr.objects['/fail_obj'])
end

function TestOmMgmt:test_init_obj_create_optical_orm_object_fail()
    local bus = {}
    local mgr = om_mgmt
    mgr.bus = bus
    mgr.objects = {}
    mgr.sig_slot = {}
    -- mock c_optical_module.collection:find返回nil
    c_optical_module.collection = { find = function() return nil end }
    mgr:init_obj('/om1')
    lu.assertNil(mgr.objects['/om1'])
end

function TestOmMgmt:test_init()
    local bus = {}
    local mgr = om_mgmt
    mgr.bus = bus
    mgr.objects = {}
    mgr.sig_slot = {}
    mgr:init()
    lu.assertIsTable(mgr.objects['/om1'])
    lu.assertIsTable(mgr.objects['/om2'])
end

function TestOmMgmt:test_get_orm_obj_by_device_path()
    local bus = {}
    local mgr = om_mgmt
    mgr.bus = bus
    mgr.objects = { ['/om1'] = { id = 1 } }
    local obj = mgr:get_orm_obj_by_device_path('/om1')
    lu.assertEquals(obj.id, 1)
    local obj2 = mgr:get_orm_obj_by_device_path('/notfound')
    lu.assertNil(obj2)
end 