-- Copyright (c) 2024 Huawei Technologies Co., Ltd.
-- openUBMC is licensed under Mulan PSL v2.
-- You can use this software according to the terms and conditions of the Mulan PSL v2.
-- You may obtain a copy of Mulan PSL v2 at: http://license.coscl.org.cn/MulanPSL2
-- THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
-- EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
-- MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
-- See the Mulan PSL v2 for more details.

local skynet = require 'skynet'
local log = require 'mc.logging'
local vos = require 'utils.vos'
local utils = require 'mc.utils'
local class = require 'mc.class'
local client = require 'general_hardware.client'
local context = require 'mc.context'
local cfg = require 'unit_manager.class.logic_fw.upgrade.fw_cfgs'
local file_sec = require 'utils.file'
local base_messages = require 'messages.base'
local task_mgmt = require 'maintenance_service.task'
local task_prop = require 'mc.mdb.task_mgmt'
local bs = require 'mc.bitstring'
local core = require 'utils.core'
local mdb = require 'mc.mdb'
local sr_upg_service = require 'sr_upg_service.sr_upg_service'

local task_state <const> = task_prop.state
local task_status <const> = task_prop.status

local maintenance_csr = class(nil, nil, true)
local IMPORT_CSR_PATH <const> = "/dev/shm/import/"
local MAINTENANCE_PATH <const> = '/bmc/kepler/UpdateService/Maintenance/Csr'
local CONNECTOR_PATH <const> = '/bmc/kepler/Connector/'
local CONNECTOR_INTERFACE <const> = 'bmc.kepler.Connector'
local EEP_HEADER_LEN <const> = 128

local bs_header = bs.new([[<<
    standard_code:12/string,
    spec_version:8,
    elabel_offset:16,
    sys_desc_offset:16,
    inner_area_offset:16,
    psr_offset:16,
    csr_offset:16,
    digital_sign_offset:16,
    _:73/string,
    sr_version:2/string,
    component_uid:24/string,
    verification_code:4/string
>>]])

function maintenance_csr:ctor(bus)
    self.bus = bus
    self.task_id = nil
    self.connector_id = nil
    self.is_importing = false
    self.tracer = nil
end

function maintenance_csr:init()
    local ok, trace = pcall(require, 'telemetry.trace')
    if not ok then
        log:notice('There is no trace to require')
        return 
    end
    self.tracer = trace.get_tracer('general_hardware')
end

-- 获取的文件长度为需要升级的CSR的UID长度
local function get_file_size(file_path)
    local file_obj = file_sec.open_s(file_path, "r")
    if not file_obj then
        log:error('Open file failed!')
        return 0
    end
    local file_size = file_obj:seek("end")
    file_obj:close()
    return file_size
end
function maintenance_csr:enrich_start_span(name, attribute, span_ctx)
    if not self.tracer then
        return nil
    end
    return self.tracer:start_span(name, attribute, span_ctx)
end

function maintenance_csr:enrich_set_status(child_span, status, desc)
    if not child_span then
        return
    end
    local err_msg = child_span:set_status(status, desc)
    if err_msg then
        log:error("%s", err_msg)
    end
end

function maintenance_csr:enrich_finish(child_span)
    if not child_span then
        return
    end
    child_span:finish()
end

function maintenance_csr:enrich_add_event(child_span, ...)
    if not child_span then
        return
    end
    local err_msg = child_span:add_event(...)
    if err_msg then
        log:error("%s", err_msg)
    end
end

-- 解压升级包并获取UID
local function process_package_info(file_path)
    local ret, package_info
    for _ = 1, 10 do
        ret, package_info = pcall(function ()
            return client:UpdateServiceUpdateServiceParseFirmwarePackage(context.new(), file_path)
        end)
        if ret and package_info then
            break
        end
    end
    if not ret then
        log:error('Parse package failed, error:%s', package_info)
        return
    end

    local dir_path = package_info.FirmwareDirectory
    ret = file_sec.check_shell_special_character_s(dir_path)
    if ret ~= 0 then
        log:error('parse package failed, file path invalid, err:%s', ret)
        return
    end
    local cfg_path = dir_path .. 'update.cfg'
    local upg_cfg_list = cfg:get_cfg_list(cfg_path)
    if not next(upg_cfg_list) then
        return
    end
    local fw_path = dir_path .. upg_cfg_list[1].name
    utils.secure_tar_unzip(
        fw_path,
        dir_path,
        0x6400000,
        1024
    )

    local tmp_uid_path = '/dev/shm/tmp/uid'
    ret = vos.system_s('/bin/sh', '-c', 'ls ' .. dir_path .. '| grep .bin &> ' .. tmp_uid_path)
    if ret ~= 0 then
        log:error('excute cmd failed')
        return
    end

    local uid, tmp_uid_path_content
    tmp_uid_path_content = file_sec.read_file_s(tmp_uid_path, get_file_size(tmp_uid_path))
    uid = string.match(tmp_uid_path_content, "(%w+)%.bin")
    utils.remove_file(tmp_uid_path)
    log:notice('upgrade file uid = %s', uid)
    return dir_path, uid
