-- Copyright (c) 2024 Huawei Technologies Co., Ltd.
-- openUBMC is licensed under Mulan PSL v2.
-- You can use this software according to the terms and conditions of the Mulan PSL v2.
-- You may obtain a copy of Mulan PSL v2 at: http://license.coscl.org.cn/MulanPSL2
-- THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
-- EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
-- MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
-- See the Mulan PSL v2 for more details.

local bs = require 'mc.bitstring'
local log = require 'mc.logging'
local core = require 'network.core'
local ncsi_protocol_intf = require 'ncsi_protocol_intf'
local ncsi_utils = require 'ncsi.ncsi_protocol.ncsi_utils'
local ncsi_def = require 'ncsi.ncsi_protocol.ncsi_def'
local ncsi_mac = require 'ncsi.ncsi_protocol.ncsi_mac'
local ncsi_vlan_mode = require 'ncsi.ncsi_protocol.ncsi_vlan_mode'
local ncsi_packet = require 'ncsi.ncsi_protocol.ncsi_packet'
local ncsi_oem_response = require 'ncsi.ncsi_protocol.ncsi_oem_response'
local ncsi_parameter = require 'ncsi.ncsi_protocol.ncsi_parameter'

local ncsi_oem_lldp = {}

-- 命令类型常量定义
local OEM_COMMAND = 0x50
local OEM_COMMAND_RSP = 0xD0

-- 子命令ID
local GET_LLDP_OVER_NCSI_STATUS = 0x0B
local GET_LLDP_OVER_NCSI_CMD_ID = 0x04
local SET_LLDP_OVER_NCSI_CMD_ID = 0x04
local SET_LLDP_OVER_NCSI_STATUS = 0x0A
local SET_LLDP_OVER_NCSI_PAYLOAD_LEN = 12
local GET_LLDP_OVER_NCSI_PAYLOAD_LEN = 8
local LLDP_OVER_NCSI_RESERVED_BYTES = 3

-- 厂商ID
local MANUFACTURE_ID_HUAWEI = 0x000007DB

-- 最大有效载荷长度
local OEM_PAYLOAD_MAX_LEN = 64
local ETH_NAME_PREFIX = 'eth'
local MAX_ETH_NUM = 16
local NCSI_CHANNEL_MAX_ID = 4
local NCSI_PACKAGE_MAX_ID = 8

local DISABLE = 0
local ENABLE = 1
local MAC_FILTER_ENABLE = 1
local MAC_FILTER_DISABLE = 0
local VLAN_MODE_FIELD_VLAN_NON_VLAN = 0x02
local VLAN_MODE_FIELD_ANYVLAN_NON_VLAN = 0x04
local VLAN_NON_VLAN = 0x02
local ANYVLAN_NON_VLAN = 0x03

local CHECK_SUM_LEN = 4

-- 全局变量
local g_lldp_over_ncsi_status = {}

-- 请求和响应的位域结构定义
local oem_command_req_bs = bs.new([[<<
    manufacture_id:32,
    cmd_rev:8,
    cmd_id:8,
    sub_cmd:8,
    reserved:8,
    payload:64/string
>>]])

local set_lldp_over_ncsi_bs = bs.new([[<<
    status:8,
    reserved:3/string,
    check_sum:32
>>]])

-- LLDP over NCSI响应结构定义
local get_lldp_over_ncsi_rsp_bs = bs.new([[<<
    rsp_code:16,
    reason_code:16,
    manufacture_id:32,
    cmd_rev:8,
    cmd_id:8,
    sub_cmd:8,
    reserved:8,
    status:8,
    reserved1:3/string,
    check_sum:32
>>]])

local function init_lldp_over_ncsi_status()
    for i = 0, MAX_ETH_NUM - 1 do
        g_lldp_over_ncsi_status[i] = {}
        for j = 0, NCSI_CHANNEL_MAX_ID - 1 do
            g_lldp_over_ncsi_status[i][j] = 0
        end
    end
end

