atcoder::lazy_segtree に1行書き足すだけの抽象化 Segment Tree Beats

AtCoder Library (ACL)atcoder::lazy_segtree をもとにした Segment tree beats の抽象化の方法と,そのいくつかの具体的な使用例を紹介します.Segment tree beats は列に対する複雑な更新・取得処理を高速かつオンラインに実現する強力な手法ですが,実装の際に考慮すべきことが複雑でコーディング量も多く,体系立った実装方法の知見も整理されていないと筆者は認識しています.そこで本稿では,ドキュメントを含め極めて整備された ACL のコードをわずかに変更して Segment tree beats として使用する方法を紹介します.この方法では,様々な Segment tree beats の実装の大部分が共通化され,個別の問題に応じた機能は atcoder::lazy_segtree の利用時と同様にクラスや関数として組み込まれます.これにより,無用なデバッグの手間が削減され,コーダーはより本質的なロジックに対して集中できるようになります.その上,元の atcoder::lazy_segtree が持っていた機能はもちろんそのまま引き継がれるので,max_right()min_left() 関数による効率的な二分探索も可能です.また,コードが問題に依存しない共通部分と問題毎に特殊な部分で完全に分離されるため,ライブラリ化の際の管理も容易です.

本稿は「Segment tree beats の存在は知っているが実装の経験がなく,これからできるだけバグで苦しまずに実装したい」という方を念頭に置いて書かれています.Segment tree beats のオーソドックスな導入は本稿には含まれません.なお,本稿に記載したコードは各種オンラインジャッジへの提出を目的として自由に使用・改変して頂いて構いません *1 が,それに伴う結果に筆者はいかなる責任も負いません.

Segment tree beats の通常の Lazy Segtree との差異

最初に,本稿で「Segment tree beats」が指す対象と,特に通常の Lazy Segtree との差異を明確化しておきます*2.以下,説明に ACL の Lazy Segtree のドキュメント の定義・用語を使用します.

