| #include "checker.h" |
| #include <dlfcn.h> |
| #include <sstream> |
| #include <fstream> |
| #include <iomanip> |
| #include <limits> |
| #include <getopt.h> |
| #include <unistd.h> |
|
|
| std::pair<bool, std::string> verbose_allclose(const torch::Tensor &received, const torch::Tensor &expected, |
| float rtol = 1e-05, float atol = 1e-08, int max_print = 5) { |
| |
| if (received.sizes() != expected.sizes()) { |
| std::string expected_shape_str = "["; |
| std::string received_shape_str = "["; |
| auto expected_sizes = expected.sizes(); |
| auto received_sizes = received.sizes(); |
|
|
| for (int i = 0; i < expected_sizes.size(); i++) { |
| expected_shape_str += std::to_string(expected_sizes[i]); |
| if (i < expected_sizes.size() - 1) |
| expected_shape_str += ", "; |
| } |
| expected_shape_str += "]"; |
|
|
| for (int i = 0; i < received_sizes.size(); i++) { |
| received_shape_str += std::to_string(received_sizes[i]); |
| if (i < received_sizes.size() - 1) |
| received_shape_str += ", "; |
| } |
| received_shape_str += "]"; |
|
|
| return {false, "SIZE MISMATCH: expected " + expected_shape_str + " but got " + received_shape_str}; |
| } |
|
|
| auto diff = torch::abs(received.to(torch::kFloat32) - expected.to(torch::kFloat32)); |
|
|
| auto tolerance = atol + rtol * torch::abs(expected); |
|
|
| auto tol_mismatched = diff > tolerance; |
| auto nan_mismatched = torch::logical_xor(torch::isnan(received), torch::isnan(expected)); |
| auto posinf_mismatched = torch::logical_xor(torch::isposinf(received), torch::isposinf(expected)); |
| auto neginf_mismatched = torch::logical_xor(torch::isneginf(received), torch::isneginf(expected)); |
|
|
| auto mismatched = torch::logical_or(torch::logical_or(tol_mismatched, nan_mismatched), |
| torch::logical_or(posinf_mismatched, neginf_mismatched)); |
|
|
| auto mismatched_indices = torch::nonzero(mismatched); |
|
|
| |
| int64_t num_mismatched = mismatched.sum().item<int64_t>(); |
|
|
| |
| if (num_mismatched >= 1) { |
| std::stringstream mismatch_details; |
| auto sizes = received.sizes(); |
| mismatch_details << "Mismatch found in tensors with shape ["; |
| for (int i = 0; i < sizes.size(); i++) { |
| mismatch_details << sizes[i]; |
| if (i < sizes.size() - 1) |
| mismatch_details << ", "; |
| } |
| mismatch_details << "]:\n"; |
| mismatch_details << "Number of mismatched elements: " << num_mismatched << "\n"; |
|
|
| for (int i = 0; i < std::min(max_print, (int)mismatched_indices.size(0)); i++) { |
| auto index = mismatched_indices[i]; |
| std::vector<int64_t> idx_vec; |
| for (int j = 0; j < index.size(0); j++) { |
| idx_vec.push_back(index[j].item<int64_t>()); |
| } |
|
|
| |
| std::string idx_str = "("; |
| for (size_t j = 0; j < idx_vec.size(); j++) { |
| idx_str += std::to_string(idx_vec[j]); |
| if (j < idx_vec.size() - 1) |
| idx_str += ", "; |
| } |
| idx_str += ")"; |
|
|
| float received_val, expected_val; |
| torch::Tensor received_elem = received; |
| torch::Tensor expected_elem = expected; |
|
|
| for (size_t j = 0; j < idx_vec.size(); j++) { |
| received_elem = received_elem[idx_vec[j]]; |
| expected_elem = expected_elem[idx_vec[j]]; |
| } |
|
|
| received_val = received_elem.item<float>(); |
| expected_val = expected_elem.item<float>(); |
|
|
| mismatch_details << "ERROR at " << idx_str << ": " << received_val << " " << expected_val << "\n"; |
| } |
|
|
| if (num_mismatched > max_print) { |
| mismatch_details << "... and " << (num_mismatched - max_print) << " more mismatched elements."; |
| } |
|
|
| return {false, mismatch_details.str()}; |
| } |
|
|
| return {true, "Maximum error: " + std::to_string(diff.max().item<float>())}; |
| } |
|
|
| |
| std::pair<bool, std::string> check_implementation(std::ofstream &fout, const torch::Tensor &output, |
| const torch::Tensor &expected, float rtol = 2e-02, float atol = 1e-03, |
| CheckerMode mode = CheckerMode::kElementWise) { |
| if (mode == CheckerMode::kRowIndex) { |
| |
| |
| auto sorted_output = output.clone(); |
| auto sorted_expected = expected.clone(); |
|
|
| sorted_output = std::get<0>(torch::sort(output, 1)); |
| sorted_expected = std::get<0>(torch::sort(expected, 1)); |
|
|
| return verbose_allclose(sorted_output, sorted_expected, rtol, atol); |
| } else if (mode == CheckerMode::kJustDump) { |
| |
| { |
| fout << "=====OUTPUT=====" << std::endl; |
| fout << output.sizes() << std::endl; |
|
|
| |
| auto sizes = output.sizes(); |
| if (sizes.size() == 2) { |
| |
| for (int64_t i = 0; i < sizes[0]; i++) { |
| for (int64_t j = 0; j < sizes[1]; j++) { |
| fout << std::setw(12) << std::setprecision(6) << output[i][j].item<float>() << " "; |
| } |
| fout << std::endl; |
| } |
| } else { |
| |
| fout << output << std::endl; |
| } |
| } |
|
|
| { |
| fout << "=====EXPECTED=====" << std::endl; |
| fout << expected.sizes() << std::endl; |
|
|
| |
| auto sizes = output.sizes(); |
| if (sizes.size() == 2) { |
| |
| for (int64_t i = 0; i < sizes[0]; i++) { |
| for (int64_t j = 0; j < sizes[1]; j++) { |
| fout << std::setw(12) << std::setprecision(6) << expected[i][j].item<float>() << " "; |
| } |
| fout << std::endl; |
| } |
| } else { |
| |
| fout << output << std::endl; |
| } |
| } |
|
|
| return {true, ""}; |
| } |
| return verbose_allclose(output, expected, rtol, atol); |
| } |
|
|
| constexpr int BENCHMARK_ITERS = 5; |
|
|
| void preload() { |
| void *handle_rocblas = dlopen("/usr/local/lib/python3.10/dist-packages/torch/lib/librocblas.so", RTLD_NOW | RTLD_GLOBAL); |
| void *handle_hipblas = dlopen("/usr/local/lib/python3.10/dist-packages/torch/lib/libhipblas.so", RTLD_NOW | RTLD_GLOBAL); |
| void *handle_hipblaslt = dlopen("/usr/local/lib/python3.10/dist-packages/torch/lib/libhipblaslt.so", RTLD_NOW | RTLD_GLOBAL); |
|
|
| if (!handle_rocblas || !handle_hipblas || !handle_hipblaslt) { |
| fprintf(stderr, "Failed to load required libraries: %s\n", dlerror()); |
| exit(1); |
| } |
| } |
|
|
| int main(int argc, char **argv) { |
| |
| |
| bool benchmark = true; |
| bool profile_mode = false; |
| int target_test_case = -1; |
| int target_sub_case = -1; |
| int opt; |
|
|
| while ((opt = getopt(argc, argv, "bpt:c:")) != -1) { |
| switch (opt) { |
| case 'b': |
| benchmark = false; |
| break; |
| case 'p': |
| profile_mode = true; |
| break; |
| case 't': |
| target_sub_case = std::stoi(optarg); |
| break; |
| case 'c': |
| target_test_case = std::stoi(optarg); |
| break; |
| default: |
| fprintf(stderr, "Usage: %s [-b] [-p] [-t subcase_index] [-c test_case_index]\n", argv[0]); |
| fprintf(stderr, " -b: Disable benchmark mode\n"); |
| fprintf(stderr, " -p: Enable profile mode (skips reference kernel and comparison)\n"); |
| fprintf(stderr, " -t: Run only the specified subcase index\n"); |
| fprintf(stderr, " -c: Run only the specified test case index\n"); |
| exit(EXIT_FAILURE); |
| } |
| } |
|
|
| case_initialize(); |
| int num_params, passed_cases = 0; |
| num_params = get_params_count(); |
|
|
| |
| if (target_test_case >= 0) { |
| if (target_test_case >= num_params) { |
| std::cerr << "Error: Test case index " << target_test_case << " is out of range (0-" << (num_params - 1) |
| << ")" << std::endl; |
| exit(EXIT_FAILURE); |
| } |
| } |
|
|
| std::vector<std::vector<PerfMetrics>> run_times(num_params); |
| std::vector<std::tuple<bool, std::string, std::vector<std::pair<float, float>>>> results; |
|
|
| |
| if (target_test_case >= 0 && target_sub_case >= 0) { |
| void *input = case_get_input(target_test_case); |
| std::vector<Checkee> output; |
| float best_time = std::numeric_limits<float>::max(); |
|
|
| for (int j = 0; j < BENCHMARK_ITERS; j++) { |
| PerfMetrics metrics; |
| output = case_run_kernel(input, &metrics); |
|
|
| if (metrics.count <= target_sub_case) { |
| std::cerr << "Error: Subcase index " << target_sub_case << " is out of range (0-" << (metrics.count - 1) |
| << ")" << std::endl; |
| exit(EXIT_FAILURE); |
| } |
|
|
| best_time = std::min(best_time, metrics.entries[target_sub_case].time); |
| } |
|
|
| std::cout << std::fixed << std::setprecision(6) << best_time * 1e3 << std::endl; |
| case_destroy(input); |
| return 0; |
| } |
|
|
| |
| if (!profile_mode && target_test_case < 0) { |
| std::cout << "Found " << num_params << " test cases for " << case_get_name() << '\n'; |
| } |
| if (benchmark) { |
| std::cout << "Benchmark mode enabled\n"; |
| } |
| if (profile_mode) { |
| std::cout << "Profile mode enabled (skipping reference kernels and comparison)\n"; |
| } |
|
|
| |
| std::vector<int> test_cases_to_run; |
| if (target_test_case >= 0) { |
| test_cases_to_run.push_back(target_test_case); |
| } else { |
| for (int i = 0; i < num_params; i++) { |
| test_cases_to_run.push_back(i); |
| } |
| } |
|
|
| for (int i : test_cases_to_run) { |
| std::ofstream *fout = nullptr; |
| void *input = case_get_input(i); |
| if (!profile_mode && target_test_case < 0) { |
| std::cerr << "Running test case " << i << std::flush; |
| } |
| std::vector<Checkee> reference; |
| if (!profile_mode) { |
| reference = case_run_ref_kernel(input); |
| } |
| std::vector<Checkee> output; |
| for (int j = 0; j < (benchmark ? BENCHMARK_ITERS : 1); j++) { |
| PerfMetrics metrics; |
| output = case_run_kernel(input, &metrics); |
| run_times[i].push_back(metrics); |
| } |
|
|
| bool match = true; |
| std::string case_message; |
|
|
| if (!profile_mode) { |
| if (reference.size() != output.size()) { |
| std::cerr << "Wrong test definition: reference and output have different sizes" << '\n'; |
| abort(); |
| } |
|
|
| for (int j = 0; j < reference.size(); j++) { |
| float rtol, atol; |
| get_error_tolerance(&rtol, &atol); |
| if (output[j].mode == CheckerMode::kJustDump) { |
| if (!fout) { |
| fout = new std::ofstream(std::string("case_") + std::to_string(i) + ".txt"); |
| } |
| *fout << "===== SUBCASE " << output[j].name << "=====" << std::endl; |
| } |
| auto [match_sub, message_sub] = |
| check_implementation(*fout, *output[j].tensor, *reference[j].tensor, rtol, atol, output[j].mode); |
| if (!match_sub) { |
| case_message += "Err on sub case " + std::to_string(j) + ": " + message_sub + "\n"; |
| match = false; |
| } |
| } |
| if (match) { |
| passed_cases++; |
| } |
| } else { |
| match = true; |
| passed_cases++; |
| } |
|
|
| std::vector<std::pair<float, float>> case_metrics; |
|
|
| |
| for (const auto &run : run_times[i]) { |
| if (run.count == 1) { |
| |
| case_metrics.push_back({run.entries[0].time, run.entries[0].gflops}); |
| } else { |
| |
| case_metrics.push_back({run.entries[0].time, run.entries[0].gflops}); |
| } |
| } |
|
|
| results.push_back(std::make_tuple(match, case_message, case_metrics)); |
| case_destroy(input); |
| if (!profile_mode && target_test_case < 0) { |
| std::cout << "\033[2K\r" << std::flush; |
| } |
| } |
|
|
| |
| if (target_test_case < 0) { |
| std::cout << "=======================" << '\n'; |
| if (!profile_mode) { |
| if (passed_cases == num_params) { |
| std::cout << "✅ All " << num_params << " test cases passed!" << '\n'; |
| } else { |
| std::cout << "❌ [" << num_params - passed_cases << "/" << num_params << "] test cases failed!" << '\n'; |
| } |
| } else { |
| std::cout << "Profile mode: results comparison skipped" << '\n'; |
| } |
| std::cout << "-----------------------" << '\n'; |
|
|
| for (int i = 0; i < num_params; i++) { |
| auto [match, message, metrics] = results[i]; |
|
|
| |
| float best_time = std::numeric_limits<float>::max(); |
| float best_gflops = 0.0f; |
| float worst_time = 0.0f; |
| float worst_gflops = std::numeric_limits<float>::max(); |
|
|
| for (const auto &[time, gflops] : metrics) { |
| best_time = std::min(best_time, time); |
| best_gflops = std::max(best_gflops, gflops); |
| worst_time = std::max(worst_time, time); |
| worst_gflops = std::min(worst_gflops, gflops); |
| } |
|
|
| std::string timing_info; |
| if (benchmark) { |
| std::stringstream ss; |
| ss << std::fixed << std::setprecision(2); |
| ss << "Best: [\033[1m" << best_time * 1e3 << "\033[0m us, \033[1m" << best_gflops / 1e3 |
| << "\033[0m TFLOPS], " |
| << "\033[2mSlowest: [" << worst_time * 1e3 << " us, " << worst_gflops / 1e3 << " TFLOPS]\033[0m"; |
| timing_info = ss.str(); |
| } else { |
| std::stringstream ss; |
| ss << std::fixed << std::setprecision(2); |
| ss << "Time: " << best_time * 1e3 << " us, TFLOPS: " << best_gflops / 1e3; |
| timing_info = ss.str(); |
| } |
|
|
| if (!profile_mode && !match) { |
| std::cout << "❌ Test case " << i << ": " << timing_info << "\n" << message << '\n'; |
| } else { |
| std::cout << "✅ Test case " << i << ": " << timing_info << "\n"; |
| } |
|
|
| |
| if (run_times[i][0].count > 1) { |
| for (int j = 1; j < run_times[i][0].count; j++) { |
| std::stringstream ss; |
| ss << std::fixed << std::setprecision(2); |
| ss << " - Sub-case " << run_times[i][0].entries[j].name << ": "; |
|
|
| if (benchmark) { |
| float sub_best_time = std::numeric_limits<float>::max(); |
| float sub_best_gflops = 0.0f; |
| float sub_worst_time = 0.0f; |
| float sub_worst_gflops = std::numeric_limits<float>::max(); |
|
|
| for (const auto &run : run_times[i]) { |
| sub_best_time = std::min(sub_best_time, run.entries[j].time); |
| sub_best_gflops = std::max(sub_best_gflops, run.entries[j].gflops); |
| sub_worst_time = std::max(sub_worst_time, run.entries[j].time); |
| sub_worst_gflops = std::min(sub_worst_gflops, run.entries[j].gflops); |
| } |
|
|
| ss << "Best: [\033[1m" << sub_best_time * 1e3 << "\033[0m us, \033[1m" << sub_best_gflops / 1e3 |
| << "\033[0m TFLOPS], " |
| << "\033[2mSlowest: [" << sub_worst_time * 1e3 << " us, " << sub_worst_gflops / 1e3 |
| << " TFLOPS]\033[0m"; |
| } else { |
| ss << "Time: " << run_times[i][0].entries[j].time * 1e3 |
| << " us, TFLOPS: " << run_times[i][0].entries[j].gflops / 1e3; |
| } |
|
|
| std::cout << ss.str() << std::endl; |
| } |
| } |
| } |
| std::cout << "-----------------------" << '\n'; |
|
|
| |
| double geo_mean_time = 1.0; |
| double geo_mean_gflops = 1.0; |
|
|
| for (int i = 0; i < num_params; i++) { |
| auto [match, message, metrics] = results[i]; |
| |
| float best_time = std::numeric_limits<float>::max(); |
| float best_gflops = 0.0f; |
|
|
| for (const auto &[time, gflops] : metrics) { |
| best_time = std::min(best_time, time); |
| best_gflops = std::max(best_gflops, gflops); |
| } |
|
|
| geo_mean_time *= best_time; |
| geo_mean_gflops *= best_gflops; |
| } |
|
|
| geo_mean_time = std::pow(geo_mean_time, 1.0 / num_params); |
| geo_mean_gflops = std::pow(geo_mean_gflops, 1.0 / num_params); |
|
|
| if (benchmark) { |
| std::stringstream ss; |
| ss << std::fixed << std::setprecision(2); |
| ss << "GeoMean - Best Time: \033[1m" << geo_mean_time * 1e3 << "\033[0m us, Best TFLOPS: \033[1m" |
| << geo_mean_gflops / 1e3 << "\033[0m"; |
| std::cout << ss.str() << std::endl; |
| } else { |
| std::stringstream ss; |
| ss << std::fixed << std::setprecision(2); |
| ss << "GeoMean - Time: " << geo_mean_time * 1e3 << " us, TFLOPS: " << geo_mean_gflops / 1e3; |
| std::cout << ss.str() << std::endl; |
| } |
| std::cout << "=======================" << '\n'; |
| } |
|
|
| return 0; |
| } |
|
|