-- Copyright (c) 2024 Huawei Technologies Co., Ltd.
-- openUBMC is licensed under Mulan PSL v2.
-- You can use this software according to the terms and conditions of the Mulan PSL v2.
-- You may obtain a copy of Mulan PSL v2 at: http://license.coscl.org.cn/MulanPSL2
-- THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
-- EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
-- MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
-- See the Mulan PSL v2 for more details.

local bs = require 'mc.bitstring'
local log = require 'mc.logging'
local core = require 'network.core'
local ncsi_protocol_intf = require 'ncsi_protocol_intf'
local ncsi_utils = require 'ncsi.ncsi_protocol.ncsi_utils'
local ncsi_def = require 'ncsi.ncsi_protocol.ncsi_def'
local ncsi_packet = require 'ncsi.ncsi_protocol.ncsi_packet'
local ncsi_parameter = require 'ncsi.ncsi_protocol.ncsi_parameter'

local ncsi_mac = {}

-- Command codes
local SET_MAC_ADDRESS = 0x0E
local SET_MAC_ADDRESS_RSP = 0x8E

-- Packet sizes and lengths
local MAC_ADDR_PKT_SIZE = 34
local MAC_ADDR_REQ_LEN = 8
local NCSI_MAC_LEN = 6
local UNICAST_MAC_ADDRESS = 0
local MAC_FILTER_ENABLE = 1
local MULTICAST_MAC_ADDRESS = 1
local MAX_MAC_STR_LEN = 32

local mac_addr_req_bs = bs.new([[<<
    mac_filter:6/string,
    mac_number:8,
    mac_enable:1,
    reserved:4,
    address_type:3,
    check_sum:32,
    data:18/string,
    fcs:32
>>]])

local function fill_mac_addr_payload(req_packet, pmac_addr)
    -- 根据地址类型选择MAC地址, 对于单播地址，使用系统源地址
    local mac_filter = ncsi_utils.src_addr
    if pmac_addr.addr_type == MULTICAST_MAC_ADDRESS then
        -- 对于组播地址，使用传入的MAC地址
        mac_filter = string.char(
            pmac_addr.mac_addr[1] or 0,
            pmac_addr.mac_addr[2] or 0,
            pmac_addr.mac_addr[3] or 0,
            pmac_addr.mac_addr[4] or 0,
            pmac_addr.mac_addr[5] or 0,
            pmac_addr.mac_addr[6] or 0
        )
    end

    -- 定义公共的payload数据结构
    local payload_data = {
        mac_filter = mac_filter,
        mac_number = pmac_addr.mac_addr_num,
        mac_enable = pmac_addr.mac_enable,
        reserved = 0,
        address_type = pmac_addr.addr_type,
        check_sum = 0, -- 初始值
        data = string.rep('\0', 18),
        fcs = 0        -- 初始值
    }

    -- 首先创建一个check_sum为0的初始payload
    req_packet.payload = mac_addr_req_bs:pack(payload_data)

    -- 计算校验和和CRC32
    local check_sum = ncsi_utils.get_checksum(req_packet, ncsi_def.PACKET_HEAD_LEN + MAC_ADDR_REQ_LEN)
    local crc32 = ncsi_utils.get_crc32(req_packet, ncsi_def.PACKET_HEAD_LEN + MAC_ADDR_PKT_SIZE - 4)

    -- 更新checksum和FCS字段
    payload_data.check_sum = core.htonl(check_sum)
    payload_data.fcs = core.htonl(crc32)

    -- 返回最终的payload
    return mac_addr_req_bs:pack(payload_data)
end

local function write_mac_addr_req(req_packet, eth_name, pmac_addr)
    ncsi_utils.ncsi_cmd_common_config(req_packet)
    req_packet.packet_head.payload_len_hi = (MAC_ADDR_REQ_LEN >> 8) & 0x0f
    req_packet.packet_head.payload_len_lo = MAC_ADDR_REQ_LEN & 0xff
    req_packet.payload = fill_mac_addr_payload(req_packet, pmac_addr)

    local req_data = ncsi_utils.ncsi_packet_bs:pack(req_packet)
    return ncsi_protocol_intf.send_ncsi_cmd(req_data,
        MAC_ADDR_PKT_SIZE + ncsi_def.ETHERNET_HEAD_LEN + ncsi_def.PACKET_HEAD_LEN, eth_name)
end

local function read_mac_addr_rsp(rsp)
    return ncsi_packet.read_common_rsp(rsp, 'set mac addr')
