-- Copyright (c) 2024 Huawei Technologies Co., Ltd.
-- openUBMC is licensed under Mulan PSL v2.
-- You can use this software according to the terms and conditions of the Mulan PSL v2.
-- You may obtain a copy of Mulan PSL v2 at:
--         http://license.coscl.org.cn/MulanPSL2
-- THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
-- EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
-- MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
-- See the Mulan PSL v2 for more details.

local log = require 'mc.logging'
local ctx = require 'mc.context'
local c_tasks = require 'mc.orm.tasks'
local utils = require 'mc.utils'
local common_def = require 'common_def'
local c_object = require 'mc.orm.object'
local nvme_utils = require 'nvme.utils'
local skynet = require 'skynet'

---@class c_vpd_connect: c_object
---@field Slot integer
---@field RefVPDChip c_object
---@field RefConnector c_object
local c_vpd_connect = c_object('VirtualVPDConnect') -- VirtualVPDConnect类

local SSD_FORM_CLASS_CODE_OFFSET1 <const> = 1
local SSD_FORM_CLASS_CODE_OFFSET2 <const> = 2
local SSD_FORM_CLASS_CODE_OFFSET3 <const> = 3
local VPD_COMMON_MODEL_NUMBER_LEN <const> = 65
local MAX_CHIP_READ_COUNT <const> = 5

local model_number_name_array = {
    ["SAMSUNG MZWLL1T6HAJQ"] = true,
    ["SAMSUNG MZWLL3T2HAJQ"] = true,
    ["SAMSUNG MZWLL6T4HMLA"] = true,
    ["SAMSUNG MZWLL12THMLA"] = true,
    ["SAMSUNG MZQLB960HAJR"] = true,
    ["SAMSUNG MZQLB1T9HAJR"] = true,
    ["SAMSUNG MZQLB3T8HALS"] = true,
    ["SAMSUNG MZQLB7T6HMLA"] = true
}

local NVME_MI_MANUFACTURE_OFFSET = 12
local NVME_MI_MANUFACTURE_LEN = 8
local NVME_COOLING_POLICY_DEFAULT = 0xffff

local function is_under_maintenance(vender_id)
    for _, v in pairs(common_def.MANUFACTURE_ID_MAP) do
        if v == vender_id then
            return true
        end
    end
    return false
end

local function get_from_vpd_nvme_mi_form(nvme)
    if not nvme.VPDChip then
        return false, false, common_def.INVALID_STRING
    end
    local ok, ret = pcall(function()
        return nvme.VPDChip:Read(ctx.get_context_or_default(),
                                 NVME_MI_MANUFACTURE_OFFSET,
                                 NVME_MI_MANUFACTURE_LEN)
    end)
    if not ok or not ret then
        log:info("get nvme %s manufacture by nvme mi form failed, %s", nvme.Slot, ret)
        return false, false, common_def.INVALID_STRING
    end
    local manufacture = ''
    local hex_arr = nvme_utils.string_split(utils.to_hex(ret), ' ', 16)
    for _, v in pairs(hex_arr) do
        if v ~= 0 then
            manufacture = manufacture .. string.char(v)
        end
    end
    manufacture = manufacture:gsub("^%s*(.-)%s*$", "%1")
    local vender_id = common_def.MANUFACTURE_ID_MAP[manufacture]
    return true, vender_id ~= nil, vender_id
end

local function get_from_ssd_form(nvme)
    if not nvme.VPDChip then
        return false, false, common_def.INVALID_U32
    end
    local ok, ret = pcall(function()
        return nvme.VPDChip:Read(ctx.get_context_or_default(),
                                 common_def.VPD_SSD_FORM_VENDOR_ID_OFFSET,
                                 common_def.VPD_SSD_FORM_VENDOR_ID_LEN)
    end)
    if not ok or not ret then
        log:info("get nvme %s manufacture id by sd form failed, %s", nvme.Slot, ret)
        return false, false, common_def.INVALID_U32
    end

    local vender_id = 0
    local arr = nvme_utils.string_split(utils.to_hex(ret), ' ', 16)
    for k, v in ipairs(arr) do
        vender_id = vender_id + (v << (8 * (k - 1)))
    end
    return true, is_under_maintenance(vender_id), vender_id