end

local function check_header(header)
    local OFFSET_MAX <const> = 25000
    if header.sys_desc_offset >= OFFSET_MAX or header.elabel_offset >=
        OFFSET_MAX or header.psr_offset >= OFFSET_MAX or header.csr_offset >=
        OFFSET_MAX or header.inner_area_offset >= OFFSET_MAX then
        return false
    end
    return true
end

-- 解析SR文件
local function parse_csr(ori_bin_path, dst_csr_path)
    local file = file_sec.open_s(ori_bin_path, 'rb')
    local header_bin = file:read(EEP_HEADER_LEN)
    local header = bs_header:unpack(header_bin)
    if not check_header(header) then
        log:error('[SRUpgrade]header check failed')
        file:close()
        return false
    end

    local len = 0
    local offset = 0
    local psr_offset = header.psr_offset
    local csr_offset = header.csr_offset
    local sign_offset = header.digital_sign_offset
    -- 根据bin包类型，计算数据长度（单位：8Bytes, 去掉body header）和起始偏移
    -- body header占用7字节，因此计算数据长度时需要-7
    if psr_offset ~= 0 then
        local next_offset = csr_offset and csr_offset or sign_offset
        len = next_offset - psr_offset - 7
        offset = psr_offset
    end

    if csr_offset ~= 0 then
        len = sign_offset - csr_offset - 7
        offset = csr_offset
    end

    -- 单位：8Bytes
    offset = offset * 8
    file:seek('set', offset + 56)
    local compressed_data = file:read(len * 8)
    file:close()
    local ret, json_data = pcall(function()
        return core.inflate(compressed_data)
    end)

    if not ret or json_data == nil then
        log:error('failed to decode eeprom description record data, error is %s', json_data)
        return false
    end

    log:debug('json_data = %s', json_data)

    -- 创建import目录
    if not vos.get_file_accessible(IMPORT_CSR_PATH) then
        os.execute('mkdir -p ' .. IMPORT_CSR_PATH)
        log:notice('create import dir')
    end

    -- 修改权限， 读方式打开
    local sr_file = file_sec.open_s(dst_csr_path, 'w+')
    if not sr_file then
        log:error('open file failed.')
        utils.remove_file(IMPORT_CSR_PATH)
        return false
    end
    sr_file:write(json_data)
    sr_file:close()
    return true
end

local function load_firmware(self, ctx, uid, dir_path, expect_connector)
    local ori_bin_path = dir_path .. uid .. '.bin'
    local dst_csr_path = IMPORT_CSR_PATH .. uid .. '.sr'
    local ret = parse_csr(ori_bin_path, dst_csr_path)
    if not ret then
        log:error('parse csr failed.')
        error(base_messages.InternalError())
    end

    self.connector_id = expect_connector.Id
    log:notice('Start to load csr')
    expect_connector.pcall:Reload(ctx, expect_connector.Bom, uid, expect_connector.AuxId, 2)
    log:notice('load csr successfully')
end

-- 异常退出，清除数据，还原连接器
local function clear_data(self, ctx, dir_path, expect_connector)
    log:notice('Start to clear temp data')
    -- 清除导入状态
    self.is_importing = false
    -- 设置任务状态
    task_mgmt.update_task_prop(self.task_id, {State = task_state.Exception, Status = task_status.Error,
        MessageId = 'InternalError', MessageArgs = {}})
    -- 清除解压的升级文件
    if dir_path and #dir_path ~= 0 then
        utils.remove_file(dir_path)
    end
    if vos.get_file_accessible(IMPORT_CSR_PATH) then
        utils.remove_file(IMPORT_CSR_PATH)
    end

    -- 还原连接器
    if expect_connector and self.connector_id and #self.connector_id ~= 0 then
        expect_connector.pcall:Reload(ctx, expect_connector.Bom, self.connector_id, expect_connector.AuxId, 3)
    end
    log:notice('Clear temp data successfully')
end

-- 记录失败日志
local function record_log(ctx, file_name, object_name)
    log:operation(ctx:get_initiator(), 'general_hardware',
        'Import HWSR with %s to component loaded by %s failed', file_name, object_name)
    log:maintenance(log.MLOG_INFO, log.FC__PUBLIC_OK,
        'Import HWSR with %s to component loaded by %s failed', file_name, object_name)
end

