diff --git a/src/brpc/policy/baidu_rpc_meta.proto b/src/brpc/policy/baidu_rpc_meta.proto index 5591c5dab7..916f604ff9 100644 --- a/src/brpc/policy/baidu_rpc_meta.proto +++ b/src/brpc/policy/baidu_rpc_meta.proto @@ -28,7 +28,8 @@ message RpcMeta { optional RpcResponseMeta response = 2; optional int32 compress_type = 3; optional int64 correlation_id = 4; - optional int32 attachment_size = 5; + optional int32 attachment_size = 5; // For compatibility, use attachment_size_long if size > INT32_MAX + optional int64 attachment_size_long = 13; // For attachment size > INT32_MAX optional ChunkInfo chunk_info = 6; optional bytes authentication_data = 7; optional StreamSettings stream_settings = 8; diff --git a/src/brpc/policy/baidu_rpc_protocol.cpp b/src/brpc/policy/baidu_rpc_protocol.cpp index 5adf77b2c5..c6da344162 100644 --- a/src/brpc/policy/baidu_rpc_protocol.cpp +++ b/src/brpc/policy/baidu_rpc_protocol.cpp @@ -16,6 +16,9 @@ // under the License. +#include // PRId64, PRIu64 +#include // UINT32_MAX, INT32_MAX +#include // INT32_MAX #include // MethodDescriptor #include // Message #include @@ -23,6 +26,7 @@ #include #include "butil/logging.h" // LOG() #include "butil/iobuf.h" // butil::IOBuf +#include "butil/macros.h" // ALLOW_UNUSED #include "butil/raw_pack.h" // RawPacker RawUnpacker #include "butil/memory/scope_guard.h" #include "json2pb/json_to_pb.h" @@ -62,36 +66,101 @@ DEFINE_bool(baidu_std_protocol_deliver_timeout_ms, false, DECLARE_bool(pb_enum_as_number); // Notes: -// 1. 12-byte header [PRPC][body_size][meta_size] +// 1. Header format: +// - Normal format (12 bytes): [PRPC][body_size(32bit)][meta_size(32bit)] +// - Extended format (20 bytes): [PRPC][UINT32_MAX][meta_size(32bit)][body_size(64bit)] +// Extended format is used when body_size > UINT32_MAX // 2. body_size and meta_size are in network byte order // 3. Use service->full_name() + method_name to specify the method to call // 4. `attachment_size' is set iff request/response has attachment // 5. Not supported: chunk_info +// Helper function to get attachment size from RpcMeta, with backward compatibility +static int64_t GetAttachmentSize(const RpcMeta& meta) { + if (meta.has_attachment_size_long()) { + return meta.attachment_size_long(); + } + if (meta.has_attachment_size()) { + return static_cast(meta.attachment_size()); + } + return 0; +} + +// Helper function to set attachment size in RpcMeta, with backward compatibility +static void SetAttachmentSize(RpcMeta* meta, size_t size) { + const size_t INT32_MAX_VALUE = static_cast(INT32_MAX); + if (size > INT32_MAX_VALUE) { + meta->set_attachment_size_long(static_cast(size)); + } else { + meta->set_attachment_size(static_cast(size)); + } +} + +// Helper function to get attachment size from RpcDumpMeta, with backward compatibility +// Marked unused to avoid -Werror-unused-function when not referenced. +static int64_t ALLOW_UNUSED GetAttachmentSizeFromDump(const RpcDumpMeta& meta) { + if (meta.has_attachment_size_long()) { + return meta.attachment_size_long(); + } + if (meta.has_attachment_size()) { + return static_cast(meta.attachment_size()); + } + return 0; +} + +// Helper function to set attachment size in RpcDumpMeta, with backward compatibility +static void SetAttachmentSizeInDump(RpcDumpMeta* meta, size_t size) { + const size_t INT32_MAX_VALUE = static_cast(INT32_MAX); + if (size > INT32_MAX_VALUE) { + meta->set_attachment_size_long(static_cast(size)); + } else { + meta->set_attachment_size(static_cast(size)); + } +} + // Pack header into `buf' -inline void PackRpcHeader(char* rpc_header, uint32_t meta_size, int payload_size) { +// Returns the size of header written (12 for normal, 20 for extended) +inline size_t PackRpcHeader(char* rpc_header, uint32_t meta_size, size_t payload_size) { uint32_t* dummy = (uint32_t*)rpc_header; // suppress strict-alias warning *dummy = *(uint32_t*)"PRPC"; - butil::RawPacker(rpc_header + 4) - .pack32(meta_size + payload_size) - .pack32(meta_size); + const uint64_t total_size = static_cast(meta_size) + payload_size; + if (total_size > UINT32_MAX) { + // Extended format: use UINT32_MAX as flag, followed by 64-bit total_size + butil::RawPacker(rpc_header + 4) + .pack32(UINT32_MAX) + .pack32(meta_size) + .pack64(total_size); + return 20; // 4 (magic) + 4 (flag) + 4 (meta_size) + 8 (total_size) + } else { + // Normal format: 32-bit total_size + butil::RawPacker(rpc_header + 4) + .pack32(static_cast(total_size)) + .pack32(meta_size); + return 12; // 4 (magic) + 4 (body_size) + 4 (meta_size) + } } static void SerializeRpcHeaderAndMeta( - butil::IOBuf* out, const RpcMeta& meta, int payload_size) { + butil::IOBuf* out, const RpcMeta& meta, size_t payload_size) { const uint32_t meta_size = GetProtobufByteSize(meta); - if (meta_size <= 244) { // most common cases + const uint64_t total_size = static_cast(meta_size) + payload_size; + const bool use_extended = (total_size > UINT32_MAX); + + if (meta_size <= 244 && !use_extended) { + // Most common cases with normal format: optimize by combining header and meta char header_and_meta[12 + meta_size]; - PackRpcHeader(header_and_meta, meta_size, payload_size); + const size_t actual_header_size = PackRpcHeader(header_and_meta, meta_size, payload_size); + CHECK_EQ(actual_header_size, 12U); // Should be 12 for normal format ::google::protobuf::io::ArrayOutputStream arr_out(header_and_meta + 12, meta_size); ::google::protobuf::io::CodedOutputStream coded_out(&arr_out); meta.SerializeWithCachedSizes(&coded_out); // not calling ByteSize again CHECK(!coded_out.HadError()); CHECK_EQ(0, out->append(header_and_meta, sizeof(header_and_meta))); } else { - char header[12]; - PackRpcHeader(header, meta_size, payload_size); - CHECK_EQ(0, out->append(header, sizeof(header))); + // Extended format or large meta: write header and meta separately + char header[20]; // Enough for both normal and extended format + const size_t actual_header_size = PackRpcHeader(header, meta_size, payload_size); + CHECK_EQ(0, out->append(header, actual_header_size)); butil::IOBufAsZeroCopyOutputStream buf_stream(out); ::google::protobuf::io::CodedOutputStream coded_out(&buf_stream); meta.SerializeWithCachedSizes(&coded_out); @@ -101,8 +170,9 @@ static void SerializeRpcHeaderAndMeta( ParseResult ParseRpcMessage(butil::IOBuf* source, Socket* socket, bool /*read_eof*/, const void*) { - char header_buf[12]; - const size_t n = source->copy_to(header_buf, sizeof(header_buf)); + // First read at least 12 bytes to check magic and determine format + char header_buf[20]; + const size_t n = source->copy_to(header_buf, 12); if (n >= 4) { void* dummy = header_buf; if (*(const uint32_t*)dummy != *(const uint32_t*)"PRPC") { @@ -113,29 +183,50 @@ ParseResult ParseRpcMessage(butil::IOBuf* source, Socket* socket, return MakeParseError(PARSE_ERROR_TRY_OTHERS); } } - if (n < sizeof(header_buf)) { + if (n < 12) { return MakeParseError(PARSE_ERROR_NOT_ENOUGH_DATA); } - uint32_t body_size; + + uint32_t body_size_32; uint32_t meta_size; - butil::RawUnpacker(header_buf + 4).unpack32(body_size).unpack32(meta_size); + uint64_t body_size; + size_t header_size; + + butil::RawUnpacker unpacker(header_buf + 4); + unpacker.unpack32(body_size_32).unpack32(meta_size); + + if (body_size_32 == UINT32_MAX) { + // Extended format: read 8 more bytes for 64-bit body_size + if (source->length() < 20) { + return MakeParseError(PARSE_ERROR_NOT_ENOUGH_DATA); + } + source->copy_to(header_buf, 20); + unpacker = butil::RawUnpacker(header_buf + 12); + unpacker.unpack64(body_size); + header_size = 20; + } else { + // Normal format: use 32-bit body_size + body_size = static_cast(body_size_32); + header_size = 12; + } + if (body_size > FLAGS_max_body_size) { // We need this log to report the body_size to give users some clues // which is not printed in InputMessenger. LOG(ERROR) << "body_size=" << body_size << " from " << socket->remote_side() << " is too large"; return MakeParseError(PARSE_ERROR_TOO_BIG_DATA); - } else if (source->length() < sizeof(header_buf) + body_size) { + } else if (source->length() < header_size + body_size) { return MakeParseError(PARSE_ERROR_NOT_ENOUGH_DATA); } if (meta_size > body_size) { LOG(ERROR) << "meta_size=" << meta_size << " is bigger than body_size=" << body_size; // Pop the message - source->pop_front(sizeof(header_buf) + body_size); + source->pop_front(header_size + body_size); return MakeParseError(PARSE_ERROR_TRY_OTHERS); } - source->pop_front(sizeof(header_buf)); + source->pop_front(header_size); MostCommonMessage* msg = MostCommonMessage::Get(); source->cutn(&msg->meta, meta_size); source->cutn(&msg->payload, body_size - meta_size); @@ -347,7 +438,7 @@ void SendRpcResponse(int64_t correlation_id, Controller* cntl, meta.set_checksum_type(cntl->response_checksum_type()); meta.set_checksum_value(accessor.checksum_value()); if (attached_size > 0) { - meta.set_attachment_size(attached_size); + SetAttachmentSize(&meta, attached_size); } StreamId response_stream_id = INVALID_STREAM_ID; SocketUniquePtr stream_ptr; @@ -585,7 +676,15 @@ void ProcessRpcRequest(InputMessageBase* msg_base) { sample->meta.set_method_name(request_meta.method_name()); sample->meta.set_compress_type((CompressType)meta.compress_type()); sample->meta.set_protocol_type(PROTOCOL_BAIDU_STD); - sample->meta.set_attachment_size(meta.attachment_size()); + const int64_t attachment_size = GetAttachmentSize(meta); + // Only set attachment_size if it's valid (non-negative) + if (attachment_size > 0) { + SetAttachmentSizeInDump(&sample->meta, static_cast(attachment_size)); + } else if (attachment_size < 0) { + // Log warning for invalid negative attachment_size in sampling + LOG(WARNING) << "Invalid negative attachment_size=" << attachment_size + << " in sampled request, ignoring"; + } sample->meta.set_authentication_data(meta.authentication_data()); sample->request = msg->payload; sample->submit(start_parse_us); @@ -676,14 +775,13 @@ void ProcessRpcRequest(InputMessageBase* msg_base) { break; } - const int req_size = static_cast(msg->payload.size()); - if (meta.has_attachment_size()) { - if (req_size < meta.attachment_size()) { - cntl->SetFailed(EREQUEST, - "attachment_size=%d is larger than request_size=%d", - meta.attachment_size(), req_size); - break; - } + const size_t req_size = msg->payload.size(); + const int64_t attachment_size = GetAttachmentSize(meta); + if (attachment_size < 0 || static_cast(attachment_size) > req_size) { + cntl->SetFailed(EREQUEST, + "attachment_size=%" PRId64 " is invalid or larger than request_size=%zu", + attachment_size, req_size); + break; } google::protobuf::Service* svc = NULL; @@ -723,9 +821,10 @@ void ProcessRpcRequest(InputMessageBase* msg_base) { } messages = BaiduProxyPBMessages::Get(); + // attachment_size already retrieved and validated at line 777 msg->payload.cutn( &((SerializedRequest*)messages->Request())->serialized_data(), - req_size - meta.attachment_size()); + req_size - static_cast(attachment_size)); if (!msg->payload.empty()) { cntl->request_attachment().swap(msg->payload); } @@ -792,9 +891,10 @@ void ProcessRpcRequest(InputMessageBase* msg_base) { } butil::IOBuf req_buf; - int body_without_attachment_size = req_size - meta.attachment_size(); + // attachment_size already retrieved and validated at line 772 + const size_t body_without_attachment_size = req_size - static_cast(attachment_size); msg->payload.cutn(&req_buf, body_without_attachment_size); - if (meta.attachment_size() > 0) { + if (attachment_size > 0) { cntl->request_attachment().swap(msg->payload); } @@ -811,7 +911,7 @@ void ProcessRpcRequest(InputMessageBase* msg_base) { cntl->SetFailed( EREQUEST, "Fail to parse request=%s, ContentType=%s, " - "CompressType=%s, ChecksumType=%s, request_size=%d", + "CompressType=%s, ChecksumType=%s, request_size=%zu", messages->Request()->GetDescriptor()->full_name().c_str(), ContentTypeToCStr(content_type), CompressTypeToCStr(compress_type), @@ -963,16 +1063,18 @@ void ProcessRpcResponse(InputMessageBase* msg_base) { } // Parse response message iff error code from meta is 0 butil::IOBuf res_buf; - const int res_size = msg->payload.length(); + const size_t res_size = msg->payload.length(); butil::IOBuf* res_buf_ptr = &msg->payload; - if (meta.has_attachment_size()) { - if (meta.attachment_size() > res_size) { - cntl->SetFailed( - ERESPONSE, "attachment_size=%d is larger than response_size=%d", - meta.attachment_size(), res_size); - break; - } - int body_without_attachment_size = res_size - meta.attachment_size(); + const int64_t attachment_size = GetAttachmentSize(meta); + // Validate attachment_size: check for negative values and size overflow + if (attachment_size < 0 || static_cast(attachment_size) > res_size) { + cntl->SetFailed( + ERESPONSE, "attachment_size=%" PRId64 " is invalid or larger than response_size=%zu", + attachment_size, res_size); + break; + } + if (attachment_size > 0) { + const size_t body_without_attachment_size = res_size - static_cast(attachment_size); msg->payload.cutn(&res_buf, body_without_attachment_size); res_buf_ptr = &res_buf; cntl->response_attachment().swap(msg->payload); @@ -995,7 +1097,7 @@ void ProcessRpcResponse(InputMessageBase* msg_base) { cntl->SetFailed( EREQUEST, "Fail to parse response=%s, ContentType=%s, " - "CompressType=%s, ChecksumType=%s, request_size=%d", + "CompressType=%s, ChecksumType=%s, response_size=%zu", cntl->response()->GetDescriptor()->full_name().c_str(), ContentTypeToCStr(content_type), CompressTypeToCStr(compress_type), @@ -1106,7 +1208,7 @@ void PackRpcRequest(butil::IOBuf* req_buf, const size_t req_size = request_body.length(); const size_t attached_size = cntl->request_attachment().length(); if (attached_size) { - meta.set_attachment_size(attached_size); + SetAttachmentSize(&meta, attached_size); } if (FLAGS_baidu_std_protocol_deliver_timeout_ms) { diff --git a/src/brpc/rpc_dump.proto b/src/brpc/rpc_dump.proto index e3c8aabfdb..9fef24f59a 100644 --- a/src/brpc/rpc_dump.proto +++ b/src/brpc/rpc_dump.proto @@ -35,7 +35,8 @@ message RpcDumpMeta { optional ProtocolType protocol_type = 5; // baidu_std, hulu_pbrpc - optional int32 attachment_size = 6; + optional int32 attachment_size = 6; // For compatibility, use attachment_size_long if size > INT32_MAX + optional int64 attachment_size_long = 10; // For attachment size > INT32_MAX // baidu_std optional bytes authentication_data = 7; diff --git a/test/brpc_server_unittest.cpp b/test/brpc_server_unittest.cpp index 4a774fab2a..812816501b 100644 --- a/test/brpc_server_unittest.cpp +++ b/test/brpc_server_unittest.cpp @@ -21,11 +21,13 @@ #include #include +#include #include #include #include #include "butil/time.h" #include "butil/macros.h" +#include "butil/raw_pack.h" #include "butil/fd_guard.h" #include "butil/files/scoped_file.h" #include "brpc/socket.h" @@ -53,6 +55,9 @@ #include "brpc/socket_map.h" #include "brpc/controller.h" #include "brpc/compress.h" +#include "brpc/policy/baidu_rpc_meta.pb.h" +#include "brpc/policy/baidu_rpc_protocol.h" +#include "brpc/policy/most_common_message.h" #include "echo.pb.h" #include "v1.pb.h" #include "v2.pb.h" @@ -2070,4 +2075,43 @@ TEST_F(ServerTest, auth) { ASSERT_EQ(0, server.Join()); } +TEST_F(ServerTest, baidu_extended_header_parse) { + brpc::policy::RpcMeta meta; + meta.set_correlation_id(123); + std::string meta_buf; + ASSERT_TRUE(meta.SerializeToString(&meta_buf)); + + const std::string payload = "extended-format-payload"; + const uint64_t body_size = meta_buf.size() + payload.size(); + + char header[20]; + memcpy(header, "PRPC", 4); + butil::RawPacker(header + 4) + .pack32(UINT32_MAX) + .pack32(static_cast(meta_buf.size())) + .pack64(body_size); + + butil::IOBuf buf; + buf.append(header, sizeof(header)); + buf.append(meta_buf); + buf.append(payload); + + brpc::ParseResult result = brpc::policy::ParseRpcMessage(&buf, NULL, false, NULL); + ASSERT_TRUE(result.is_ok()); + + brpc::policy::MostCommonMessage* msg = + static_cast(result.message()); + std::string parsed_payload; + msg->payload.copy_to(&parsed_payload); + EXPECT_EQ(payload, parsed_payload); + + std::string parsed_meta; + msg->meta.copy_to(&parsed_meta); + brpc::policy::RpcMeta parsed_meta_pb; + ASSERT_TRUE(parsed_meta_pb.ParseFromString(parsed_meta)); + EXPECT_EQ(meta.correlation_id(), parsed_meta_pb.correlation_id()); + + msg->Destroy(); +} + } //namespace diff --git a/tools/rpc_replay/rpc_replay.cpp b/tools/rpc_replay/rpc_replay.cpp index c3cd7c4c3a..a51bc89bd1 100644 --- a/tools/rpc_replay/rpc_replay.cpp +++ b/tools/rpc_replay/rpc_replay.cpp @@ -181,13 +181,25 @@ static void* replay_thread(void* arg) { memcpy(&nshead_req.head, sample->meta.nshead().c_str(), sample->meta.nshead().length()); nshead_req.body = sample->request; req_ptr = &nshead_req; - } else if (sample->meta.attachment_size() > 0) { - sample->request.cutn( - &req.serialized_data(), - sample->request.size() - sample->meta.attachment_size()); - cntl->request_attachment() = sample->request.movable(); } else { - req.serialized_data() = sample->request.movable(); + // Get attachment size with backward compatibility + int64_t attachment_size = 0; + if (sample->meta.has_attachment_size_long()) { + attachment_size = sample->meta.attachment_size_long(); + } else if (sample->meta.has_attachment_size()) { + attachment_size = static_cast(sample->meta.attachment_size()); + } + // Validate attachment_size: check for negative values and size overflow + // Explicitly validate the range before casting to size_t + if (attachment_size > 0 && + attachment_size < static_cast(sample->request.size())) { + sample->request.cutn( + &req.serialized_data(), + sample->request.size() - static_cast(attachment_size)); + cntl->request_attachment() = sample->request.movable(); + } else { + req.serialized_data() = sample->request.movable(); + } } g_sent_count << 1; const int64_t start_time = butil::gettimeofday_us();