-- Copyright (c) 2025 Huawei Technologies Co., Ltd.
-- openUBMC is licensed under Mulan PSL v2.
-- You can use this software according to the terms and conditions of the Mulan PSL v2.
-- You may obtain a copy of Mulan PSL v2 at:
--         http://license.coscl.org.cn/MulanPSL2
-- THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
-- EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
-- MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
-- See the Mulan PSL v2 for more details.

local log = require 'mc.logging'
local service = require 'observability.service'
local client = require 'observability.client'
local c_initiator = require 'mc.initiator'
local context = require 'mc.context'
local utils = require 'mc.utils'
local vos = require 'utils.vos'
local base_messages = require 'messages.base'

local module  = {}

module.initiator = c_initiator.new('N/A', 'N/A', 'N/A')
local MODULE_NAME = 'observability'
local TOTAL_MEM_THRESHOLD = 1.8 * 1000 * 1024 -- 系统总内存门限值设置为1800MB
local SUPPORTED_METRICS = {
    "bmc.hw.chip.io",
    "bmc.hw.chip.accessor",
    "bmc.hw.chip.scanner",
    "bmc.mc.flash.io",
    "bmc.rpc.client.request",
    "bmc.persistence.server.request",
    "bmc.persistence.flash.io",
    "bmc.broadcast.producer.send",
    "bmc.system.uptime",
    "bmc.system.cpu.usage",
    "bmc.system.memory.usage",
    "bmc.process.cpu.usage",
    "bmc.process.memory.usage",
    "bmc.system.flash.io",
    "bmc.system.flash.lifeleft"
}
local SUPPORTED_SAMPLING_RATE = {0.1, 0.05, 0.01, 0.005, 0.001, 0}

local function get_supported_sampling_rate()
    return SUPPORTED_SAMPLING_RATE
end

local function get_supported_metrics()
    return SUPPORTED_METRICS
end

local function modify_supported_metrics()
    local ok, ret = pcall(vos.popen_s, string.format("cat /proc/meminfo | grep -i \"MemTotal\" | awk '{ print $2 }'"))
    if not ok or not ret then
        log:error("Get system mem info failed")
        return
    end

    if tonumber(ret) > TOTAL_MEM_THRESHOLD then
        return
    end

    -- 总内存小于1800MB的环境需要去掉bmc.rpc.client.request和bmc.broadcast.producer.send两个指标
    for i = #SUPPORTED_METRICS, 1, -1 do
        if SUPPORTED_METRICS[i] == "bmc.rpc.client.request" or SUPPORTED_METRICS[i] == "bmc.broadcast.producer.send" then
            table.remove(SUPPORTED_METRICS, i)
        end
    end
end

local function get_sampling_policy(data)
    if data == 1 then
        return "Fixed"
    elseif data == 2 then
        return "Adaptive"
    else
        return ""
    end
end

local function get_sampling_level(data)
    if data == 1 then
        return "System"
    elseif data == 2 then
        return "Component"
    elseif data == 3 then
        return "Function"
    else
        return ""
    end
end

