-- Copyright (c) 2025 Huawei Technologies Co., Ltd.
-- openUBMC is licensed under Mulan PSL v2.
-- THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
-- EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
-- MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
-- See the Mulan PSL v2 for more details.

local lu = require('luaunit')
local om_mgmt = require 'device.class.nic_mgmt.om_mgmt'
local comm_fun = require 'device.class.nic_mgmt.comm_fun'
local comm_defs = require 'device.class.nic_mgmt.comm_defs'
local ctx = require 'mc.context'
local mdb = require 'mc.mdb'
local c_factory = require 'mc.orm.factory'
local c_optical_module = require 'device.class.optical_module'
local port_mgmt = require 'device.class.nic_mgmt.port_mgmt'
local om_init = require 'device.class.nic_mgmt.om.om_init'
local log = require 'mc.logging'

local origs = {}
local function save_orig(tbl, key)
    origs[tbl] = origs[tbl] or {}
    origs[tbl][key] = tbl[key]
end
local function stub(tbl, key, value)
    save_orig(tbl, key)
    tbl[key] = value
end
local function restore_all()
    for tbl, keys in pairs(origs) do
        for k, v in pairs(keys) do tbl[k] = v end
    end
end

TestOmMgmt = {}

local sync_cases = {
    { func = 'set_temperature_to_device_obj', interface_key = 'OPTICAL_MODULE_COOLING_INTERFACE', field = 'TemperatureCelsius' },
    { func = 'set_presence_to_device_obj', interface_key = 'OPTICAL_MODULE_DEVICE_INTERFACE', field = 'Presence' },
    { func = 'set_speed_match_to_device_obj', interface_key = 'OPTICAL_MODULE_STATUS_INTERFACE', field = 'SpeedMatch' },
    { func = 'set_type_match_to_device_obj', interface_key = 'OPTICAL_MODULE_STATUS_INTERFACE', field = 'TypeMatch' },
    { func = 'set_supply_voltage_to_device_obj', interface_key = 'OPTICAL_MODULE_VOLTAGE_INTERFACE', field = 'SupplyVoltage' },
    { func = 'set_voltage_lower_threshold_critical_to_device_obj', interface_key = 'OPTICAL_MODULE_VOLTAGE_INTERFACE', field = 'VoltageLowerThresholdCritical' },
    { func = 'set_voltage_upper_threshold_critical_to_device_obj', interface_key = 'OPTICAL_MODULE_VOLTAGE_INTERFACE', field = 'VoltageUpperThresholdCritical' }
}

local NIL = {}

local function make_optical_obj(overrides)
    local obj = {
        CreatedByDeviceObject = true,
        device_path = '/dev/optical',
        NetworkAdapterId = 'nid',
        PortID = '1'
    }
    if overrides then
        for k, v in pairs(overrides) do
            if v == NIL then
                obj[k] = nil
            else
                obj[k] = v
            end
        end
    end
    return obj
end

local function set_device_obj_stub(fn)
    rawset(comm_fun, 'get_device_obj_by_path_interface', fn)
end

local function reset_device_obj_stub()
    rawset(comm_fun, 'get_device_obj_by_path_interface', function()
        return {}
    end)
end

