Browse Source

AK: Don't assert things about active union members in StringBase

This involves yeeting the 'invalid' union member as it was not really
checked against properly anyway; now the 'invalid' state is simply
StringData*{nullptr}, which was assumed to not exist previously.

Note that this is still accessing inactive union members, but is
promising to the compiler that they're fine where they are (the provided
debug macro AK_STRINGBASE_VERIFY_LAUNDER_DEBUG makes the
would-be-UB-if-not-for-launder ops verify that the operation is correct)

Should fix the GCC build.
Ali Mohammad Pur 1 week ago
parent
commit
a83145c751
7 changed files with 96 additions and 60 deletions
  1. 4 0
      AK/Debug.h.in
  2. 6 6
      AK/FlyString.cpp
  3. 3 3
      AK/FlyString.h
  4. 5 0
      AK/String.h
  5. 3 7
      AK/StringBase.cpp
  6. 74 44
      AK/StringBase.h
  7. 1 0
      Meta/CMake/all_the_debug_macros.cmake

+ 4 - 0
AK/Debug.h.in

@@ -6,6 +6,10 @@
 
 #pragma once
 
+#ifndef AK_STRINGBASE_VERIFY_LAUNDER_DEBUG
+#    cmakedefine01 AK_STRINGBASE_VERIFY_LAUNDER_DEBUG
+#endif
+
 #ifndef AUDIO_DEBUG
 #    cmakedefine01 AUDIO_DEBUG
 #endif

+ 6 - 6
AK/FlyString.cpp

@@ -57,19 +57,19 @@ FlyString::FlyString(String const& string)
         return;
     }
 
-    if (string.m_data->is_fly_string()) {
+    if (string.m_impl.data->is_fly_string()) {
         m_data = string;
         return;
     }
 
-    auto it = all_fly_strings().find(string.m_data);
+    auto it = all_fly_strings().find(string.m_impl.data);
     if (it == all_fly_strings().end()) {
         m_data = string;
-        all_fly_strings().set(string.m_data);
-        string.m_data->set_fly_string(true);
+        all_fly_strings().set(string.m_impl.data);
+        string.m_impl.data->set_fly_string(true);
     } else {
-        m_data.m_data = *it;
-        m_data.m_data->ref();
+        m_data.m_impl.data = *it;
+        m_data.m_impl.data->ref();
     }
 }
 

+ 3 - 3
AK/FlyString.h

@@ -43,7 +43,7 @@ public:
     [[nodiscard]] ReadonlyBytes bytes() const { return m_data.bytes(); }
     [[nodiscard]] StringView bytes_as_string_view() const { return m_data.bytes(); }
 
-    [[nodiscard]] ALWAYS_INLINE bool operator==(FlyString const& other) const { return m_data.raw({}) == other.m_data.raw({}); }
+    [[nodiscard]] ALWAYS_INLINE bool operator==(FlyString const& other) const { return m_data.raw(Badge<FlyString> {}) == other.m_data.raw(Badge<FlyString> {}); }
     [[nodiscard]] bool operator==(String const& other) const { return m_data == other; }
     [[nodiscard]] bool operator==(StringView) const;
     [[nodiscard]] bool operator==(char const*) const;
@@ -89,7 +89,7 @@ private:
     friend class Optional<FlyString>;
 
     explicit FlyString(nullptr_t)
-        : m_data(Detail::StringBase(nullptr))
+        : m_data(nullptr)
     {
     }
 
@@ -100,7 +100,7 @@ private:
 
     Detail::StringBase m_data;
 
-    bool is_invalid() const { return m_data.is_invalid(); }
+    bool is_invalid() const { return m_data.raw(Badge<FlyString> {}) == 0; }
 };
 
 void did_destroy_fly_string_data(Badge<Detail::StringData>, Detail::StringData const&);

+ 5 - 0
AK/String.h

@@ -224,6 +224,11 @@ private:
 
     using ShortString = Detail::ShortString;
 
