-- Copyright (c) 2024 Huawei Technologies Co., Ltd.
-- openUBMC is licensed under Mulan PSL v2.
-- You can use this software according to the terms and conditions of the Mulan PSL v2.
-- You may obtain a copy of Mulan PSL v2 at: http://license.coscl.org.cn/MulanPSL2
-- THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
-- EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
-- MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
-- See the Mulan PSL v2 for more details.

-- Description: 定义VRD升级的生效过程
local log = require 'mc.logging'
local skynet = require 'skynet'
local context = require 'mc.context'
local defs = require 'mcu.upgrade.defs'
local vos = require 'utils.vos'
local utils = require 'mc.utils'
local utils_core = require 'utils.core'
local client = require 'general_hardware.client'
local cmn = require 'common'
local upgrade_service_comm = require 'mcu.upgrade.upgrade_service.upgrade_service_comm'
local fructl = require 'mcu.upgrade.fructl_handler'
local MCU_ENUMS = require 'mcu.enum.mcu_enums'

local valid = {
    valid_list = {}
}

local FRUCTRL_LOCK_TIME = 1000 -- 5min

local function get_firmware_name(db, valid_path, firmware_type)
    local name_list = utils_core.dir(valid_path)
    local cfg_table = db.FwUpdateCfgTable
    for _, name in ipairs(name_list) do
        local records = db:select(cfg_table):where(cfg_table.Id:eq(name),
            cfg_table.ComponentId:eq(defs.convert_component[firmware_type].id),
            cfg_table.ComponentIdex:eq(defs.convert_component[firmware_type].idex)):all()
        if #records ~= 0 then
            return name
        end
    end
    return nil
end

local function get_cfg_from_db(db, firmware_name, firmware_type)
    local cfg_table = db.FwUpdateCfgTable
    local records = db:select(cfg_table):where(cfg_table.Id:eq(firmware_name),
        cfg_table.ComponentId:eq(defs.convert_component[firmware_type].id),
        cfg_table.ComponentIdex:eq(defs.convert_component[firmware_type].idex)):all()
    -- 当前vrd升级有且仅有一个Firmware文件
    for _, item in ipairs(records) do
        return item
    end
end

