/*
 * Copyright (c) 2015-2025 The Khronos Group Inc.
 * Copyright (c) 2015-2025 Valve Corporation
 * Copyright (c) 2015-2025 LunarG, Inc.
 * Copyright (c) 2015-2024 Google, Inc.
 * Modifications Copyright (C) 2020 Advanced Micro Devices, Inc. All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 */

#include "../framework/layer_validation_tests.h"
#include "../framework/pipeline_helper.h"
#include "../framework/shader_object_helper.h"
#include "cooperative_matrix_helper.h"

void CooperativeMatrixTest::InitCooperativeMatrixKHR() {
    AddRequiredExtensions(VK_KHR_COOPERATIVE_MATRIX_EXTENSION_NAME);
    // glslang will generate OpCapability VulkanMemoryModel and need entension enabled
    AddRequiredExtensions(VK_KHR_VULKAN_MEMORY_MODEL_EXTENSION_NAME);
    AddRequiredFeature(vkt::Feature::cooperativeMatrix);
    AddRequiredFeature(vkt::Feature::vulkanMemoryModel);
    RETURN_IF_SKIP(Init());
}

class PositiveShaderCooperativeMatrix : public CooperativeMatrixTest {};

TEST_F(PositiveShaderCooperativeMatrix, CooperativeMatrixKHR) {
    TEST_DESCRIPTION("Test VK_KHR_cooperative_matrix.");
    SetTargetApiVersion(VK_API_VERSION_1_3);
    AddRequiredFeature(vkt::Feature::shaderFloat16);
    AddRequiredFeature(vkt::Feature::storageBuffer16BitAccess);
    RETURN_IF_SKIP(InitCooperativeMatrixKHR());
    CooperativeMatrixHelper helper(*this);

    VkCooperativeMatrixPropertiesKHR subgroup_prop = vku::InitStructHelper();
    bool found_scope_subgroup = false;
    for (const auto &prop : helper.coop_matrix_props) {
        // We only have the 16-bit features enabled, but 32-bit also works
        if (prop.scope == VK_SCOPE_SUBGROUP_KHR && !helper.Has8BitComponentType(prop) && !helper.Has64BitComponentType(prop)) {
            found_scope_subgroup = true;
            subgroup_prop = prop;
            break;
        }
    }
    if (!found_scope_subgroup) {
        GTEST_SKIP() << "VK_SCOPE_SUBGROUP_KHR not Found";
    }

    const vkt::DescriptorSetLayout dsl(*m_device,
                                       {
                                           {0, VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, 1, VK_SHADER_STAGE_COMPUTE_BIT, nullptr},
                                           {1, VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, 1, VK_SHADER_STAGE_COMPUTE_BIT, nullptr},
                                           {2, VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, 1, VK_SHADER_STAGE_COMPUTE_BIT, nullptr},
                                           {3, VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, 1, VK_SHADER_STAGE_COMPUTE_BIT, nullptr},
                                       });

    std::string css = R"glsl(
         #version 450 core
         #pragma use_vulkan_memory_model
         #extension GL_KHR_shader_subgroup_basic : enable
         #extension GL_KHR_memory_scope_semantics : enable
         #extension GL_KHR_cooperative_matrix : enable
         #extension GL_EXT_shader_explicit_arithmetic_types : enable
         #extension GL_EXT_shader_explicit_arithmetic_types_float16 : enable
         layout(local_size_x = 64) in;
         layout(set=0, binding=0) coherent buffer InputA { %type_A% x[]; } inputA;
         layout(set=0, binding=1) coherent buffer InputB { %type_B% x[]; } inputB;
         layout(set=0, binding=2) coherent buffer InputC { %type_C% x[]; } inputC;
         layout(set=0, binding=3) coherent buffer Output { %type_R% x[]; } outputO;
         coopmat<%type_A%, gl_ScopeSubgroup, %M%, %K%, gl_MatrixUseA> matA;
         coopmat<%type_B%, gl_ScopeSubgroup, %K%, %N%, gl_MatrixUseB> matB;
         coopmat<%type_C%, gl_ScopeSubgroup, %M%, %N%, gl_MatrixUseAccumulator> matC;
         coopmat<%type_R%, gl_ScopeSubgroup, %M%, %N%, gl_MatrixUseAccumulator> matO;
         void main()
         {
             coopMatLoad(matA, inputA.x, 0, %M%, gl_CooperativeMatrixLayoutRowMajor);
             coopMatLoad(matB, inputB.x, 0, %K%, gl_CooperativeMatrixLayoutRowMajor);
             coopMatLoad(matC, inputC.x, 0, %M%, gl_CooperativeMatrixLayoutRowMajor);
             matO = coopMatMulAdd(matA, matB, matC);
             coopMatStore(matO, outputO.x, 0, %M%, gl_CooperativeMatrixLayoutRowMajor);
         }
    )glsl";

    auto replace = [](std::string &str, const std::string &from, const std::string &to) {
        size_t pos;
        while ((pos = str.find(from)) != std::string::npos) str.replace(pos, from.length(), to);
    };
    replace(css, "%M%", std::to_string(subgroup_prop.MSize));
    replace(css, "%N%", std::to_string(subgroup_prop.NSize));
    replace(css, "%K%", std::to_string(subgroup_prop.KSize));
    replace(css, "%type_A%", helper.VkComponentTypeToGLSL(subgroup_prop.AType));
    replace(css, "%type_B%", helper.VkComponentTypeToGLSL(subgroup_prop.BType));
    replace(css, "%type_C%", helper.VkComponentTypeToGLSL(subgroup_prop.CType));
    replace(css, "%type_R%", helper.VkComponentTypeToGLSL(subgroup_prop.ResultType));

    CreateComputePipelineHelper pipe(*this);
    pipe.cs_ = VkShaderObj(*m_device, css.c_str(), VK_SHADER_STAGE_COMPUTE_BIT, SPV_ENV_VULKAN_1_3);
    pipe.pipeline_layout_ = vkt::PipelineLayout(*m_device, {&dsl});
    pipe.CreateComputePipeline();
    m_errorMonitor->VerifyFound();
}

