cpp-library

This documentation is automatically generated by competitive-verifier/competitive-verifier

View the Project on GitHub shino16/cpp-library

:heavy_check_mark: ds/binary_trie.hpp

Depends on

Verified with

Code

#pragma once
#include "util/pool.hpp"

template <class T>
struct xor_mask_enabled {
  void set_mask(T s2) { s = s2; }
  void add_mask(T s2) { s ^= s2; }
  T mask() const { return s; }

 private:
  T s = 0;
};

template <class T>
struct xor_mask_disabled {
  constexpr T mask() const { return 0; }
};

// [0, U]
template <
    class T, T U, class Alloc = pool<>,
    template <class> class XorMask = xor_mask_disabled>
class binary_trie : public XorMask<T> {
 private:
  struct node : Alloc::template alloc<node> {
    node *l = nullptr, *r = nullptr;
    int cnt = 0;
  };

 public:
  binary_trie() = default;

  int size() const { return root->cnt; }
  bool empty() const { return size() == 0; }
  void insert(T x) { insert(root, x, B - 1); }
  void erase(T x) { erase(root, x, B - 1); }
  int count(T x) const { return count(root, x, B - 1); }
  // returns -1 if empty
  T max() const { return empty() ? -1 : max(root, 0, B - 1); }
  // returns -1 if empty
  T min() const { return empty() ? -1 : min(root, 0, B - 1); }
  T find(int k) const { return find(root, 0, k, B - 1); }

 private:
  static constexpr int calc_b(T u) {
    int b = 0;
    while (u) u >>= 1, b++;
    return b;
  }
  static constexpr int B = calc_b(U);

  node* root = new node;

  T bit_at(T x, int i) const { return (x ^ this->mask()) >> i & 1; }
  node* insert(node* p, T x, int i) {
    if (!p)
      p = new node;
    p->cnt++;
    if (i != -1) {
      if (bit_at(x, i)) p->r = insert(p->r, x, i - 1);
      else p->l = insert(p->l, x, i - 1);
    }
    return p;
  }
  void erase(node* p, T x, int i) {
    p->cnt--;
    if (i != -1) {
      if (bit_at(x, i)) erase(p->r, x, i - 1);
      else erase(p->l, x, i - 1);
    }
  }
  int count_tree(const node* p) const { return p ? p->cnt : 0; }
  int count(const node* p, T x, int i) const {
    return !p             ? 0
           : i == -1      ? p->cnt
           : bit_at(x, i) ? count(p->r, x, i - 1)
                          : count(p->l, x, i - 1);
  }
  T max(const node* p, T x, int i) const {
    if (i == -1) return x;
    if (this->mask() >> i & 1)
      return count_tree(p->l) ? max(p->l, x | 1 << i, i - 1)
                              : max(p->r, x, i - 1);
    else
      return count_tree(p->r) ? max(p->r, x | 1 << i, i - 1)
                              : max(p->l, x, i - 1);
  }
  T min(const node* p, T x, int i) const {
    if (i == -1) return x;
    if (this->mask() >> i & 1)
      return count_tree(p->r) ? min(p->r, x, i - 1)
                              : min(p->l, x | 1 << i, i - 1);
    else
      return count_tree(p->l) ? min(p->l, x, i - 1)
                              : min(p->r, x | 1 << i, i - 1);
  }
  T find(const node* p, T x, int k, int i) const {
    if (i == -1) {
      assert(k == 0 && p->cnt != 0);
      return x;
    }
    if (this->mask() >> i & 1)
      return count_tree(p->r) < k
                 ? find(p->r, x, k, i - 1)
                 : find(p->l, x | 1 << i, k - count_tree(p->r), i - 1);
    else
      return count_tree(p->l) < k
                 ? find(p->l, x, k, i - 1)
                 : find(p->r, x | 1 << i, k - count_tree(p->l), i - 1);
  }
};
#line 2 "prelude.hpp"
#include <bits/stdc++.h>
using namespace std;
using ll = long long;
using vi = vector<int>;
using vvi = vector<vector<int>>;
using vll = vector<ll>;
using vvll = vector<vector<ll>>;
using vc = vector<char>;
#define rep2(i, m, n) for (auto i = (m); i < (n); i++)
#define rep(i, n) rep2(i, 0, n)
#define repr2(i, m, n) for (auto i = (n); i-- > (m);)
#define repr(i, n) repr2(i, 0, n)
#define all(x) begin(x), end(x)
auto ndvec(int n, auto e) { return vector(n, e); }
auto ndvec(int n, auto ...e) { return vector(n, ndvec(e...)); }
auto comp_key(auto&& f) { return [&](auto&& a, auto&& b) { return f(a) < f(b); }; }
auto& max(const auto& a, const auto& b) { return a < b ? b : a; }
auto& min(const auto& a, const auto& b) { return b < a ? b : a; }
#if __cpp_lib_ranges
namespace R = std::ranges;
namespace V = std::views;
#endif
#line 3 "util/pool.hpp"

