-- Copyright (c) 2025 Huawei Technologies Co., Ltd.
-- openUBMC is licensed under Mulan PSL v2.
-- You can use this software according to the terms and conditions of the Mulan PSL v2.
-- You may obtain a copy of Mulan PSL v2 at:
--         http://license.coscl.org.cn/MulanPSL2
-- THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
-- EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
-- MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
-- See the Mulan PSL v2 for more details.

local skynet = require 'skynet'
local class = require 'mc.class'
local log = require 'mc.logging'
local observability_client = require 'observability.client'
local utils = require 'mc.utils'
local mc_context = require 'mc.context'
local utils_core = require 'utils.core'
local file_sec = require 'utils.file'
local vos = require 'utils.vos'
local defs = require 'public.defs'
local certificate_service_type = require 'observability.json_types.CertificateService'

local m = class()

-- 固定配置
local g_fixed_config = {
    'service:',
    '  flush: 1',
    '  log_level: error',
    '',
    'pipeline:',
    '  inputs:',
    '    - name: opentelemetry',
    '      listen: ' .. defs.FLUENT_BIT_LISTEN_ADDRESS,
    '      port: ' .. defs.FLUENT_BIT_LISTEN_PORT,
    '      buffer_max_size: 1MB',
    '      buffer_chunk_size: 256KB',
    '      mem_buf_limit: 2MB',
    '',
    '  outputs:',
    '    - name: opentelemetry',
    '      match: \"*\"',
    '      metrics_uri: /v1/metrics',
    '      logs_uri: /v1/logs',
    '      traces_uri: /v1/traces',
    '      log_response_payload: true',
    '      http2: off',
    '      tls: on',
    '      tls.verify: on\n',
}

function m:ctor(db)
    self.db = db
    self.key_text = nil
end

function m:write_config_to_file(config_name, new_config_str, config_path)
    if config_name == nil or new_config_str == nil or config_path == nil then
        log:error('Invalid parameter')
        return false
    end

    local res = file_sec.check_realpath_before_open_s(config_path)
    if res ~= 0 then
        log:error('Check %s realpath failed', config_name)
        return false
    end

    local config_file = file_sec.open_s(config_path, 'a+')
    if config_file == nil then
        log:error('Open %s failed', config_name)
        return false
    end

    local old_config_str = config_file:read('a')
    config_file:close()

    --对比配置是否变化
    if new_config_str == old_config_str then
        log:info('No change in %s', config_name)
        return true
    end

    local tmp_config_path = defs.TMP_OBSERVABILITY_PATH .. config_name
    res = file_sec.check_realpath_before_open_s(tmp_config_path)
    if res ~= 0 then
        log:error('Check tmp %s realpath failed', config_name)
        return false
    end

    local tmp_config_file = file_sec.open_s(tmp_config_path, 'w+')
    if tmp_config_file == nil then
        log:error('Open tmp %s failed', config_name)
        return false
    end
    tmp_config_file:write(new_config_str)
    tmp_config_file:close()

    utils.remove_file(config_path)
    file_sec.move_file_s(tmp_config_path, config_path)
    log:info("Generate %s successfully", config_name)

    return true
end

-- 生成fluent-bit配置文件
function m:generate()
    local per_config = m.get_persistence_config(self, 0)
    if per_config == nil then
        return false
    end

    local fixed_config = table.concat(g_fixed_config, '\n')
    local config_path = defs.FLUENT_BIT_CFG_PATH

    local new_config_str = fixed_config .. per_config
    local ok = m:write_config_to_file('fluent-bit.yaml', new_config_str, config_path)
    if not ok then
        log:error('Write config to file failed')
        return false
    end

    utils_core.chmod_s(config_path, (utils.S_IRUSR | utils.S_IRGRP))
    utils_core.chown_s(config_path, defs.COMM_USER_UID, defs.COMM_USER_GID)
    return true
end