local function configure_request_packet(req_packet, payload_len)
    ncsi_utils.ncsi_cmd_common_config(req_packet)
    req_packet.packet_head.payload_len_hi = ((payload_len >> 8) & 0x0f)
    req_packet.packet_head.payload_len_lo = (payload_len & 0xff)

    local req_len = ((req_packet.packet_head.payload_len_hi << 8) | req_packet.packet_head.payload_len_lo)
    if req_len >= (ncsi_def.PACKET_ALL_LEN - ncsi_def.ETHERNET_HEAD_LEN - ncsi_def.PACKET_HEAD_LEN) then
        log:error('Request length[%u] is too long', req_len)
        return nil
    end
    return req_len
end

local function fill_oem_ncsi_payload(req_packet, sub_cmd, cmd_id, req_len, padding)
    local payload_data = {
        manufacture_id = core.htonl(MANUFACTURE_ID_HUAWEI),
        cmd_rev = 0,
        cmd_id = cmd_id,
        sub_cmd = sub_cmd,
        reserved = 0,
        payload = padding
    }
    req_packet.payload = oem_command_req_bs:pack(payload_data)

    -- 计算校验和和CRC32
    local check_sum = ncsi_utils.get_checksum(req_packet, ncsi_def.PACKET_HEAD_LEN + req_len)

    -- 更新checksum到payload
    payload_data.payload =
        string.rep('\0', OEM_PAYLOAD_MAX_LEN - 4) ..string.pack("I4", core.htonl(check_sum))

    -- 返回最终的payload
    return oem_command_req_bs:pack(payload_data)
end

local function write_get_lldp_status_req(req_packet, eth_name, sub_cmd, cmd_id, payload_len)
    local req_len = configure_request_packet(req_packet, payload_len)
    if not req_len then
        return ncsi_def.NCSI_FAIL
    end

    local padding = string.rep('\0', OEM_PAYLOAD_MAX_LEN)
    req_packet.payload = fill_oem_ncsi_payload(req_packet, sub_cmd, cmd_id, req_len, padding)

    local req_data = ncsi_utils.ncsi_packet_bs:pack(req_packet)
    return ncsi_protocol_intf.send_ncsi_cmd(req_data,
        req_len + ncsi_def.ETHERNET_HEAD_LEN + ncsi_def.PACKET_HEAD_LEN + CHECK_SUM_LEN, eth_name)
end

