From a90e273e671d2cb76bc24ff80033487fefd8158c Mon Sep 17 00:00:00 2001 From: lijincheng Date: Sat, 19 Aug 2023 20:45:41 +0800 Subject: [PATCH] Add security verification in Parcel to ensure data security during read operations Add data verification within the basic read data operation to avoid using non object read operations to read object data and avoid illegal data leakage to external attackers for attack opportunities. Issue:https://gitee.com/openharmony/commonlibrary_c_utils/issues/I7SMYA Signed-off-by: lijincheng --- base/BUILD.gn | 9 + base/include/parcel.h | 7 + base/src/parcel.cpp | 128 +++++++++++- .../unittest/common/utils_parcel_test.cpp | 195 ++++++++++++++++++ 4 files changed, 334 insertions(+), 5 deletions(-) diff --git a/base/BUILD.gn b/base/BUILD.gn index 1e0ea91..b935e0f 100644 --- a/base/BUILD.gn +++ b/base/BUILD.gn @@ -20,6 +20,7 @@ declare_args() { c_utils_print_track_at_once = false c_utils_debug_log_enabled = false c_utils_feature_intsan = true + c_utils_parcel_object_check = true } config("utils_config") { @@ -54,6 +55,11 @@ config("track_all") { defines = [ "TRACK_ALL" ] } +config("parcel_object_check") { + visibility = [ ":*" ] + defines = [ "PARCEL_OBJECT_CHECK" ] +} + sources_utils = [ "src/string_ex.cpp", "src/unicode_ex.cpp", @@ -167,6 +173,9 @@ ohos_shared_library("utils") { configs += [ ":print_track_at_once" ] } } + if (c_utils_parcel_object_check) { + configs += [ ":parcel_object_check" ] + } all_dependent_configs = [ ":utils_config" ] if (current_os != "android" && current_os != "ios") { defines = [ "CONFIG_HILOG" ] diff --git a/base/include/parcel.h b/base/include/parcel.h index 89ac76c..08453bb 100644 --- a/base/include/parcel.h +++ b/base/include/parcel.h @@ -820,6 +820,12 @@ private: bool WriteParcelableOffset(size_t offset); + const uint8_t *BasicReadBuffer(size_t length); + + bool ValidateReadData(size_t upperBound); + + void ClearObjects(); + private: uint8_t *data_; size_t readCursor_; @@ -828,6 +834,7 @@ private: size_t dataCapacity_; size_t maxDataCapacity_; binder_size_t *objectOffsets_; + size_t nextObjectIdx_; size_t objectCursor_; size_t objectsCapacity_; Allocator *allocator_; diff --git a/base/src/parcel.cpp b/base/src/parcel.cpp index 991c669..86f0215 100644 --- a/base/src/parcel.cpp +++ b/base/src/parcel.cpp @@ -48,6 +48,7 @@ Parcel::Parcel(Allocator *allocator) maxDataCapacity_ = DEFAULT_CPACITY; objectOffsets_ = nullptr; + nextObjectIdx_ = 0; objectCursor_ = 0; objectsCapacity_ = 0; } @@ -59,6 +60,7 @@ Parcel::~Parcel() { FlushBuffer(); delete allocator_; + allocator_ = nullptr; } size_t Parcel::GetWritableBytes() const @@ -147,6 +149,33 @@ bool Parcel::EnsureWritableCapacity(size_t desireCapacity) return false; } +// ValidateReadData only works in basic type read. It doesn't work when read remote object. +// And read/write remote object has no effect on "nextObjectIdx_". +bool Parcel::ValidateReadData([[maybe_unused]]size_t upperBound) +{ +#ifdef PARCEL_OBJECT_CHECK + if (objectOffsets_ == nullptr || objectCursor_ == 0) { + return true; + } + size_t readPos = readCursor_; + size_t objSize = objectCursor_; + binder_size_t *objects = objectOffsets_; + if (nextObjectIdx_ < objSize && upperBound > objects[nextObjectIdx_]) { + size_t nextObj = nextObjectIdx_; + do { + if (readPos < objects[nextObj] + sizeof(parcel_flat_binder_object)) { + UTILS_LOGE("Non-object Read object data, readPos = %{public}zu, upperBound = %{public}zu", + readPos, upperBound); + return false; + } + nextObj++; + } while (nextObj < objSize && upperBound > objects[nextObj]); + nextObjectIdx_ = nextObj; + } +#endif + return true; +} + size_t Parcel::GetDataSize() const { return dataSize_; @@ -254,6 +283,16 @@ void Parcel::InjectOffsets(binder_size_t offsets, size_t offsetSize) } } +void Parcel::ClearObjects() +{ + objectHolder_.clear(); + free(objectOffsets_); + nextObjectIdx_ = 0; + objectCursor_ = 0; + objectOffsets_ = nullptr; + objectsCapacity_ = 0; +} + void Parcel::FlushBuffer() { if (allocator_ == nullptr) { @@ -270,11 +309,7 @@ void Parcel::FlushBuffer() } if (objectOffsets_) { - objectHolder_.clear(); - free(objectOffsets_); - objectCursor_ = 0; - objectOffsets_ = nullptr; - objectsCapacity_ = 0; + ClearObjects(); } } @@ -691,6 +726,13 @@ bool Parcel::Read(T &value) if (desireCapacity <= GetReadableBytes()) { const void *data = data_ + readCursor_; +#ifdef PARCEL_OBJECT_CHECK + size_t upperBound = readCursor_ + desireCapacity; + if (!ValidateReadData(upperBound)) { + readCursor_ += desireCapacity; + return false; + } +#endif readCursor_ += desireCapacity; value = *reinterpret_cast(data); return true; @@ -717,6 +759,11 @@ bool Parcel::ParseFrom(uintptr_t data, size_t size) dataSize_ = size; /* data is alloc by driver, can not write again */ writable_ = false; +#ifdef PARCEL_OBJECT_CHECK + if (objectOffsets_) { + ClearObjects(); + } +#endif return true; } @@ -731,6 +778,23 @@ const uint8_t *Parcel::ReadBuffer(size_t length) return nullptr; } +const uint8_t *Parcel::BasicReadBuffer([[maybe_unused]]size_t length) +{ +#ifdef PARCEL_OBJECT_CHECK + if (GetReadableBytes() >= length) { + uint8_t *buffer = data_ + readCursor_; + size_t upperBound = readCursor_ + length; + if (!ValidateReadData(upperBound)) { + readCursor_ += length; + return nullptr; + } + readCursor_ += length; + return buffer; + } +#endif + return nullptr; +} + const uint8_t *Parcel::ReadUnpadBuffer(size_t length) { if (GetReadableBytes() >= length) { @@ -763,6 +827,7 @@ bool Parcel::RewindRead(size_t newPosition) return false; } readCursor_ = newPosition; + nextObjectIdx_ = 0; return true; } @@ -778,6 +843,35 @@ bool Parcel::RewindWrite(size_t newPosition) } writeCursor_ = newPosition; dataSize_ = newPosition; +#ifdef PARCEL_OBJECT_CHECK + if (objectOffsets_ == nullptr || objectCursor_ == 0) { + return true; + } + size_t objectSize = objectCursor_; + if (objectOffsets_[objectSize - 1] + sizeof(parcel_flat_binder_object) > newPosition) { + while (objectSize > 0) { + if (objectOffsets_[objectSize - 1] + sizeof(parcel_flat_binder_object) <= newPosition) { + break; + } + objectSize--; + } + if (objectSize == 0) { + ClearObjects(); + return true; + } + size_t newBytes = objectSize * sizeof(binder_size_t); + void *newOffsets = realloc(objectOffsets_, newBytes); + if (newOffsets == nullptr) { + return false; + } + objectOffsets_ = reinterpret_cast(newOffsets); + objectCursor_ = objectSize; + objectsCapacity_ = objectCursor_; + objectHolder_.resize(objectSize); + nextObjectIdx_ = 0; + return true; + } +#endif return true; } @@ -966,7 +1060,11 @@ const std::string Parcel::ReadString() size_t readCapacity = static_cast(dataLength) + 1; if (readCapacity <= GetReadableBytes()) { +#ifdef PARCEL_OBJECT_CHECK + const uint8_t *dest = BasicReadBuffer(readCapacity); +#else const uint8_t *dest = ReadBuffer(readCapacity); +#endif if (dest != nullptr) { const auto *str = reinterpret_cast(dest); SkipBytes(GetPadSize(readCapacity)); @@ -992,7 +1090,11 @@ bool Parcel::ReadString(std::string &value) size_t readCapacity = static_cast(dataLength) + 1; if (readCapacity <= GetReadableBytes()) { +#ifdef PARCEL_OBJECT_CHECK + const uint8_t *dest = BasicReadBuffer(readCapacity); +#else const uint8_t *dest = ReadBuffer(readCapacity); +#endif if (dest != nullptr) { const auto *str = reinterpret_cast(dest); SkipBytes(GetPadSize(readCapacity)); @@ -1019,7 +1121,11 @@ const std::u16string Parcel::ReadString16() size_t readCapacity = (static_cast(dataLength) + 1) * sizeof(char16_t); if ((readCapacity > (static_cast(dataLength))) && (readCapacity <= GetReadableBytes())) { +#ifdef PARCEL_OBJECT_CHECK + const uint8_t *str = BasicReadBuffer(readCapacity); +#else const uint8_t *str = ReadBuffer(readCapacity); +#endif if (str != nullptr) { const auto *u16Str = reinterpret_cast(str); SkipBytes(GetPadSize(readCapacity)); @@ -1045,7 +1151,11 @@ bool Parcel::ReadString16(std::u16string &value) size_t readCapacity = (static_cast(dataLength) + 1) * sizeof(char16_t); if ((readCapacity > (static_cast(dataLength))) && (readCapacity <= GetReadableBytes())) { +#ifdef PARCEL_OBJECT_CHECK + const uint8_t *str = BasicReadBuffer(readCapacity); +#else const uint8_t *str = ReadBuffer(readCapacity); +#endif if (str != nullptr) { const auto *u16Str = reinterpret_cast(str); SkipBytes(GetPadSize(readCapacity)); @@ -1077,7 +1187,11 @@ const std::u16string Parcel::ReadString16WithLength(int32_t &readLength) size_t readCapacity = (static_cast(dataLength) + 1) * sizeof(char16_t); if ((readCapacity > (static_cast(dataLength))) && (readCapacity <= GetReadableBytes())) { +#ifdef PARCEL_OBJECT_CHECK + const uint8_t *str = BasicReadBuffer(readCapacity); +#else const uint8_t *str = ReadBuffer(readCapacity); +#endif if (str != nullptr) { const auto *u16Str = reinterpret_cast(str); SkipBytes(GetPadSize(readCapacity)); @@ -1108,7 +1222,11 @@ const std::string Parcel::ReadString8WithLength(int32_t &readLength) size_t readCapacity = (static_cast(dataLength) + 1) * sizeof(char); if (readCapacity <= GetReadableBytes()) { +#ifdef PARCEL_OBJECT_CHECK + const uint8_t *str = BasicReadBuffer(readCapacity); +#else const uint8_t *str = ReadBuffer(readCapacity); +#endif if (str != nullptr) { const auto *u8Str = reinterpret_cast(str); SkipBytes(GetPadSize(readCapacity)); diff --git a/base/test/unittest/common/utils_parcel_test.cpp b/base/test/unittest/common/utils_parcel_test.cpp index 46653d8..a736a33 100644 --- a/base/test/unittest/common/utils_parcel_test.cpp +++ b/base/test/unittest/common/utils_parcel_test.cpp @@ -42,6 +42,38 @@ void UtilsParcelTest::TearDownTestCase(void) } } +class RemoteObject : public virtual Parcelable { +public: + RemoteObject() { asRemote_ = true; }; + bool Marshalling(Parcel &parcel) const override; + static sptr Unmarshalling(Parcel &parcel); +}; + +bool RemoteObject::Marshalling(Parcel &parcel) const +{ + parcel_flat_binder_object flat; + flat.hdr.type = 0xff; + flat.flags = 0x7f; + flat.binder = 0; + flat.handle = (uint32_t)(-1); + flat.cookie = reinterpret_cast(this); + bool status = parcel.WriteBuffer(&flat, sizeof(parcel_flat_binder_object)); + if (!status) { + return false; + } + return true; +} + +sptr RemoteObject::Unmarshalling(Parcel &parcel) +{ + const uint8_t *buffer = parcel.ReadBuffer(sizeof(parcel_flat_binder_object)); + if (buffer == nullptr) { + return nullptr; + } + sptr obj = new RemoteObject(); + return obj; +} + /*-------------------------------base data------------------------------------*/ struct TestData { @@ -1816,5 +1848,168 @@ HWTEST_F(UtilsParcelTest, test_SetMaxCapacity_002, TestSize.Level0) ret = parcel.ReadString16Vector(&val); EXPECT_EQ(false, ret); } + +HWTEST_F(UtilsParcelTest, test_ValidateReadData_001, TestSize.Level0) +{ + Parcel parcel(nullptr); + parcel.WriteBool(true); + string strWrite = "test"; + bool result = parcel.WriteString(strWrite); + EXPECT_EQ(result, true); + + RemoteObject obj1; + result = parcel.WriteRemoteObject(&obj1); + EXPECT_EQ(result, true); + parcel.WriteInt32(5); + RemoteObject obj2; + result = parcel.WriteRemoteObject(&obj2); + EXPECT_EQ(result, true); + u16string str16Write = u"12345"; + result = parcel.WriteString16(str16Write); + EXPECT_EQ(result, true); + + bool readBool = parcel.ReadBool(); + EXPECT_EQ(readBool, true); + + string strRead = parcel.ReadString(); + EXPECT_EQ(0, strcmp(strRead.c_str(), strWrite.c_str())); + + sptr readObj1 = parcel.ReadObject(); + EXPECT_EQ(true, readObj1.GetRefPtr() != nullptr); + + int32_t readInt32 = parcel.ReadInt32(); + EXPECT_EQ(readInt32, 5); + + sptr readObj2 = parcel.ReadObject(); + EXPECT_EQ(true, readObj2.GetRefPtr() != nullptr); + + u16string str16Read = parcel.ReadString16(); + EXPECT_EQ(0, str16Read.compare(str16Write)); +} + +HWTEST_F(UtilsParcelTest, test_ValidateReadData_002, TestSize.Level0) +{ + Parcel parcel(nullptr); + parcel.WriteBool(true); + string strWrite = "test"; + bool result = parcel.WriteString(strWrite); + EXPECT_EQ(result, true); + + RemoteObject obj1; + result = parcel.WriteRemoteObject(&obj1); + EXPECT_EQ(result, true); + parcel.WriteInt32(5); + RemoteObject obj2; + result = parcel.WriteRemoteObject(&obj2); + EXPECT_EQ(result, true); + u16string str16Write = u"12345"; + result = parcel.WriteString16(str16Write); + EXPECT_EQ(result, true); + + bool readBool = parcel.ReadBool(); + EXPECT_EQ(readBool, true); + + string strRead = parcel.ReadString(); + EXPECT_EQ(0, strcmp(strRead.c_str(), strWrite.c_str())); + + int32_t readInt32 = parcel.ReadInt32(); + EXPECT_EQ(readInt32, 0); + + u16string str16Read = parcel.ReadString16(); + EXPECT_EQ(0, str16Read.compare(std::u16string())); + + sptr readObj1 = parcel.ReadObject(); + EXPECT_EQ(true, readObj1.GetRefPtr() == nullptr); +} + +HWTEST_F(UtilsParcelTest, test_RewindWrite_001, TestSize.Level0) +{ + Parcel parcel(nullptr); + parcel.WriteInt32(5); + string strWrite = "test"; + parcel.WriteString(strWrite); + RemoteObject obj1; + parcel.WriteRemoteObject(&obj1); + size_t pos = parcel.GetWritePosition(); + parcel.WriteInt32(5); + RemoteObject obj2; + parcel.WriteRemoteObject(&obj2); + u16string str16Write = u"12345"; + parcel.WriteString16(str16Write); + + bool result = parcel.RewindWrite(pos); + EXPECT_EQ(result, true); + parcel.WriteInt32(5); + parcel.WriteInt32(5); + + int32_t readint32 = parcel.ReadInt32(); + EXPECT_EQ(readint32, 5); + string strRead = parcel.ReadString(); + EXPECT_EQ(0, strcmp(strRead.c_str(), strWrite.c_str())); + sptr readObj1 = parcel.ReadObject(); + EXPECT_EQ(true, readObj1.GetRefPtr() != nullptr); + readint32 = parcel.ReadInt32(); + EXPECT_EQ(readint32, 5); + sptr readObj2 = parcel.ReadObject(); + EXPECT_EQ(true, readObj2.GetRefPtr() == nullptr); + readint32 = parcel.ReadInt32(); + EXPECT_EQ(readint32, 5); +} + +HWTEST_F(UtilsParcelTest, test_RewindWrite_002, TestSize.Level0) +{ + Parcel parcel(nullptr); + parcel.WriteInt32(5); + string strWrite = "test"; + parcel.WriteString(strWrite); + RemoteObject obj1; + parcel.WriteRemoteObject(&obj1); + parcel.WriteInt32(5); + RemoteObject obj2; + parcel.WriteRemoteObject(&obj2); + size_t pos = parcel.GetWritePosition(); + u16string str16Write = u"12345"; + parcel.WriteString16(str16Write); + + bool result = parcel.RewindWrite(pos); + EXPECT_EQ(result, true); + + int32_t readint32 = parcel.ReadInt32(); + EXPECT_EQ(readint32, 5); + string strRead = parcel.ReadString(); + EXPECT_EQ(0, strcmp(strRead.c_str(), strWrite.c_str())); + uint32_t readUint32 = parcel.ReadUint32(); + EXPECT_EQ(readUint32, 0); + string strRead2 = parcel.ReadString(); + EXPECT_EQ(0, strRead2.compare(std::string())); + sptr readObj1 = parcel.ReadObject(); + EXPECT_EQ(true, readObj1.GetRefPtr() == nullptr); + double readDouble = parcel.ReadDouble(); + EXPECT_EQ(readDouble, 0); +} + +HWTEST_F(UtilsParcelTest, test_RewindWrite_003, TestSize.Level0) +{ + Parcel parcel(nullptr); + std::vector val{1, 2, 3, 4, 5}; + EXPECT_EQ(val.size(), 5); + bool result = parcel.WriteInt32Vector(val); + EXPECT_EQ(result, true); + size_t pos = parcel.GetWritePosition() - sizeof(int32_t); + result = parcel.RewindWrite(pos); + EXPECT_EQ(result, true); + RemoteObject obj; + parcel.WriteRemoteObject(&obj); + + std::vector int32Read; + result = parcel.ReadInt32Vector(&int32Read); + EXPECT_EQ(result, false); + EXPECT_EQ(int32Read.size(), 5); + EXPECT_EQ(int32Read[0], 1); + EXPECT_EQ(int32Read[1], 2); + EXPECT_EQ(int32Read[2], 3); + EXPECT_EQ(int32Read[3], 4); + EXPECT_EQ(int32Read[4], 0); +} } // namespace } // namespace OHOS \ No newline at end of file -- Gitee