-- 查询升级进度
local function wait_upgrade_finish(bus, task_id)
    local path = '/bmc/kepler/UpdateService/TaskService/Tasks/' .. task_id
    local intf = 'bmc.kepler.TaskService.Task'
    local obj = mdb.get_object(bus, path, intf)
    local query_count = 0
    -- 5分钟左右
    while query_count < 300 do
        log:debug('upgrade csr: loop=%d, task State=%s, Progress=%d', query_count, obj.State, obj.Progress)
        if obj.State == 'Completed' and obj.Progress == 100 then
            return true
        end
        query_count = query_count + 1
        skynet.sleep(100)
    end
    return false
end

-- 查询SR升级对象
local function wait_srupgrade_object(group_position, uid)
    local query_count = 0
    -- 2分钟左右
    while query_count < 120 do
        log:debug('get csr object: loop = %d', query_count)
        for _, v in pairs(sr_upg_service.get_instance().sr_upgrade_list) do
            if v.position == group_position and v.mds_obj.UID == uid then
                return true
            end
        end
        query_count = query_count + 1
        skynet.sleep(100)
    end
    return false
end

-- 查询导入状态
function maintenance_csr:get_is_importing()
    return self.is_importing
end

local function is_valid_filename(sec_file_name)
    local pattern = '^[%w%_%(%)%-%.% ]+$'
    if not string.match(sec_file_name, pattern) or string.sub(sec_file_name, 1, 1) == '.' then
        log:error('file name [%s] is not valid', sec_file_name)
        return false
    end
    return true
end

local function is_local_path(file_path)
    -- 文件路径是否为本地路径，只能为/tmp/目录：必须含有字符串'/tmp/', 并且是根路径
    local is_tmp = file_sec.check_realpath_before_open_s(file_path, '/tmp')
    if is_tmp == 0 then
        log:notice('The file path is local.')
        return true
    end

    return false
end

local function get_valid_filepath(self, file_path)
    local child_span = self:enrich_start_span("get_valid_filepath", {file_path = file_path})
    local ok, real_path
    ok, real_path = pcall(utils.realpath, file_path)
    if not ok then
        log:error('get realpath error.')
        self:enrich_set_status(child_span, 'error', 'get realpath error.')
        self:enrich_finish(child_span)
        return ''
    end
    self:enrich_finish(child_span)
    return real_path
end

local function get_valid_file_path(self, file_path)
    local child_span = self:enrich_start_span("get_valid_file_path", {file_path = file_path})
    local real_path = get_valid_filepath(self, file_path)
    local ok, sec_file_name = pcall(file_sec.get_file_name_s, file_path)
    if not ok then
        log:error('get_file_name_s error:%s', sec_file_name)
        self:enrich_add_event(child_span, string.format('get_file_name_s error'), {file_path = file_path})
        self:enrich_finish(child_span)
        return false
    end
    if real_path and is_local_path(real_path) then
        if is_valid_filename(sec_file_name) then
            self:enrich_add_event(child_span, string.format
            ('get valid filepath = %s', real_path), {real_path = real_path, sec_file_name = sec_file_name})
            self:enrich_finish(child_span)
            return real_path, sec_file_name
        end
        utils.remove_file(real_path)
    end
    self:enrich_set_status(child_span, 'error',
    string.format('real_path: %s or is_local_path(real_path) is false', real_path))
    self:enrich_finish(child_span)
    return nil
end

local function quit_import(self, ctx, file_name, dir_path, expect_connector, object_name)
    local ok, rsp = pcall(clear_data, self, ctx, dir_path, expect_connector)
    if not ok then
        log:error('clear data failed, error:%s', rsp)
    end
    record_log(ctx, file_name, object_name)
end

local function get_expect_connector(self, uid, object_name)
    local expect_connector = mdb.get_object(self.bus, CONNECTOR_PATH .. object_name, CONNECTOR_INTERFACE)
    if not expect_connector then
        log:error('get connector object failed.')
        return nil
    end

    -- 组件不在位或者非天池组件，不允许导入
    if expect_connector.Presence == 0 or expect_connector.IdentifyMode == 2 then
        log:error('dont support import')
        return nil
    end
    -- 当前组件正常且与导入的csr一致，不允许导入
    if expect_connector.Id == uid and expect_connector.LoadStatus == 0 then
        log:error('not change uid, dont load')
        return nil
    end

    return expect_connector
end

local function start_upgrade_csr(ctx, bus, file_path)
    local ok, rsp = client:PUpdateServiceUpdateServiceStartUpgrade(ctx, file_path, {})
    if not ok then
        log:error('create upgrade task failed, error is %s', rsp)
        return false
    end

    -- 等待升级结束
    ok, rsp = pcall(wait_upgrade_finish, bus, rsp.TaskId)
    if not ok then
        log:error('upgrade csr failed, error is %s', rsp)
        return false
    end

    return true