end

-- 返回值1：是否成功获取vender_id; 返回值2: vender_id是否被维护；返回值3：vender_id
local function get_nvme_manufacture(nvme)
    local mi_value = nvme:pcie_nvme_get_protocol_type()
    if mi_value == common_def.NVME_VPD_PROTOCOL_NVME_MI then
        return get_from_vpd_nvme_mi_form(nvme)
    end

    return get_from_ssd_form(nvme)
end

function c_vpd_connect:update_policy_connector(aux_id)
    self.RefPolicyConnector.AuxId = string.format("%x", aux_id)
    while true do
        if self.RefPolicyConnector.AuxId == string.format("%x", aux_id) then
            self.RefPolicyConnector.Presence = 1
            break
        end
        skynet.sleep(100)
    end
end

-- AuxId取值：255(sr配置初始值), 0xffff(获取不到厂商时的默认cooling policy)
-- common_def.MANUFACTURE_ID_MAP(具体厂商的4位16进制数)
function c_vpd_connect:load_cooling_policy(nvme)
    local task_name = string.format("update_policy_connector%s", nvme.Slot)
    c_tasks.get_instance():new_task(task_name):loop(function(task)
        if self.RefPolicyConnector.AuxId == tostring(common_def.INVALID_U8) then
            local ret, maintenance, vender_id = get_nvme_manufacture(nvme)
            if ret and maintenance then
                self:update_policy_connector(vender_id)
            elseif ret and not maintenance then
                self:update_policy_connector(NVME_COOLING_POLICY_DEFAULT)
            end
        else
            task:stop()
        end
    end):set_timeout_ms(30000)
end

function c_vpd_connect:get_common_header()
    if not self.RefVPDChip then
        log:error('RefVPDChip not exist.')
        return
    end

    local ok, ret = pcall(function()
        local hex_str = utils.to_hex(self.RefVPDChip:Read(ctx.get_context_or_default(), 0, 
            common_def.NVME_COMMON_HEADER_LEN))
        return nvme_utils.string_split(hex_str, ' ', 16)
    end)

    if not ok then
        log:error('get header data failed, ret is %s', ret)
        return
    end

    return ret
end

function c_vpd_connect:read_vpd_chip(offset, length)
    return pcall(function()
        return self.RefVPDChip:Read(ctx.get_context_or_default(), offset, length)
    end)
end

function c_vpd_connect:get_model_number()
    -- 最多尝试5次，防止因为链路问题导致访问失败
    local ok, ret = self:read_vpd_chip(0, VPD_COMMON_MODEL_NUMBER_LEN)
    if ok then
        local arr = nvme_utils.string_split(utils.to_hex(ret), ' ', 16)
        local str = ''
        for _, v in pairs(arr) do
            -- ASCII数字转字符
            if v ~= 0 then
                str = str .. string.char(v)
            end
        end

        return str:gsub("^%s*(.-)%s*$", "%1")
    end
end

function c_vpd_connect:get_model_number_with_repeat()
    local model_number = ''
    -- 最多尝试5次，防止因为链路问题导致访问失败
    for i = 1, MAX_CHIP_READ_COUNT do
        model_number = self:get_model_number()
        if model_number then
            return model_number
        end
        skynet.sleep(300)
    end
end

-- 获取三星盘的model number值
function c_vpd_connect:pcie_card_get_vpd_model_number_value()
    -- VirtualVPDTmpAccessor，用于获取VPD结构中的modelnumber
    return self:get_model_number_with_repeat() or ''
end


