Skip to content

Commit ec42e19

Browse files
sbenzaquencopybara-github
authored andcommitted
Add native Abseil flag support to protobuf message types.
Now protobuf messages can be used with ABSL_FLAG declarations. It uses the text format representation. PiperOrigin-RevId: 867760218
1 parent a1e184a commit ec42e19

6 files changed

Lines changed: 300 additions & 0 deletions

File tree

src/google/protobuf/BUILD.bazel

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1145,9 +1145,11 @@ filegroup(
11451145

11461146
filegroup(
11471147
name = "test_proto_all_srcs",
1148+
testonly = 1,
11481149
srcs = [
11491150
":test_proto_editions_srcs",
11501151
":test_proto_srcs",
1152+
"//src/google/protobuf/test_protos:test_proto_srcs",
11511153
],
11521154
visibility = ["//:__subpackages__"],
11531155
)
@@ -2152,15 +2154,19 @@ cc_test(
21522154
"//src/google/protobuf/io",
21532155
"//src/google/protobuf/io:io_win32",
21542156
"//src/google/protobuf/stubs",
2157+
"//src/google/protobuf/test_protos:test_cc_protos",
21552158
"//src/google/protobuf/testing",
21562159
"//src/google/protobuf/testing:file",
21572160
"//src/google/protobuf/util:differencer",
21582161
"@abseil-cpp//absl/base:config",
2162+
"@abseil-cpp//absl/flags:flag",
2163+
"@abseil-cpp//absl/flags:marshalling",
21592164
"@abseil-cpp//absl/hash:hash_testing",
21602165
"@abseil-cpp//absl/log:absl_check",
21612166
"@abseil-cpp//absl/log:scoped_mock_log",
21622167
"@abseil-cpp//absl/strings",
21632168
"@abseil-cpp//absl/strings:cord",
2169+
"@abseil-cpp//absl/strings:str_format",
21642170
"@googletest//:gtest",
21652171
"@googletest//:gtest_main",
21662172
],

src/google/protobuf/message.cc

Lines changed: 211 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,20 +14,30 @@
1414
#include <cstddef>
1515
#include <cstdint>
1616
#include <cstring>
17+
#include <initializer_list>
1718
#include <new> // IWYU pragma: keep for operator new().
1819
#include <queue>
1920
#include <string>
21+
#include <utility>
22+
#include <variant>
2023
#include <vector>
2124

25+
#include "absl/algorithm/container.h"
2226
#include "absl/base/call_once.h"
2327
#include "absl/base/optimization.h"
2428
#include "absl/container/flat_hash_map.h"
2529
#include "absl/container/flat_hash_set.h"
2630
#include "absl/hash/hash.h"
2731
#include "absl/log/absl_check.h"
2832
#include "absl/log/absl_log.h"
33+
#include "absl/strings/escaping.h"
34+
#include "absl/strings/match.h"
35+
#include "absl/strings/str_cat.h"
36+
#include "absl/strings/str_format.h"
2937
#include "absl/strings/str_join.h"
38+
#include "absl/strings/str_split.h"
3039
#include "absl/strings/string_view.h"
40+
#include "absl/strings/strip.h"
3141
#include "absl/synchronization/mutex.h"
3242
#include "absl/types/optional.h"
3343
#include "google/protobuf/descriptor.h"
@@ -37,13 +47,15 @@
3747
#include "google/protobuf/generated_message_tctable_impl.h"
3848
#include "google/protobuf/generated_message_util.h"
3949
#include "google/protobuf/io/coded_stream.h"
50+
#include "google/protobuf/io/tokenizer.h"
4051
#include "google/protobuf/map_field.h"
4152
#include "google/protobuf/message_lite.h"
4253
#include "google/protobuf/parse_context.h"
4354
#include "google/protobuf/port.h"
4455
#include "google/protobuf/reflection_internal.h"
4556
#include "google/protobuf/reflection_ops.h"
4657
#include "google/protobuf/reflection_visit_fields.h"
58+
#include "google/protobuf/text_format.h"
4759
#include "google/protobuf/unknown_field_set.h"
4860
#include "google/protobuf/wire_format.h"
4961

@@ -124,6 +136,205 @@ PROTOBUF_CONSTINIT PROTOBUF_EXPORT const DescriptorMethods
124136
using internal::ReflectionOps;
125137
using internal::WireFormat;
126138

139+
namespace {
140+
141+
enum class AbslFlagFormat {
142+
kTextFormat,
143+
kSerialized,
144+
};
145+
146+
struct AbslFlagHeader {
147+
AbslFlagFormat format;
148+
absl::string_view format_name;
149+
std::vector<absl::string_view> options;
150+
bool uses_dead_char = false;
151+
bool uses_prefix = false;
152+
};
153+
154+
std::variant<AbslFlagHeader, std::string> ConsumeAbslFlagHeader(
155+
absl::string_view& text) {
156+
AbslFlagHeader header;
157+
158+
if (text.empty()) {
159+
// Whatever format is fine.
160+
header.format = AbslFlagFormat::kTextFormat;
161+
return header;
162+
}
163+
164+
if (absl::ConsumePrefix(&text, ":")) {
165+
header.uses_dead_char = true;
166+
}
167+
168+
auto pos = text.find(':');
169+
if (pos == text.npos) {
170+
header.format = AbslFlagFormat::kTextFormat;
171+
return header;
172+
}
173+
174+
header.uses_prefix = true;
175+
176+
absl::string_view format_spec = text.substr(0, pos);
177+
if (!header.uses_dead_char) {
178+
header.format_name = format_spec;
179+
// Legacy specs.
180+
if (format_spec == "text") {
181+
header.format = AbslFlagFormat::kTextFormat;
182+
} else if (format_spec == "base64text") {
183+
header.format = AbslFlagFormat::kTextFormat;
184+
header.options = {"base64"};
185+
} else if (format_spec == "base64serialized") {
186+
header.format = AbslFlagFormat::kSerialized;
187+
header.options = {"base64"};
188+
} else {
189+
if (absl::StrContains(format_spec, ",")) {
190+
return absl::StrFormat(
191+
"Format options are only allowed with delimited format specifier. "
192+
"Use `:%1$s:` instead of `%1$s:`",
193+
format_spec);
194+
}
195+
header.uses_prefix = false;
196+
header.format = AbslFlagFormat::kTextFormat;
197+
return header;
198+
}
199+
} else {
200+
std::vector<absl::string_view> parts = absl::StrSplit(format_spec, ',');
201+
header.format_name = parts[0];
202+
203+
if (header.format_name == "text") {
204+
header.format = AbslFlagFormat::kTextFormat;
205+
} else if (header.format_name == "serialized") {
206+
header.format = AbslFlagFormat::kSerialized;
207+
} else {
208+
return absl::StrFormat("Invalid format `%s`.", header.format_name);
209+
}
210+
211+
header.options.assign(parts.begin() + 1, parts.end());
212+
}
213+
214+
if (header.uses_prefix) {
215+
text.remove_prefix(pos + 1);
216+
}
217+
return header;
218+
}
219+
220+
} // namespace
221+
222+
bool Message::AbslParseFlagImpl(absl::string_view text, std::string& error) {
223+
Clear();
224+
225+
auto header_or_error = ConsumeAbslFlagHeader(text);
226+
if (std::holds_alternative<std::string>(header_or_error)) {
227+
error = std::get<std::string>(header_or_error);
228+
return false;
229+
}
230+
auto header = std::get<AbslFlagHeader>(std::move(header_or_error));
231+
232+
if (!header.uses_dead_char) {
233+
error = "Prefix must start with a `:`. Eg `:text:`.";
234+
return false;
235+
}
236+
237+
// If we have a prefix without a dead char, verify that the message does not
238+
// have a field by that name as that would be ambiguous.
239+
if (!header.uses_dead_char && header.uses_prefix &&
240+
GetDescriptor()->FindFieldByName(header.format_name) != nullptr) {
241+
error = absl::StrFormat(
242+
"Prefix `%s:` used is ambiguous with message fields. If you meant to "
243+
"use this prefix, use `:%s:` instead. If you meant to use text "
244+
"format, use `:text:` as a prefix.",
245+
header.format_name, header.format_name);
246+
return false;
247+
}
248+
249+
const auto verify_options =
250+
[&](std::initializer_list<absl::string_view> valid_options) -> bool {
251+
for (absl::string_view o : header.options) {
252+
if (!absl::c_linear_search(valid_options, o)) {
253+
error = absl::StrFormat("Unknown option `%s` for format `%s`.", o,
254+
header.format_name);
255+
return false;
256+
}
257+
}
258+
return true;
259+
};
260+
261+
static constexpr absl::string_view kBase64 = "base64";
262+
263+
std::string unescaped;
264+
const auto unescape_if_needed = [&] {
265+
if (absl::c_linear_search(header.options, kBase64)) {
266+
if (!absl::Base64Unescape(text, &unescaped)) {
267+
error = absl::StrFormat("Invalid base64 input.");
268+
return false;
269+
}
270+
text = unescaped;
271+
}
272+
return true;
273+
};
274+
275+
switch (header.format) {
276+
case AbslFlagFormat::kTextFormat: {
277+
static constexpr absl::string_view kIgnoreUnknown = "ignore_unknown";
278+
if (!verify_options({kIgnoreUnknown, kBase64})) return false;
279+
if (!unescape_if_needed()) return false;
280+
TextFormat::Parser parser;
281+
struct StringErrorCollector : io::ErrorCollector {
282+
explicit StringErrorCollector(std::string& error) : error(error) {}
283+
std::string& error;
284+
void RecordError(int line, io::ColumnNumber column,
285+
absl::string_view message) override {
286+
error = absl::StrFormat("(Line %v, Column %v): %v", line, column,
287+
message);
288+
}
289+
} collector(error);
290+
if (absl::c_linear_search(header.options, kIgnoreUnknown)) {
291+
parser.AllowUnknownField(true);
292+
parser.AllowUnknownExtension(true);
293+
}
294+
parser.RecordErrorsTo(&collector);
295+
return parser.ParseFromString(text, this);
296+
}
297+
298+
case AbslFlagFormat::kSerialized: {
299+
if (!verify_options({kBase64})) return false;
300+
if (!unescape_if_needed()) return false;
301+
return ParseFromString(text);
302+
}
303+
304+
default:
305+
internal::Unreachable();
306+
}
307+
}
308+
309+
std::string Message::AbslUnparseFlagImpl() const {
310+
bool has_ufs = !GetReflection()->GetUnknownFields(*this).empty();
311+
internal::VisitMessageFields(*this, [&](const auto& msg) {
312+
has_ufs = has_ufs || !msg.GetReflection()->GetUnknownFields(msg).empty();
313+
});
314+
315+
if (has_ufs) {
316+
// We can't use text format because it won't round trip
317+
// Use binary instead.
318+
return absl::StrCat(":serialized,base64:",
319+
absl::Base64Escape(SerializeAsString()));
320+
} else {
321+
TextFormat::Printer printer;
322+
printer.SetSingleLineMode(true);
323+
printer.SetUseShortRepeatedPrimitives(true);
324+
std::string str;
325+
// PrintToString can't really fail.
326+
(void)printer.PrintToString(*this, &str);
327+
328+
// If completely empty, just return the empty string.
329+
// It is usually the default and nicer to read.
330+
if (str.empty()) {
331+
return str;
332+
}
333+
334+
return absl::StrCat(":text:", str);
335+
}
336+
}
337+
127338
void Message::MergeImpl(MessageLite& to, const MessageLite& from) {
128339
ReflectionOps::Merge(DownCastMessage<Message>(from),
129340
DownCastMessage<Message>(&to));

src/google/protobuf/message.h

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -388,6 +388,33 @@ class PROTOBUF_EXPORT Message : public MessageLite {
388388
return GetMetadata().reflection;
389389
}
390390

391+
// Abseil flag support for Messages.
392+
//
393+
// Flag syntax is `:format,options...:value` where:
394+
// - `format` is one of `text`, `serialized`.
395+
// - `options` is a possibly empty list of options. Each format has its
396+
// supported options.
397+
// - `value` is the payload in the specified format.
398+
//
399+
// The valid options are:
400+
//
401+
// * For `text`:
402+
// - `base64`: indicates that `value` is encoded as base64.
403+
// - `ignore_unknown`: when specified, unknown field/extensions are
404+
// dropped. Otherwise, they cause a parse failure.
405+
//
406+
// * For `serialized`:
407+
// - `base64`: indicates that `value` is encoded as base64. It is
408+
// recommended to use `serialized` with `base64` given that passing
409+
// binary data in shells is difficult and error prone.
410+
friend bool AbslParseFlag(absl::string_view text, Message* msg,
411+
std::string* error) {
412+
return msg->AbslParseFlagImpl(text, *error);
413+
}
414+
friend std::string AbslUnparseFlag(const Message& msg) {
415+
return msg.AbslUnparseFlagImpl();
416+
}
417+
391418
protected:
392419
#if !defined(PROTOBUF_CUSTOM_VTABLE)
393420
constexpr Message() {}
@@ -402,6 +429,9 @@ class PROTOBUF_EXPORT Message : public MessageLite {
402429
// For CODE_SIZE types
403430
static bool IsInitializedImpl(const MessageLite&);
404431

432+
bool AbslParseFlagImpl(absl::string_view text, std::string& error);
433+
std::string AbslUnparseFlagImpl() const;
434+
405435
size_t ComputeUnknownFieldsSize(
406436
size_t total_size, const internal::CachedSize* cached_size) const;
407437
size_t MaybeComputeUnknownFieldsSize(

src/google/protobuf/message_unittest.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,17 @@
1616
#include <memory>
1717
#include <string>
1818
#include <tuple>
19+
#include <variant>
1920
#include <vector>
2021

2122
#include <gmock/gmock.h>
2223
#include <gtest/gtest.h>
24+
#include "absl/flags/marshalling.h"
2325
#include "absl/hash/hash_testing.h"
2426
#include "absl/log/absl_check.h"
27+
#include "absl/strings/escaping.h"
2528
#include "absl/strings/str_cat.h"
29+
#include "absl/strings/str_format.h"
2630
#include "absl/strings/string_view.h"
2731
#include "google/protobuf/arena.h"
2832
#include "google/protobuf/dynamic_message.h"
@@ -32,6 +36,7 @@
3236
#include "google/protobuf/internal_visibility.h"
3337
#include "google/protobuf/message_lite.h"
3438
#include "google/protobuf/port.h"
39+
#include "google/protobuf/test_protos/abseil_flag_test.pb.h"
3540
#include "google/protobuf/unittest.pb.h"
3641
#include "google/protobuf/unittest_import.pb.h"
3742
#include "google/protobuf/unittest_lite.pb.h"
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
load("//bazel:cc_proto_library.bzl", "cc_proto_library")
2+
load("//bazel:proto_library.bzl", "proto_library")
3+
4+
package(
5+
default_applicable_licenses = ["//:license"],
6+
default_testonly = 1,
7+
default_visibility = ["//src/google/protobuf:__subpackages__"],
8+
)
9+
10+
proto_library(
11+
name = "test_protos",
12+
srcs = glob(["*.proto"]),
13+
strip_import_prefix = "/src",
14+
)
15+
16+
cc_proto_library(
17+
name = "test_cc_protos",
18+
deps = [":test_protos"],
19+
)
20+
21+
filegroup(
22+
name = "test_proto_srcs",
23+
srcs = glob(["*.proto"]),
24+
visibility = ["//:__subpackages__"],
25+
)

0 commit comments

Comments
 (0)