end

function maintenance_csr:import_csr_task(ctx, object_name, file_path, file_name)
    -- 解压升级包
    local ok, rsp, expect_connector, dir_path, uid
    ok, dir_path, uid = pcall(process_package_info, file_path)
    if not ok or not dir_path or not uid then
        log:error('Process package info failed.')
        quit_import(self, ctx, file_name, dir_path, expect_connector, object_name)
        return
    end
    task_mgmt.update_task_prop(self.task_id, {Progress = 20})

    -- 根据connector_name获取连接器对象
    ok, expect_connector = pcall(get_expect_connector, self, uid, object_name)
    log:notice('uid = %s, object name = %s', uid, object_name)
    if not ok or not expect_connector then
        log:error('Get expect connector failed, error:%s', expect_connector)
        quit_import(self, ctx, file_name, dir_path, expect_connector, object_name)
        return
    end

    task_mgmt.update_task_prop(self.task_id, {Progress = 30})
    -- 记录当前ID，解析并加载CSR文件
    ok, rsp = pcall(load_firmware, self, ctx, uid, dir_path, expect_connector)
    if not ok then
        log:error('load firmware failed, error:%s', rsp)
        quit_import(self, ctx, file_name, dir_path, expect_connector, object_name)
        return
    end

    task_mgmt.update_task_prop(self.task_id, {Progress = 60})
    -- CSR升级
    ok, rsp = pcall(wait_srupgrade_object, expect_connector.GroupPosition, uid)
    if not ok or not rsp then
        log:error('get csr upgrade object failed, error is %s', rsp)
        quit_import(self, ctx, file_name, dir_path, expect_connector, object_name)
        return
    end
    utils.remove_file(IMPORT_CSR_PATH)
    log:notice('Start to upgrade csr')
    task_mgmt.update_task_prop(self.task_id, {Progress = 100})
    log:operation(ctx:get_initiator(), 'general_hardware',
        'Import HWSR with %s to component loaded by %s successfully', file_name, object_name)
    log:maintenance(log.MLOG_INFO, log.FC__PUBLIC_OK,
        'Import HWSR with %s to component loaded by %s successfully', file_name, object_name)
    -- 清除导入状态， 开始调用升级接口
    self.is_importing = false
    ok, rsp = pcall(start_upgrade_csr, ctx, self.bus, file_path)
    if not ok or not rsp then
        log:error('Upgrade csr failed, error is %s', rsp)
        ok, rsp = pcall(clear_data, self, ctx, dir_path, expect_connector)
        if not ok then
            log:error('clear data failed, error:%s', rsp)
        end
    end
end

-- rpc方法，导入CSR
function maintenance_csr:import_csr(ctx, connector_name, file_path)
    local root_span = self:enrich_start_span("import_csr", {connector_name = connector_name, file_path = file_path})
    log:notice('Connector Name = %s', connector_name)
    self:enrich_add_event(root_span, string.format('Connector Name = %s', connector_name),
     {connector_name = connector_name, file_path = file_path})
    -- 已存在导入任务
    if self:get_is_importing() then
        log:error('A csr import task already exists.')
        self:enrich_set_status(root_span, 'error', 'A csr import task already exists.')
        self:enrich_finish(root_span)
        return nil
    end

    -- 创建任务
    self.is_importing = true
    self.task_id = task_mgmt.create_task(self.bus, 'Import Csr Task', MAINTENANCE_PATH, 20)
    local ok, real_path, file_name = pcall(get_valid_file_path, self, file_path)
    if not ok or not real_path or not file_name then
        log:error('file path is not valid')
        task_mgmt.update_task_prop(self.task_id, {State = task_state.Exception, Status = task_status.Error,
            MessageId = 'InternalError', MessageArgs = {}})
        self.is_importing = false
        self:enrich_set_status(root_span, 'error', string.format
         ('connector_name: %s, path: %s, is not valid. task_id: ', connector_name, file_path, self.task_id))
        self:enrich_finish(root_span)
        return self.task_id
    end

    -- 创建线程
    skynet.fork_once(self.import_csr_task, self, ctx, connector_name, real_path, file_name)
    log:notice('Import csr task start successfully')
    self:enrich_add_event(root_span, 'Import csr task start successfully',
    {connector_name = connector_name, real_path = real_path, file_name = file_name, task_id = self.task_id})
    log:operation(ctx:get_initiator(), 'general_hardware',
        'Import HWSR with %s to component loaded by %s started', file_name, connector_name)
    log:maintenance(log.MLOG_INFO, log.FC__PUBLIC_OK,
        'Import HWSR with %s to component loaded by %s started', file_name, connector_name)
    self:enrich_finish(root_span)
    return self.task_id
end

return maintenance_csr