TEST_F(PositiveShaderCooperativeMatrix, RequiredSubgroupSize) {
    TEST_DESCRIPTION("https://github.com/KhronosGroup/Vulkan-ValidationLayers/issues/9843");
    SetTargetApiVersion(VK_API_VERSION_1_3);
    AddRequiredFeature(vkt::Feature::shaderFloat16);
    AddRequiredFeature(vkt::Feature::storageBuffer16BitAccess);
    AddRequiredFeature(vkt::Feature::subgroupSizeControl);
    RETURN_IF_SKIP(InitCooperativeMatrixKHR());
    if (!IsPlatformMockICD()) {
        GTEST_SKIP() << "This makes assumption about possible coop matrix subgroup size and support.";
    }

    const vkt::DescriptorSetLayout dsl(*m_device, {0, VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, 1, VK_SHADER_STAGE_COMPUTE_BIT, nullptr});
    const vkt::PipelineLayout pipeline_layout(*m_device, {&dsl});

    const char *cs_source = R"glsl(
         #version 450 core
         #pragma use_vulkan_memory_model
         #extension GL_KHR_shader_subgroup_basic : enable
         #extension GL_KHR_memory_scope_semantics : enable
         #extension GL_KHR_cooperative_matrix : enable
         #extension GL_EXT_shader_explicit_arithmetic_types : enable
         #extension GL_EXT_shader_explicit_arithmetic_types_float16 : enable
         layout(local_size_x = 16) in;
         layout(set=0, binding=0) coherent buffer InputA { uint32_t x[]; } inputA;
         coopmat<uint32_t, gl_ScopeSubgroup, 16, 16, gl_MatrixUseA> matA;
         void main() {
             coopMatLoad(matA, inputA.x, 0, 16, gl_CooperativeMatrixLayoutRowMajor);
         }
    )glsl";
    VkShaderObj cs(*m_device, cs_source, VK_SHADER_STAGE_COMPUTE_BIT, SPV_ENV_VULKAN_1_3);

    VkPhysicalDeviceSubgroupSizeControlPropertiesEXT subgroup_properties = vku::InitStructHelper();
    VkPhysicalDeviceVulkan11Properties props11 = vku::InitStructHelper(&subgroup_properties);
    GetPhysicalDeviceProperties2(props11);
    if ((subgroup_properties.requiredSubgroupSizeStages & VK_SHADER_STAGE_COMPUTE_BIT) == 0) {
        GTEST_SKIP() << "Required shader stage not present in requiredSubgroupSizeStages";
    }

    if (subgroup_properties.minSubgroupSize != 16) {
        GTEST_SKIP() << "Testing when we go under the limit";
    }

    VkPipelineShaderStageRequiredSubgroupSizeCreateInfo subgroup_size_control = vku::InitStructHelper();
    subgroup_size_control.requiredSubgroupSize = subgroup_properties.minSubgroupSize;

    CreateComputePipelineHelper pipe(*this);
    pipe.cp_ci_.stage = cs.GetStageCreateInfo();
    pipe.cp_ci_.stage.pNext = &subgroup_size_control;
    pipe.cp_ci_.layout = pipeline_layout;
    pipe.CreateComputePipeline(false);
}