end

-- Command table
local mac_addr_table = {
    [SET_MAC_ADDRESS] = write_mac_addr_req,
    [SET_MAC_ADDRESS_RSP] = read_mac_addr_rsp
}

function ncsi_mac.ncsi_set_mac_addr(eth_name, package_id, channel_id, pmac_addr)
    local req_packet = ncsi_packet.create_request_packet(package_id, channel_id, SET_MAC_ADDRESS)

    local custom_cmd_table = ncsi_utils.create_custom_cmd_table(
        mac_addr_table, SET_MAC_ADDRESS, write_mac_addr_req, pmac_addr
    )

    local ret = ncsi_utils.ncsi_cmd_ctrl(package_id, channel_id, req_packet, eth_name, custom_cmd_table)
    if ret ~= ncsi_def.NCSI_SUCCESS then
        log:error('ncsi cmd ctrl set mac address failed, package_id = %s, channel_id = %s, eth_name = %s',
            package_id, channel_id, eth_name)
    end
    return ret
end

function ncsi_mac.ncsi_set_phy_mac_filter(pkg_id, chan_id, eth_name, mac_num)
    local ncsi_para = ncsi_parameter.get_instance():get_ncsi_parameter()
    local max_mac_num = math.max(
        ncsi_para.channel_cap[pkg_id][chan_id].unicast_filter_cnt,
        ncsi_para.channel_cap[pkg_id][chan_id].mix_filter_cnt
    )

    if max_mac_num < mac_num then
        log:error("mac_num %d invalid, max %d", mac_num, max_mac_num)
        return ncsi_def.NCSI_FAIL
    end

    local mac_addr_filter_param = {
        mac_addr_len = NCSI_MAC_LEN,
        addr_type = UNICAST_MAC_ADDRESS,
        mac_enable = MAC_FILTER_ENABLE,
        mac_addr_num = mac_num
    }

    local ret = ncsi_mac.ncsi_set_mac_addr(eth_name, pkg_id, chan_id, mac_addr_filter_param)
    if ret ~= ncsi_def.NCSI_SUCCESS then
        log:error("set mac addr filter failed")
        return ncsi_def.NCSI_FAIL
    end

    return ncsi_def.NCSI_SUCCESS
end

-- 将整数MAC地址转换为字符串格式
local function arith_intmac_to_strmac(int_mac)
    if not int_mac then
        return nil
    end

    -- 格式化MAC地址字符串
    local mac_str = string.format("%02x:%02x:%02x:%02x:%02x:%02x",
        int_mac[1], int_mac[2], int_mac[3],
        int_mac[4], int_mac[5], int_mac[6])

    if #mac_str > MAX_MAC_STR_LEN then
        return nil
    end

    return mac_str
end

-- 设置组播MAC地址过滤
function ncsi_mac.ncsi_set_multicast_mac_filter(package_id, channel_id, eth_name, mac_addr, mac_enable)
    local ncsi_para = ncsi_parameter.get_instance():get_ncsi_parameter()
    local mac_num = math.max(
        ncsi_para.channel_cap[package_id][channel_id].mul_filter_cnt,
        ncsi_para.channel_cap[package_id][channel_id].mix_filter_cnt
    )

    local mac_addr_filter_param = {
        mac_addr_len = NCSI_MAC_LEN,
        addr_type = MULTICAST_MAC_ADDRESS,
        mac_enable = mac_enable,
        mac_addr_num = mac_num,
        mac_addr = {}
    }

     -- 复制MAC地址
     for i = 1, NCSI_MAC_LEN do
        mac_addr_filter_param.mac_addr[i] = mac_addr[i]
    end

    -- 转换MAC地址为字符串格式用于日志
    local mac_addr_str = arith_intmac_to_strmac(mac_addr)
    if not mac_addr_str then
        log:error('%s: convert mac address to string failed', 'ncsi_set_multicast_mac_filter')
        return ncsi_def.NCSI_FAIL
    end

    -- 设置MAC地址过滤
    local ret = ncsi_mac.ncsi_set_mac_addr(eth_name, package_id, channel_id, mac_addr_filter_param)
    if ret ~= ncsi_def.NCSI_SUCCESS then
        log:error('%s: set mac addr filter failed', 'ncsi_set_multicast_mac_filter')
        return ncsi_def.NCSI_FAIL
    end

    return ncsi_def.NCSI_SUCCESS
end

return ncsi_mac