-- Copyright (c) 2025 Huawei Technologies Co., Ltd.
-- openUBMC is licensed under Mulan PSL v2.
-- You can use this software according to the terms and conditions of the Mulan PSL v2.
-- You may obtain a copy of Mulan PSL v2 at: http://license.coscl.org.cn/MulanPSL2
--
-- THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
-- EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
-- MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
-- See the Mulan PSL v2 for more details.

local mdb_service = require 'mc.mdb.mdb_service'
local log = require 'mc.logging'

local ETH_OBJ_PATH<const> = '/bmc/kepler/Managers/1/EthernetInterfaces'
local ETH_OBJ_INTERFACE<const> = 'bmc.kepler.Managers.EthernetInterfaces'
local ETH_PORT_INTERFACE<const> = 'bmc.kepler.Managers.EthernetInterfaces.MgmtPort'
local ETH_GROUP_INTERFACE<const> = 'bmc.kepler.Managers.EthernetInterfaces.EthGroup'
local ETHERNET_PATH<const> = '/bmc/kepler/Managers/1/EthernetInterfaces'
local MGMT_PORT_INTERFACE<const> = 'bmc.kepler.Managers.EthernetInterfaces.MgmtPort'

local m = {}

local function create_base_setting()
    local setting = {
        EthernetInterfaces = {
            Path = "",
            Interface = "",
            PortId = "PortId",
            VLANEnable = "VLANEnable",
            SetFlag = false,
            EthModeNetMode = "NetMode",
            EthModePortId = 0,
            EthModeVLANEnable = false,
            EthModeVLANId = 0,
            IpVersion = "",
            HostPortPortId = 0
        },
        Ipv4 = {
            Path = "",
            Interface = "",
            SetDefaultGateway = "SetDefaultGateway"
        },
        Ipv6 = {
            Path = "",
            Interface = "",
            IpMode = "IpMode",
            IpAddr = "IpAddr",
            DefaultGateway = "DefaultGateway",
            SetIpAddr = "SetIpAddr",
            SetDefaultGateway = "SetDefaultGateway"
        }
    }
    return setting
end

local function get_port_id_by_type(type)
    local port_id = nil
    local ports = mdb.get_sub_objects(bus, ETHERNET_PATH, MGMT_PORT_INTERFACE)
    for _, value in pairs(ports) do
        -- 默认切到port1和v2保持一致
        if value.Type == type and value.DevicePortId == 1 then
            return value.Id
        end
    end
    return port_id
end

local function get_net_mode(obj, eth_mode, obj_port_id, obj_vlan_enable)
    -- 3表示自适应模式
    if eth_mode == 3 then
        -- 如果当前为自适应模式不作处理
        if obj.NetMode == 'Automatic' then return {SetFlag = false} end
        return {
            SetFlag = true,
            NetMode = 'Automatic',
            PortId = obj_port_id,
            VLANEnable = obj_vlan_enable,
            VLANId = obj.VLANId
        }
    end
    local port = mdb.get_object(bus, ETH_OBJ_PATH .. '/' .. obj_port_id,
                                ETH_PORT_INTERFACE)
    local port_id = nil
    if eth_mode == 2 and port.Type ~= 'LOM ' then
        port_id = get_port_id_by_type('LOM')
    elseif eth_mode == 1 and port.Type ~= 'Dedicated' then
        port_id = 1
    elseif eth_mode == 4 and port.Type ~= 'ExternalPCIe' then
        port_id = get_port_id_by_type('ExternalPCIe')
    end
    if port_id == nil then port_id = obj_port_id end
    return {
        SetFlag = true,
        NetMode = 'Fixed',
        PortId = port_id,
        VLANEnable = obj_vlan_enable,
        VLANId = obj.VLANId
    }
end

-- 根据传入的NCSI端口号获取要切换的网口
local function get_port_id_by_host_port(host_port, obj_port_id)
    local active_port = mdb.get_object(bus, ETH_OBJ_PATH  .. '/' .. obj_port_id, ETH_PORT_INTERFACE)
    local active_port_type = active_port.Type
    local port_id = nil
    local ports = mdb.get_sub_objects(bus, ETH_OBJ_PATH, ETH_PORT_INTERFACE)
    for _, value in pairs(ports) do
        if value.Type == active_port_type and value.DevicePortId == host_port - 1 then
            port_id = value.Id
        end
    end
    if port_id == nil then
        custom_messages.PortNotExist(active_port_type, host_port - 1)
    end
    return port_id
end

local function get_ip_version(ip_version, ipv4_enable)
    if ip_version == 'IPv4' and ipv4_enable == 1 then
        error(custom_messages.OperationNotAllowed())
    end
    if ip_version == 'IPv6' and ipv4_enable == 2 then return 'IPv4AndIPv6' end
    if ip_version == 'IPv4AndIPv6' and ipv4_enable == 1 then return 'IPv6' end
    return ip_version
end

-- 获取和当前eth_num匹配的ethGroup资源树对象
local function get_eth_num_patched_ethgroup_object(eth_num, sorted_ethGroup_objs)
    if sorted_ethGroup_objs and sorted_ethGroup_objs[eth_num] then
        return sorted_ethGroup_objs[eth_num]
    end
    return nil
end