TEST_F(PositiveShaderCooperativeMatrix, RequiredVulkanVersionPipeline) {
    TEST_DESCRIPTION("https://gitlab.khronos.org/spirv/SPIR-V/-/issues/847");
    SetTargetApiVersion(VK_API_VERSION_1_1);
    AddRequiredExtensions(VK_KHR_SHADER_FLOAT16_INT8_EXTENSION_NAME);
    AddRequiredExtensions(VK_EXT_SUBGROUP_SIZE_CONTROL_EXTENSION_NAME);
    AddRequiredExtensions(VK_KHR_16BIT_STORAGE_EXTENSION_NAME);
    AddRequiredFeature(vkt::Feature::shaderFloat16);
    AddRequiredFeature(vkt::Feature::storageBuffer16BitAccess);
    AddRequiredFeature(vkt::Feature::computeFullSubgroups);
    RETURN_IF_SKIP(InitCooperativeMatrixKHR());
    if (!IsPlatformMockICD()) {
        GTEST_SKIP() << "This makes assumption about possible coop matrix subgroup size and support.";
    }

    const vkt::DescriptorSetLayout dsl(*m_device, {0, VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, 1, VK_SHADER_STAGE_COMPUTE_BIT, nullptr});
    const vkt::PipelineLayout pipeline_layout(*m_device, {&dsl});

    const char *cs_source = R"glsl(
         #version 450 core
         #pragma use_vulkan_memory_model
         #extension GL_KHR_shader_subgroup_basic : enable
         #extension GL_KHR_memory_scope_semantics : enable
         #extension GL_KHR_cooperative_matrix : enable
         #extension GL_EXT_shader_explicit_arithmetic_types : enable
         #extension GL_EXT_shader_explicit_arithmetic_types_float16 : enable
         layout(local_size_x = 32) in;
         layout(set=0, binding=0) coherent buffer InputA { uint32_t x[]; } inputA;
         coopmat<uint32_t, gl_ScopeSubgroup, 16, 16, gl_MatrixUseA> matA;
         void main() {
             coopMatLoad(matA, inputA.x, 0, 16, gl_CooperativeMatrixLayoutRowMajor);
         }
    )glsl";
    VkShaderObj cs(*m_device, cs_source, VK_SHADER_STAGE_COMPUTE_BIT, SPV_ENV_VULKAN_1_1);

    CreateComputePipelineHelper pipe(*this);
    pipe.cp_ci_.stage = cs.GetStageCreateInfo();
    pipe.cp_ci_.stage.flags = VK_PIPELINE_SHADER_STAGE_CREATE_REQUIRE_FULL_SUBGROUPS_BIT;
    pipe.cp_ci_.layout = pipeline_layout;
    pipe.CreateComputePipeline(false);
}