local function string_split(str)
    local array = {}
    for value in string.gmatch(str, '[^,]+') do
        array[#array + 1] = value
    end
    return array
end

-- 通过升级文件路径获取升级固件列表
local function get_vrd_upgrade_list(upg_vrd_obj, firmware_type, firmware_name, system_id)
    local cfg_db = get_cfg_from_db(upg_vrd_obj.db, firmware_name, firmware_type)
    if not cfg_db then
        return nil
    end
    local cfg = {}
    cfg.uid_list = string_split(cfg_db.Uid)
    return upgrade_service_comm.get_upgrade_list(upg_vrd_obj.bus, upg_vrd_obj.mcu_collection, cfg,
        firmware_type, system_id)
end

local function get_upgrade_obj_version(upgrade_list)
    local vrd_version = {}
    for _, upgrade_obj in pairs(upgrade_list) do
        local component = upgrade_obj:get_ref_subcomp()
        if not component then
            goto continue
        end
        local ok, version = pcall(component.ge_component_major_version, component)
        if ok and version and version ~= '' then
            table.insert(vrd_version, version)
        end
        ::continue::
    end
    if #vrd_version ~= 0 then
        return table.concat(vrd_version, '.')
    end
    log:error('get before version failed')
    return ''
end

local function get_single_vrd_version(obj)
    local component = obj:get_ref_subcomp()
    local ok, version = pcall(function ()
        return component:ge_component_major_version()
    end)
    if ok and version and version ~= '' then
        return version
    end

    return '255'
end

local function load_one_fw(upg_vrd_obj, fw_info, dir, valid_details)
    log:notice("[vrd valid task] firmware start valid index:%s", fw_info.fw_index)

    -- 生效过程pcall捕获异常，防止中间失败导致无法解锁
    local ok, ret_code = pcall(function ()
        local old_version = get_single_vrd_version(fw_info.fw_obj)
        local ret = upgrade_service_comm.upgrade_mcu(upg_vrd_obj.bus, fw_info.fw_obj, fw_info, dir)
        upgrade_service_comm.record_upgrade_log(valid_details.firmware_type, ret, fw_info.fw_obj, old_version)
        if ret == MCU_ENUMS.MCU_UPGRADE_STATUS.LOW_VERSION then
            log:warn('[vrd valid task] upgrade file version low, firmware index:%s', fw_info.fw_index)
        end
        valid_details.upgraded_cnt = valid_details.upgraded_cnt + 1
        if ret ~= defs.RET.OK then
            valid_details.upgrade_ret_code = ret
        end
        return ret
    end)
    if not ok or ret_code ~= defs.RET.OK then
        log:error("[vrd valid task] firmware valid failed, index:%s, err:%s", fw_info.fw_index, ret_code)
    end

    -- 硬件总线解锁
    ok, ret_code = fw_info.fw_obj:chip_unlock(require 'mc.context'.new())
    if not ok or ret_code ~= 0 then
        log:error("[vrd valid task] firmware(index:%s) set lock status to 0 failed, ret code:%s", fw_info.fw_index,
            ret_code)
    end

    log:notice("[vrd valid task] firmware validated index:%s", fw_info.fw_index)
end

local function paraller_load_fw(upg_vrd_obj, valid_details)
    local fw_info, ok, ret_code
    while true do
        fw_info = table.remove(valid_details.upgrade_list, 1)
        if not fw_info then
            log:debug("no firmware info to valid")
            goto continue
        end
        log:debug("get firmware info to valid, fw index:%s", fw_info.fw_index)

        ok, ret_code = fw_info.fw_obj:chip_lock(require 'mc.context'.new(), 1800)
        if not ok or ret_code ~= 0 then
            log:debug("firmware(index:%s) set lock status to 1 failed, ret code:%s", fw_info.fw_index, ret_code)
            -- 加锁失败，放到队尾等待
            table.insert(valid_details.upgrade_list, fw_info)
        else
            skynet.fork_once(function ()
                load_one_fw(upg_vrd_obj, fw_info, valid_details.valid_path, valid_details)
            end)
        end

        ::continue::
        skynet.sleep(100)
        -- 所有Chip都加载完成
        if valid_details.upgraded_cnt == valid_details.upgrade_list_cnt then
            log:notice("[vrd valid task] all vrd firmware validated")
            break
        end
    end
end

local function serial_load_fw(upg_vrd_obj, valid_details)
    log:notice("serial upgrade mcu/vrd")

    for _, fw_info in pairs(valid_details.upgrade_list) do
        load_one_fw(upg_vrd_obj, fw_info, valid_details.valid_path, valid_details)
    end
end

local function upgrade_task(upg_vrd_obj, valid_path, system_id, firmware_type)
    local firmware_name = get_firmware_name(upg_vrd_obj.db, valid_path, firmware_type)
    if firmware_name == nil then
        log:error('[McuUpgrade] Can not find upgrade file')
        return defs.RET.ERR
    end

    local upgrade_path = valid_path .. firmware_name
    local ok, rsp = client:PFileFileChown(context.new(), nil, upgrade_path, 104, 104)
    if not ok then
        log:error('[McuUpgrade] chown upgrade file failed, error %s', rsp)
        return defs.RET.ERR
    end
    local upgrade_list = get_vrd_upgrade_list(upg_vrd_obj, firmware_type, firmware_name, system_id)
    if upgrade_list == nil or not next(upgrade_list) then
        log:error('[McuUpgrade] Can not find int cfgs')
        return defs.RET.ERR
    end
    -- 记录升级前的VRD版本
    local old_vrd_version
    if system_id ~= defs.ALL_SYSTEM_ID then
        old_vrd_version = get_upgrade_obj_version(upgrade_list)
    end

    -- 解压固件
    utils.secure_tar_unzip(upgrade_path, valid_path, 0x6400000, 1024)  -- 最大解压限制100M
    local valid_details = {
        upgraded_cnt = 0,
        upgrade_list = {},
        upgrade_list_cnt = #upgrade_list,
        valid_path = valid_path,
        upgrade_ret_code = defs.RET.OK,
        firmware_type = firmware_type
    }
    for fw_index, fw_obj in pairs(upgrade_list) do
        table.insert(valid_details.upgrade_list, {
            fw_index = fw_index,
            fw_obj = fw_obj })
    end

    if upgrade_list[1]:chip_lock_supported() then
        paraller_load_fw(upg_vrd_obj, valid_details)
    else
        serial_load_fw(upg_vrd_obj, valid_details)
    end

    -- 统一更新子件版本号
    upgrade_service_comm.update_vrd_component_version(upg_vrd_obj.bus, upg_vrd_obj.mcu_collection, system_id)
    -- system_id为255时，维护日志不记录版本信息
    if system_id == defs.ALL_SYSTEM_ID or firmware_type == defs.VDM_NAME then
        return valid_details.upgrade_ret_code
    end
    -- 记录升级后的VRD版本
    if valid_details.upgrade_ret_code == defs.RET.OK then
        log:maintenance(log.MLOG_INFO, log.FC__PUBLIC_OK, 'Upgrade Power from version %s to version %s successfully',
            old_vrd_version, get_upgrade_obj_version(upgrade_list))
    else
        log:maintenance(log.MLOG_INFO, log.FC__PUBLIC_OK, 'Upgrade Power from version %s failed', old_vrd_version)
    end
    return valid_details.upgrade_ret_code
end

local function finish_upgrade_vrd(file_path, valid_path, sys_id, bus)
    fructl.set_poweron_lock_until_success(bus, sys_id, false, FRUCTRL_LOCK_TIME, defs.VRD_NAME)
    if vos.get_file_accessible(file_path) then
        utils.remove_file(file_path)
    end
    if vos.get_file_accessible(valid_path) then
        utils.remove_file(valid_path)
    end
end

local function wait_vrd_load(upg_vrd_obj)
    for _ = 1, 200 do
        if upg_vrd_obj:get_vrd_load() then
            return true
        end
        skynet.sleep(100)
    end
    return false
end

-- vrd升级使用pcall, 防止出现死锁
local function vrd_valid_task(upg_vrd_obj, sys_id, file_path)
    log:notice('[Vrd] recover vrd upgrade, No CPLD is mutually exclusive.')
    -- 2.将hpm解压,获取升级文件
    local package_info = cmn.get_package_info(file_path)
    if not package_info or not defs.NEED_VALID_FW[package_info.FirmwareType] then
        log:error('[Vrd] hpm is not vrd :%s', package_info and package_info.FirmwareType)
        return
    end
    local valid_path = package_info.FirmwareDirectory
    -- 3.加上电锁，用于kvm展示
    fructl.set_poweron_lock_until_success(upg_vrd_obj.bus, sys_id, true, FRUCTRL_LOCK_TIME, defs.VRD_NAME)
    if not wait_vrd_load(upg_vrd_obj) then
        log:error('[Vrd]wait vrd load fail.')
        finish_upgrade_vrd(file_path, valid_path, sys_id, upg_vrd_obj.bus)
        return
    end
    log:notice('[Vrd]wait vrd load success.')
    -- 4. 执行升级
    local ret = upgrade_task(upg_vrd_obj, valid_path, sys_id, package_info.FirmwareType)
    log:notice('[Vrd] VRD finish upgrading. Set Flag to finish')

    finish_upgrade_vrd(file_path, valid_path, sys_id, upg_vrd_obj.bus)
    if ret ~= defs.RET.OK then
        log:error('[Vrd] valid Vrd failed.')
    else
        log:notice('[Vrd] valid Vrd successfully.')
    end
end

function valid.vrd_valid_task(upg_vrd_obj, sys_id, file_path)
    log:notice('start to validate system(%s) vrd', sys_id)
    valid.valid_list[sys_id] = true
    local ok, rsp = pcall(vrd_valid_task, upg_vrd_obj, sys_id, file_path)
    valid.valid_list[sys_id] = nil
    if not ok then
        log:error('vrd valid task failed, err: %s', rsp)
        return nil
    end

    return rsp
end

return valid
