-- Copyright (c) 2025 Huawei Technologies Co., Ltd.
-- openUBMC is licensed under Mulan PSL v2.
-- THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
-- EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
-- MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
-- See the Mulan PSL v2 for more details.

local lu = require('luaunit')
local port_mgmt = require 'device.class.nic_mgmt.port_mgmt'
local comm_fun = require 'device.class.nic_mgmt.comm_fun'
local comm_defs = require 'device.class.nic_mgmt.comm_defs'
local ctx = require 'mc.context'
local mdb = require 'mc.mdb'
local c_factory = require 'mc.orm.factory'
local c_network_port = require 'device.class.network_port'
local card_mgmt = require 'device.class.nic_mgmt.card_mgmt'
local port_init = require 'device.class.nic_mgmt.port.port_init'
local log = require 'mc.logging'

local origs = {}
local function save_orig(tbl, key)
    origs[tbl] = origs[tbl] or {}
    if not origs[tbl][key] then
        origs[tbl][key] = tbl[key]
    end
end
local function stub(tbl, key, value)
    save_orig(tbl, key)
    rawset(tbl, key, value)
end
local function restore_all()
    for tbl, keys in pairs(origs) do
        for k, v in pairs(keys) do tbl[k] = v end
    end
    origs = {}
end

TestPortMgmt = {}

local NIL = {}
local sync_cases = {
    { func = 'set_link_status_to_device_obj', interface_key = 'NETWORK_PORT_LINK_INFO_INTERFACE', field = 'LinkStatus' },
    { func = 'set_bandwidth_usage_to_device_obj', interface_key = 'NETWORK_PORT_METRICS_INTERFACE', field = 'BandwidthPercent' }
}

local function make_port_obj(overrides)
    local obj = {
        CreatedByDeviceObject = true,
        device_path = '/dev/port',
        NodeId = 'node-1'
    }
    if overrides then
        for k, v in pairs(overrides) do
            if v == NIL then
                obj[k] = nil
            else
                obj[k] = v
            end
        end
    end
    return obj
end

local function set_device_obj_stub(fn)
    local original = comm_fun.get_device_obj_by_path_interface
    rawset(comm_fun, 'get_device_obj_by_path_interface', fn)
    return function()
        rawset(comm_fun, 'get_device_obj_by_path_interface', original)
    end
end

