close
Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions paddle/phi/kernels/stride/elementwise_grad_stride_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ COMMON_DECLARE_bool(use_stride_compute_kernel);

namespace phi {

inline void PrepareStridedOut(DenseTensor* out) {
inline void PrepareStridedOut_elementwise(DenseTensor* out) {
if (!FLAGS_use_stride_kernel) {
PADDLE_THROW(common::errors::Fatal(
"FLAGS_use_stride_kernel is closed. Strided kernel "
Expand All @@ -56,7 +56,7 @@ void SumStrideKernel(const Context& dev_ctx,
DataType out_dtype,
bool keep_dim,
DenseTensor* out) {
PrepareStridedOut(out);
PrepareStridedOut_elementwise(out);

phi::SumKernel<T, Context>(dev_ctx, x, dims, out_dtype, keep_dim, out);
}
Expand Down
10 changes: 5 additions & 5 deletions paddle/phi/kernels/stride/matmul_grad_stride_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ inline bool UseCanonicalizedTransposeGradPath(const Context& dev_ctx) {
#endif
}

inline void PrepareStridedOut(DenseTensor* out) {
inline void PrepareStridedOut_matmul(DenseTensor* out) {
if (out == nullptr) {
return;
}
Expand Down Expand Up @@ -175,8 +175,8 @@ void MatmulGradStrideKernel(const Context& dev_ctx,
if (!out_grad_.meta().is_contiguous()) {
out_grad_ = Tensor2Contiguous<Context>(dev_ctx, out_grad_);
}
PrepareStridedOut(dx);
PrepareStridedOut(dy);
PrepareStridedOut_matmul(dx);
PrepareStridedOut_matmul(dy);
phi::MatmulGradKernel<T, Context>(
dev_ctx, x_, y_, out_grad_, transpose_x, transpose_y, dx, dy);
return;
Expand Down Expand Up @@ -204,14 +204,14 @@ void MatmulGradStrideKernel(const Context& dev_ctx,
dx_tmp.Resize(x_.dims());
dx_out = &dx_tmp;
} else {
PrepareStridedOut(dx_out);
PrepareStridedOut_matmul(dx_out);
}

if (dy != nullptr && y_info.applied) {
dy_tmp.Resize(y_.dims());
dy_out = &dy_tmp;
} else {
PrepareStridedOut(dy_out);
PrepareStridedOut_matmul(dy_out);
}

phi::MatmulGradKernel<T, Context>(
Expand Down
22 changes: 11 additions & 11 deletions paddle/phi/kernels/stride/reduce_stride_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ COMMON_DECLARE_bool(force_stride_compute_contig_out);

namespace phi {

inline void PrepareStridedOut(DenseTensor* out) {
inline void PrepareStridedOut_reduce(DenseTensor* out) {
if (!FLAGS_use_stride_kernel) {
PADDLE_THROW(common::errors::Fatal(
"FLAGS_use_stride_kernel is closed. Strided kernel "
Expand All @@ -51,7 +51,7 @@ void AMaxStrideKernel(const Context& dev_ctx,
const std::vector<int64_t>& dims,
bool keep_dim,
DenseTensor* out) {
PrepareStridedOut(out);
PrepareStridedOut_reduce(out);

phi::AMaxKernel<T, Context>(dev_ctx, x, dims, keep_dim, out);
}
Expand All @@ -62,7 +62,7 @@ void AMinStrideKernel(const Context& dev_ctx,
const std::vector<int64_t>& dims,
bool keep_dim,
DenseTensor* out) {
PrepareStridedOut(out);
PrepareStridedOut_reduce(out);

phi::AMinKernel<T, Context>(dev_ctx, x, dims, keep_dim, out);
}
Expand All @@ -73,7 +73,7 @@ void MaxStrideKernel(const Context& dev_ctx,
const IntArray& dims,
bool keep_dim,
DenseTensor* out) {
PrepareStridedOut(out);
PrepareStridedOut_reduce(out);

phi::MaxKernel<T, Context>(dev_ctx, x, dims, keep_dim, out);
}
Expand All @@ -84,7 +84,7 @@ void MinStrideKernel(const Context& dev_ctx,
const IntArray& dims,
bool keep_dim,
DenseTensor* out) {
PrepareStridedOut(out);
PrepareStridedOut_reduce(out);

phi::MinKernel<T, Context>(dev_ctx, x, dims, keep_dim, out);
}
Expand All @@ -96,7 +96,7 @@ void ProdStrideKernel(const Context& dev_ctx,
bool keep_dim,
bool reduce_all,
DenseTensor* out) {
PrepareStridedOut(out);
PrepareStridedOut_reduce(out);

phi::ProdKernel<T, Context>(dev_ctx, x, dims, keep_dim, reduce_all, out);
}
Expand All @@ -107,7 +107,7 @@ void AllStrideKernel(const Context& dev_ctx,
const std::vector<int64_t>& dims,
bool keep_dim,
DenseTensor* out) {
PrepareStridedOut(out);
PrepareStridedOut_reduce(out);

phi::AllKernel<T, Context>(dev_ctx, x, dims, keep_dim, out);
}
Expand All @@ -118,7 +118,7 @@ void AnyStrideKernel(const Context& dev_ctx,
const std::vector<int64_t>& dims,
bool keep_dim,
DenseTensor* out) {
PrepareStridedOut(out);
PrepareStridedOut_reduce(out);

phi::AnyKernel<T, Context>(dev_ctx, x, dims, keep_dim, out);
}
Expand All @@ -130,7 +130,7 @@ void SumStrideKernel(const Context& dev_ctx,
DataType out_dtype,
bool keep_dim,
DenseTensor* out) {
PrepareStridedOut(out);
PrepareStridedOut_reduce(out);

phi::SumKernel<T, Context>(dev_ctx, x, dims, out_dtype, keep_dim, out);
}
Expand All @@ -142,7 +142,7 @@ void NansumStrideKernel(const Context& dev_ctx,
DataType out_dtype,
bool keep_dim,
DenseTensor* out) {
PrepareStridedOut(out);
PrepareStridedOut_reduce(out);
phi::NansumKernel<T, Context>(dev_ctx, x, dims, out_dtype, keep_dim, out);
}

Expand All @@ -152,7 +152,7 @@ void MeanStrideKernel(const Context& dev_ctx,
const IntArray& dims,
bool keep_dim,
DenseTensor* out) {
PrepareStridedOut(out);
PrepareStridedOut_reduce(out);

phi::MeanKernel<T, Context>(dev_ctx, x, dims, keep_dim, out);
}
Expand Down
16 changes: 12 additions & 4 deletions paddle/utils/string/printf.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,11 @@ namespace string {

template <typename... Args>
void Fprintf(std::ostream& out, const char* fmt, const Args&... args) {
tinyformat::vformat(out, fmt, tinyformat::makeFormatList(args...));
try {
tinyformat::vformat(out, fmt, tinyformat::makeFormatList(args...));
} catch (const tinyformat::detail::FormatError&) {
out << fmt;
}
}

inline std::string Sprintf() { return ""; }
Expand All @@ -95,9 +99,13 @@ std::string Sprintf(const Args&... args) {

template <typename... Args>
std::string Sprintf(const char* fmt, const Args&... args) {
std::ostringstream oss;
Fprintf(oss, fmt, args...);
return oss.str();
try {
std::ostringstream oss;
Fprintf(oss, fmt, args...);
return oss.str();
} catch (const tinyformat::detail::FormatError&) {
return fmt;
}
Comment on lines 101 to +108
}

template <typename... Args>
Expand Down
43 changes: 16 additions & 27 deletions paddle/utils/string/tinyformat/tinyformat.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,8 @@
// Additional API information
// --------------------------
//
// Error handling: Define TINYFORMAT_ERROR to customize the error handling for
// format strings which are unsupported or have the wrong number of format
// specifiers (calls assert() by default).
// Error handling: Format errors throw detail::FormatError, which is caught
// at the public API level to fall back to the raw format string.
Comment on lines +122 to +123
//
// User defined types: Uses operator<< for user defined types by default.
// Overload formatValue() for more control.
Expand All @@ -139,13 +138,14 @@ namespace paddle {
namespace string {
namespace tinyformat {

#ifndef TINYFORMAT_ERROR
#define TINYFORMAT_ERROR(reason) assert(0 && reason)
#endif

//------------------------------------------------------------------------------
namespace detail {

// Exception thrown on format errors instead of crashing via assert.
// Caught at the public API level to fall back to returning the raw format
// string, so that a wrong PADDLE_ENFORCE format never causes an abort.
struct FormatError {};

// Test whether type T1 is convertible to type T2
template <typename T1, typename T2>
struct is_convertible {
Expand Down Expand Up @@ -192,9 +192,7 @@ struct formatValueAsType<T, fmtT, true> {
template <typename T, bool convertible = is_convertible<T, int>::value>
struct convertToInt {
static int invoke(const T & /*value*/) {
TINYFORMAT_ERROR(
"tinyformat: Cannot convert from argument type to "
"integer for use as variable width or precision");
throw FormatError();
return 0;
}
};
Expand Down Expand Up @@ -579,8 +577,7 @@ inline const char *streamStateFromFormat(std::ostream &out, // NOLINT
int &argIndex, // NOLINT
int numFormatters) {
if (*fmtStart != '%') {
TINYFORMAT_ERROR(
"tinyformat: Not enough conversion specifiers in format string");
throw FormatError();
return fmtStart;
}
// Reset stream state to defaults.
Expand Down Expand Up @@ -639,8 +636,7 @@ inline const char *streamStateFromFormat(std::ostream &out, // NOLINT
if (argIndex < numFormatters)
width = formatters[argIndex++].toInt();
else
TINYFORMAT_ERROR(
"tinyformat: Not enough arguments to read variable width");
throw FormatError();
if (width < 0) {
// negative widths correspond to '-' flag set
out.fill(' ');
Expand All @@ -659,8 +655,7 @@ inline const char *streamStateFromFormat(std::ostream &out, // NOLINT
if (argIndex < numFormatters)
precision = formatters[argIndex++].toInt();
else
TINYFORMAT_ERROR(
"tinyformat: Not enough arguments to read variable precision");
throw FormatError();
} else {
if (*c >= '0' && *c <= '9')
precision = parseIntAndAdvance(c);
Expand Down Expand Up @@ -724,9 +719,7 @@ inline const char *streamStateFromFormat(std::ostream &out, // NOLINT
break;
case 'a':
case 'A':
TINYFORMAT_ERROR(
"tinyformat: the %a and %A conversion specs "
"are not supported");
throw FormatError();
break;
case 'c':
// Handled as special case inside formatValue()
Expand All @@ -738,12 +731,10 @@ inline const char *streamStateFromFormat(std::ostream &out, // NOLINT
break;
case 'n':
// Not supported - will cause problems!
TINYFORMAT_ERROR("tinyformat: %n conversion spec not supported");
throw FormatError();
break;
case '\0':
TINYFORMAT_ERROR(
"tinyformat: Conversion spec incorrectly "
"terminated by end of string");
throw FormatError();
return c;
default:
break;
Expand Down Expand Up @@ -785,7 +776,7 @@ inline void formatImpl(std::ostream &out,
numFormatters);
if (argIndex >= numFormatters) {
// Check args remain after reading any variable width/precision
TINYFORMAT_ERROR("tinyformat: Not enough format arguments");
throw FormatError();
return;
}
const FormatArg &arg = formatters[argIndex];
Expand All @@ -811,9 +802,7 @@ inline void formatImpl(std::ostream &out,

// Print remaining part of format string.
fmt = printFormatStringLiteral(out, fmt);
if (fmt != nullptr && *fmt != '\0' && *fmt != 0)
TINYFORMAT_ERROR(
"tinyformat: Too many conversion specifiers in format string");
if (fmt != nullptr && *fmt != '\0' && *fmt != 0) throw FormatError();

// Restore stream state
out.width(origWidth);
Expand Down
Loading