TEST_F(PositiveShaderCooperativeMatrix, RequiredVulkanVersionShaderObject) {
    TEST_DESCRIPTION("https://gitlab.khronos.org/spirv/SPIR-V/-/issues/847");
    SetTargetApiVersion(VK_API_VERSION_1_1);
    AddRequiredExtensions(VK_KHR_SHADER_FLOAT16_INT8_EXTENSION_NAME);
    AddRequiredExtensions(VK_EXT_SUBGROUP_SIZE_CONTROL_EXTENSION_NAME);
    AddRequiredExtensions(VK_KHR_16BIT_STORAGE_EXTENSION_NAME);
    AddRequiredExtensions(VK_EXT_SHADER_OBJECT_EXTENSION_NAME);
    AddRequiredFeature(vkt::Feature::shaderObject);
    AddRequiredFeature(vkt::Feature::shaderFloat16);
    AddRequiredFeature(vkt::Feature::storageBuffer16BitAccess);
    AddRequiredFeature(vkt::Feature::computeFullSubgroups);
    RETURN_IF_SKIP(InitCooperativeMatrixKHR());
    if (!IsPlatformMockICD()) {
        GTEST_SKIP() << "This makes assumption about possible coop matrix subgroup size and support.";
    }

    const vkt::DescriptorSetLayout dsl(*m_device,
                                       {0, VK_DESCRIPTOR_TYPE_STORAGE_BUFFER_DYNAMIC, 1, VK_SHADER_STAGE_COMPUTE_BIT, nullptr});

    const char *cs_source = R"glsl(
         #version 450 core
         #pragma use_vulkan_memory_model
         #extension GL_KHR_shader_subgroup_basic : enable
         #extension GL_KHR_memory_scope_semantics : enable
         #extension GL_KHR_cooperative_matrix : enable
         #extension GL_EXT_shader_explicit_arithmetic_types : enable
         #extension GL_EXT_shader_explicit_arithmetic_types_float16 : enable
         layout(local_size_x = 32) in;
         layout(set=0, binding=0) coherent buffer InputA { uint32_t x[]; } inputA;
         coopmat<uint32_t, gl_ScopeSubgroup, 16, 16, gl_MatrixUseA> matA;
         void main() {
             coopMatLoad(matA, inputA.x, 0, 16, gl_CooperativeMatrixLayoutRowMajor);
         }
    )glsl";

    const auto spv = GLSLToSPV(VK_SHADER_STAGE_COMPUTE_BIT, cs_source, SPV_ENV_VULKAN_1_1);
    auto shader_ci = ShaderCreateInfoNoNextStage(spv, VK_SHADER_STAGE_COMPUTE_BIT, 1, &dsl.handle());
    shader_ci.flags = VK_SHADER_CREATE_REQUIRE_FULL_SUBGROUPS_BIT_EXT;
    const vkt::Shader comp_shader(*m_device, shader_ci);
}

TEST_F(PositiveShaderCooperativeMatrix, BFloat16) {
    SetTargetApiVersion(VK_API_VERSION_1_3);
    AddRequiredExtensions(VK_KHR_SHADER_BFLOAT16_EXTENSION_NAME);
    AddRequiredExtensions(VK_KHR_SHADER_FLOAT16_INT8_EXTENSION_NAME);
    AddRequiredFeature(vkt::Feature::shaderFloat16);
    AddRequiredFeature(vkt::Feature::shaderBFloat16Type);
    AddRequiredFeature(vkt::Feature::shaderBFloat16CooperativeMatrix);
    RETURN_IF_SKIP(InitCooperativeMatrixKHR());

    const char *cs_source = R"glsl(
        #version 450 core
        #extension GL_EXT_bfloat16 : require
        #extension GL_EXT_shader_explicit_arithmetic_types : enable
        #extension GL_KHR_memory_scope_semantics : enable
        #extension GL_KHR_cooperative_matrix : enable
        layout(local_size_x = 32) in;
        void main() {
            coopmat<bfloat16_t, gl_ScopeSubgroup, 16, 16, gl_MatrixUseA> cmA = coopmat<bfloat16_t, gl_ScopeSubgroup, 16, 16, gl_MatrixUseA>(3.0);
        }
    )glsl";

    CreateComputePipelineHelper pipe(*this);
    pipe.cs_ = VkShaderObj(*m_device, cs_source, VK_SHADER_STAGE_COMPUTE_BIT, SPV_ENV_VULKAN_1_1);
    pipe.CreateComputePipeline();
}