local function get_eth_group_setting(ethGroup_obj, ipv4_enable, host_port, eth_mode, eth_group_setting)
    eth_group_setting.EthernetInterfaces.Path = ethGroup_obj.path
    eth_group_setting.EthernetInterfaces.Interface = ETH_GROUP_INTERFACE

    if eth_mode ~= nil then
        local object = get_net_mode(ethGroup_obj, eth_mode, ethGroup_obj.ActivePortId, ethGroup_obj.VLANEnabled)
        eth_group_setting.EthernetInterfaces.SetFlag = object.SetFlag
        eth_group_setting.EthernetInterfaces.EthModeNetMode = object.NetMode
        eth_group_setting.EthernetInterfaces.EthModePortId = object.PortId
        eth_group_setting.EthernetInterfaces.EthModeVLANEnable = object.VLANEnable
        eth_group_setting.EthernetInterfaces.EthModeVLANId = object.VLANId
    end

    if ipv4_enable ~= nil then
        eth_group_setting.EthernetInterfaces.IpVersion = get_ip_version(ethGroup_obj.IpVersion, ipv4_enable)
    end
    if host_port ~= nil then
        eth_group_setting.EthernetInterfaces.HostPortPortId = get_port_id_by_host_port(host_port, ethGroup_obj.ActivePortId)
    end
    eth_group_setting.EthernetInterfaces.PortId = "ActivePortId"
    eth_group_setting.EthernetInterfaces.VLANEnable = "VLANEnabled"
    eth_group_setting.Ipv4.Path = ethGroup_obj.path
    eth_group_setting.Ipv4.Interface = ETH_GROUP_INTERFACE
    eth_group_setting.Ipv6.Path = ethGroup_obj.path
    eth_group_setting.Ipv6.Interface = ETH_GROUP_INTERFACE
    eth_group_setting.Ipv6.IpMode = "Ipv6Mode" 
    eth_group_setting.Ipv6.IpAddr = "Ipv6Addr"
    eth_group_setting.Ipv6.DefaultGateway = "Ipv6DefaultGateway"
    eth_group_setting.Ipv6.SetIpAddr = "SetIpv6Addr"
    eth_group_setting.Ipv6.SetDefaultGateway = "SetIpv6DefaultGateway"
    return true
end


local function patch_eth_num_to_ethgroup(eth_num, ipv4_enable, host_port, eth_mode, eth_group_setting)
    local ok, rsp = pcall(mdb_service.get_sub_paths, bus, ETH_OBJ_PATH .. "/EthGroup", 1, {ETH_GROUP_INTERFACE})

    if not ok then
        log:debug('Incorrect parent path or interface')
        return false
    end

    local ethGroup_objs = {}
    for _, sub_path in pairs(rsp.SubPaths) do
        local ok, ethGroup_obj = pcall(mdb.get_object, bus, sub_path, ETH_GROUP_INTERFACE)
        if ok and ethGroup_obj.OutType == 2 and ethGroup_obj.Status then
            table.insert(ethGroup_objs, ethGroup_obj)
        end
    end

    if ethGroup_objs ~= nil then
        table.sort(ethGroup_objs, function(a, b)
            return a.GroupId < b.GroupId
        end)

        local ethGroup_obj = get_eth_num_patched_ethgroup_object(eth_num, ethGroup_objs)
        if ethGroup_obj ~= nil then
            get_eth_group_setting(ethGroup_obj, ipv4_enable, host_port, eth_mode, eth_group_setting)
            return true
        end
    end

    return false
end

function m.patch_eth_info(eth_num, ipv4_enable, host_port, eth_mode)
    eth_num = eth_num - 1
    -- 符合Interfaces的eth_num
    log:notice("patch_eth_info eth_num = %s", eth_num)
    if eth_num == 0 then
        local mdb_setting = create_base_setting()
        mdb_setting.EthernetInterfaces.Path = ETH_OBJ_PATH
        mdb_setting.EthernetInterfaces.Interface = ETH_OBJ_INTERFACE
        local eth_obj = mdb.get_object(bus, ETH_OBJ_PATH, ETH_OBJ_INTERFACE)
        if eth_mode ~= nil then
            local object = get_net_mode(eth_obj, eth_mode, eth_obj.PortId, eth_obj.VLANEnable)
            mdb_setting.EthernetInterfaces.SetFlag = object.SetFlag
            mdb_setting.EthernetInterfaces.EthModeNetMode = object.NetMode
            mdb_setting.EthernetInterfaces.EthModePortId = object.PortId
            mdb_setting.EthernetInterfaces.EthModeVLANEnable = object.VLANEnable
            mdb_setting.EthernetInterfaces.EthModeVLANId = object.VLANId
        end
        if host_port ~= nil then
            mdb_setting.EthernetInterfaces.HostPortPortId = get_port_id_by_host_port(host_port, eth_obj.PortId)
        end
        if ipv4_enable ~= nil then
            mdb_setting.EthernetInterfaces.IpVersion = get_ip_version(eth_obj.IpVersion, ipv4_enable)
        end
        mdb_setting.Ipv4.Path = ETH_OBJ_PATH .. "/Ipv4"
        mdb_setting.Ipv4.Interface = ETH_OBJ_INTERFACE ..".Ipv4"
        mdb_setting.Ipv6.Path = ETH_OBJ_PATH .. "/Ipv6"
        mdb_setting.Ipv6.Interface = ETH_OBJ_INTERFACE ..".Ipv6"
        return mdb_setting
    end

    -- 符合EthGroup的eth_num 
    local eth_group_setting = create_base_setting()
    if patch_eth_num_to_ethgroup(eth_num, ipv4_enable, host_port, eth_mode, eth_group_setting) then
        return eth_group_setting
    end

    return create_base_setting()
end

return m