-- Copyright (c) 2024 Huawei Technologies Co., Ltd.
-- openUBMC is licensed under Mulan PSL v2.
-- You can use this software according to the terms and conditions of the Mulan PSL v2.
-- You may obtain a copy of Mulan PSL v2 at: http://license.coscl.org.cn/MulanPSL2
-- THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
-- EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
-- MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
-- See the Mulan PSL v2 for more details.

local bs = require 'mc.bitstring'
local log = require 'mc.logging'
local core = require 'network.core'
local ncsi_protocol_intf = require 'ncsi_protocol_intf'
local ncsi_utils = require 'ncsi.ncsi_protocol.ncsi_utils'
local ncsi_def = require 'ncsi.ncsi_protocol.ncsi_def'
local ncsi_packet = require 'ncsi.ncsi_protocol.ncsi_packet'

local ncsi_vlan_mode = {}

-- 命令类型常量定义
local ENABLE_VLAN_REQ = 0x0C
local ENABLE_VLAN_RSP = 0x8C
local DISABLE_VLAN_REQ = 0x0D
local DISABLE_VLAN_RSP = 0x8D

-- VLAN模式常量
local VLAN_ONLY = 0x01        -- 只接受匹配设置的VLAN标记包
local VLAN_NON_VLAN = 0x02    -- 接受匹配设置的VLAN标记包和非VLAN包
local ANYVLAN_NON_VLAN = 0x03 -- 接受任何VLAN标记包（如果匹配MAC过滤）和非VLAN包

-- 导出VLAN模式常量，供外部模块使用
ncsi_vlan_mode.VLAN_ONLY = VLAN_ONLY
ncsi_vlan_mode.VLAN_NON_VLAN = VLAN_NON_VLAN
ncsi_vlan_mode.ANYVLAN_NON_VLAN = ANYVLAN_NON_VLAN

-- 请求长度
local ENABLE_VLAN_REQ_LEN = 4
local DISABLE_VLAN_REQ_LEN = 0

-- 填充长度
local ENABLE_VLAN_REQ_PAD_LEN = 22
local DISABLE_VLAN_REQ_PAD_LEN = 26

-- 报文大小
local ENABLE_VLAN_PKT_SIZE = ENABLE_VLAN_REQ_LEN + 4 + ENABLE_VLAN_REQ_PAD_LEN + 4 -- payload + checksum + pad + FCS
local DISABLE_VLAN_PKT_SIZE = DISABLE_VLAN_REQ_LEN + 4 + DISABLE_VLAN_REQ_PAD_LEN + 4 -- checksum + pad + FCS

-- 启用VLAN请求报文结构
local enable_vlan_req_bs = bs.new([[<<
    reserved:24,
    vlan_mode:8,
    check_sum:32,
    pad_data:22/string,
    fcs:32
>>]])

-- 禁用VLAN请求报文结构
local disable_vlan_req_bs = bs.new([[<<
    check_sum:32,
    pad_data:26/string,
    fcs:32
>>]])

-- 填充启用VLAN请求的payload
local function fill_enable_vlan_req_payload(req_packet, vlan_mode)
    -- 定义公共的payload数据结构
    local payload_data = {
        reserved = 0,
        vlan_mode = vlan_mode,
        check_sum = 0,
        pad_data = string.rep('\0', ENABLE_VLAN_REQ_PAD_LEN),
        fcs = 0
    }

    -- 首先创建一个check_sum为0的初始payload
    req_packet.payload = enable_vlan_req_bs:pack(payload_data)

    -- 计算校验和和CRC32
    local check_sum = ncsi_utils.get_checksum(req_packet, ncsi_def.PACKET_HEAD_LEN + ENABLE_VLAN_REQ_LEN)
    local crc32 = ncsi_utils.get_crc32(req_packet, ncsi_def.PACKET_HEAD_LEN + ENABLE_VLAN_PKT_SIZE - 4)

    -- 更新checksum和FCS字段
    payload_data.check_sum = core.htonl(check_sum)
    payload_data.fcs = core.htonl(crc32)

    -- 返回最终的payload
    return enable_vlan_req_bs:pack(payload_data)
end

-- 写入启用VLAN请求
local function write_enable_vlan_req(req_packet, eth_name, vlan_mode)
    ncsi_utils.ncsi_cmd_common_config(req_packet)
    req_packet.packet_head.payload_len_hi = (ENABLE_VLAN_REQ_LEN >> 8) & 0x0f
    req_packet.packet_head.payload_len_lo = ENABLE_VLAN_REQ_LEN & 0xff
    req_packet.payload = fill_enable_vlan_req_payload(req_packet, vlan_mode)

    local req_data = ncsi_utils.ncsi_packet_bs:pack(req_packet)
    return ncsi_protocol_intf.send_ncsi_cmd(req_data,
        ENABLE_VLAN_PKT_SIZE + ncsi_def.ETHERNET_HEAD_LEN + ncsi_def.PACKET_HEAD_LEN, eth_name)