local function write_set_lldp_status_req(req_packet, eth_name, sub_cmd, cmd_id, payload_len, status)
    local req_len = configure_request_packet(req_packet, payload_len)
    if not req_len then
        return ncsi_def.NCSI_FAIL
    end

    local set_lldp_payload = set_lldp_over_ncsi_bs:pack({
        status = status,
        reserved = string.rep('\0', LLDP_OVER_NCSI_RESERVED_BYTES),
        check_sum = 0
    })

    local padding = set_lldp_payload .. string.rep('\0', OEM_PAYLOAD_MAX_LEN - #set_lldp_payload)
    req_packet.payload = fill_oem_ncsi_payload(req_packet, sub_cmd, cmd_id, req_len, padding)

    local req_data = ncsi_utils.ncsi_packet_bs:pack(req_packet)
    return ncsi_protocol_intf.send_ncsi_cmd(req_data,
        req_len + ncsi_def.ETHERNET_HEAD_LEN + ncsi_def.PACKET_HEAD_LEN + CHECK_SUM_LEN, eth_name)
end

-- 验证eth_name并获取eth_num
local function validate_eth_name(eth_name)
    if not eth_name then
        log:error('%s : eth_name is nil', 'validate_eth_name')
        return nil
    end

    if string.sub(eth_name, 1, #ETH_NAME_PREFIX) == ETH_NAME_PREFIX then
        local eth_num = tonumber(string.match(eth_name, "eth(%d+)"))
        if not eth_num then
            log:error('%s : Failed to parse eth number from eth name[%s]', 'validate_eth_name', eth_name)
            return nil
        end

        -- 检查网卡号是否有效
        if eth_num >= MAX_ETH_NUM then
            log:error('%s : Get eth num[%u] from eth name[%s] is larger than %u',
                'validate_eth_name', eth_num, eth_name, MAX_ETH_NUM)
            return nil
        end

        return eth_num
    else
        log:error('%s : Invalid eth name[%s]', 'validate_eth_name', eth_name)
        return nil
    end
end

local function get_lldp_over_ncsi_status_rsp(eth_name, channel_id)
    if channel_id >= NCSI_CHANNEL_MAX_ID then
        return 0xff
    end

    local eth_num = validate_eth_name(eth_name)
    if not eth_num then
        return 0xff
    end

    return g_lldp_over_ncsi_status[eth_num][channel_id]
end

-- 华为OEM命令回调函数定义
local function hw_get_lldp_over_ncsi_status_rsp(rsp_packet, eth_name)
    -- 解析响应包
    local lldp_over_ncsi_rsp = get_lldp_over_ncsi_rsp_bs:unpack(rsp_packet.payload, true)

    -- 验证eth_name并获取eth_num
    local eth_num = validate_eth_name(eth_name)
    if not eth_num then
        return
    end

    -- 更新状态
    local channel_id = rsp_packet.packet_head.channel_id
    g_lldp_over_ncsi_status[eth_num][rsp_packet.packet_head.channel_id] = lldp_over_ncsi_rsp.status

    log:debug('%s : get_lldp_over_ncsi_rsp, eth_num[%u] channel_id[%u] status[%u]',
        'hw_get_lldp_over_ncsi_status_rsp', eth_num, channel_id, lldp_over_ncsi_rsp.status)
end

-- 创建LLDP over NCSI回调表
local function create_lldp_callback_table(eth_name)
    return {
        ncsi_oem_response.create_callback_entry(GET_LLDP_OVER_NCSI_CMD_ID, GET_LLDP_OVER_NCSI_STATUS,
            function(rsp) return hw_get_lldp_over_ncsi_status_rsp(rsp, eth_name) end)
    }
end

local function write_oem_command_req(req_packet, eth_name, sub_cmd, status)
    if sub_cmd == GET_LLDP_OVER_NCSI_STATUS then
        return write_get_lldp_status_req(req_packet, eth_name, sub_cmd,
            GET_LLDP_OVER_NCSI_CMD_ID, GET_LLDP_OVER_NCSI_PAYLOAD_LEN)
    end

    if sub_cmd == SET_LLDP_OVER_NCSI_STATUS then
        if not status then
            log:error('write oem command req: status parameter is required for SET_LLDP_OVER_NCSI_STATUS')
            return ncsi_def.NCSI_FAIL
        end
        return write_set_lldp_status_req(req_packet, eth_name, sub_cmd,
            SET_LLDP_OVER_NCSI_CMD_ID, SET_LLDP_OVER_NCSI_PAYLOAD_LEN, status)
    end

    log:error('write_oem_command_req: Unknown sub_cmd: %s', tostring(sub_cmd))
    return ncsi_def.NCSI_FAIL
end

local function read_oem_command_rsp(rsp, eth_name, sub_cmd)
    if sub_cmd == GET_LLDP_OVER_NCSI_STATUS then
        local callback_table = create_lldp_callback_table(eth_name)
        return ncsi_oem_response.read_oem_command_rsp(rsp, callback_table)
    end

    -- 默认处理，只检查响应码
    return ncsi_oem_response.read_oem_command_rsp(rsp, nil)
end

local oem_command_table = {
    [OEM_COMMAND] = write_oem_command_req,
    [OEM_COMMAND_RSP] = read_oem_command_rsp
}

-- 获取LLDP over NCSI状态
local function get_port_lldp_over_ncsi_status_hw(eth_name, package_id, channel_id)
    local req_packet = ncsi_packet.create_request_packet(package_id, channel_id, OEM_COMMAND)

    -- 先创建自定义写入命令表
    local custom_cmd_table = ncsi_utils.create_custom_cmd_table(
        oem_command_table, OEM_COMMAND, write_oem_command_req, GET_LLDP_OVER_NCSI_STATUS
    )

    -- 再创建自定义响应处理表
    custom_cmd_table = ncsi_utils.create_custom_rsp_table(
        custom_cmd_table, OEM_COMMAND_RSP, read_oem_command_rsp, eth_name, GET_LLDP_OVER_NCSI_STATUS
    )

    local ret = ncsi_utils.ncsi_cmd_ctrl(package_id, channel_id, req_packet, eth_name, custom_cmd_table)
    if ret ~= ncsi_def.NCSI_SUCCESS then
        log:info('Failed to get LLDP over NCSI status, ret = %d, eth_name[%s], channel_id[%u]',
            ret, eth_name, channel_id)
        return ncsi_def.NCSI_FAIL
    end

    -- 获取状态
    local status = get_lldp_over_ncsi_status_rsp(eth_name, channel_id)
    if status ~= DISABLE and status ~= ENABLE then
        log:error('Invalid LLDP Over NCSI status: %u', status)
        return ncsi_def.NCSI_FAIL
    end

    return ncsi_def.NCSI_SUCCESS, status
end

-- 设置LLDP over NCSI状态
local function set_port_lldp_over_ncsi_status_hw(eth_name, package_id, channel_id, status)
    local req_packet = ncsi_packet.create_request_packet(package_id, channel_id, OEM_COMMAND)

    -- 先创建自定义写入命令表
    local custom_cmd_table = ncsi_utils.create_custom_cmd_table(
        oem_command_table, OEM_COMMAND, write_oem_command_req, SET_LLDP_OVER_NCSI_STATUS, status
    )

    -- 再创建自定义响应处理表
    custom_cmd_table = ncsi_utils.create_custom_rsp_table(
        custom_cmd_table, OEM_COMMAND, read_oem_command_rsp, eth_name, SET_LLDP_OVER_NCSI_STATUS
    )

    -- 发送请求并重试
    local ret = ncsi_utils.ncsi_cmd_ctrl(package_id, channel_id, req_packet, eth_name, custom_cmd_table)
    if ret ~= ncsi_def.NCSI_SUCCESS then
        log:error('Set %s LLDP over NCSI status(%u) failed(%d)', eth_name, status, ret)
        return ncsi_def.NCSI_FAIL
    end

    return ncsi_def.NCSI_SUCCESS
end

-- 获取VLAN模式能力
local function ncsi_get_capability_vlan_mode(pkg_id, chan_id)
    if pkg_id >= NCSI_PACKAGE_MAX_ID or chan_id >= NCSI_CHANNEL_MAX_ID then
        return 0
    end
    local ncsi_para = ncsi_parameter.get_instance():get_ncsi_parameter()
    return ncsi_para.channel_cap[pkg_id][chan_id].vlan_mode
end

-- 设置非标签VLAN模式
local function __set_vlan_mode_for_non_tag(package_id, channel_id, eth_name)
    local vlan_mode = ncsi_get_capability_vlan_mode(package_id, channel_id)

    if (vlan_mode & VLAN_MODE_FIELD_VLAN_NON_VLAN) ~= 0 then
        log:debug('%s: use specific vlan and non vlan, vlan_mode=%u', '__set_vlan_mode_for_non_tag', vlan_mode)
        return ncsi_vlan_mode.ncsi_enable_vlan_req(package_id, channel_id, eth_name, VLAN_NON_VLAN)
    elseif (vlan_mode & VLAN_MODE_FIELD_ANYVLAN_NON_VLAN) ~= 0 then
        log:debug('%s: use any vlan and non vlan, vlan_mode=%u', '__set_vlan_mode_for_non_tag', vlan_mode)
        return ncsi_vlan_mode.ncsi_enable_vlan_req(package_id, channel_id, eth_name, ANYVLAN_NON_VLAN)
    end

    log:error('%s: non vlan mode fit, vlan_mode=%u', '__set_vlan_mode_for_non_tag', vlan_mode)
    return ncsi_def.NCSI_FAIL
end

-- 设置标准LLDP转发使能
local function __set_port_std_lldp_forward_enable(package_id, channel_id, eth_name)
    local lldp_packet_mac = {0x01, 0x80, 0xc2, 0x00, 0x00, 0x0e}

    -- 设置组播Mac地址过滤
    local ret = ncsi_mac.ncsi_set_multicast_mac_filter(package_id, channel_id, eth_name, lldp_packet_mac,
        MAC_FILTER_ENABLE)
    if ret ~= ncsi_def.NCSI_SUCCESS then
        log:error('%s: ncsi_set_multicast_mac_filter failed, ret=%d', '__set_port_std_lldp_forward_enable', ret)
        return ret
    end

    -- 设置Vlan Mode，允许Non-Tag报文转发
    ret = __set_vlan_mode_for_non_tag(package_id, channel_id, eth_name)
    if ret ~= ncsi_def.NCSI_SUCCESS then
        log:error('%s: __set_vlan_mode_for_non_tag failed, ret=%d', '__set_port_std_lldp_forward_enable', ret)
        return ret
    end

    return ret
end

-- 设置标准LLDP转发禁用
local function __set_port_std_lldp_forward_disable(package_id, channel_id, eth_name)
    local lldp_packet_mac = {0x01, 0x80, 0xc2, 0x00, 0x00, 0x0e}

    -- 去除组播Mac地址过滤
    local ret = ncsi_mac.ncsi_set_multicast_mac_filter(package_id, channel_id, eth_name, lldp_packet_mac,
        MAC_FILTER_DISABLE)
    if ret ~= ncsi_def.NCSI_SUCCESS then
        log:error('%s: ncsi_set_multicast_mac_filter failed, ret = %d', '__set_port_std_lldp_forward_disable', ret)
    end
    return ret
end

-- 使用标准NCSI命令设置LLDP over NCSI状态
local function __set_port_lldp_over_ncsi_capability_std(eth_name, package_id, channel_id, status)
    if status == ENABLE then
        return __set_port_std_lldp_forward_enable(package_id, channel_id, eth_name)
    elseif status == DISABLE then
        return __set_port_std_lldp_forward_disable(package_id, channel_id, eth_name)
    end

    return ncsi_def.NCSI_FAIL
end

function ncsi_oem_lldp.update_lldp_over_ncsi_status(eth_name, package_id, channel_id, status)
    -- 初始化lldp
    init_lldp_over_ncsi_status()

    -- 获取status
    local ret, lldp_over_ncsi_hw = get_port_lldp_over_ncsi_status_hw(eth_name, 0, channel_id)
    if ret ~= ncsi_def.NCSI_SUCCESS then
        -- 如果硬件特定命令失败，尝试使用标准NCSI命令
        return __set_port_lldp_over_ncsi_capability_std(eth_name, package_id, channel_id, status)
    end

    -- 如果状态已经是期望值，直接返回成功
    if lldp_over_ncsi_hw == status then
        return ncsi_def.NCSI_SUCCESS
    end

    -- 设置新状态
    ret = set_port_lldp_over_ncsi_status_hw(eth_name, 0, channel_id, status)
    if ret ~= ncsi_def.NCSI_SUCCESS then
        log:error('Set %s LLDP over NCSI status(%u) failed(%d)', eth_name, status, ret)
        return ncsi_def.NCSI_FAIL
    end

    -- 验证状态是否设置成功
    ret, lldp_over_ncsi_hw = get_port_lldp_over_ncsi_status_hw(eth_name, 0, channel_id)
    if ret ~= ncsi_def.NCSI_SUCCESS or lldp_over_ncsi_hw ~= status then
        return ncsi_def.NCSI_FAIL
    end

    return ncsi_def.NCSI_SUCCESS
end

return ncsi_oem_lldp