-- 获取BMC证书路径
local function get_server_cert_path(cert_objs)
    if next(cert_objs) == nil or cert_objs[defs.HTTPS_CERTIFICATES_PATH] == nil then
        log:error('Get SSL certificate objects failed')
        return nil
    end

    return cert_objs[defs.HTTPS_CERTIFICATES_PATH].FilePath
end

local function export_ssl_key()
    local type = certificate_service_type.CertificateUsageType.ManagerSSLCertificate
    local res
    local max_retry = 5 -- 获取失败后重新尝试五次
    local ctx = mc_context.get_context_or_default()

    local ok, err = pcall(function ()
        res = observability_client:CertificateServiceCertificateServiceExportCertKeyByFIFO(ctx, type)
    end)

    while (not ok or res == nil or #res.FilePath == 0) and max_retry > 0 do
        log:notice("Export ssl key failed, will try again, err:%s", err)
        max_retry = max_retry + 1
        skynet.sleep(500) -- 5s重试一次
        ok, err = pcall(function ()
            res = observability_client:CertificateServiceCertificateServiceExportCertKeyByFIFO(ctx, type)
        end)
    end

    if not ok or res == nil or #res.FilePath == 0 then
        log:error("Export ssl key failed, err:%s", err)
        return nil
    end

    local retry = 0
    while true do
        skynet.sleep(10) -- 导出key到fifo是异步流程，等待100ms，否则open操作概率性阻塞卡住
        if vos.get_file_accessible(res.FilePath) then
            break
        end
        retry = retry + 1
        if retry >= 10 then
            log:error("Export ssl key timed out")
            return nil
        end
    end

    if file_sec.check_realpath_before_open_s(res.FilePath) ~= 0 then
        log:error('Check SSL key realpath failed')
        return nil
    end

    local key_file = file_sec.open_s(res.FilePath, "r")
    if key_file == nil then
        log:error('Open SSL fifo file failed')
        return nil
    end

    local key_text = key_file:read("a")
    key_file:close()
    utils.remove_file(res.FilePath)
    return key_text
end

local function generate_ssl_key_fifo(_self)
    _self.key_text = export_ssl_key()

    if _self.key_text == nil or #_self.key_text == 0 then
        log:error('SSL key is invalid')
        _self.key_text = nil
        return false
    end

    utils.remove_file(defs.SSL_KEY_FILE_PATH)
    if file_sec.check_shell_special_character_s(defs.SSL_KEY_FILE_PATH) ~= 0 then
        log:error('Create fifo file failed, file path is invalid')
        return false
    end

    local res = vos.system_s('/bin/sh', '-c', 'mkfifo -m 0600 ' .. defs.SSL_KEY_FILE_PATH)
    if res ~= 0 then
        log:error('mkfifo SSL key failed')
        return false
    end
    utils_core.chown_s(defs.SSL_KEY_FILE_PATH, defs.COMM_USER_UID, defs.COMM_USER_GID)

    return true
end

local function add_tls_config(_self, cfg, obj)
    -- TLS单向和双向认证配置
    if not utils_core.is_file(defs.TLS_CA_FILE_PATH) then
        return false
    else
        local flie_stat = utils_core.stat(defs.TLS_CA_FILE_PATH)
        if flie_stat.st_size <= 0 then
            return false
        end
    end
    cfg[#cfg + 1] = string.format('      tls.ca_file: %s\n', defs.TLS_CA_FILE_PATH)

    -- 双向认证需要加载客户端证书
    if obj.TLSMode == "mTLS" then
        -- 生成私钥文件
        local ok, ret = pcall(function ()
            return generate_ssl_key_fifo(_self)
        end)
        if not ok or not ret then
            log:error("Generate ssl key file failed, ret:%s", ret)
            return false
        end

        -- 获取客户端(BMC)证书路径
        local cert_objs = observability_client:GetCertificateObjects()
        local cert_path = get_server_cert_path(cert_objs)
        if cert_path == nil then
            log:error("Get client certificate path failed")
            return false
        end

        cfg[#cfg + 1] = string.format('      tls.crt_file: %s\n', cert_path)
        cfg[#cfg + 1] = string.format('      tls.key_file: %s\n', defs.SSL_KEY_FILE_PATH)
    end

    return true
end

function m:write_data_to_fifo()
    local obj = self.db:select(self.db.ObservabilityService):first()
    if obj == nil then
        log:error("Get observability config failed")
        return false
    end
    if obj.TLSMode ~= 'mTLS' then
        return true
    end

    local fifo_file = file_sec.open_s(defs.SSL_KEY_FILE_PATH, 'w')
    if fifo_file == nil then
        log:error('Open SSL key file failed')
        return false
    end
    if self.key_text then
        fifo_file:write(self.key_text)
    end
    fifo_file:close()

    return true
end

local function get_board_sn()
    local board_objs = observability_client:GetBoardObjects()
    local overview_objs = observability_client:GetOverviewObjects()
    if board_objs == nil or overview_objs == nil then
        log:error("Get board serial number failed")
        return ''
    end

    for path, obj in pairs(board_objs) do
        if overview_objs[path].FruId and overview_objs[path].FruId == defs.FRU_ID then
            return obj.BoardSerialNumber or ''
        end
    end

    log:error('Get board serial number failed')
    return ''
end

local function get_asset_tag()
    local product_objs = observability_client:GetProductObjects()
    local overview_objs = observability_client:GetOverviewObjects()
    if product_objs == nil or overview_objs == nil then
        log:error("Get asset tag failed")
        return ''
    end

    for path, obj in pairs(product_objs) do
        if overview_objs[path].FruId and overview_objs[path].FruId == defs.FRU_ID then
            return obj.AssetTag or ''
        end
    end

    log:error("Get asset tag failed")
    return ''
end

local function get_host_name()
    local assembly_objs = observability_client:GetAssemblyObjects()
    if assembly_objs == nil then
        log:error("Get host name failed")
        return ''
    end

    local host_name = assembly_objs[defs.ASSEMBLY_PATH].HostName

    if host_name == nil then
        log:error("Get host name failed")
        return ''
    end

    return host_name
end

local function add_identity_config(obj)
    local get_identity_func = {
        ['BoardSN'] = get_board_sn,
        ['ProductAssetTag'] = get_asset_tag,
        ['HostName'] = get_host_name
    }

    local identity = ''
    if get_identity_func[obj.ServerIdentity] then
        identity = get_identity_func[obj.ServerIdentity]()
    end
    if identity == '' then
        identity = '\"\"'
    end

    local config = {
        '      add_label: ServerIdentity ' .. identity,
        '      processors:',
        '        traces:',
        '          - name: content_modifier',
        '            context: span_attributes',
        '            action: insert',
        '            key: ServerIdentity',
        '            value: ' .. identity,
        '        logs:',
        '          - name: content_modifier',
        '            context: otel_resource_attributes',
        '            action: insert',
        '            key: ServerIdentity',
        '            value: ' .. identity,
        '\n'
    }

    return table.concat(config, '\n')
end

-- 获取持久化数据生成配置
function m:get_persistence_config(receiver_id)
    local observability_obj = self.db:select(self.db.ObservabilityService):first()
    local receiver_obj = self.db:select(self.db.Receivers):where(self.db.Receivers.ReceiverId:eq(receiver_id)):first()
    if observability_obj == nil or receiver_obj == nil then
        log:error("Get observability config failed")
        return nil
    end

    local config = {}
    config[#config + 1] = string.format('      host: %s\n', receiver_obj.Address)
    config[#config + 1] = string.format('      port: %s\n', receiver_obj.Port)

    local ok = add_tls_config(self, config, observability_obj)
    if not ok then
        log:error("Add tls config failed")
        return nil
    end

    local identity_config = add_identity_config(observability_obj)

    config[#config + 1] = '\n'
    return table.concat(config) .. identity_config
end

return m