通常の Lazy Segtree では,区間の情報を集約した元 \(x \in S \) について,写像 \(f \in F\) の \(x\) への作用 \(x' = f(x)\) が常に正しく計算できる(形式的に書くと,任意の \(\mathbf{x} = [x_1, \dots, x_k] \in S^k \) と \(f \in F\) について \(f(x_1 \cdot x_2 \cdots x_k) = f(x_1) \cdot f(x_2) \cdots f(x_k)\) が成立)ことを前提に考えます.しかし,複雑な更新・取得クエリを扱いたい場合,区間を管理する \(x\) の持つ情報の不足によってこの計算に「失敗」 してしまうことがあります.例えば,Range Chmax Add Range Sum などのクエリを扱う典型的な Segment tree beats では,各モノイドが「区間に含まれる最小値の値・個数」「区間に含まれる二番目に小さい値」などの情報を持ちますが,ここで例えば \(f\) が \(x\) の「二番目に小さい値」以上の値による Chmax クエリであった場合,更新後の二番目に小さい値はこの \(x\) と \(f\) の情報だけからは計算できません.

作用 \( f(x) \) の計算に失敗してしまった場合,「とりあえず \(f\) を \(x\) の子に伝播させて子の計算を先に行い,子に関する計算結果を用いてボトムアップに \(x'\) を再構築する」という方法で問題を回避できそうです.ここで計算の失敗回数が多すぎる(例えば列の長さ \(N\) やクエリ個数 \(Q\) に関する二乗のオーダーになってしまう)と計算量が悪化してしまいますが, \(S\) や \(F\) の構造などに由来する何らかの理由で「Segtree 上の全頂点における失敗回数の合計」に良い上界(例えば \(O( (N + Q) \log^2 N)\))が与えられるならば,計算量の本質的な悪化が回避でき,高速なクエリ処理が可能となります.

以上のように,「\(S\) の元同士の二項演算と写像 \( f \) による作用の可換性が満たされず,たびたび作用の計算に失敗してしまうが,その場合には失敗した部分をボトムアップに再構築する」ような Segtree を本稿では Segment tree beats と呼び,通常の Lazy Segtree と区別することにします*3ACL の Lazy Segtree には勿論このような処理は実装されていないため,本稿の以下でこの処理を「追加」する方法を説明します.

実装方法

以下,特に ACL の 2021/2/1 時点で最新のコミット のバージョンを対象に説明します.

実装方法はシンプルです.atcoder::lazy_segtree 末尾付近の all_apply() に以下のように書き足してください.

変更前
void all_apply(int k, F f) {
    d[k] = mapping(f, d[k]);
    if (k < size) lz[k] = composition(f, lz[k]);
}
変更後(Segment tree beats)
void all_apply(int k, F f) {
    d[k] = mapping(f, d[k]);
    if (k < size) {
        lz[k] = composition(f, lz[k]);
        if (d[k].fail) push(k), update(k);
    }
}

これにより,区間を管理する元 d[k] への写像 f による作用の計算に失敗し fail フラグが立ったときに,d[k] の再計算がボトムアップに行われるようになります*4

以上の変更を施された atcoder::lazy_segtree を使う際は,S, op, F, mapping たちは以下のように実装されなければなりません(ACL のドキュメントのような厳密な形式にできていないのは申し訳ないです.少なくとも本稿の後半で挙げる実用例は全てこれに従っています):

  • Satcoder::lazy_segtree から参照可能なメンバ変数 fail を持つ.
  • mapping 関数による S の元 x への作用の結果を得る計算が(x の持つ情報の不足が原因で)失敗した場合のみ,mapping 関数が返す Sインスタンスfail の値は true となる.
  • mapping 関数による作用以外の部分(例えば,op による S の元の二項演算)で計算が失敗することはない.
  • 素数 \(1\) の区間を管理する S の元(すなわち Segtree 上の葉)に対しては,mapping 関数は計算を失敗してはならない(これに失敗したら困るので,アタリマエですが…).

具体的な問題に対してアルゴリズムを設計する際,特に S にどのような情報を持たせるか決める際には,fail が立つ回数ができるだけ少なくなるような工夫を施すのがポイントです.あなたがうまく設計できれば,この回数が \(O( (N+Q) \log (N + Q))\) や \( O( (N + Q) \log^2 (N + Q) )\) になります*5

使用例

上述の改造を施した atcoder::lazy_segtree を用いて実際のオンラインジャッジ上の問題を解いてみます*6

yukicoder No.880 "Yet Another Segment Tree Problem"

本問は,以下のクエリを効率的に処理することを問うています:

  • \(\mathrm{Update}: a_i \leftarrow x \; \mathrm{for} \; l \le i \le r\)
  • \(\mathrm{Update}: a_i \leftarrow \mathrm{gcd}(a_i, x) \; \mathrm{for} \; l \le i \le r\)
  • \(\mathrm{Output}: \; \max_{l \le i \le r} a_i \)
  • \(\mathrm{Output}: \; \sum_{i = l}^r a_i. \)

この問題を本稿の手法で解くための実装が以下のようになります:

namespace RangeUpdateChgcdRangeMaxSum {

constexpr uint32_t BINF = 1 << 30;
struct S {
    uint32_t max;    // 区間最大値
    uint32_t lcm;    // min(BINF, (区間内全要素の最小公倍数))
    uint32_t sz;     // 区間要素数
    uint64_t sum;    // 区間内全要素の総和
    bool fail;
    S() : max(0), lcm(1), sz(0), sum(0), fail(0) {}
    S(uint32_t x, uint32_t sz_ = 1) : max(x), lcm(x), sz(sz_), sum((uint64_t)x * sz_), fail(0) {}
};

S e() { return S(); }

S op(S l, S r) {
    if (r.sz == 0) return l;
    if (l.sz == 0) return r;
    S ret;
    ret.max = std::max(l.max, r.max);
    ret.sum = l.sum + r.sum;
    ret.lcm = std::min(uint64_t(BINF), (uint64_t)l.lcm * r.lcm / std::__gcd(l.lcm, r.lcm));
    ret.sz = l.sz + r.sz;
    return ret;
}

struct F {
    uint32_t dogcd, reset;
    F() : dogcd(0), reset(0) {}
    F(uint32_t g, uint32_t upd) : dogcd(g), reset(upd) {}
    static F gcd(uint32_t g) noexcept { return F(g, 0); }
    static F update(uint32_t a) noexcept { return F(0, a); }
};

F composition(F fnew, F fold) {
    if (fnew.reset) return F::update(fnew.reset);
    else if (fold.reset) {
        return F::update(std::__gcd(fnew.dogcd, fold.reset));
    } else {
        return F::gcd(std::__gcd(fnew.dogcd, fold.dogcd));
    }
}

F id() { return F(); }

S mapping(F f, S x) {
    if (x.fail) return x;
    if (f.reset) x = S(f.reset, x.sz);
    if (f.dogcd) {
        if (x.sz == 1) {
            x = S(std::__gcd(x.max, f.dogcd));
        } else if (x.lcm == BINF or f.dogcd % x.lcm) {
            // 区間 gcd クエリによって,複数個の要素からなる区間で
            // いずれかの値が変更を受ける場合のみ計算失敗
            x.fail = true;
        }
    }
    return x;
}
using segtree = atcoder::lazy_segtree<S, op, e, F, mapping, composition, id>;
} // namespace RangeUpdateChgcdRangeMaxSum

コードから分かるように,区間を管理するモノイド \(S\) には以下の情報を持たせています:

  • max : 区間に含まれる値の最大値
  • lcm : 区間に含まれる値の最小公倍数
  • sz : 区間に含まれ得る要素数
  • sum : 区間に含まれる値の総和

区間 \(\gcd\) という(見慣れない)クエリに対処するため変数 lcm を保持するのがこのデータ構造のポイントです.上述のように定義された S について作用の計算が失敗するのは,値 F::dogcd による gcd クエリに対して lcm がその約数ではないときに限られます.しかし,そのような状況では区間に含まれる一個以上の元の値が必ず減少します.全ての頂点について gcd クエリによる値の減少回数の総和が上から評価できることから,作用の計算が失敗する回数も同様に上から抑えられます.詳しくは 本問の公式解説(要ログイン) をご参照ください.

一つ目の例なので main() 関数も貼っておきます.

#include <iostream>
using namespace std;

int main() {
    cin.tie(nullptr), ios::sync_with_stdio(false);
    uint32_t N, Q;
    cin >> N >> Q;
    vector<RangeUpdateChgcdRangeMaxSum::S> A(N);
    for (auto &a : A) {
        uint32_t tmp;
        cin >> tmp, a = {tmp, 1};
    }

    RangeUpdateChgcdRangeMaxSum::segtree segtree(A);
    uint32_t q, l, r, x;
    while (Q--) {
        cin >> q >> l >> r;
        l--;
        if (q <= 2) {
            cin >> x;
            if (q == 1) segtree.apply(l, r, RangeUpdateChgcdRangeMaxSum::F::update(x));
            if (q == 2) segtree.apply(l, r, RangeUpdateChgcdRangeMaxSum::F::gcd(x));
        } else {
            auto v = segtree.prod(l, r);
            if (q == 3) cout << v.max << '\n';
            if (q == 4) cout << v.sum << '\n';
        }
    }
}

参考提出 (C++11 (gcc 4.8.5), 769 ms)

CS Academy Round #70 "And or Max"

まず本問題で扱うべきクエリ及び実装例を以下に示します:

  • \(\mathrm{Update}: \; a_i \leftarrow \mathrm{bitwise\_and}(a_i, b) \; \mathrm{for} \; l \le i \le r\)
  • \(\mathrm{Update}: \; a_i \leftarrow \mathrm{bitwise\_or}(a_i, b)\; \mathrm{for} \; l \le i \le r\)
  • \(\mathrm{Output}: \; \max_{l \le i \le r} a_i. \)
namespace RangeBitwiseAndOrRangeMax {
using UINT = uint32_t;
constexpr UINT digit = 20;
constexpr int mask = (1 << digit) - 1;

struct S {
    UINT max;    // 区間最大値
    UINT upper;  // 区間内全要素の bitwise or
    UINT lower;  // 区間内全要素の bitwise and
    bool fail;
    S(UINT x = 0) : max(x), upper(x), lower(x), fail(false) {}
};

S e() { return S(); }

S op(S l, S r) {
    l.max = std::max(l.max, r.max);
    l.upper |= r.upper;
    l.lower &= r.lower;
    return l;
}

struct F {
    UINT bit_and;
    UINT bit_or;
    F() : bit_and(mask), bit_or(0) {}
    F(UINT a, UINT o) : bit_and(a), bit_or(o) {}
    static F b_and(UINT a) noexcept { return {a, 0}; }
    static F b_or(UINT a) noexcept { return {mask, a}; }
};

F composition(F fnew, F fold) {
    return F{fnew.bit_and & fold.bit_and, fnew.bit_or | (fnew.bit_and & fold.bit_or)};
}

F id() { return F(); }

S mapping(F f, S x) {
    if ((x.upper - x.lower) & (~f.bit_and | f.bit_or)) {
        // 区間内で立っている要素と立っていない要素が混在するような bit で
        // 変更が起きた場合のみ計算失敗(新たな最大値が不明なので)
        x.fail = true;
        return x;
    }
    x.upper = (x.upper & f.bit_and) | f.bit_or;
    x.lower = (x.lower & f.bit_and) | f.bit_or;
    x.max = (x.max & f.bit_and) | f.bit_or;
    return x;
}
using segtree = atcoder::lazy_segtree<S, op, e, F, mapping, composition, id>;
} // namespace RangeBitwiseAndOrRangeMax

この問題では,区間を管理するモノイドに「区間最大値」に加え「区間内全要素の bitwise and / bitwise or」を持たせます.このとき,作用の計算に失敗するのは「区間全体の bitwise and と bitwise or の状態が異なるような bit に対する更新が起きた場合」のみです(bitwise and と bitwise or の状態が一致しているような bit に対しては,更新が起きても区間内の全要素の値に同じ値が足されたり引かれたりするだけなので,最大値の追跡が可能です).

ところで作用の計算が失敗するとき,その区間における bitwise and と bitwise or の状態が異なっているような bit の個数(Hamming 距離)が必ず減少します.この事実から,作用の計算が失敗する回数を上から抑えることが可能です.詳細は 本問の公式解説 をご参照ください.

参考実装 (C++, 449 ms)

[2021/2/3 追記,お気持ちパート]以上二問の例題を比較して眺めると,「作用の計算が失敗するとき,必ずその頂点が管理する区間に含まれる要素の『複雑さ』や『多様性』のようなものが減少する」という共通的な観点が見えてきます(という気持ちに筆者はなります).例えば最初の例題では,gcd クエリに失敗することは区間内のいずれかの値に対して減少が起きることに他なりません.値 \(A\) に対して,gcd クエリによる値の減少はせいぜい \(O(\log A)\) 回しか起きないわけですから,このような「失敗」を重ねればほどなく区間内の全要素の値が \(1\) になるか,あるいは少なくとも全要素が同じ値へと収斂していくでしょう.二問目も同様に,「失敗」が一度起きる毎にその区間内の全要素の bitwise and と bitwise or の Hamming 距離が減少していくのですから,いずれ(本問の制約では各要素 \(A_i\) は \(0 \le A_i \le 2^{20}\)を満たすので,たかだか20回の失敗で)「区間内の全要素が同じ値」という極めてシンプルな状態に落ち着くはずです. このような観点で考えることが,Segment tree beats における S の設計や計算量保証に対する直観的な理解を深める上で有効かもしれません(と,筆者は漠然と感じています).

Library Checker "Range Chmin Chmax Add Range Sum"

Segment tree beats として最も有名なものかと思います.本稿のアプローチは,「二番目に小さい・大きい値の更新」など,場合分けが少々ややこしい部分から本質的に解放されるものではありませんが,少なくともコードの見通しの良さ・再利用性はかなり改善されるのではないかと思っています.

namespace RangeChMinMaxAddSum {
template <typename Num> inline Num second_lowest(Num a, Num a2, Num c, Num c2) noexcept {
    // a < a2, c < c2 のとき全引数を昇順ソートして二番目の値
    return a == c ? std::min(a2, c2) : a2 <= c ? a2 : c2 <= a ? c2 : std::max(a, c);
}
template <typename Num> inline Num second_highest(Num a, Num a2, Num b, Num b2) noexcept {
    // a > a2, b > b2 のとき全引数を降順ソートして二番目の値
    return a == b ? std::max(a2, b2) : a2 >= b ? a2 : b2 >= a ? b2 : std::min(a, b);
}

using BNum = long long;
constexpr BNum BINF = 1LL << 61;

struct S {
    BNum lo, hi, lo2, hi2, sum;  // 区間最小・最大値,区間最小・最大から二番目の値,区間総和
    unsigned sz, nlo, nhi;       // 区間要素数,区間最小・最大値をとる要素の個数
    bool fail;
    S() : lo(BINF), hi(-BINF), lo2(BINF), hi2(-BINF), sum(0), sz(0), nlo(0), nhi(0), fail(0) {}
    S(BNum x, unsigned sz_ = 1)
        : lo(x), hi(x), lo2(BINF), hi2(-BINF), sum(x * sz_), sz(sz_), nlo(sz_), nhi(sz_), fail(0) {}
};

S e() { return S(); }

S op(S l, S r) {
    S ret;
    ret.lo = std::min(l.lo, r.lo), ret.hi = std::max(l.hi, r.hi);
    ret.lo2 = second_lowest(l.lo, l.lo2, r.lo, r.lo2);
    ret.hi2 = second_highest(l.hi, l.hi2, r.hi, r.hi2);
    ret.sum = l.sum + r.sum, ret.sz = l.sz + r.sz;
    ret.nlo = l.nlo * (l.lo <= r.lo) + r.nlo * (r.lo <= l.lo);
    ret.nhi = l.nhi * (l.hi >= r.hi) + r.nhi * (r.hi >= l.hi);
    return ret;
}

struct F {
    BNum lb, ub, bias;
    F(BNum chmax_ = -BINF, BNum chmin_ = BINF, BNum add = 0) : lb(chmax_), ub(chmin_), bias(add) {}
    static F chmin(BNum x) noexcept { return F(-BINF, x, BNum(0)); }
    static F chmax(BNum x) noexcept { return F(x, BINF, BNum(0)); }
    static F add(BNum x) noexcept { return F(-BINF, BINF, x); };
};

F composition(F fnew, F fold) {
    F ret;
    ret.lb = std::max(std::min(fold.lb + fold.bias, fnew.ub), fnew.lb) - fold.bias;
    ret.ub = std::min(std::max(fold.ub + fold.bias, fnew.lb), fnew.ub) - fold.bias;
    ret.bias = fold.bias + fnew.bias;
    return ret;
}

F id() { return F(); }

S mapping(F f, S x) {
    if (x.sz == 0) return e();
    else if (x.lo == x.hi or f.lb == f.ub or f.lb >= x.hi or f.ub < x.lo) {
        return S(std::min(std::max(x.lo, f.lb), f.ub) + f.bias, x.sz);
    } else if (x.lo2 == x.hi) {
        x.lo = x.hi2 = std::max(x.lo, f.lb) + f.bias;
        x.hi = x.lo2 = std::min(x.hi, f.ub) + f.bias;
        x.sum = x.lo * x.nlo + x.hi * x.nhi;
        return x;
    } else if (f.lb < x.lo2 and f.ub > x.hi2) {
        BNum nxt_lo = std::max(x.lo, f.lb), nxt_hi = std::min(x.hi, f.ub);
        x.sum += (nxt_lo - x.lo) * x.nlo - (x.hi - nxt_hi) * x.nhi + f.bias * x.sz;
        x.lo = nxt_lo + f.bias, x.hi = nxt_hi + f.bias, x.lo2 += f.bias, x.hi2 += f.bias;
        return x;
    }
    x.fail = 1;
    return x;
}
using segtree = atcoder::lazy_segtree<S, op, e, F, mapping, composition, id>;
} // namespace RangeChMinMaxAddSum

参考提出 (cpp, 872 ms)

その他知見,定数倍高速化について

  • 本稿で作った Segment tree beats は,残念ながら定数倍も最強という訳ではありません(当たり前ではありますが……).定量的な比較はまだできていませんが,オンラインジャッジ上では最速の提出と比べ 2 倍から 3 倍程度の実行時間となることが多いようです.とはいえ,筆者は実際のコンテストで定数倍が原因で落とされるまでは自分で一から Segment tree beats を書かずこれを使用するつもりです.
  • 筆者のプログラミングコンテスト用ライブラリ では執筆時点で segtree_beats クラスを atcoder::lazy_segtree からの継承を用いて記述しています(そのままではコンパイルが通らないので, private なメンバ変数・関数を一部 protected にしたり仮想関数化したりしています).ライブラリ内のコードの重複が抑えられてすっきりしますが,問題や実行環境によっては仮想関数の呼び出しに伴うオーバーヘッドによって定数倍が悪化する可能性があるのが難点です.例えば上で例題として紹介した "And or Max" では,atcoder::lazy_segtree を直接改造する場合に比べ実行時間は(CS Academy のジャッジ上で)少々悪化(449 ms -> 527 ms)しました.ただし yukicoder の問題例では差はほぼ確認できませんでした.

今後,逐次追記するかもしれません.

おわりに

璃奈「バグなく作用を伝えることって難しい」

璃奈「Beats の場合は特にそう……」

本稿では,Segment tree beats を「作用の計算がある程度失敗しうる Lazy Segtree」と特徴付けた上で,ACLlazy_segtree の一か所を改変し,抽象化された Segment tree beats として使用する方法を述べました.またその具体的な使用方法をいくつかの問題例を通して説明しました.本稿の方法について,atcoder::lazy_segtree をある程度使用した経験があればスムースに Segment tree beats の実装が行えるという点や,ACL のドキュメントの内容が今回の Segment tree beats にもほぼそのまま通用する(自作ライブラリの使用方法を暗記したり,ドキュメントを自分で整備する必要がない)という点は,プログラミングコンテストでの実使用においても極めて大きなメリットであると筆者は考えています. 筆者は当面の間 Segment tree beats はこれを使用していく予定ですので,今後もし何か大きな知見が得られた際には継続的に報告していこうと思っています.

璃奈「侑さん,頼まれてた抽象化 Segment tree beats できたよ」

侑「うひょーっ! 璃奈ちゃんありがとう! 早速 Range Chmin Chmax Add Range Sum 埋めてくるね!」シュタタタッ

璃奈「くれぐれも悪用しないでね」フリフリ

*1:ACL のヘッダファイル群には CC0 ライセンスが適用されているため,ACL の改変行為に対する心配は不要です.

*2:競技プログラミングのコミュニティで Segment tree beats として呼称されるテクニックの範疇が,本稿の抽象化の範囲と一致しない可能性がある点に注意してください.例えば,この Codeforces の記事 で紹介されているあらゆるデータ構造が本稿の手法で実現できるかどうかについて筆者はまだ確認していません.

*3:繰り返しますが,この定義は本稿の抽象化のアプローチに基づいて導入されたものであることや,必ずしも広く同意が得られているものではないということに注意してください.

*4:ACL のコードを読むと確認できるように,composition 関数は lazy_segtree 内の二か所でしか呼ばれません.しかも,そのうち(今回手を加えた方ではない)一方ではこの関数は Segtree 上の葉(単一の要素からなる区間を表す頂点)に対して作用するため,情報の不足による計算の失敗は起こりえません.よって,本文中に挙げた一か所を修正するだけで所望の処理が実装できます.

*5:理論的な保証をあまり考えずに実装に突っ走った結果,実はそもそも最悪計算量が二乗から改善されていなかったという事例が筆者の場合よくあります.

*6:本稿のコードは基本的に gcc, C++11 でコンパイルが通るように実装されています.最近の標準規格で導入された std::clamp() などの関数を使うことで,例示したコードのいくつかの部分はより簡潔に記述できます.