-- Copyright (c) 2025 Huawei Technologies Co., Ltd.
-- openUBMC is licensed under Mulan PSL v2.
-- You can use this software according to the terms and conditions of the Mulan PSL v2.
-- You may obtain a copy of Mulan PSL v2 at:
--         http://license.coscl.org.cn/MulanPSL2
-- THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
-- EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
-- MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
-- See the Mulan PSL v2 for more details.
local lu = require 'luaunit'
local trace = require 'otel.trace'
local utils = require 'mc.utils'
local paser = require 'parser.trace'
local dump_file = './span_test.txt'

TestTrace = {}

function TestTrace:setup()
    self.tracer = self.tracer or trace.get_tracer("test_trace", "1.0.0", "")
    self.span = self.span or self.tracer:start_span("test_span", 
    {component = "observability", service_id = 123}, {force_sample = true})
    self.span:add_event("start test", {})
end

function TestTrace:teardownClass()
    utils.remove_file(dump_file)
end

function TestTrace:test_get_context()
    local spancontext = self.span:get_context()
    lu.assertNotIsNil(spancontext.trace_id)
    lu.assertNotIsNil(spancontext.span_id)
    lu.assertNotIsNil(spancontext.trace_flags)
    lu.assertNotIsNil(spancontext.trace_state)
    lu.assertNotIsNil(spancontext.is_remote)
end

