diff --git a/paddle/phi/kernels/stride/elementwise_grad_stride_kernel.cu b/paddle/phi/kernels/stride/elementwise_grad_stride_kernel.cu index d63f5b9652849..3bda19d312da4 100644 --- a/paddle/phi/kernels/stride/elementwise_grad_stride_kernel.cu +++ b/paddle/phi/kernels/stride/elementwise_grad_stride_kernel.cu @@ -38,7 +38,7 @@ COMMON_DECLARE_bool(use_stride_compute_kernel); namespace phi { -inline void PrepareStridedOut(DenseTensor* out) { +inline void PrepareStridedOut_elementwise(DenseTensor* out) { if (!FLAGS_use_stride_kernel) { PADDLE_THROW(common::errors::Fatal( "FLAGS_use_stride_kernel is closed. Strided kernel " @@ -56,7 +56,7 @@ void SumStrideKernel(const Context& dev_ctx, DataType out_dtype, bool keep_dim, DenseTensor* out) { - PrepareStridedOut(out); + PrepareStridedOut_elementwise(out); phi::SumKernel(dev_ctx, x, dims, out_dtype, keep_dim, out); } diff --git a/paddle/phi/kernels/stride/matmul_grad_stride_kernel.cu b/paddle/phi/kernels/stride/matmul_grad_stride_kernel.cu index 69859b66b8ba7..aa60a77db2449 100644 --- a/paddle/phi/kernels/stride/matmul_grad_stride_kernel.cu +++ b/paddle/phi/kernels/stride/matmul_grad_stride_kernel.cu @@ -44,7 +44,7 @@ inline bool UseCanonicalizedTransposeGradPath(const Context& dev_ctx) { #endif } -inline void PrepareStridedOut(DenseTensor* out) { +inline void PrepareStridedOut_matmul(DenseTensor* out) { if (out == nullptr) { return; } @@ -175,8 +175,8 @@ void MatmulGradStrideKernel(const Context& dev_ctx, if (!out_grad_.meta().is_contiguous()) { out_grad_ = Tensor2Contiguous(dev_ctx, out_grad_); } - PrepareStridedOut(dx); - PrepareStridedOut(dy); + PrepareStridedOut_matmul(dx); + PrepareStridedOut_matmul(dy); phi::MatmulGradKernel( dev_ctx, x_, y_, out_grad_, transpose_x, transpose_y, dx, dy); return; @@ -204,14 +204,14 @@ void MatmulGradStrideKernel(const Context& dev_ctx, dx_tmp.Resize(x_.dims()); dx_out = &dx_tmp; } else { - PrepareStridedOut(dx_out); + PrepareStridedOut_matmul(dx_out); } if (dy != nullptr && y_info.applied) { dy_tmp.Resize(y_.dims()); dy_out = &dy_tmp; } else { - PrepareStridedOut(dy_out); + PrepareStridedOut_matmul(dy_out); } phi::MatmulGradKernel( diff --git a/paddle/phi/kernels/stride/reduce_stride_kernel.cu b/paddle/phi/kernels/stride/reduce_stride_kernel.cu index 1c6c1d0bc9655..47b6f0882f32a 100644 --- a/paddle/phi/kernels/stride/reduce_stride_kernel.cu +++ b/paddle/phi/kernels/stride/reduce_stride_kernel.cu @@ -34,7 +34,7 @@ COMMON_DECLARE_bool(force_stride_compute_contig_out); namespace phi { -inline void PrepareStridedOut(DenseTensor* out) { +inline void PrepareStridedOut_reduce(DenseTensor* out) { if (!FLAGS_use_stride_kernel) { PADDLE_THROW(common::errors::Fatal( "FLAGS_use_stride_kernel is closed. Strided kernel " @@ -51,7 +51,7 @@ void AMaxStrideKernel(const Context& dev_ctx, const std::vector& dims, bool keep_dim, DenseTensor* out) { - PrepareStridedOut(out); + PrepareStridedOut_reduce(out); phi::AMaxKernel(dev_ctx, x, dims, keep_dim, out); } @@ -62,7 +62,7 @@ void AMinStrideKernel(const Context& dev_ctx, const std::vector& dims, bool keep_dim, DenseTensor* out) { - PrepareStridedOut(out); + PrepareStridedOut_reduce(out); phi::AMinKernel(dev_ctx, x, dims, keep_dim, out); } @@ -73,7 +73,7 @@ void MaxStrideKernel(const Context& dev_ctx, const IntArray& dims, bool keep_dim, DenseTensor* out) { - PrepareStridedOut(out); + PrepareStridedOut_reduce(out); phi::MaxKernel(dev_ctx, x, dims, keep_dim, out); } @@ -84,7 +84,7 @@ void MinStrideKernel(const Context& dev_ctx, const IntArray& dims, bool keep_dim, DenseTensor* out) { - PrepareStridedOut(out); + PrepareStridedOut_reduce(out); phi::MinKernel(dev_ctx, x, dims, keep_dim, out); } @@ -96,7 +96,7 @@ void ProdStrideKernel(const Context& dev_ctx, bool keep_dim, bool reduce_all, DenseTensor* out) { - PrepareStridedOut(out); + PrepareStridedOut_reduce(out); phi::ProdKernel(dev_ctx, x, dims, keep_dim, reduce_all, out); } @@ -107,7 +107,7 @@ void AllStrideKernel(const Context& dev_ctx, const std::vector& dims, bool keep_dim, DenseTensor* out) { - PrepareStridedOut(out); + PrepareStridedOut_reduce(out); phi::AllKernel(dev_ctx, x, dims, keep_dim, out); } @@ -118,7 +118,7 @@ void AnyStrideKernel(const Context& dev_ctx, const std::vector& dims, bool keep_dim, DenseTensor* out) { - PrepareStridedOut(out); + PrepareStridedOut_reduce(out); phi::AnyKernel(dev_ctx, x, dims, keep_dim, out); } @@ -130,7 +130,7 @@ void SumStrideKernel(const Context& dev_ctx, DataType out_dtype, bool keep_dim, DenseTensor* out) { - PrepareStridedOut(out); + PrepareStridedOut_reduce(out); phi::SumKernel(dev_ctx, x, dims, out_dtype, keep_dim, out); } @@ -142,7 +142,7 @@ void NansumStrideKernel(const Context& dev_ctx, DataType out_dtype, bool keep_dim, DenseTensor* out) { - PrepareStridedOut(out); + PrepareStridedOut_reduce(out); phi::NansumKernel(dev_ctx, x, dims, out_dtype, keep_dim, out); } @@ -152,7 +152,7 @@ void MeanStrideKernel(const Context& dev_ctx, const IntArray& dims, bool keep_dim, DenseTensor* out) { - PrepareStridedOut(out); + PrepareStridedOut_reduce(out); phi::MeanKernel(dev_ctx, x, dims, keep_dim, out); } diff --git a/paddle/utils/string/printf.h b/paddle/utils/string/printf.h index f2c87fb5e8ed3..8d4d5df84b3cf 100644 --- a/paddle/utils/string/printf.h +++ b/paddle/utils/string/printf.h @@ -81,7 +81,11 @@ namespace string { template void Fprintf(std::ostream& out, const char* fmt, const Args&... args) { - tinyformat::vformat(out, fmt, tinyformat::makeFormatList(args...)); + try { + tinyformat::vformat(out, fmt, tinyformat::makeFormatList(args...)); + } catch (const tinyformat::detail::FormatError&) { + out << fmt; + } } inline std::string Sprintf() { return ""; } @@ -95,9 +99,13 @@ std::string Sprintf(const Args&... args) { template std::string Sprintf(const char* fmt, const Args&... args) { - std::ostringstream oss; - Fprintf(oss, fmt, args...); - return oss.str(); + try { + std::ostringstream oss; + Fprintf(oss, fmt, args...); + return oss.str(); + } catch (const tinyformat::detail::FormatError&) { + return fmt; + } } template diff --git a/paddle/utils/string/tinyformat/tinyformat.h b/paddle/utils/string/tinyformat/tinyformat.h index 37a469cc763bf..4e9ec0be3f0de 100644 --- a/paddle/utils/string/tinyformat/tinyformat.h +++ b/paddle/utils/string/tinyformat/tinyformat.h @@ -119,9 +119,8 @@ // Additional API information // -------------------------- // -// Error handling: Define TINYFORMAT_ERROR to customize the error handling for -// format strings which are unsupported or have the wrong number of format -// specifiers (calls assert() by default). +// Error handling: Format errors throw detail::FormatError, which is caught +// at the public API level to fall back to the raw format string. // // User defined types: Uses operator<< for user defined types by default. // Overload formatValue() for more control. @@ -139,13 +138,14 @@ namespace paddle { namespace string { namespace tinyformat { -#ifndef TINYFORMAT_ERROR -#define TINYFORMAT_ERROR(reason) assert(0 && reason) -#endif - //------------------------------------------------------------------------------ namespace detail { +// Exception thrown on format errors instead of crashing via assert. +// Caught at the public API level to fall back to returning the raw format +// string, so that a wrong PADDLE_ENFORCE format never causes an abort. +struct FormatError {}; + // Test whether type T1 is convertible to type T2 template struct is_convertible { @@ -192,9 +192,7 @@ struct formatValueAsType { template ::value> struct convertToInt { static int invoke(const T & /*value*/) { - TINYFORMAT_ERROR( - "tinyformat: Cannot convert from argument type to " - "integer for use as variable width or precision"); + throw FormatError(); return 0; } }; @@ -579,8 +577,7 @@ inline const char *streamStateFromFormat(std::ostream &out, // NOLINT int &argIndex, // NOLINT int numFormatters) { if (*fmtStart != '%') { - TINYFORMAT_ERROR( - "tinyformat: Not enough conversion specifiers in format string"); + throw FormatError(); return fmtStart; } // Reset stream state to defaults. @@ -639,8 +636,7 @@ inline const char *streamStateFromFormat(std::ostream &out, // NOLINT if (argIndex < numFormatters) width = formatters[argIndex++].toInt(); else - TINYFORMAT_ERROR( - "tinyformat: Not enough arguments to read variable width"); + throw FormatError(); if (width < 0) { // negative widths correspond to '-' flag set out.fill(' '); @@ -659,8 +655,7 @@ inline const char *streamStateFromFormat(std::ostream &out, // NOLINT if (argIndex < numFormatters) precision = formatters[argIndex++].toInt(); else - TINYFORMAT_ERROR( - "tinyformat: Not enough arguments to read variable precision"); + throw FormatError(); } else { if (*c >= '0' && *c <= '9') precision = parseIntAndAdvance(c); @@ -724,9 +719,7 @@ inline const char *streamStateFromFormat(std::ostream &out, // NOLINT break; case 'a': case 'A': - TINYFORMAT_ERROR( - "tinyformat: the %a and %A conversion specs " - "are not supported"); + throw FormatError(); break; case 'c': // Handled as special case inside formatValue() @@ -738,12 +731,10 @@ inline const char *streamStateFromFormat(std::ostream &out, // NOLINT break; case 'n': // Not supported - will cause problems! - TINYFORMAT_ERROR("tinyformat: %n conversion spec not supported"); + throw FormatError(); break; case '\0': - TINYFORMAT_ERROR( - "tinyformat: Conversion spec incorrectly " - "terminated by end of string"); + throw FormatError(); return c; default: break; @@ -785,7 +776,7 @@ inline void formatImpl(std::ostream &out, numFormatters); if (argIndex >= numFormatters) { // Check args remain after reading any variable width/precision - TINYFORMAT_ERROR("tinyformat: Not enough format arguments"); + throw FormatError(); return; } const FormatArg &arg = formatters[argIndex]; @@ -811,9 +802,7 @@ inline void formatImpl(std::ostream &out, // Print remaining part of format string. fmt = printFormatStringLiteral(out, fmt); - if (fmt != nullptr && *fmt != '\0' && *fmt != 0) - TINYFORMAT_ERROR( - "tinyformat: Too many conversion specifiers in format string"); + if (fmt != nullptr && *fmt != '\0' && *fmt != 0) throw FormatError(); // Restore stream state out.width(origWidth);