end

-- 填充禁用VLAN请求的payload
local function fill_disable_vlan_req_payload(req_packet)
    local check_sum = ncsi_utils.get_checksum(req_packet, ncsi_def.PACKET_HEAD_LEN + DISABLE_VLAN_REQ_LEN)
    local crc32 = ncsi_utils.get_crc32(req_packet, ncsi_def.PACKET_HEAD_LEN + DISABLE_VLAN_PKT_SIZE - 4)

    local payload_data = disable_vlan_req_bs:pack({
        check_sum = core.htonl(check_sum),
        pad_data = string.rep('\0', DISABLE_VLAN_REQ_PAD_LEN),
        fcs = core.htonl(crc32)
    })

    return payload_data
end

-- 写入禁用VLAN请求
local function write_disable_vlan_req(req_packet, eth_name)
    ncsi_utils.ncsi_cmd_common_config(req_packet)
    req_packet.packet_head.payload_len_hi = (DISABLE_VLAN_REQ_LEN >> 8) & 0x0f
    req_packet.packet_head.payload_len_lo = DISABLE_VLAN_REQ_LEN & 0xff
    req_packet.payload = fill_disable_vlan_req_payload(req_packet)

    local req_data = ncsi_utils.ncsi_packet_bs:pack(req_packet)
    return ncsi_protocol_intf.send_ncsi_cmd(req_data,
        DISABLE_VLAN_PKT_SIZE + ncsi_def.ETHERNET_HEAD_LEN + ncsi_def.PACKET_HEAD_LEN, eth_name)
end

-- 读取启用VLAN响应
local function read_enable_vlan_rsp(rsp)
    return ncsi_packet.read_common_rsp(rsp, 'enable vlan')
end

-- 读取禁用VLAN响应
local function read_disable_vlan_rsp(rsp)
    return ncsi_packet.read_common_rsp(rsp, 'disable vlan')
end

-- 命令处理表
local vlan_mode_table = {
    [ENABLE_VLAN_REQ] = write_enable_vlan_req,
    [ENABLE_VLAN_RSP] = read_enable_vlan_rsp,
    [DISABLE_VLAN_REQ] = write_disable_vlan_req,
    [DISABLE_VLAN_RSP] = read_disable_vlan_rsp
}

-- 验证VLAN模式
local function validate_vlan_mode(vlan_mode)
    if vlan_mode ~= VLAN_ONLY and vlan_mode ~= VLAN_NON_VLAN and vlan_mode ~= ANYVLAN_NON_VLAN then
        log:error('Invalid VLAN mode: %s. Must be one of: VLAN_ONLY(1), VLAN_NON_VLAN(2), ANYVLAN_NON_VLAN(3)',
            vlan_mode)
        return false
    end
    return true
end

-- 启用VLAN
function ncsi_vlan_mode.ncsi_enable_vlan_req(package_id, channel_id, eth_name, vlan_mode)
    -- 验证VLAN模式
    if not validate_vlan_mode(vlan_mode) then
        return ncsi_def.NCSI_FAIL
    end

    local req_packet = ncsi_packet.create_request_packet(package_id, channel_id, ENABLE_VLAN_REQ)

    local custom_cmd_table = ncsi_utils.create_custom_cmd_table(
        vlan_mode_table, ENABLE_VLAN_REQ, write_enable_vlan_req, vlan_mode
    )

    local ret = ncsi_utils.ncsi_cmd_ctrl(package_id, channel_id, req_packet, eth_name, custom_cmd_table)
    if ret ~= ncsi_def.NCSI_SUCCESS then
        log:error('NCSI enable VLAN failed, package_id = %s, channel_id = %s, eth_name = %s, vlan_mode = %s',
            package_id, channel_id, eth_name, vlan_mode)
    end

    return ret
end

-- 禁用VLAN
function ncsi_vlan_mode.ncsi_disable_vlan_req(package_id, channel_id, eth_name)
    local req_packet = ncsi_packet.create_request_packet(package_id, channel_id, DISABLE_VLAN_REQ)
    local ret = ncsi_utils.ncsi_cmd_ctrl(package_id, channel_id, req_packet, eth_name, vlan_mode_table)
    if ret ~= ncsi_def.NCSI_SUCCESS then
        log:error('NCSI disable VLAN failed, package_id = %s, channel_id = %s, eth_name = %s',
            package_id, channel_id, eth_name)
    end

    return ret
end

return ncsi_vlan_mode