function TestTrace:test_set_attribute()
    local span = self.tracer:start_span("test_set_attribute", {}, {force_sample = true})
    lu.assertEquals(span:is_recording(), true)
    span:set_attribute("method", "test_set_attribute")
    span:set_attribute("opcode", 1)
    span:set_attribute("is_error", false)
    span:set_attribute("cost_time", 1.22)

    lu.assertErrorMsgContains("invalid parameter, set attribute failed", function()
        span:set_attribute({}, "test_set_attribute")
    end)
    lu.assertErrorMsgContains("unsupported attribute value type", function()
        span:set_attribute("test error type value", {})
    end)

    span:finish()
    trace.flush()
    local spans = paser:parse_span_file(dump_file)
    lu.assertEquals(spans[#spans].name, "test_set_attribute")
    lu.assertNotIsNil(spans[#spans].attributes)
    lu.assertEquals(spans[#spans].attributes.method, "test_set_attribute")
    lu.assertEquals(spans[#spans].attributes.opcode, '1')
    lu.assertEquals(spans[#spans].attributes.is_error, '0')
    lu.assertEquals(spans[#spans].attributes.cost_time, '1.22')
end

function TestTrace:test_add_event()
    local span = self.tracer:start_span("test_add_event", {method = "test_add_event"}, {force_sample = true})
    lu.assertEquals(span:is_recording(), true)

    span:add_event("testing add event")
    span:add_event("testing add event1", {opcode = 1, name = "testing add event1"})

    lu.assertErrorMsgContains("invalid parameter, add event failed", function()
        span:add_event({}, {})
    end)
    lu.assertErrorMsgContains("invalid parameter, get attributes failed", function()
        span:add_event("name", "error type")
    end)
    lu.assertErrorMsgContains("unsupported attribute value type", function()
        span:add_event("name", {name = {}})
    end)

    span:finish()
    trace.flush()
    local spans = paser:parse_span_file(dump_file)
    lu.assertEquals(spans[#spans].name, "test_add_event")
    lu.assertNotIsNil(spans[#spans].attributes)
    lu.assertEquals(spans[#spans].attributes.method, "test_add_event")
    lu.assertNotIsNil(spans[#spans].events)
    lu.assertEquals(spans[#spans].events[1].name, 'testing add event')
    lu.assertEquals(spans[#spans].events[1].attributes, {})
    lu.assertEquals(spans[#spans].events[2].name, 'testing add event1')
    lu.assertNotIsNil(spans[#spans].events[2].attributes)
    lu.assertEquals(spans[#spans].events[2].attributes.opcode, '1')
end

function TestTrace:test_add_link()
    local span = self.tracer:start_span("test_add_link", {method = "test_add_link"}, {force_sample = true})
    lu.assertEquals(span:is_recording(), true)
    local spancontext = self.span:get_context()
    span:add_link(spancontext)
    span:add_link(spancontext, {link = "parent span"})

    span:finish()
    trace.flush()

    local spans = paser:parse_span_file(dump_file)
    lu.assertEquals(spans[#spans].name, "test_add_link")
    lu.assertNotIsNil(spans[#spans].attributes)
    lu.assertEquals(spans[#spans].attributes.method, "test_add_link")
    lu.assertEquals(spans[#spans].events, {})
    lu.assertNotIsNil(spans[#spans].links)
    lu.assertEquals(spans[#spans].links[1].trace_id, spancontext.trace_id)
    lu.assertEquals(spans[#spans].links[1].span_id, spancontext.span_id)
    lu.assertEquals(spans[#spans].links[1].tracestate, spancontext.trace_state)
    lu.assertEquals(spans[#spans].links[1].attributes, {})
    lu.assertEquals(spans[#spans].links[2].trace_id, spancontext.trace_id)
    lu.assertEquals(spans[#spans].links[2].span_id, spancontext.span_id)
    lu.assertEquals(spans[#spans].links[2].tracestate, spancontext.trace_state)
    lu.assertNotIsNil(spans[#spans].links[2].attributes)
    lu.assertEquals(spans[#spans].links[2].attributes.link, "parent span")
end

function TestTrace:test_set_status()
    local span = self.tracer:start_span("test_set_status", 
    {method = "test_set_status"}, {force_sample = true})
    lu.assertEquals(span:is_recording(), true)

    span:set_status("ok", "test set status ok")

    lu.assertErrorMsgContains("invalid parameter, set status failed", function()
        span:set_status({}, "12")
    end)
    lu.assertErrorMsgContains("invalid parameter, set status failed", function()
        span:set_status("12", {})
    end)
    lu.assertErrorMsgContains("invalid parameter, set status failed", function()
        span:set_status("nostatus", "test error status")
    end)

    span:finish()
    trace.flush()
    
    local spans = paser:parse_span_file(dump_file)
    lu.assertEquals(spans[#spans].name, "test_set_status")
    lu.assertNotIsNil(spans[#spans].attributes)
    lu.assertEquals(spans[#spans].attributes.method, "test_set_status")
    lu.assertEquals(spans[#spans].status, "Ok")
    lu.assertEquals(spans[#spans].description, "test set status ok")
    lu.assertEquals(spans[#spans].events, {})
    lu.assertEquals(spans[#spans].links, {})
end

function TestTrace:test_start_child_span()
    local parent_ctx = {
        trace_id = "00000000000000000000000000000001",
        span_id = "0000000000000001",
        trace_flags = 1,
        trace_state = "",
        is_remote = false
    }
    local span = self.tracer:start_span("test_child_span", 
    {method = "test_child_span"}, {parent = parent_ctx})
    lu.assertEquals(span:is_recording(), true)

    span:set_status("ok", "set child span status ok")
    span:add_event("testing child span add event", {opcode = 1})
    span:set_attribute("method", "test_start_child_span")

    lu.assertErrorMsgContains("invalid parameter, start span failed", function()
        self.tracer:start_span({}, {method = "test_child_span"}, {parent = parent_ctx})
    end)

    span:finish()
    trace.flush()

    local spans = paser:parse_span_file(dump_file)
    lu.assertEquals(spans[#spans].name, "test_child_span")
    lu.assertEquals(spans[#spans].status, "Ok")
    lu.assertEquals(spans[#spans].description, "set child span status ok")
    lu.assertEquals(spans[#spans].parent_span_id, "0000000000000001")
    lu.assertEquals(spans[#spans].trace_id, '00000000000000000000000000000001')
    lu.assertEquals(spans[#spans].links, {})
    lu.assertNotIsNil(spans[#spans].attributes)
    lu.assertEquals(spans[#spans].attributes.method, "test_start_child_span")
    lu.assertNotIsNil(spans[#spans].events)
    lu.assertEquals(spans[#spans].events[1].name, 'testing child span add event')
    lu.assertNotIsNil(spans[#spans].events[1].attributes)
    lu.assertEquals(spans[#spans].events[1].attributes.opcode, '1')
end

function TestTrace:test_spancontext_trace_id()
    local error_trace_id_type_ctx = {
        trace_id = 1,
        span_id = "0000000000000001",
        trace_flags = 0,
        trace_state = "",
        is_remote = false
    }
    lu.assertErrorMsgContains("invalid trace id", function()
        self.tracer:start_span("test_spancontext_trace_id", {}, {parent = error_trace_id_type_ctx})
    end)

    local error_trace_id_ctx = {
        trace_id = "0000000000000001",
        span_id = "0000000000000001",
        trace_flags = 0,
        trace_state = "",
        is_remote = false
    }
    lu.assertErrorMsgContains("invalid trace id", function()
        self.tracer:start_span("test_spancontext_trace_id", {}, {parent = error_trace_id_ctx})
    end)
end

function TestTrace:test_spancontext_span_id()
    local error_span_id_type_ctx = {
        trace_id = "00000000000000000000000000000001",
        span_id = 2,
        trace_flags = 0,
        trace_state = "",
        is_remote = false
    }
    lu.assertErrorMsgContains("invalid span id", function()
        self.tracer:start_span("test_spancontext_span_id", {}, {parent = error_span_id_type_ctx})
    end)

    local error_span_id_ctx = {
        trace_id = "00000000000000000000000000000001",
        span_id = "00000000000000000000000000000001",
        trace_flags = 0,
        trace_state = "",
        is_remote = false
    }
    lu.assertErrorMsgContains("invalid span id", function()
        self.tracer:start_span("test_spancontext_span_id", {}, {parent = error_span_id_ctx})
    end)
end

function TestTrace:test_spancontext_trace_flags()
    local error_trace_flags_type_ctx = {
        trace_id = "00000000000000000000000000000001",
        span_id = "0000000000000001",
        trace_flags = "",
        trace_state = "",
        is_remote = false
    }
    lu.assertErrorMsgContains("invalid trace flags", function()
        self.tracer:start_span("test_spancontext_trace_flags", {}, {parent = error_trace_flags_type_ctx})
    end)
end

function TestTrace:test_spancontext_trace_state()
    local error_trace_state_ctx = {
        trace_id = "00000000000000000000000000000001",
        span_id = "0000000000000001",
        trace_flags = 0,
        trace_state = {},
        is_remote = false
    }
    lu.assertErrorMsgContains("invalid trace state", function()
        self.tracer:start_span("test_spancontext_trace_state", {}, {parent = error_trace_state_ctx})
    end)
end

function TestTrace:test_sampling_ratio()
    -- 测试不同采样率
    local span1 = self.tracer:start_span("test_ratio_0", {}, {specified_ratio = 0.0, kind = "Server"})
    local span2 = self.tracer:start_span("test_ratio_1", {}, {specified_ratio = 1.0, kind = "Server"})
    local span3
    for i = 1, 10 do
        span3 = span3 or self.tracer:start_span("test_ratio_0_5", {}, {specified_ratio = 0.5, kind = "Server"})
    end
    local span4 = self.tracer:start_span("test_ratio_1", {}, {specified_ratio = 1.0, kind = "Internal"})
    
    -- 验证采样结果
    lu.assertNil(span1)
    lu.assertNotIsNil(span2)
    lu.assertNotIsNil(span3)
    lu.assertNil(span4)
    span2:finish()
    span3:finish()
end

function TestTrace:test_span_kind()
    local span_kinds = {"Server", "Client", "Producer", "Consumer", "Internal"}
    
    for _, kind in ipairs(span_kinds) do
        local span = self.tracer:start_span("test_" .. kind, {}, {kind = kind, force_sample = true})
        lu.assertNotIsNil(span)
        span:finish()
    end
    
    -- 测试无效的 span kind
    lu.assertErrorMsgContains("invalid span kind value", function()
        self.tracer:start_span("test_invalid_kind", {}, {kind = "Invalid", force_sample = true})
    end)
end

function TestTrace:test_invalid_ratio_type()
    lu.assertErrorMsgContains("specified_ratio must be number", function()
        self.tracer:start_span("test_invalid_ratio", {}, {specified_ratio = "not_a_number"})
    end)
end

function TestTrace:test_invalid_force_sample_type()
    lu.assertErrorMsgContains("force_sample must be boolean", function()
        self.tracer:start_span("test_invalid_force_sample", {}, {force_sample = "not_a_bool"})
    end)
end

function TestTrace:test_gc()
    collectgarbage('collect')
end