+    bool is_invalid() const
+    {
+        return raw(Badge<String> {}) == 0;
+    }
+
     explicit constexpr String(StringBase&& base)
         : StringBase(move(base))
     {

+ 3 - 7
AK/StringBase.cpp

@@ -13,7 +13,6 @@ namespace AK::Detail {
 
 void StringBase::replace_with_string_builder(StringBuilder& builder)
 {
-    ASSERT(!is_invalid());
     if (builder.length() <= MAX_SHORT_STRING_BYTE_COUNT) {
         return replace_with_new_short_string(builder.length(), [&](Bytes buffer) {
             builder.string_view().bytes().copy_to(buffer);
@@ -22,24 +21,22 @@ void StringBase::replace_with_string_builder(StringBuilder& builder)
 
     destroy_string();
 
-    m_data = &StringData::create_from_string_builder(builder).leak_ref();
+    m_impl = { .data = &StringData::create_from_string_builder(builder).leak_ref() };
 }
 
 ErrorOr<Bytes> StringBase::replace_with_uninitialized_buffer(size_t byte_count)
 {
-    ASSERT(!is_invalid());
     if (byte_count <= MAX_SHORT_STRING_BYTE_COUNT)
         return replace_with_uninitialized_short_string(byte_count);
 
     u8* buffer = nullptr;
     destroy_string();
-    m_data = &TRY(StringData::create_uninitialized(byte_count, buffer)).leak_ref();
+    m_impl = { .data = &TRY(StringData::create_uninitialized(byte_count, buffer)).leak_ref() };
     return Bytes { buffer, byte_count };
 }
 
 ErrorOr<StringBase> StringBase::substring_from_byte_offset_with_shared_superstring(size_t start, size_t length) const
 {
-    ASSERT(!is_invalid());
     VERIFY(start + length <= byte_count());
 
     if (length == 0)
@@ -49,7 +46,6 @@ ErrorOr<StringBase> StringBase::substring_from_byte_offset_with_shared_superstri
         bytes().slice(start, length).copy_to(result.replace_with_uninitialized_short_string(length));
         return result;
     }
-    return StringBase { TRY(Detail::StringData::create_substring(*m_data, start, length)) };
+    return StringBase { TRY(Detail::StringData::create_substring(*m_impl.data, start, length)) };
 }
-
 }

+ 74 - 44
AK/StringBase.h

@@ -40,10 +40,9 @@ public:
     StringBase(StringBase const&);
 
     constexpr StringBase(StringBase&& other)
-        : m_short_string(other.m_short_string)
+        : m_impl(other.m_impl)
     {
-        other.m_short_string = ShortString {};
-        other.m_short_string.byte_count_and_short_string_flag = SHORT_STRING_FLAG;
+        other.m_impl = { .short_string = { .byte_count_and_short_string_flag = SHORT_STRING_FLAG } };
     }
 
     StringBase& operator=(StringBase&&);
@@ -58,7 +57,9 @@ public:
     // NOTE: This is primarily interesting to unit tests.
     [[nodiscard]] constexpr bool is_short_string() const
     {
-        return (m_short_string.byte_count_and_short_string_flag & SHORT_STRING_FLAG) != 0;
+        if (is_constant_evaluated())
+            return (m_impl.short_string.byte_count_and_short_string_flag & SHORT_STRING_FLAG) != 0;
+        return (short_string_without_union_member_assertion().byte_count_and_short_string_flag & SHORT_STRING_FLAG) != 0;
     }
 
     // Returns the underlying UTF-8 encoded bytes.
@@ -69,11 +70,10 @@ public:
 
     [[nodiscard]] bool operator==(StringBase const&) const;
 
-    [[nodiscard]] ALWAYS_INLINE FlatPtr raw(Badge<FlyString>) const { return bit_cast<FlatPtr>(m_data); }
+    [[nodiscard]] ALWAYS_INLINE FlatPtr raw(Badge<FlyString>) const { return bit_cast<FlatPtr>(m_impl); }
+    [[nodiscard]] ALWAYS_INLINE FlatPtr raw(Badge<String>) const { return bit_cast<FlatPtr>(m_impl); }
 
 protected:
-    bool is_invalid() const { return m_invalid_tag == UINTPTR_MAX; }
-
     template<typename Func>
     ErrorOr<void> replace_with_new_string(size_t byte_count, Func&& callback)
     {
@@ -109,12 +109,12 @@ private:
     explicit StringBase(NonnullRefPtr<Detail::StringData const>);
 
     explicit constexpr StringBase(nullptr_t)
-        : m_invalid_tag(UINTPTR_MAX)
+        : m_impl { .data = nullptr }
     {
     }
 
     explicit constexpr StringBase(ShortString short_string)
-        : m_short_string(short_string)
+        : m_impl { .short_string = short_string }
     {
     }
 
@@ -125,18 +125,45 @@ private:
         VERIFY(is_short_string());
         VERIFY(byte_count <= MAX_SHORT_STRING_BYTE_COUNT);
 
-        m_short_string = ShortString {};
-        m_short_string.byte_count_and_short_string_flag = (byte_count << SHORT_STRING_BYTE_COUNT_SHIFT_COUNT) | SHORT_STRING_FLAG;
-        return { m_short_string.storage, byte_count };
+        m_impl = { .short_string = {} };
+        m_impl.short_string.byte_count_and_short_string_flag = (byte_count << SHORT_STRING_BYTE_COUNT_SHIFT_COUNT) | SHORT_STRING_FLAG;
+        return { m_impl.short_string.storage, byte_count };
     }
 
     void destroy_string();
 
+// from the union member that is not active; note that this guarantees nothing and just checks whatever state we're in - not all.
+#ifdef AK_STRINGBASE_VERIFY_LAUNDER_DEBUG
+    ShortString short_string_without_union_member_assertion() const
+    {
+        auto laundered_value = *__builtin_launder(&m_impl.short_string);
+        auto bitcast_value1 = bit_cast<FlatPtr>(*__builtin_launder(&m_impl.data));
+        auto bitcast_value2 = bit_cast<FlatPtr>(*__builtin_launder(&m_impl.short_string)); // one of these is the active one :P
+        VERIFY(bit_cast<FlatPtr>(laundered_value) == bitcast_value1 && bit_cast<FlatPtr>(laundered_value) == bitcast_value2);
+        return *__builtin_launder(&laundered_value);
+    }
+    StringData const* data_without_union_member_assertion() const
+    {
+        auto laundered_value = *__builtin_launder(&m_impl.data);
+        auto bitcast_value1 = bit_cast<FlatPtr>(*__builtin_launder(&m_impl.data));
+        auto bitcast_value2 = bit_cast<FlatPtr>(*__builtin_launder(&m_impl.short_string)); // one of these is the active one :P
+        VERIFY(bit_cast<FlatPtr>(laundered_value) == bitcast_value1 && bit_cast<FlatPtr>(laundered_value) == bitcast_value2);
+        return *__builtin_launder(&laundered_value);
+    }
+#else
+    // This is technically **invalid**!
+    // Inactive union members are not required to exist at all, and at this point there might not be any real object at this address.
+    // Empirically though, they point at the same address, so we can tell the compiler to _trust me bro_, and launder the pointer despite
+    // the launder itself being possibly-invalid.
+    // The block above asserts that we're reading the right value (:tm:), but this here is for schpeed.
+    ALWAYS_INLINE ShortString short_string_without_union_member_assertion() const { return *__builtin_launder(&m_impl.short_string); }
+    ALWAYS_INLINE StringData const* data_without_union_member_assertion() const { return *__builtin_launder(&m_impl.data); }
+#endif
+
     union {
-        ShortString m_short_string;
-        Detail::StringData const* m_data { nullptr };
-        uintptr_t m_invalid_tag;
-    };
+        ShortString short_string;
+        StringData const* data;
+    } m_impl;
 };
 
 inline ReadonlyBytes ShortString::bytes() const
@@ -151,81 +178,84 @@ inline size_t ShortString::byte_count() const
 
 inline ReadonlyBytes StringBase::bytes() const
 {
-    ASSERT(!is_invalid());
     if (is_short_string())
-        return m_short_string.bytes();
-    return m_data->bytes();
+        return m_impl.short_string.bytes();
+    if (!m_impl.data)
+        return {};
+    return data_without_union_member_assertion()->bytes();
 }
 
 inline u32 StringBase::hash() const
 {
-    ASSERT(!is_invalid());
     if (is_short_string()) {
         auto bytes = this->bytes();
         return string_hash(reinterpret_cast<char const*>(bytes.data()), bytes.size());
     }
-    return m_data->hash();
+    if (!m_impl.data)
+        return string_hash(nullptr, 0);
+    return data_without_union_member_assertion()->hash();
 }
 
 inline size_t StringBase::byte_count() const
 {
-    ASSERT(!is_invalid());
     if (is_short_string())
-        return m_short_string.byte_count_and_short_string_flag >> StringBase::SHORT_STRING_BYTE_COUNT_SHIFT_COUNT;
-    return m_data->byte_count();
+        return m_impl.short_string.byte_count_and_short_string_flag >> StringBase::SHORT_STRING_BYTE_COUNT_SHIFT_COUNT;
+
+    if (!m_impl.data)
+        return 0;
+    return data_without_union_member_assertion()->byte_count();
 }
 
 inline void StringBase::destroy_string()
 {
-    if (!is_short_string())
-        m_data->unref();
+    if (!is_short_string() && m_impl.data)
+        data_without_union_member_assertion()->unref();
 }
 
 inline StringBase::StringBase(NonnullRefPtr<Detail::StringData const> data)
-    : m_data(&data.leak_ref())
+    : m_impl { .data = &data.leak_ref() }
 {
 }
 
 inline StringBase::StringBase(StringBase const& other)
-    : m_data(other.m_data)
+    : m_impl(other.m_impl)
 {
-    if (!is_short_string())
-        m_data->ref();
+    if (!is_short_string() && m_impl.data)
+        data_without_union_member_assertion()->ref();
 }
 
 inline StringBase& StringBase::operator=(StringBase&& other)
 {
-    if (!is_short_string())
-        m_data->unref();
+    if (!is_short_string() && m_impl.data)
+        data_without_union_member_assertion()->unref();
 
-    m_data = exchange(other.m_data, nullptr);
-    other.m_short_string.byte_count_and_short_string_flag = SHORT_STRING_FLAG;
+    m_impl = exchange(other.m_impl, { .short_string = { .byte_count_and_short_string_flag = SHORT_STRING_FLAG } });
     return *this;
 }
 
 inline StringBase& StringBase::operator=(StringBase const& other)
 {
     if (&other != this) {
-        if (!is_short_string())
-            m_data->unref();
+        if (!is_short_string() && m_impl.data)
+            data_without_union_member_assertion()->unref();
 
-        m_data = other.m_data;
-        if (!is_short_string())
-            m_data->ref();
+        m_impl = other.m_impl;
+        if (!is_short_string() && m_impl.data)
+            data_without_union_member_assertion()->ref();
     }
     return *this;
 }
 
 inline bool StringBase::operator==(StringBase const& other) const
 {
-    ASSERT(!is_invalid());
     if (is_short_string())
-        return m_data == other.m_data;
+        return bit_cast<FlatPtr>(m_impl) == bit_cast<FlatPtr>(other.m_impl);
     if (other.is_short_string())
         return false;
-    if (m_data->is_fly_string() && other.m_data->is_fly_string())
-        return m_data == other.m_data;
+    if (m_impl.data == nullptr || other.m_impl.data == nullptr)
+        return m_impl.data == other.m_impl.data;
+    if (data_without_union_member_assertion()->is_fly_string() && other.data_without_union_member_assertion()->is_fly_string())
+        return m_impl.data == other.m_impl.data;
     return bytes() == other.bytes();
 }
-
 }

+ 1 - 0
Meta/CMake/all_the_debug_macros.cmake

@@ -1,3 +1,4 @@
+set(AK_STRINGBASE_VERIFY_LAUNDER_DEBUG ON)
 set(AUDIO_DEBUG ON)
 set(BMP_DEBUG ON)
 set(CACHE_DEBUG ON)