diff --git a/base/BUILD.gn b/base/BUILD.gn index 1e0ea91d0f967477a9d5751ae21e5c2cf6bb8aa1..b935e0f584cc147f6309dfd209b06d66ea5e9a49 100644 --- a/base/BUILD.gn +++ b/base/BUILD.gn @@ -20,6 +20,7 @@ declare_args() { c_utils_print_track_at_once = false c_utils_debug_log_enabled = false c_utils_feature_intsan = true + c_utils_parcel_object_check = true } config("utils_config") { @@ -54,6 +55,11 @@ config("track_all") { defines = [ "TRACK_ALL" ] } +config("parcel_object_check") { + visibility = [ ":*" ] + defines = [ "PARCEL_OBJECT_CHECK" ] +} + sources_utils = [ "src/string_ex.cpp", "src/unicode_ex.cpp", @@ -167,6 +173,9 @@ ohos_shared_library("utils") { configs += [ ":print_track_at_once" ] } } + if (c_utils_parcel_object_check) { + configs += [ ":parcel_object_check" ] + } all_dependent_configs = [ ":utils_config" ] if (current_os != "android" && current_os != "ios") { defines = [ "CONFIG_HILOG" ] diff --git a/base/include/parcel.h b/base/include/parcel.h index 89ac76c3ad181a228d57f56bce7f749febe76b3c..08453bbdd7a23c63e51e2343c83a9c52a1c2e2e2 100644 --- a/base/include/parcel.h +++ b/base/include/parcel.h @@ -820,6 +820,12 @@ private: bool WriteParcelableOffset(size_t offset); + const uint8_t *BasicReadBuffer(size_t length); + + bool ValidateReadData(size_t upperBound); + + void ClearObjects(); + private: uint8_t *data_; size_t readCursor_; @@ -828,6 +834,7 @@ private: size_t dataCapacity_; size_t maxDataCapacity_; binder_size_t *objectOffsets_; + size_t nextObjectIdx_; size_t objectCursor_; size_t objectsCapacity_; Allocator *allocator_; diff --git a/base/src/parcel.cpp b/base/src/parcel.cpp index 991c669aca9572b03dd9fa5550695fe3a6c6736f..86f021583331178531f7ce52a49965812962049e 100644 --- a/base/src/parcel.cpp +++ b/base/src/parcel.cpp @@ -48,6 +48,7 @@ Parcel::Parcel(Allocator *allocator) maxDataCapacity_ = DEFAULT_CPACITY; objectOffsets_ = nullptr; + nextObjectIdx_ = 0; objectCursor_ = 0; objectsCapacity_ = 0; } @@ -59,6 +60,7 @@ Parcel::~Parcel() { FlushBuffer(); delete allocator_; + allocator_ = nullptr; } size_t Parcel::GetWritableBytes() const @@ -147,6 +149,33 @@ bool Parcel::EnsureWritableCapacity(size_t desireCapacity) return false; } +// ValidateReadData only works in basic type read. It doesn't work when read remote object. +// And read/write remote object has no effect on "nextObjectIdx_". +bool Parcel::ValidateReadData([[maybe_unused]]size_t upperBound) +{ +#ifdef PARCEL_OBJECT_CHECK + if (objectOffsets_ == nullptr || objectCursor_ == 0) { + return true; + } + size_t readPos = readCursor_; + size_t objSize = objectCursor_; + binder_size_t *objects = objectOffsets_; + if (nextObjectIdx_ < objSize && upperBound > objects[nextObjectIdx_]) { + size_t nextObj = nextObjectIdx_; + do { + if (readPos < objects[nextObj] + sizeof(parcel_flat_binder_object)) { + UTILS_LOGE("Non-object Read object data, readPos = %{public}zu, upperBound = %{public}zu", + readPos, upperBound); + return false; + } + nextObj++; + } while (nextObj < objSize && upperBound > objects[nextObj]); + nextObjectIdx_ = nextObj; + } +#endif + return true; +} + size_t Parcel::GetDataSize() const { return dataSize_; @@ -254,6 +283,16 @@ void Parcel::InjectOffsets(binder_size_t offsets, size_t offsetSize) } } +void Parcel::ClearObjects() +{ + objectHolder_.clear(); + free(objectOffsets_); + nextObjectIdx_ = 0; + objectCursor_ = 0; + objectOffsets_ = nullptr; + objectsCapacity_ = 0; +} + void Parcel::FlushBuffer() { if (allocator_ == nullptr) { @@ -270,11 +309,7 @@ void Parcel::FlushBuffer() } if (objectOffsets_) { - objectHolder_.clear(); - free(objectOffsets_); - objectCursor_ = 0; - objectOffsets_ = nullptr; - objectsCapacity_ = 0; + ClearObjects(); } } @@ -691,6 +726,13 @@ bool Parcel::Read(T &value) if (desireCapacity <= GetReadableBytes()) { const void *data = data_ + readCursor_; +#ifdef PARCEL_OBJECT_CHECK + size_t upperBound = readCursor_ + desireCapacity; + if (!ValidateReadData(upperBound)) { + readCursor_ += desireCapacity; + return false; + } +#endif readCursor_ += desireCapacity; value = *reinterpret_cast(data); return true; @@ -717,6 +759,11 @@ bool Parcel::ParseFrom(uintptr_t data, size_t size) dataSize_ = size; /* data is alloc by driver, can not write again */ writable_ = false; +#ifdef PARCEL_OBJECT_CHECK + if (objectOffsets_) { + ClearObjects(); + } +#endif return true; } @@ -731,6 +778,23 @@ const uint8_t *Parcel::ReadBuffer(size_t length) return nullptr; } +const uint8_t *Parcel::BasicReadBuffer([[maybe_unused]]size_t length) +{ +#ifdef PARCEL_OBJECT_CHECK + if (GetReadableBytes() >= length) { + uint8_t *buffer = data_ + readCursor_; + size_t upperBound = readCursor_ + length; + if (!ValidateReadData(upperBound)) { + readCursor_ += length; + return nullptr; + } + readCursor_ += length; + return buffer; + } +#endif + return nullptr; +} + const uint8_t *Parcel::ReadUnpadBuffer(size_t length) { if (GetReadableBytes() >= length) { @@ -763,6 +827,7 @@ bool Parcel::RewindRead(size_t newPosition) return false; } readCursor_ = newPosition; + nextObjectIdx_ = 0; return true; } @@ -778,6 +843,35 @@ bool Parcel::RewindWrite(size_t newPosition) } writeCursor_ = newPosition; dataSize_ = newPosition; +#ifdef PARCEL_OBJECT_CHECK + if (objectOffsets_ == nullptr || objectCursor_ == 0) { + return true; + } + size_t objectSize = objectCursor_; + if (objectOffsets_[objectSize - 1] + sizeof(parcel_flat_binder_object) > newPosition) { + while (objectSize > 0) { + if (objectOffsets_[objectSize - 1] + sizeof(parcel_flat_binder_object) <= newPosition) { + break; + } + objectSize--; + } + if (objectSize == 0) { + ClearObjects(); + return true; + } + size_t newBytes = objectSize * sizeof(binder_size_t); + void *newOffsets = realloc(objectOffsets_, newBytes); + if (newOffsets == nullptr) { + return false; + } + objectOffsets_ = reinterpret_cast(newOffsets); + objectCursor_ = objectSize; + objectsCapacity_ = objectCursor_; + objectHolder_.resize(objectSize); + nextObjectIdx_ = 0; + return true; + } +#endif return true; } @@ -966,7 +1060,11 @@ const std::string Parcel::ReadString() size_t readCapacity = static_cast(dataLength) + 1; if (readCapacity <= GetReadableBytes()) { +#ifdef PARCEL_OBJECT_CHECK + const uint8_t *dest = BasicReadBuffer(readCapacity); +#else const uint8_t *dest = ReadBuffer(readCapacity); +#endif if (dest != nullptr) { const auto *str = reinterpret_cast(dest); SkipBytes(GetPadSize(readCapacity)); @@ -992,7 +1090,11 @@ bool Parcel::ReadString(std::string &value) size_t readCapacity = static_cast(dataLength) + 1; if (readCapacity <= GetReadableBytes()) { +#ifdef PARCEL_OBJECT_CHECK + const uint8_t *dest = BasicReadBuffer(readCapacity); +#else const uint8_t *dest = ReadBuffer(readCapacity); +#endif if (dest != nullptr) { const auto *str = reinterpret_cast(dest); SkipBytes(GetPadSize(readCapacity)); @@ -1019,7 +1121,11 @@ const std::u16string Parcel::ReadString16() size_t readCapacity = (static_cast(dataLength) + 1) * sizeof(char16_t); if ((readCapacity > (static_cast(dataLength))) && (readCapacity <= GetReadableBytes())) { +#ifdef PARCEL_OBJECT_CHECK + const uint8_t *str = BasicReadBuffer(readCapacity); +#else const uint8_t *str = ReadBuffer(readCapacity); +#endif if (str != nullptr) { const auto *u16Str = reinterpret_cast(str); SkipBytes(GetPadSize(readCapacity)); @@ -1045,7 +1151,11 @@ bool Parcel::ReadString16(std::u16string &value) size_t readCapacity = (static_cast(dataLength) + 1) * sizeof(char16_t); if ((readCapacity > (static_cast(dataLength))) && (readCapacity <= GetReadableBytes())) { +#ifdef PARCEL_OBJECT_CHECK + const uint8_t *str = BasicReadBuffer(readCapacity); +#else const uint8_t *str = ReadBuffer(readCapacity); +#endif if (str != nullptr) { const auto *u16Str = reinterpret_cast(str); SkipBytes(GetPadSize(readCapacity)); @@ -1077,7 +1187,11 @@ const std::u16string Parcel::ReadString16WithLength(int32_t &readLength) size_t readCapacity = (static_cast(dataLength) + 1) * sizeof(char16_t); if ((readCapacity > (static_cast(dataLength))) && (readCapacity <= GetReadableBytes())) { +#ifdef PARCEL_OBJECT_CHECK + const uint8_t *str = BasicReadBuffer(readCapacity); +#else const uint8_t *str = ReadBuffer(readCapacity); +#endif if (str != nullptr) { const auto *u16Str = reinterpret_cast(str); SkipBytes(GetPadSize(readCapacity)); @@ -1108,7 +1222,11 @@ const std::string Parcel::ReadString8WithLength(int32_t &readLength) size_t readCapacity = (static_cast(dataLength) + 1) * sizeof(char); if (readCapacity <= GetReadableBytes()) { +#ifdef PARCEL_OBJECT_CHECK + const uint8_t *str = BasicReadBuffer(readCapacity); +#else const uint8_t *str = ReadBuffer(readCapacity); +#endif if (str != nullptr) { const auto *u8Str = reinterpret_cast(str); SkipBytes(GetPadSize(readCapacity)); diff --git a/base/test/unittest/common/utils_parcel_test.cpp b/base/test/unittest/common/utils_parcel_test.cpp index 46653d85cf8b93d24b796ea7482c196429e5fef6..a736a33ef7344cad861435957c4058c0f7f826bd 100644 --- a/base/test/unittest/common/utils_parcel_test.cpp +++ b/base/test/unittest/common/utils_parcel_test.cpp @@ -42,6 +42,38 @@ void UtilsParcelTest::TearDownTestCase(void) } } +class RemoteObject : public virtual Parcelable { +public: + RemoteObject() { asRemote_ = true; }; + bool Marshalling(Parcel &parcel) const override; + static sptr Unmarshalling(Parcel &parcel); +}; + +bool RemoteObject::Marshalling(Parcel &parcel) const +{ + parcel_flat_binder_object flat; + flat.hdr.type = 0xff; + flat.flags = 0x7f; + flat.binder = 0; + flat.handle = (uint32_t)(-1); + flat.cookie = reinterpret_cast(this); + bool status = parcel.WriteBuffer(&flat, sizeof(parcel_flat_binder_object)); + if (!status) { + return false; + } + return true; +} + +sptr RemoteObject::Unmarshalling(Parcel &parcel) +{ + const uint8_t *buffer = parcel.ReadBuffer(sizeof(parcel_flat_binder_object)); + if (buffer == nullptr) { + return nullptr; + } + sptr obj = new RemoteObject(); + return obj; +} + /*-------------------------------base data------------------------------------*/ struct TestData { @@ -1816,5 +1848,168 @@ HWTEST_F(UtilsParcelTest, test_SetMaxCapacity_002, TestSize.Level0) ret = parcel.ReadString16Vector(&val); EXPECT_EQ(false, ret); } + +HWTEST_F(UtilsParcelTest, test_ValidateReadData_001, TestSize.Level0) +{ + Parcel parcel(nullptr); + parcel.WriteBool(true); + string strWrite = "test"; + bool result = parcel.WriteString(strWrite); + EXPECT_EQ(result, true); + + RemoteObject obj1; + result = parcel.WriteRemoteObject(&obj1); + EXPECT_EQ(result, true); + parcel.WriteInt32(5); + RemoteObject obj2; + result = parcel.WriteRemoteObject(&obj2); + EXPECT_EQ(result, true); + u16string str16Write = u"12345"; + result = parcel.WriteString16(str16Write); + EXPECT_EQ(result, true); + + bool readBool = parcel.ReadBool(); + EXPECT_EQ(readBool, true); + + string strRead = parcel.ReadString(); + EXPECT_EQ(0, strcmp(strRead.c_str(), strWrite.c_str())); + + sptr readObj1 = parcel.ReadObject(); + EXPECT_EQ(true, readObj1.GetRefPtr() != nullptr); + + int32_t readInt32 = parcel.ReadInt32(); + EXPECT_EQ(readInt32, 5); + + sptr readObj2 = parcel.ReadObject(); + EXPECT_EQ(true, readObj2.GetRefPtr() != nullptr); + + u16string str16Read = parcel.ReadString16(); + EXPECT_EQ(0, str16Read.compare(str16Write)); +} + +HWTEST_F(UtilsParcelTest, test_ValidateReadData_002, TestSize.Level0) +{ + Parcel parcel(nullptr); + parcel.WriteBool(true); + string strWrite = "test"; + bool result = parcel.WriteString(strWrite); + EXPECT_EQ(result, true); + + RemoteObject obj1; + result = parcel.WriteRemoteObject(&obj1); + EXPECT_EQ(result, true); + parcel.WriteInt32(5); + RemoteObject obj2; + result = parcel.WriteRemoteObject(&obj2); + EXPECT_EQ(result, true); + u16string str16Write = u"12345"; + result = parcel.WriteString16(str16Write); + EXPECT_EQ(result, true); + + bool readBool = parcel.ReadBool(); + EXPECT_EQ(readBool, true); + + string strRead = parcel.ReadString(); + EXPECT_EQ(0, strcmp(strRead.c_str(), strWrite.c_str())); + + int32_t readInt32 = parcel.ReadInt32(); + EXPECT_EQ(readInt32, 0); + + u16string str16Read = parcel.ReadString16(); + EXPECT_EQ(0, str16Read.compare(std::u16string())); + + sptr readObj1 = parcel.ReadObject(); + EXPECT_EQ(true, readObj1.GetRefPtr() == nullptr); +} + +HWTEST_F(UtilsParcelTest, test_RewindWrite_001, TestSize.Level0) +{ + Parcel parcel(nullptr); + parcel.WriteInt32(5); + string strWrite = "test"; + parcel.WriteString(strWrite); + RemoteObject obj1; + parcel.WriteRemoteObject(&obj1); + size_t pos = parcel.GetWritePosition(); + parcel.WriteInt32(5); + RemoteObject obj2; + parcel.WriteRemoteObject(&obj2); + u16string str16Write = u"12345"; + parcel.WriteString16(str16Write); + + bool result = parcel.RewindWrite(pos); + EXPECT_EQ(result, true); + parcel.WriteInt32(5); + parcel.WriteInt32(5); + + int32_t readint32 = parcel.ReadInt32(); + EXPECT_EQ(readint32, 5); + string strRead = parcel.ReadString(); + EXPECT_EQ(0, strcmp(strRead.c_str(), strWrite.c_str())); + sptr readObj1 = parcel.ReadObject(); + EXPECT_EQ(true, readObj1.GetRefPtr() != nullptr); + readint32 = parcel.ReadInt32(); + EXPECT_EQ(readint32, 5); + sptr readObj2 = parcel.ReadObject(); + EXPECT_EQ(true, readObj2.GetRefPtr() == nullptr); + readint32 = parcel.ReadInt32(); + EXPECT_EQ(readint32, 5); +} + +HWTEST_F(UtilsParcelTest, test_RewindWrite_002, TestSize.Level0) +{ + Parcel parcel(nullptr); + parcel.WriteInt32(5); + string strWrite = "test"; + parcel.WriteString(strWrite); + RemoteObject obj1; + parcel.WriteRemoteObject(&obj1); + parcel.WriteInt32(5); + RemoteObject obj2; + parcel.WriteRemoteObject(&obj2); + size_t pos = parcel.GetWritePosition(); + u16string str16Write = u"12345"; + parcel.WriteString16(str16Write); + + bool result = parcel.RewindWrite(pos); + EXPECT_EQ(result, true); + + int32_t readint32 = parcel.ReadInt32(); + EXPECT_EQ(readint32, 5); + string strRead = parcel.ReadString(); + EXPECT_EQ(0, strcmp(strRead.c_str(), strWrite.c_str())); + uint32_t readUint32 = parcel.ReadUint32(); + EXPECT_EQ(readUint32, 0); + string strRead2 = parcel.ReadString(); + EXPECT_EQ(0, strRead2.compare(std::string())); + sptr readObj1 = parcel.ReadObject(); + EXPECT_EQ(true, readObj1.GetRefPtr() == nullptr); + double readDouble = parcel.ReadDouble(); + EXPECT_EQ(readDouble, 0); +} + +HWTEST_F(UtilsParcelTest, test_RewindWrite_003, TestSize.Level0) +{ + Parcel parcel(nullptr); + std::vector val{1, 2, 3, 4, 5}; + EXPECT_EQ(val.size(), 5); + bool result = parcel.WriteInt32Vector(val); + EXPECT_EQ(result, true); + size_t pos = parcel.GetWritePosition() - sizeof(int32_t); + result = parcel.RewindWrite(pos); + EXPECT_EQ(result, true); + RemoteObject obj; + parcel.WriteRemoteObject(&obj); + + std::vector int32Read; + result = parcel.ReadInt32Vector(&int32Read); + EXPECT_EQ(result, false); + EXPECT_EQ(int32Read.size(), 5); + EXPECT_EQ(int32Read[0], 1); + EXPECT_EQ(int32Read[1], 2); + EXPECT_EQ(int32Read[2], 3); + EXPECT_EQ(int32Read[3], 4); + EXPECT_EQ(int32Read[4], 0); +} } // namespace } // namespace OHOS \ No newline at end of file