template <size_t N = 1 << 20>
struct pool {
  template <class T>
  struct alloc {
    static inline T* ptr =
        (T*)new (align_val_t(alignof(T))) unsigned char[sizeof(T) * N];
    static void* operator new(size_t) noexcept { return ptr++; }
    static void operator delete(void* p) { destroy_at((T*)p); }

   private:
    alloc() = default;
    friend T;
  };
};

struct no_pool {
  template <class T>
  struct alloc {};
};
#line 3 "ds/binary_trie.hpp"

template <class T>
struct xor_mask_enabled {
  void set_mask(T s2) { s = s2; }
  void add_mask(T s2) { s ^= s2; }
  T mask() const { return s; }

 private:
  T s = 0;
};

template <class T>
struct xor_mask_disabled {
  constexpr T mask() const { return 0; }
};

// [0, U]
template <
    class T, T U, class Alloc = pool<>,
    template <class> class XorMask = xor_mask_disabled>
class binary_trie : public XorMask<T> {
 private:
  struct node : Alloc::template alloc<node> {
    node *l = nullptr, *r = nullptr;
    int cnt = 0;
  };

 public:
  binary_trie() = default;

  int size() const { return root->cnt; }
  bool empty() const { return size() == 0; }
  void insert(T x) { insert(root, x, B - 1); }
  void erase(T x) { erase(root, x, B - 1); }
  int count(T x) const { return count(root, x, B - 1); }
  // returns -1 if empty
  T max() const { return empty() ? -1 : max(root, 0, B - 1); }
  // returns -1 if empty
  T min() const { return empty() ? -1 : min(root, 0, B - 1); }
  T find(int k) const { return find(root, 0, k, B - 1); }

 private:
  static constexpr int calc_b(T u) {
    int b = 0;
    while (u) u >>= 1, b++;
    return b;
  }
  static constexpr int B = calc_b(U);

  node* root = new node;

  T bit_at(T x, int i) const { return (x ^ this->mask()) >> i & 1; }
  node* insert(node* p, T x, int i) {
    if (!p)
      p = new node;
    p->cnt++;
    if (i != -1) {
      if (bit_at(x, i)) p->r = insert(p->r, x, i - 1);
      else p->l = insert(p->l, x, i - 1);
    }
    return p;
  }
  void erase(node* p, T x, int i) {
    p->cnt--;
    if (i != -1) {
      if (bit_at(x, i)) erase(p->r, x, i - 1);
      else erase(p->l, x, i - 1);
    }
  }
  int count_tree(const node* p) const { return p ? p->cnt : 0; }
  int count(const node* p, T x, int i) const {
    return !p             ? 0
           : i == -1      ? p->cnt
           : bit_at(x, i) ? count(p->r, x, i - 1)
                          : count(p->l, x, i - 1);
  }
  T max(const node* p, T x, int i) const {
    if (i == -1) return x;
    if (this->mask() >> i & 1)
      return count_tree(p->l) ? max(p->l, x | 1 << i, i - 1)
                              : max(p->r, x, i - 1);
    else
      return count_tree(p->r) ? max(p->r, x | 1 << i, i - 1)
                              : max(p->l, x, i - 1);
  }
  T min(const node* p, T x, int i) const {
    if (i == -1) return x;
    if (this->mask() >> i & 1)
      return count_tree(p->r) ? min(p->r, x, i - 1)
                              : min(p->l, x | 1 << i, i - 1);
    else
      return count_tree(p->l) ? min(p->l, x, i - 1)
                              : min(p->r, x | 1 << i, i - 1);
  }
  T find(const node* p, T x, int k, int i) const {
    if (i == -1) {
      assert(k == 0 && p->cnt != 0);
      return x;
    }
    if (this->mask() >> i & 1)
      return count_tree(p->r) < k
                 ? find(p->r, x, k, i - 1)
                 : find(p->l, x | 1 << i, k - count_tree(p->r), i - 1);
    else
      return count_tree(p->l) < k
                 ? find(p->l, x, k, i - 1)
                 : find(p->r, x | 1 << i, k - count_tree(p->l), i - 1);
  }
};
Back to top page