-- Copyright (c) 2024 Huawei Technologies Co., Ltd.
-- openUBMC is licensed under Mulan PSL v2.
-- You can use this software according to the terms and conditions of the Mulan PSL v2.
-- You may obtain a copy of Mulan PSL v2 at:
--         http://license.coscl.org.cn/MulanPSL2
-- THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
-- EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
-- MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
-- See the Mulan PSL v2 for more details.

local class = require 'mc.class'
local c_object_index = require 'object_manager.object_index'
require 'object_manager.basic'

local MAIN_INDEX<const> = 'main_index'

---@class c_object_collection: c_basic_class
local c_object_collection = class()

function c_object_collection:ctor(cls, primary_keys)
    self.indexs = {}
    self.cls = cls
    self.count = 0
    self.primary_keys = primary_keys or {'object_id'}
    self.raw_object_mt = {
        __index = function(obj, name)
            local val = cls[name]
            if val ~= nil then
                return val
            end
            return obj.base_obj[name]
        end
    }
end

function c_object_collection:init()
    self.indexs[MAIN_INDEX] = c_object_index.new(self.primary_keys)
end

function c_object_collection:reset()
    for _, index in pairs(self.indexs) do
        index:reset()
    end
    self.count = 0
end

function c_object_collection:new_raw_object(base_obj, position)
    return setmetatable({base_obj = base_obj, position = position}, self.raw_object_mt)
end

function c_object_collection:fold(cb, acc)
    return self.indexs[MAIN_INDEX]:fold(cb, acc)
end

function c_object_collection:safe_fold(cb, acc)
    return self.indexs[MAIN_INDEX]:safe_fold(cb, acc)
end

function c_object_collection:objects()
    return self.indexs[MAIN_INDEX].objects
end

-- index_name 为 0 表示获取主索引
---@param index_name string | nil
function c_object_collection:index(index_name)
    return self.indexs[index_name or MAIN_INDEX]
end

function c_object_collection:add_index(index_name, keys)
    local index = self.indexs[index_name]
    if index then
        return index
    end

    index = c_object_index.new(keys)
    self:fold(function(_, object)
        local key = object:get_key(index.keys)
        if index:get_object(key) then
            error(string.format('add index [%s]%s failed: object duplicate, key=%s',
                self.cls.__class_name, index_name, key))
        end
        index:add_object(object)
    end)
    self.indexs[index_name] = index
    return index
end

function c_object_collection:find_by_raw_object(raw_obj)
    for index_name, index in pairs(self.indexs) do
        local key = raw_obj:get_key(index.keys)
        if key then
            local object = index:get_object(key)
            if object then
                return object, index_name, key
            end
        end
    end
end

local function match_object(condition, object)
    if not object then
        return nil
    end

    for k, v in pairs(condition) do
        if object[k] ~= v then
            return nil
        end
    end

    return object
end

---@param cb_or_row (fun(object: any):boolean) | table<string, any> 
function c_object_collection:find(cb_or_row)
    if type(cb_or_row) == 'function' then
        return self:index():find_object(cb_or_row)
    end

    -- 通过索引查询
    for _, index in pairs(self.indexs) do
        local key = self.cls.get_key(cb_or_row, index.keys)
        local object = match_object(cb_or_row, index:get_object(key))
        if object then
            return object
        end
    end

    -- 索引没有命中
    return self:index():find_object(function(object)
        return match_object(cb_or_row, object) ~= nil
    end)
end

function c_object_collection:add_object(object)
    for index_name, index in pairs(self.indexs) do
        local key = object:get_key(index.keys)
        if key and index:get_object(key) then
            error(string.format('add object by index [%s]%s duplicate: key=%s',
                self.cls.__class_name, index_name, key))
        end
    end

    for _, index in pairs(self.indexs) do
        local key = object:get_key(index.keys)
        index:add_object(key, object)
    end
    self.count = self.count + 1
    return object
end

function c_object_collection:del_object(raw_obj)
    local object = self:find_by_raw_object(raw_obj)
    if not object then
        return
    end

    for index_name, index in pairs(self.indexs) do
        local key = object:get_key(index.keys)
        local obj = index:get_object(key)
        if obj ~= object then
            error(string.format('del object by index [%s]%s failed: object not match, key=%s',
                self.cls.__class_name, index_name, key))
        end
    end

    for _, index in pairs(self.indexs) do
        local key = object:get_key(index.keys)
        index:del_object(key)
    end

    self.count = self.count - 1
end

return c_object_collection