-- 根据获取到的model number值确认协议类型
function c_vpd_connect:get_samsung_vpd_type()
    local model_number = self:pcie_card_get_vpd_model_number_value()
    return model_number_name_array[model_number] and
        common_def.SAMSUNG_NVME_VPD_PROTOCOL_SSD_FORM_FACTOR or
        common_def.NVME_VPD_PROTOCOL_SSD_FORM_FACTOR
end


-- 判断NVMe-mi协议条件：
-- 1、第8个byte是前7个byte的checksum；
-- 2、前3个byte不等于SSD form factor协议的ClassCode 
function c_vpd_connect:judge_nvme_mi(common_header)
    --0x010802: SSD form factor协议的ClassCode
    local ssd_form_class_code = 0x010802

    local header_sum = 0
    for i = 1, common_def.NVME_COMMON_HEADER_LEN - 1 do
        header_sum = header_sum + common_header[i]
    end

    header_sum = nvme_utils.negation(header_sum, 8) % 256 + 1
    local class_code = common_header[SSD_FORM_CLASS_CODE_OFFSET1] |
        (common_header[SSD_FORM_CLASS_CODE_OFFSET2] << 8) |
        (common_header[SSD_FORM_CLASS_CODE_OFFSET3] << 16)
    log:notice('connector%s classcode is %s, header_sum is %s, target sum is %s', self.Slot,
        class_code, header_sum, common_header[common_def.NVME_COMMON_HEADER_LEN])
    if header_sum == common_header[common_def.NVME_COMMON_HEADER_LEN] and
        class_code ~= ssd_form_class_code then
        return true
    end

    return false
end

-- 区分当前硬盘位哪种协议
function c_vpd_connect:verify_vpd_protocol()
    local common_header = self:get_common_header()
    if not common_header then
        log:notice('vpd_connector%s get common header failed', self.Slot)
        return common_def.INVALID_U8
    end

    local vender_id = common_header[common_def.NVME_VPD_PROTOCOL_SSD_FORM_VENDOR_ID_OFFSET_L] |
        (common_header[common_def.NVME_VPD_PROTOCOL_SSD_FORM_VENDOR_ID_OFFSET_H] << 8)
    if self:judge_nvme_mi(common_header) then
        return common_def.NVME_VPD_PROTOCOL_NVME_MI
    elseif vender_id == common_def.NVME_VPD_VENDOR_ID_FOR_SAMSUNG then
        return self:get_samsung_vpd_type()
    else
         return common_def.NVME_VPD_PROTOCOL_SSD_FORM_FACTOR
    end
end

function c_vpd_connect:update_connector()
    if not self.RefConnector then
        return
    end

    local ok, aux_id = pcall(function()
        return self:verify_vpd_protocol()
    end)

    log:notice('connector get protocol is %s', aux_id)

    if ok and aux_id == common_def.INVALID_U8 then
        return
    end
    self.RefConnector.AuxId = tostring(aux_id)
    -- 确保AuxId比Presence先更新
    while ok do
        if self.RefConnector.AuxId == tostring(aux_id) then
            self.RefConnector.Presence = 1
            break
        end
        -- 避免长时间占用阻塞协程切换
        skynet.sleep(100)
    end
end

function c_vpd_connect:init()
    c_vpd_connect.super.init(self)
    local task_name = string.format('update_vpd_connector%s', self.Slot)
    c_tasks.get_instance():new_task(task_name):loop(function(task)
        if self.RefConnector.AuxId == tostring(common_def.INVALID_U8) then
             log:notice('begin update_connector%s', self.Slot)
            self:update_connector()
        else
            task:stop()
        end
    end):set_timeout_ms(30000)
end

function c_vpd_connect:dtor()
    local task_name = string.format('update_vpd_connector%s', self.Slot)
    local task = c_tasks.get_instance():get_task(task_name)
    if task then
        task:stop()
    end
end

return c_vpd_connect