Convex Hull Trick

概要

問題 : 「多数の f_i(x) = a_i * x + b_i と,多数のクエリ x に対し min f_i(x) を求めよ.」


この問題を解くためにConvex Hull Trickというものがある.
蟻本において,K-anonymous Sequenceという問題の解法として紹介されている(Convex Hull Trickという名前は出てきていない).
蟻本で紹介されているパターンには以下の条件がある.

  • 追加される直線の係数が単調減少
  • クエリが単調増加

この記事では,これらの条件が無いパターンの解き方とその実装を紹介する.
http://wcipeg.com/wiki/Convex_hull_trickアルゴリズムの説明はあるので英語を読むのが苦でない人はこれを見たほうがいい.

注意:蟻本のパターンに比べて条件がない場合は log N 分計算量が悪くなる.

クエリが単調増加でない場合

蟻本パターンと同様に適当に直線を管理する.
クエリに対しては二分探索する.
おわり.

#include <bits/stdc++.h>
using namespace std;
typedef pair<int,int> pii;

struct CHT {
  vector<pii> deq;              // first * x + second
  int s,t;
  CHT(int n) {                  // n : クエリ数
    deq.resize(n);
    s=0, t=0;
  }
  void add(int a, int b) {      // a : 単調減少
    const pii p(a,b);
    while(s+1<t && check(deq[t-2],deq[t-1],p)) t--;
    deq[t++] = p;
  }
  int incl_query(int x) {            // x : 単調増加
    while(s+1<t && f(deq[s], x) >= f(deq[s+1], x)) s++;
    return f(deq[s], x);
  }
  int query(int x) {           // 条件なし
    int low = s-1, high = t-1;
    while(low+1<high) {
      int mid = low+high>>1;
      if (isright(deq[mid], deq[mid+1], x)) low = mid;
      else high = mid;
    }
    return f(deq[high], x);
  }
private:
  bool isright(const pii &p1, const pii &p2, int x) {
    return (p1.second-p2.second) >= x * (p2.first-p1.first);
  }
  bool check(const pii &p1, const pii &p2, const pii &p3) {
    return (p2.first-p1.first)*(p3.second-p2.second) >=
      (p2.second-p1.second)*(p3.first-p2.first);
  }
  int f(const pii &p, int x) {
    return p.first * x + p.second;
  }
};
int main() {
  CHT cht(3);
  cht.add(2,3);
  cht.add(-1,4);
  cout << cht.query(1) << endl; // 3
}

直線の係数が単調減少でない場合

直線の追加
係数が単調減少の場合,直線は後ろに追加すると決まっていたので嬉しかった.
そうでない場合,追加する場所を二分探索木(Sとする)で探す.
Sは係数をキーとする.

Sに直線を追加したとき,その左右の直線が不要になれば,不要な分だけ削除する.

クエリ
Sにおけるどの直線がクエリに対して最小を達成するかを探すために,別の二分探索木(Cとする)を用意する.
Cにおけるキーは,Sにおいて隣接する直線の交点とする.

実装のコメント

  • 簡潔にするため,番兵となる直線を追加した.
  • doubleにしてもいいけど,全部整数型で扱うようにした.
  • イテレータを(+1,-1)動かす関数を用意した.
  • a,bの条件は |ab| < LLONG_MAX/4 くらいだと思うけどよく分からない.
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const ll INF = LLONG_MAX;

struct CHT2 {  
  CHT2() {
    // 番兵
    S.insert({L(INF,0), L(-INF,0)});
    C.insert(cp(L(INF,0),L(-INF,0)));
  }
  // for debug
  void print() {
    cout << "S : "; for (auto it : S) printf("(%lld,%lld)", it.a, it.b); puts("");
    cout << "C : "; for (auto it : C) printf("(%lld,%lld)", it.n, it.d); puts("");
  }
  // |ab| < LLONG_MAX/4 ???
  void add(ll a, ll b) {
    const L p(a,b);
    It pos = S.insert(p).first;
    if (check(*it_m1(pos), p, *it_p1(pos))) {
      // 直線(a,b)が不要
      S.erase(pos);
      return;
    }
    C.erase(cp(*it_m1(pos), *it_p1(pos)));
    {
      // 右方向の削除
      It it = it_m1(pos);
      while(it!=S.begin() && check(*it_m1(it), *it, p)) --it;
      C_erase(it, it_m1(pos));
      S.erase(++it,pos);
      pos = S.find(p);
    }
    {
      // 左方向の削除
      It it = it_p1(pos);
      while(it_p1(it)!=S.end() && check(p,*it, *it_p1(it))) ++it;
      C_erase(++pos, it);
      S.erase(pos, it);
      pos = S.find(p);
    }
    C.insert(cp(*it_m1(pos), *pos));
    C.insert(cp(*pos, *it_p1(pos)));
  }
  ll query(ll x) {
    const L &p = (--C.lower_bound(CP(x,1,L(0,0))))->p;
    return p.a*x + p.b;
  }
  
private:
  
  template<class T> T it_p1(T a) { return ++a; }
  template<class T> T it_m1(T a) { return --a; }  
  struct L {
    ll a, b;
    L(ll a, ll b) : a(a),b(b) {}
    bool operator<(const L &rhs) const {
      return a != rhs.a ? a > rhs.a : b < rhs.b;
    }
  };
  struct CP {
    ll n,d;
    L p;
    CP(ll _n, ll _d, const L &p) : n(_n),d(_d),p(p) {
      if (d < 0) { n *= -1; d *= -1; }
    };
    bool operator<(const CP &rhs) const {
      if (n == INF || rhs.n == -INF) return 0;
      if (n == -INF || rhs.n == INF) return 1;      
      return n * rhs.d < rhs.n * d;
    }
  };
  set<L> S;
  set<CP> C;

  typedef set<L>::iterator It;
  
  void C_erase(It a, It b) {
    for (It it=a; it!=b; ++it)
      C.erase(cp(*it, *it_p1(it)));
  }
  CP cp(const L &p1, const L &p2) {
    if (p1.a == INF) return CP(-INF,1,p2);
    if (p2.a == -INF) return CP(INF,1,p2);
    return CP(p1.b-p2.b, p2.a-p1.a, p2);
  }
  bool check(const L &p1, const L &p2, const L &p3) {
    if (p1.a==p2.a && p1.b <= p2.b) return 1;
    if (p1.a == INF || p3.a == -INF) return 0;
    return (p2.a-p1.a)*(p3.b-p2.b) >= (p2.b-p1.b)*(p3.a-p2.a);
  }
};

int main() {
  CHT2 cht;
  cht.add(2,2);
  cht.add(-1,4);
  cht.add(3,4);
  cout << cht.query(-1) << endl; // 0
  cht.print();
}

バグってたり,実装を改善できそうだったり,条件がない場合を使う必要のある問題があったりしたら教えてください.