TEST_F(PositiveShaderCooperativeMatrix, Float8) {
    SetTargetApiVersion(VK_API_VERSION_1_3);
    AddRequiredExtensions(VK_EXT_SHADER_FLOAT8_EXTENSION_NAME);
    AddRequiredFeature(vkt::Feature::storageBuffer8BitAccess);
    AddRequiredFeature(vkt::Feature::shaderFloat8);
    AddRequiredFeature(vkt::Feature::shaderFloat8CooperativeMatrix);
    AddRequiredFeature(vkt::Feature::shaderInt8);
    RETURN_IF_SKIP(InitCooperativeMatrixKHR());

    const char *cs_source = R"glsl(
        #version 450 core
        #extension GL_EXT_float_e4m3 : require
        #extension GL_EXT_shader_explicit_arithmetic_types : enable
        #extension GL_KHR_memory_scope_semantics : enable
        #extension GL_KHR_cooperative_matrix : enable
        layout(local_size_x = 32) in;
        void main() {
            coopmat<floate4m3_t, gl_ScopeSubgroup, 16, 16, gl_MatrixUseA> cmA = coopmat<floate4m3_t, gl_ScopeSubgroup, 16, 16, gl_MatrixUseA>(3.0);
        }
    )glsl";

    CreateComputePipelineHelper pipe(*this);
    pipe.cs_ = VkShaderObj(*m_device, cs_source, VK_SHADER_STAGE_COMPUTE_BIT, SPV_ENV_VULKAN_1_1);
    pipe.CreateComputePipeline();
}

TEST_F(PositiveShaderCooperativeMatrix, Int8) {
    SetTargetApiVersion(VK_API_VERSION_1_3);
    AddRequiredFeature(vkt::Feature::shaderInt8);
    AddRequiredFeature(vkt::Feature::storageBuffer8BitAccess);
    RETURN_IF_SKIP(InitCooperativeMatrixKHR());
    CooperativeMatrixHelper helper(*this);
    if (!helper.Has16x16UintProperty()) {
        GTEST_SKIP() << "desired VkCooperativeMatrixPropertiesKHR not found";
    }

    OneOffDescriptorSet descriptor_set(m_device, {{0, VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, 1, VK_SHADER_STAGE_ALL, nullptr},
                                                  {1, VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, 1, VK_SHADER_STAGE_ALL, nullptr},
                                                  {2, VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, 1, VK_SHADER_STAGE_ALL, nullptr},
                                                  {3, VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, 1, VK_SHADER_STAGE_ALL, nullptr}});
    vkt::PipelineLayout pl(*m_device, {&descriptor_set.layout_});

    std::string css = R"glsl(
         #version 450 core
         #pragma use_vulkan_memory_model
         #extension GL_KHR_shader_subgroup_basic : enable
         #extension GL_KHR_memory_scope_semantics : enable
         #extension GL_KHR_cooperative_matrix : enable
         #extension GL_EXT_shader_explicit_arithmetic_types : enable
         #extension GL_EXT_shader_explicit_arithmetic_types_int8 : enable
         layout(local_size_x = 64) in;
         layout(set=0, binding=0) coherent buffer InputA { uint8_t x[]; } inputA;
         layout(set=0, binding=1) coherent buffer InputB { uint8_t x[]; } inputB;
         layout(set=0, binding=2) coherent buffer InputC { uint32_t x[]; } inputC;
         layout(set=0, binding=3) coherent buffer Output { uint32_t x[]; } outputO;
         coopmat<uint8_t, gl_ScopeSubgroup, 16, 16, gl_MatrixUseA> matA;
         coopmat<uint8_t, gl_ScopeSubgroup, 16, 16, gl_MatrixUseB> matB;
         coopmat<uint32_t, gl_ScopeSubgroup, 16, 16, gl_MatrixUseAccumulator> matC;
         coopmat<uint32_t, gl_ScopeSubgroup, 16, 16, gl_MatrixUseAccumulator> matO;
         void main() {
             coopMatLoad(matA, inputA.x, 0, 16, gl_CooperativeMatrixLayoutRowMajor);
             coopMatLoad(matB, inputB.x, 0, 16, gl_CooperativeMatrixLayoutRowMajor);
             coopMatLoad(matC, inputC.x, 0, 16, gl_CooperativeMatrixLayoutRowMajor);
             matO = coopMatMulAdd(matA, matB, matC);
             coopMatStore(matO, outputO.x, 0, 16, gl_CooperativeMatrixLayoutRowMajor);
         }
    )glsl";

    CreateComputePipelineHelper pipe(*this);
    pipe.cs_ = VkShaderObj(*m_device, css.c_str(), VK_SHADER_STAGE_COMPUTE_BIT, SPV_ENV_VULKAN_1_3);
    pipe.cp_ci_.layout = pl;
    pipe.CreateComputePipeline();
    m_errorMonitor->VerifyFound();
}