local function get_activated_metrics(data)
    if type(data) ~= "table" then
        return ''
    end

    local valid_metrics_list = {}
    for i = 1, #data do
        if utils.array_contains(SUPPORTED_METRICS, data[i]) then
            valid_metrics_list[#valid_metrics_list + 1] = data[i]
        end
    end

    return table.concat(valid_metrics_list, ',')
end

local function convert_string_to_boolean(data)
    if data ~= "true" and data ~= "false" then
        return
    end

    if data == "true" then
        return true
    elseif data == "false" then
        return false
    end
end

----------------   traces配置设置   -------------------------
local function set_sampling_rate(traces_obj, data)
    local sampling_rate = tonumber(data)
    if not sampling_rate or not utils.array_contains(SUPPORTED_SAMPLING_RATE, sampling_rate) then
        log:error("Set traces sampling rate to %s failed, parameter is out of range", data)
        log:operation((context.get_context() or context.new('N/A', 'N/A', '127.0.0.1')):get_initiator(),
            MODULE_NAME, "Set traces sampling rate to %s failed", data)
        error(base_messages.PropertyValueFormatError(data, "Traces/SamplingRate"))
    end

    if sampling_rate ~= traces_obj.SamplingRate then
        local ret, err = traces_obj:set_property("SamplingRate", sampling_rate)
        if ret ~= 0 then
            log:error("Set traces sampling rate to %s failed, err:%s", data, err)
            log:operation((context.get_context() or context.new('N/A', 'N/A', '127.0.0.1')):get_initiator(),
                MODULE_NAME, "Set traces sampling rate to %s failed", data)
            error(base_messages.InternalError())
        else
            log:operation((context.get_context() or context.new('N/A', 'N/A', '127.0.0.1')):get_initiator(),
                MODULE_NAME, "Set traces sampling rate to %s successfully", data)
        end
    end
end

local function set_sampling_policy(traces_obj, data)
    local policy
    if data == "Fixed" then
        policy = 1
    elseif data == "Adaptive" then
        policy = 2
    else
        log:error("Set traces sampling policy to %s failed", data)
        log:operation((context.get_context() or context.new('N/A', 'N/A', '127.0.0.1')):get_initiator(),
            MODULE_NAME, "Set traces sampling policy to %s failed", data)
        error(base_messages.PropertyValueNotInList(data, "Traces/SamplingPolicy"))
    end

    if policy ~= traces_obj.SamplingPolicy then
        local ret, err = traces_obj:set_property("SamplingPolicy", policy)
        if ret ~= 0 then
            log:error("Set traces sampling policy to %s failed, err:%s", data, err)
            log:operation((context.get_context() or context.new('N/A', 'N/A', '127.0.0.1')):get_initiator(),
                MODULE_NAME, "Set traces sampling policy to %s failed", data)
            error(base_messages.InternalError())
        else
            log:operation((context.get_context() or context.new('N/A', 'N/A', '127.0.0.1')):get_initiator(),
                MODULE_NAME, "Set traces sampling policy to %s successfully", data)
        end
    end
end

local function set_sampling_level(traces_obj, data)
    local level
    if data == "System" then
        level = 1
    elseif data == "Component" then
        level = 2
    elseif data == "Function" then
        level = 3
    else
        log:error("Set traces sampling level to %s failed", data)
        log:operation((context.get_context() or context.new('N/A', 'N/A', '127.0.0.1')):get_initiator(),
            MODULE_NAME, "Set traces sampling level to %s failed", data)
        error(base_messages.PropertyValueNotInList(data, "Traces/SamplingLevel"))
    end

    if level ~= traces_obj.SamplingLevel then
        local ret, err = traces_obj:set_property("SamplingLevel", level)
        if ret ~= 0 then
            log:error("Set traces sampling level to %s failed, err:%s", data, err)
            log:operation((context.get_context() or context.new('N/A', 'N/A', '127.0.0.1')):get_initiator(),
                MODULE_NAME, "Set traces sampling level to %s failed", data)
            error(base_messages.InternalError())
        else
            log:operation((context.get_context() or context.new('N/A', 'N/A', '127.0.0.1')):get_initiator(),
                MODULE_NAME, "Set traces sampling level to %s successfully", data)
        end
    end
end

local function set_traces_export_interval(traces_obj, data)
    local interval = tonumber(data)
    if not interval or interval < 10 or interval > 30 then
        log:error("Set traces export interval to %s failed", data)
        log:operation((context.get_context() or context.new('N/A', 'N/A', '127.0.0.1')):get_initiator(),
            MODULE_NAME, "Set traces export interval to %s failed", data)
        error(base_messages.PropertyValueFormatError(data, "Traces/ExportIntervalSeconds"))
    end

    if interval ~= traces_obj.ExportIntervalSeconds then
        local ret, err = traces_obj:set_property("ExportIntervalSeconds", interval)
        if ret ~= 0 then
            log:error("Set traces export interval to %s failed, err:%s", data, err)
            log:operation((context.get_context() or context.new('N/A', 'N/A', '127.0.0.1')):get_initiator(),
                MODULE_NAME, "Set traces export interval to %s failed", data)
            error(base_messages.InternalError())
        else
            log:operation((context.get_context() or context.new('N/A', 'N/A', '127.0.0.1')):get_initiator(),
                MODULE_NAME, "Set traces export interval to %s successfully", data)
        end
    end
end

----------------   metrics配置设置   -------------------------
local function set_metrics_export_interval(metrics_obj, data)
    local interval = tonumber(data)
    if not interval or interval < 30 or interval > 3600 then
        log:error("Set metrics export interval to %s failed", data)
        log:operation((context.get_context() or context.new('N/A', 'N/A', '127.0.0.1')):get_initiator(),
            MODULE_NAME, "Set metrics export interval to %s failed", data)
        error(base_messages.PropertyValueFormatError(data, "Metrics/ExportIntervalSeconds"))
        return
    end

    if interval ~= metrics_obj.ExportIntervalSeconds then
        local ret, err = metrics_obj:set_property("ExportIntervalSeconds", interval)
        if ret ~= 0 then
            log:error("Set metrics export interval to %s failed, err:%s", data, err)
            log:operation((context.get_context() or context.new('N/A', 'N/A', '127.0.0.1')):get_initiator(),
                MODULE_NAME, "Set metrics export interval to %s failed", data)
            error(base_messages.InternalError())
        else
            log:operation((context.get_context() or context.new('N/A', 'N/A', '127.0.0.1')):get_initiator(),
                MODULE_NAME, "Set metrics export interval to %s successfully", data)
        end
    end
end

local function set_activated_metrics(metrics_obj, data)
    local metrics = {}
    if #data > 0 then
        local activated_metrics = utils.split(data, ',')
        for i = 1, #activated_metrics do
            if not utils.array_contains(SUPPORTED_METRICS, activated_metrics[i]) then
                log:operation((context.get_context() or context.new('N/A', 'N/A', '127.0.0.1')):get_initiator(),
                    MODULE_NAME, "Set activated metrics to %s failed", data)
                error(base_messages.PropertyValueNotInList(activated_metrics[i],
                    "Metrics/ActivatedMetrics/" .. (i - 1)))
            end
            metrics[#metrics + 1] = activated_metrics[i]
        end
    end

    if not utils.table_compare(metrics, metrics_obj.ActivatedMetrics) then
        local ret, err = metrics_obj:set_property("ActivatedMetrics", metrics)
        if ret ~= 0 then
            log:error("Set activated metrics to %s failed, err:%s", data, err)
            log:operation((context.get_context() or context.new('N/A', 'N/A', '127.0.0.1')):get_initiator(),
                MODULE_NAME, "Set activated metrics to %s failed", data)
            error(base_messages.InternalError())
        else
            log:operation((context.get_context() or context.new('N/A', 'N/A', '127.0.0.1')):get_initiator(),
                MODULE_NAME, "Set activated metrics to %s successfully", data)
        end
    end
end

----------------   logs配置设置   -------------------------
local function set_logs_enabled(logs_obj, data)
    local enabled = convert_string_to_boolean(data)
    if enabled == nil then
        log:error("Set logs enabled to %s failed", data)
        log:operation((context.get_context() or context.new('N/A', 'N/A', '127.0.0.1')):get_initiator(),
            MODULE_NAME, "Set logs collection status failed")
        error(base_messages.PropertyValueFormatError(data, "Logs/Enabled"))
    end
    if enabled ~= logs_obj.Enabled then
        local ret, err = logs_obj:set_property("Enabled", enabled)
        if ret ~= 0 then
            log:error("%s logs collection failed, err:%s", enabled and 'Enable' or 'Disable', err)
            log:operation((context.get_context() or context.new('N/A', 'N/A', '127.0.0.1')):get_initiator(),
                MODULE_NAME, "%s logs collection failed", enabled and 'Enable' or 'Disable')
            error(base_messages.InternalError())
        else
            log:operation((context.get_context() or context.new('N/A', 'N/A', '127.0.0.1')):get_initiator(),
                MODULE_NAME, "%s logs collection successfully", enabled and 'Enable' or 'Disable')
        end
    end
end

local function set_logs_export_interval(logs_obj, data)
    local interval = tonumber(data)
    if not interval or interval < 10 or interval > 30 then
        log:error("Set logs export interval to %s failed", data)
        log:operation((context.get_context() or context.new('N/A', 'N/A', '127.0.0.1')):get_initiator(),
            MODULE_NAME, "Set logs export interval to %s failed", data)
        error(base_messages.PropertyValueFormatError(data, "Logs/ExportIntervalSeconds"))
    end

    if interval ~= logs_obj.ExportIntervalSeconds then
        local ret, err = logs_obj:set_property("ExportIntervalSeconds", interval)
        if ret ~= 0 then
            log:error("Set logs export interval to %s failed, err:%s", data, err)
            log:operation((context.get_context() or context.new('N/A', 'N/A', '127.0.0.1')):get_initiator(),
                MODULE_NAME, "Set logs export interval to %s failed", data)
        else
            log:operation((context.get_context() or context.new('N/A', 'N/A', '127.0.0.1')):get_initiator(),
                MODULE_NAME, "Set logs export interval to %s successfully", data)
        end
    end
    return interval
end

function module:get_traces_config()
    local traces_obj = client:GetDashboardObservabilityTracesObject()
    if not traces_obj then
        log:error("Get Dashboard traces obj failed")
        return
    end

    local config_data = {}
    config_data.SamplingRate = tostring(traces_obj.SamplingRate) or ''
    config_data.SamplingPolicy = get_sampling_policy(traces_obj.SamplingPolicy)
    config_data.SamplingLevel = get_sampling_level(traces_obj.SamplingLevel)
    config_data.ExportIntervalSeconds = tostring(traces_obj.ExportIntervalSeconds) or ''

    return config_data
end

function module:set_traces_config(config)
    local traces_obj = client:GetDashboardObservabilityTracesObject()
    if not traces_obj then
        log:error("Get Dashboard traces obj failed")
        log:operation((context.get_context() or context.new('N/A', 'N/A', '127.0.0.1')):get_initiator(),
            MODULE_NAME, "Set traces config failed")
        error(base_messages.InternalError())
    end

    if config.SamplingRate and config.SamplingRate ~= 'nil' then
        set_sampling_rate(traces_obj, config.SamplingRate)
    end

    if config.SamplingPolicy then
        set_sampling_policy(traces_obj, config.SamplingPolicy)
    end

    if config.SamplingLevel then
        set_sampling_level(traces_obj, config.SamplingLevel)
    end

    if config.ExportIntervalSeconds and config.ExportIntervalSeconds ~= 'nil' then
        set_traces_export_interval(traces_obj, config.ExportIntervalSeconds)
    end
end

function module:get_metrics_config()
    local metrics_obj = client:GetDashboardObservabilityMetricsObject()
    if not metrics_obj then
        log:error("Get Dashboard metrics obj failed")
        return
    end

    local config_data = {}
    config_data.ExportIntervalSeconds = tostring(metrics_obj.ExportIntervalSeconds) or ''
    config_data.ActivatedMetrics = get_activated_metrics(metrics_obj.ActivatedMetrics)

    return config_data
end

function module:set_metrics_config(config)
    local metrics_obj = client:GetDashboardObservabilityMetricsObject()
    if not metrics_obj then
        log:error("Get Dashboard metrics obj failed")
        log:operation((context.get_context() or context.new('N/A', 'N/A', '127.0.0.1')):get_initiator(),
            MODULE_NAME, "Set metrics config failed")
        error(base_messages.InternalError())
        return
    end

    if config.ExportIntervalSeconds and config.ExportIntervalSeconds ~= 'nil' then
        set_metrics_export_interval(metrics_obj, config.ExportIntervalSeconds)
    end

    if config.ActivatedMetrics and config.ActivatedMetrics ~= 'nil' then
        set_activated_metrics(metrics_obj, config.ActivatedMetrics)
    end
end

function module:get_logs_config()
    local logs_obj = client:GetDashboardObservabilityLogsObject()
    if not logs_obj then
        log:error("Get Dashboard logs obj failed")
        return
    end

    local config_data = {}
    config_data.Enabled = tostring(logs_obj.Enabled) or ''
    config_data.ExportIntervalSeconds = tostring(logs_obj.ExportIntervalSeconds) or ''

    return config_data
end

function module:set_logs_config(config)
    local logs_obj = client:GetDashboardObservabilityLogsObject()
    if not logs_obj then
        log:error("Get Dashboard logs obj failed")
        log:operation((context.get_context() or context.new('N/A', 'N/A', '127.0.0.1')):get_initiator(),
            MODULE_NAME, "Set logs collection config failed")
        error(base_messages.InternalError())
    end

    if config.Enabled and config.Enabled ~= 'nil' then
        set_logs_enabled(logs_obj, config.Enabled)
    end

    if config.ExportIntervalSeconds and config.ExportIntervalSeconds ~= 'nil' then
        set_logs_export_interval(logs_obj, config.ExportIntervalSeconds)
    end
end

function module:register_observability_policy_service(db, bus)
    self.db = db
    self.bus = bus

    local traces_service = service:CreateTraces()
    local traces_obj = traces_service:get_mdb_object("bmc.kepler.ObservabilityService.Traces")
    traces_obj.SupportedSamplingRate = get_supported_sampling_rate()
    local metrics_service = service:CreateMetrics()
    local metrics_obj = metrics_service:get_mdb_object("bmc.kepler.ObservabilityService.Metrics")
    metrics_obj.SupportedMetrics = get_supported_metrics()
    service:CreateLogs()
end

function module:init(db, bus)
    modify_supported_metrics()
    self:register_observability_policy_service(db, bus)
end

return module