function TestPortMgmt:setUp()
    -- mock comm_defs
    stub(comm_defs, 'CARD_DEVICE_PATH_PATTERN', '/mock/pattern')
    stub(comm_defs, 'NETWORK_PORT_DEVICE_INTERFACE', 'iface_port')
    stub(comm_defs, 'NETWORK_PORT_LINK_INFO_INTERFACE', 'link_info_iface')
    stub(comm_defs, 'MACA_SERVICE', 'svc')
    stub(comm_defs, 'MDB_PATH', '/mdb')
    stub(comm_defs, 'MDB_INTERFACE', 'mdb_iface')
    -- mock ctx
    stub(ctx, 'get_context_or_default', function() return {} end)
    -- mock mdb
    stub(mdb, 'register_interface', function() end)
    stub(mdb, 'get_object', function()
        return {
            get_property = function(self, prop)
                if prop == 'PortId' then return 0, 123 end
                return 0, 'val'
            end
        }
    end)
    -- mock c_network_port
    stub(c_network_port, 'create_mdb_object', function(tbl) return tbl end)
    stub(c_network_port, 'collection', {
        find = function(self, q)
            return {
                children = {},
                PortID = q and q.PortID or 123,
                NetworkAdapterId = q and q.NetworkAdapterId or 'nid',
                FirmwareVersion = 'fw',
                ObjectName = 'portobj',
                parent = { ObjectName = 'parentobj' },
                id = 1
            }
        end,
        del_object = function() end
    })
    -- mock c_factory
    stub(c_factory, 'get_instance', function() return { create_object = function() end } end)
    -- mock card_mgmt.get_instance
    stub(card_mgmt, 'get_instance', function() return {
        get_orm_obj_by_device_path = function(_, path)
            if path == '/parent/path' then
                return {
                    NodeId = 'nid',
                    ObjectName = 'obj',
                    FirmwareVersion = 'fw',
                    children = {},
                    NetworkPortCount = 1,
                    parent = { ObjectName = 'parentobj' }
                }
            end
        end,
        register_children = function() end
    } end)
    -- mock comm_fun
    stub(comm_fun, 'get_parent_path', function(bus, port_device_path)
        if port_device_path == '/fail' then return nil end
        return '/parent/path'
    end)
    stub(comm_fun, 'get_all_device_paths', function() return { '/port1', '/port2' } end)
    stub(comm_fun, 'set_interface_add_signal', function(bus, sig_slot, path_pattern, interface_name, cb)
        sig_slot[#sig_slot + 1] = 123
        if not bus.match then
            bus.match = function(self, sig, cb)
                return 123
            end
        end
        cb('/mock/path')
    end)
    stub(comm_fun, 'get_object_name_by_device_path', function(bus, port_device_path, rep)
        return 'NetworkPort1'
    end)
    -- mock port_init
    stub(port_init, 'init', function() end)
    -- mock log
    stub(log, 'notice', function() end)
    stub(log, 'error', function() end)
    self.bus = { call = function(self, ...) return { { 'iface_port' } } end }
end

function TestPortMgmt:tearDown()
    restore_all()
end

function TestPortMgmt:test_ctor_and_init()
    local mgr = port_mgmt
    mgr.bus = self.bus
    mgr.objects = {}
    mgr.sig_slot = {}
    lu.assertIs(mgr.bus, self.bus)
    lu.assertIsTable(mgr.objects)
    lu.assertIsTable(mgr.sig_slot)
end

function TestPortMgmt:test_init_obj_normal()
    local mgr = port_mgmt
    mgr.bus = self.bus
    mgr.objects = {}
    mgr.sig_slot = {}
    mgr:init_obj('/port1')
    lu.assertIsTable(mgr.objects['/port1'])
end

function TestPortMgmt:test_init_obj_already_exists()
    local mgr = port_mgmt
    mgr.bus = self.bus
    mgr.objects = { ['/port1'] = { id = 1 } }
    mgr.sig_slot = {}
    mgr:init_obj('/port1')
    lu.assertEquals(mgr.objects['/port1'].id, 1)
end

function TestPortMgmt:test_init_obj_no_parent_path()
    local mgr = port_mgmt
    mgr.bus = self.bus
    mgr.objects = {}
    mgr.sig_slot = {}
    mgr:init_obj('/fail')
    lu.assertNil(mgr.objects['/fail'])
end

function TestPortMgmt:test_init_obj_no_parent_orm()
    local mgr = port_mgmt
    mgr.bus = self.bus
    mgr.objects = {}
    mgr.sig_slot = {}
    -- mock card_mgmt.get_instance返回无效对象
    card_mgmt.get_instance = function() return {
        get_orm_obj_by_device_path = function(_, path) return nil end,
        register_children = function() end
    } end
    mgr:init_obj('/port1')
    lu.assertNil(mgr.objects['/port1'])
end

function TestPortMgmt:test_init()
    local mgr = port_mgmt
    mgr.bus = self.bus
    mgr.objects = {}
    mgr.sig_slot = {}
    mgr:init()
    lu.assertIsTable(mgr.objects['/port1'])
    lu.assertIsTable(mgr.objects['/port2'])
end

function TestPortMgmt:test_get_orm_obj_by_device_path()
    local mgr = port_mgmt
    mgr.bus = self.bus
    mgr.objects = { ['/port1'] = { id = 1 } }
    local obj = mgr:get_orm_obj_by_device_path('/port1')
    lu.assertEquals(obj.id, 1)
    local obj2 = mgr:get_orm_obj_by_device_path('/notfound')
    lu.assertNil(obj2)
end

function TestPortMgmt:test_register_children()
    local port_orm_obj = { children = {} }
    local optical_module_orm_obj = { id = 1 }
    port_mgmt:register_children(optical_module_orm_obj, port_orm_obj)
    lu.assertEquals(port_orm_obj.children[1], optical_module_orm_obj)
end

function TestPortMgmt:test_sync_port_device_obj_skip_when_not_created()
    local mgr = port_mgmt
    mgr.bus = self.bus
    for _, case in ipairs(sync_cases) do
        local call_count = 0
        local restore = set_device_obj_stub(function(...)
            call_count = call_count + 1
            return {}
        end)
        mgr[case.func](mgr, make_port_obj({ CreatedByDeviceObject = false }), 'value')
        lu.assertEquals(call_count, 0)
        restore()
    end
end

function TestPortMgmt:test_sync_port_device_obj_skip_when_device_path_missing()
    local mgr = port_mgmt
    mgr.bus = self.bus
    for _, case in ipairs(sync_cases) do
        local call_count = 0
        local restore = set_device_obj_stub(function(...)
            call_count = call_count + 1
            return {}
        end)
        mgr[case.func](mgr, make_port_obj({ device_path = '' }), 'value')
        lu.assertEquals(call_count, 0)
        restore()
    end
end

function TestPortMgmt:test_sync_port_device_obj_skip_when_device_obj_missing()
    local mgr = port_mgmt
    mgr.bus = self.bus
    local call_count = 0
    local restore = set_device_obj_stub(function(...)
        call_count = call_count + 1
        return nil
    end)
    for _, case in ipairs(sync_cases) do
        mgr[case.func](mgr, make_port_obj(), 'value')
    end
    lu.assertEquals(call_count, #sync_cases)
    restore()
end

function TestPortMgmt:test_sync_port_device_obj_success_sets_field()
    local mgr = port_mgmt
    mgr.bus = self.bus
    for _, case in ipairs(sync_cases) do
        local device_obj = {}
        local call_count = 0
        local restore = set_device_obj_stub(function(_, path, interface_name)
            call_count = call_count + 1
            lu.assertEquals(path, '/dev/port')
            lu.assertEquals(interface_name, comm_defs[case.interface_key])
            return device_obj
        end)
        mgr[case.func](mgr, make_port_obj(), 'new-value')
        lu.assertEquals(call_count, 1)
        lu.assertEquals(device_obj[case.field], 'new-value')
        restore()
    end
end

function TestPortMgmt:test_link_status_property_change()
    local set_link_status_called = false
    local set_link_status_numeric_called = false
    local received_value = nil
    
    local port_orm_obj = {
        PortID = 123,
        NetworkAdapterId = 'nid',
        FirmwareVersion = 'fw',
        ObjectName = 'portobj',
        parent = { ObjectName = 'parentobj' },
        id = 1,
        set_link_status = function(self, value)
            set_link_status_called = true
            received_value = value
        end,
        set_link_status_numeric = function(self, value)
            set_link_status_numeric_called = true
        end
    }
    
    local captured_callback = nil
    local mock_bus = {
        match = function(self, signal, callback)
            captured_callback = callback
            return {}
        end,
        call = function(self, ...) 
            return { { 'iface_port' } } 
        end
    }
    
    local original_find = c_network_port.collection.find
    c_network_port.collection.find = function(self, q)
        return port_orm_obj
    end
    
    local mgr = port_mgmt
    mgr.bus = mock_bus
    mgr.objects = {}
    mgr.sig_slot = {}
    
    local mock_msg = {
        read = function(self, pattern)
            local interface = comm_defs.NETWORK_PORT_LINK_INFO_INTERFACE
            local props = {
                LinkStatus = {
                    value = function() return 1 end
                }
            }
            return interface, props
        end
    }
    
    mgr:init_obj('/port1')
    
    lu.assertNotNil(captured_callback, "Should capture property change callback")
    if not captured_callback then
        error('property change callback missing')
    end
    captured_callback(mock_msg)
    
    lu.assertTrue(set_link_status_called, "Should call set_link_status")
    lu.assertTrue(set_link_status_numeric_called, "Should call set_link_status_numeric")
    lu.assertEquals(received_value, "Connected", "Should receive correct value")
    
    c_network_port.collection.find = original_find
end

function TestPortMgmt:test_link_status_all_values()
    local test_cases = {
        {input = 0, expected = 'Disconnected', desc = 'LinkStatus=0 should map to Disconnected'},
        {input = 1, expected = 'Connected', desc = 'LinkStatus=1 should map to Connected'},
        {input = 255, expected = 'N/A', desc = 'LinkStatus=255 should map to N/A'},
        {input = 99, expected = 'N/A', desc = 'Unknown LinkStatus should map to N/A'}
    }
    
    for _, tc in ipairs(test_cases) do
        local received_value = nil
        local port_orm_obj = {
            PortID = 123,
            set_link_status = function(self, value)
                received_value = value
            end,
            set_link_status_numeric = function(self, value) end
        }
        
        local captured_callback = nil
        local mock_bus = {
            match = function(self, signal, callback)
                captured_callback = callback
                return {}
            end,
            call = function(self, ...) 
                return { { 'iface_port' } } 
            end
        }
        
        local original_find = c_network_port.collection.find
        c_network_port.collection.find = function(self, q)
            return port_orm_obj
        end
        
        local mgr = port_mgmt
        mgr.bus = mock_bus
        mgr.objects = {}
        mgr.sig_slot = {}
        
        local mock_msg = {
            read = function(self, pattern)
                local interface = comm_defs.NETWORK_PORT_LINK_INFO_INTERFACE
                local props = {
                    LinkStatus = {
                        value = function() return tc.input end
                    }
                }
                return interface, props
            end
        }
        
        mgr:init_obj('/port1')
        if not captured_callback then
            error('property change callback missing')
        end
        captured_callback(mock_msg)
        
        lu.assertEquals(received_value, tc.expected, tc.desc)
        
        c_network_port.collection.find = original_find
    end
end

function TestPortMgmt:test_del_obj_remove_port()
    local mgr = port_mgmt
    mgr.objects = { ['/mock/port'] = { id = 1 } }

    local removed_obj
    local orig_del = c_network_port.collection.del_object
    c_network_port.collection.del_object = function(_, obj)
        removed_obj = obj
    end

    mgr:del_obj('/mock/port')

    lu.assertEquals(removed_obj.id, 1)
    lu.assertNil(mgr.objects['/mock/port'])

    c_network_port.collection.del_object = orig_del
end

return TestPortMgmt