function TestOmMgmt:setUp()
    -- mock comm_defs
    stub(comm_defs, 'CARD_DEVICE_PATH_PATTERN', '/mock/pattern')
    stub(comm_defs, 'OPTICAL_MODULE_DEVICE_INTERFACE', 'iface_om')
    stub(comm_defs, 'MACA_SERVICE', 'svc')
    stub(comm_defs, 'MDB_PATH', '/mdb')
    stub(comm_defs, 'MDB_INTERFACE', 'mdb_iface')
    -- mock ctx
    stub(ctx, 'get_context_or_default', function() return {} end)
    -- mock mdb
    stub(mdb, 'register_interface', function() end)
    stub(mdb, 'get_object', function()
        return {
            get_property = function(self, prop)
                return 0, 'val'
            end
        }
    end)
    -- mock c_optical_module
    stub(c_optical_module, 'create_mdb_object', function(tbl) return tbl end)
    stub(c_optical_module, 'collection', {
        find = function(self, q)
            return { children = {}, PortID = q and q.PortID or 123,
                NetworkAdapterId = q and q.NetworkAdapterId or 'nid',
                parent = { ObjectName = 'obj' }, ObjectName = 'portobj' }
        end,
        del_object = function() end
    })
    -- mock c_factory
    stub(c_factory, 'get_instance', function() return { create_object = function() end } end)
    -- mock port_mgmt.get_instance
    stub(port_mgmt, 'get_instance', function() return {
        get_orm_obj_by_device_path = function(_, path)
            if path == '/parent/path' then return { PortID = 123, NetworkAdapterId = 'nid',
                parent = { ObjectName = 'obj' }, ObjectName = 'portobj', FirmwareVersion = 'fw', children = {} } end
        end,
        register_children = function() end
    } end)
    -- mock comm_fun
    stub(comm_fun, 'get_parent_path', function(bus, optical_device_path)
        if optical_device_path == '/fail' then return nil end
        return '/parent/path'
    end)
    stub(comm_fun, 'get_object_name_by_device_path', function(bus, optical_device_path, rep)
        if optical_device_path == '/fail_obj' then return nil end
        return 'OpticalModule1'
    end)
    stub(comm_fun, 'get_object_identifier_by_device_path', function(bus, device_path)
        return {1, 'test_manager', 'test_chassis', 'test_position'}
    end)
    stub(comm_fun, 'get_all_device_paths', function() return { '/om1', '/om2' } end)
    stub(comm_fun, 'set_interface_add_signal', function(bus, sig_slot, path_pattern, interface_name, cb)
        sig_slot[#sig_slot + 1] = 123
        -- mock bus:match
        if not bus.match then
            bus.match = function(self, sig, cb)
                return 123
            end
        end
        cb('/mock/path')
    end)
    stub(comm_fun, 'get_device_obj_by_path_interface', function()
        return {}
    end)
    -- mock om_init
    stub(om_init, 'init', function() end)
    -- mock log
    stub(log, 'notice', function() end)
    stub(log, 'error', function() end)
    stub(log, 'info', function() end)
    reset_device_obj_stub()
end

function TestOmMgmt:tearDown()
    restore_all()
end

function TestOmMgmt:test_ctor_and_init()
    local bus = {}
    local mgr = om_mgmt
    mgr.bus = bus
    mgr.objects = {}
    mgr.sig_slot = {}
    lu.assertIs(mgr.bus, bus)
    lu.assertIsTable(mgr.objects)
    lu.assertIsTable(mgr.sig_slot)
end

function TestOmMgmt:test_init_obj_normal()
    local bus = {}
    local mgr = om_mgmt
    mgr.bus = bus
    mgr.objects = {}
    mgr.sig_slot = {}
    mgr:init_obj('/om1')
    lu.assertIsTable(mgr.objects['/om1'])
end

function TestOmMgmt:test_init_obj_already_exists()
    local bus = {}
    local mgr = om_mgmt
    mgr.bus = bus
    mgr.objects = { ['/om1'] = { id = 1 } }
    mgr.sig_slot = {}
    mgr:init_obj('/om1')
    lu.assertEquals(mgr.objects['/om1'].id, 1)
end

function TestOmMgmt:test_init_obj_no_parent_path()
    local bus = {}
    local mgr = om_mgmt
    mgr.bus = bus
    mgr.objects = {}
    mgr.sig_slot = {}
    mgr:init_obj('/fail')
    lu.assertNil(mgr.objects['/fail'])
end

function TestOmMgmt:test_init_obj_no_parent_orm()
    local bus = {}
    local mgr = om_mgmt
    mgr.bus = bus
    mgr.objects = {}
    mgr.sig_slot = {}
    -- mock port_mgmt.get_instance返回无效对象
    port_mgmt.get_instance = function() return {
        get_orm_obj_by_device_path = function(_, path) return nil end,
        register_children = function() end
    } end
    mgr:init_obj('/om1')
    lu.assertNil(mgr.objects['/om1'])
end

function TestOmMgmt:test_init_obj_no_object_name()
    local bus = {}
    local mgr = om_mgmt
    mgr.bus = bus
    mgr.objects = {}
    mgr.sig_slot = {}
    mgr:init_obj('/fail_obj')
    lu.assertNil(mgr.objects['/fail_obj'])
end

function TestOmMgmt:test_init_obj_create_optical_orm_object_fail()
    local bus = {}
    local mgr = om_mgmt
    mgr.bus = bus
    mgr.objects = {}
    mgr.sig_slot = {}
    -- mock c_optical_module.collection:find返回nil
    c_optical_module.collection = { find = function() return nil end }
    mgr:init_obj('/om1')
    lu.assertNil(mgr.objects['/om1'])
end

function TestOmMgmt:test_init()
    local bus = {}
    local mgr = om_mgmt
    mgr.bus = bus
    mgr.objects = {}
    mgr.sig_slot = {}
    mgr:init()
    lu.assertIsTable(mgr.objects['/om1'])
    lu.assertIsTable(mgr.objects['/om2'])
end

function TestOmMgmt:test_get_orm_obj_by_device_path()
    local bus = {}
    local mgr = om_mgmt
    mgr.bus = bus
    mgr.objects = { ['/om1'] = { id = 1 } }
    local obj = mgr:get_orm_obj_by_device_path('/om1')
    lu.assertEquals(obj.id, 1)
    local obj2 = mgr:get_orm_obj_by_device_path('/notfound')
    lu.assertNil(obj2)
end 

function TestOmMgmt:test_sync_device_obj_skip_when_not_created_by_device()
    local mgr = om_mgmt
    mgr.bus = {}
    for _, case in ipairs(sync_cases) do
        local call_count = 0
        set_device_obj_stub(function(...)
            call_count = call_count + 1
            return {}
        end)
        mgr[case.func](mgr, make_optical_obj({ CreatedByDeviceObject = false }), 'val')
        lu.assertEquals(call_count, 0)
    end
    reset_device_obj_stub()
end

function TestOmMgmt:test_sync_device_obj_skip_when_device_path_missing()
    local mgr = om_mgmt
    mgr.bus = {}
    for _, case in ipairs(sync_cases) do
        local call_count = 0
        set_device_obj_stub(function(...)
            call_count = call_count + 1
            return {}
        end)
        mgr[case.func](mgr, make_optical_obj({ device_path = '' }), 'val')
        lu.assertEquals(call_count, 0)
    end

    reset_device_obj_stub()
end

function TestOmMgmt:test_sync_device_obj_skip_when_device_obj_missing()
    local mgr = om_mgmt
    mgr.bus = {}
    local call_count = 0
    set_device_obj_stub(function(...)
        call_count = call_count + 1
        return nil
    end)
    for _, case in ipairs(sync_cases) do
        mgr[case.func](mgr, make_optical_obj(), 'val')
    end
    lu.assertEquals(call_count, #sync_cases)
    reset_device_obj_stub()
end

function TestOmMgmt:test_sync_device_obj_success_sets_field()
    local mgr = om_mgmt
    mgr.bus = {}
    for _, case in ipairs(sync_cases) do
        local device_obj = {}
        local call_count = 0
        set_device_obj_stub(function(_, path, interface_name)
            call_count = call_count + 1
            lu.assertEquals(path, '/dev/optical')
            lu.assertEquals(interface_name, comm_defs[case.interface_key])
            return device_obj
        end)
        mgr[case.func](mgr, make_optical_obj(), 'new-value')
        lu.assertEquals(call_count, 1)
        lu.assertEquals(device_obj[case.field], 'new-value')
    end
    reset_device_obj_stub()
end

function TestOmMgmt:test_del_obj_remove_optical_module()
    local mgr = om_mgmt
    mgr.objects = { ['/mock/optical'] = { id = 1 } }

    local removed_obj
    local orig_del = c_optical_module.collection.del_object
    c_optical_module.collection.del_object = function(_, obj)
        removed_obj = obj
    end

    mgr:del_obj('/mock/optical')

    lu.assertEquals(removed_obj.id, 1)
    lu.assertNil(mgr.objects['/mock/optical'])

    c_optical_module.collection.del_object = orig_del
end

return TestOmMgmt