fetburner.core

コアダンプ

OCamlでのダイクストラ法の実装を改良する

OCamlAtCoderの問題を解くのに昔書いたダイクストラ法の実装を使っていたんですが, どうも定数倍でTLEすることが多く,もどかしい気持ちになることもしばしばでした. 本記事では行儀の良いスタイルに囚われず以前の実装に定数倍高速化を施し,実際のコンテストでの使用に堪える実装を得ることを目標とします.

隣接リストの表現を一般化

以前のダイクストラ法の実装の型*1を見てみましょう.

module WeightedDirectedGraph
  (Vertex : sig
    type t
    val compare : t -> t -> int
  end)
  (Weight : sig
    type t
    val zero : t
    val ( + ) : t -> t -> t
    val compare : t -> t -> int
  end) :
sig
  val dijkstra :
    (* 隣接リスト *)
    (Vertex.t -> (Vertex.t * Weight.t) list) ->
    (* 始点 *)
    Vertex.t ->
    (* 始点から辿り着けなければNoneを返す *)
    (Vertex.t -> Weight.t option)
end

ここでグラフの隣接リストが,始点を受け取って終点と辺の重みの二つ組のリストを返す関数として表現されていることに注目して下さい.

この表現は大変直感的で分かりやすいのですが,実際の競プロの問題を解く際には逆辺を張る必要があったり, 頂点を拡張してグラフを作り直す必要があったりと,入力された辺の情報をそのまま突っ込めないことが多いです. 我々はただ,ある頂点から伸びる辺のリストに対して畳み込みがしたいだけなのに,いちいち辺のリストを作り直してから畳み込みを行うのは非効率的ではないでしょうか?(融合変換したいですよね?)

そこで以下のように,隣接リストを渡すのではなく隣接リストに対しての畳み込みを渡すようにシグニチャを変更します.

module WeightedDirectedGraph
  (Vertex : sig
    type t
    val compare : t -> t -> int
  end)
  (Weight : sig
    type t
    val zero : t
    val ( + ) : t -> t -> t
    val compare : t -> t -> int
  end) :
sig
  type 'a church_list = { fold : 'b. ('a -> 'b -> 'b) -> 'b -> 'b }

  val dijkstra :
    (* 隣接リスト *)
    (Vertex.t -> (Vertex.t * Weight.t) church_list) ->
    (* 始点 *)
    Vertex.t ->
    (* 始点から辿り着けなければNoneを返す *)
    (Vertex.t -> Weight.t option)
end

要するにリストをチャーチエンコーディングする訳ですね. ちなみにランク2多相が必要になるので,OCamlだと一旦レコードに包まなくてはなりません.

シグニチャ変更後の実装は以下の通りです.

= struct
  module WMap = Map.Make (Weight)
  module VMap = Map.Make (Vertex)

  type 'a church_list = { fold : 'b. ('a -> 'b -> 'b) -> 'b -> 'b }

  let rec dijkstra_aux e (d, q) =
    match WMap.min_binding q with
    | exception Not_found -> d
    | (w, us) ->
        dijkstra_aux e @@ List.fold_left (fun (d, q) u ->
          if Weight.compare (VMap.find u d) w < 0
          then (d, q)
          (* !!! uから伸びる辺のリストを畳み込む部分の実装が変わっている !!! *)
          else (e u).fold (fun (v, c) (d, q) ->
            let open Weight in
            if
              try Weight.compare (VMap.find v d) (w + c) <= 0
              with Not_found -> false
            then (d, q)
            else
              VMap.add v (w + c) d,
              WMap.add (w + c) (v :: try WMap.find (w + c) q with Not_found -> []) q)
          (d, q)) (d, WMap.remove w q) us

  let dijkstra e s =
    let d =
      dijkstra_aux e
        (VMap.singleton s Weight.zero, WMap.singleton Weight.zero [s]) in
    fun t -> try Some (VMap.find t d) with Not_found -> None
end

新しめのバージョンのOCamlで導入された関数の利用

最近AtCoderで使えるOCaml処理系が4.02.3から4.10.0に新しくなったので, 新しめのバージョンで導入された標準ライブラリの関数と言語機能を使って書き換えます.

高速化というよりリファクタリングですが,一部定数倍高速化に効果のある変更もあります.*2

= struct
  module WMap = Map.Make (Weight)
  module VMap = Map.Make (Vertex)

  type 'a church_list = { fold : 'b. ('a -> 'b -> 'b) -> 'b -> 'b }

  let rec dijkstra_aux e (d, q) =
    (* !!! 例外を使うよりはoptionを使う方が行儀が良い !!! *)
    match WMap.min_binding_opt q with
    | None -> d
    | Some (w, us) ->
        dijkstra_aux e @@ List.fold_left (fun (d, q) u ->
          if Weight.compare (VMap.find u d) w < 0
          then (d, q)
          else (e u).fold (fun (v, c) (d, q) ->
            let open Weight in
            (* !!! 最近のOCamlはexceptionパターンの位置に融通が効く !!! *)
            match VMap.find v d with
            | x when Weight.compare x (w + c) <= 0 -> (d, q)
            | _ | exception Not_found ->
                VMap.add v (w + c) d,
                (* !!! findとaddを組み合わせて書いていた部分が,update一つで置き換えられる !!!  *)
                WMap.update (w + c) (fun vs -> Some (v :: Option.value ~default:[] vs)) q)
          (d, q)) (d, WMap.remove w q) us

  let dijkstra e s =
    let d =
      dijkstra_aux e
        (VMap.singleton s Weight.zero, WMap.singleton Weight.zero [s]) in
    fun t -> try Some (VMap.find t d) with Not_found -> None
end

2頂点間最短経路問題に対しての高速化

この節で話すのは,既にML Day#2で話したのと同じ内容です.

競プロでダイクストラ法を使う際,全ての頂点への最短距離が欲しい場合よりも,むしろ終点tまでの最短距離だけ欲しい場合の方が多いのではないでしょうか?ダイクストラ法では,ヒープから頂点vを取り出した段階でvまでの最短距離が確定するため,ヒープから終点tを取り出したタイミングで計算を打ち切ってやれば速くなりそうです.

このアイデアに基づいて,以下のようにダイクストラ法の実装を変更します.

= struct
  module WMap = Map.Make (Weight)
  module VMap = Map.Make (Vertex)

  type 'a church_list = { fold : 'b. ('a -> 'b -> 'b) -> 'b -> 'b }

  let rec dijkstra_aux e t (d, q) =
    match WMap.min_binding_opt q with
    | None -> None
    | Some (w, us) ->
        (* !!! 終点までの距離が分かったので計算を切り上げる !!! *)
        if List.exists (( = ) t) us then Some w else
        dijkstra_aux e t @@ List.fold_left (fun (d, q) u ->
          if Weight.compare (VMap.find u d) w < 0
          then (d, q)
          else (e u).fold (fun (v, c) (d, q) ->
            let open Weight in
            match VMap.find v d with
            | x when Weight.compare x (w + c) <= 0 -> (d, q)
            | _ | exception Not_found ->
                VMap.add v (w + c) d,
                WMap.update (w + c) (fun vs -> Some (v :: Option.value ~default:[] vs)) q)
          (d, q)) (d, WMap.remove w q) us

  let dijkstra e s t =
    dijkstra_aux e t
        (VMap.singleton s Weight.zero, WMap.singleton Weight.zero [s])
end

ちなみにこの変更を施すと,調子の良い時*3は無限グラフでも最短距離を求められるようになります. 関数型言語のオタクはこういう性質を見ると嬉しくなる.

もっとも,この変更は諸刃の剣でもあります. ある終点までの距離が分かる度に途中までの計算結果を捨てているので, この実装を使って全ての頂点までの最短経路を求めようとするとO(VE \log E)かかってしまうのです. 実装を使いわけろと言うのは簡単ですが,終点が一つの場合に定数倍速く, なおかつ全頂点までの最短距離が欲しい場合でも計算量のオーダーを落とさないことは可能でしょうか?

結論から言えば可能で,与えられた終点までの最短距離が分かった時点で結果を返し, 次の終点が与えられた時は前回最短距離を返したところから実行を再開できれば良さそうです. コルーチンとか限定継続があれば綺麗に書けそうですがOCaml標準ライブラリにそんなものはないので, 計算を再開するのに必要な状態を参照として外に出してやることで実装します.*4

= struct
  module WMap = Map.Make (Weight)
  module VMap = Map.Make (Vertex)

  type 'a church_list = { fold : 'b. ('a -> 'b -> 'b) -> 'b -> 'b }

  let rec dijkstra e s =
    (* !!! 計算の途中状態を参照として外に出す !!! *)
    let d = ref (VMap.singleton s Weight.zero) in
    let q = ref (WMap.singleton Weight.zero [s]) in
    let rec dijkstra_aux t =
      match WMap.min_binding_opt !q with
      (* !!! 既にダイクストラ法の実行が終わっていた場合は,dに最短距離が入っている !!! *)
      | None -> VMap.find_opt t !d
      | Some (w, us) ->
          match VMap.find t !d with
          (* !!! 新しく距離が確定した頂点より終点の方が近ければ,
                 終点までの距離は確定している !!! *)
          | x when Weight.compare x w <= 0 -> Some x
          | _ | exception Not_found ->
              q := WMap.remove w !q;
              List.iter (fun u ->
                if 0 <= Weight.compare (VMap.find u !d) w then
                  (e u).fold (fun (v, c) () ->
                    let open Weight in
                    match VMap.find v !d with
                    | d when Weight.compare d (w + c) <= 0 -> ()
                    | _ | exception Not_found ->
                        d := VMap.add v (w + c) !d;
                        q := WMap.update (w + c) (fun vs -> Some (v :: Option.value ~default:[] vs)) !q) ()) us;
              dijkstra_aux t in
    dijkstra_aux
end

相変わらず無限グラフでも動いてくれるので嬉しいですね.

ここで,既に最短距離が確定している頂点へのクエリばかり行われた場合,!qが変化していないのに毎回min_binding_optが呼び出されて効率が悪いことに気付きます.こういうのはメモ化してやりましょう.*5

= struct
  module WMap = Map.Make (Weight)
  module VMap = Map.Make (Vertex)

  type 'a church_list = { fold : 'b. ('a -> 'b -> 'b) -> 'b -> 'b }

  let rec dijkstra e s =
    let d = ref (VMap.singleton s Weight.zero) in
    let q = ref (WMap.singleton Weight.zero [s]) in
    (* !!! WMap.min_binding !qの結果をメモしておく !!! *)
    let min_binding_opt = ref (Some (Weight.zero, [s])) in
    let rec dijkstra_aux t =
      match !min_binding_opt with
      | None -> VMap.find_opt t !d
      | Some (w, us) ->
          match VMap.find t !d with
          | x when Weight.compare x w <= 0 -> Some x
          | _ | exception Not_found ->
              q := WMap.remove w !q;
              List.iter (fun u ->
                if 0 <= Weight.compare (VMap.find u !d) w then
                  (e u).fold (fun (v, c) () ->
                    let open Weight in
                    match VMap.find v !d with
                    | d when Weight.compare d (w + c) <= 0 -> ()
                    | _ | exception Not_found ->
                        d := VMap.add v (w + c) !d;
                        q := WMap.update (w + c) (fun vs -> Some (v :: Option.value ~default:[] vs)) !q) ()) us;
              (* !!! qが変更されたので,min_binding_optを更新 !!! *)
              min_binding_opt := WMap.min_binding_opt !q;
              dijkstra_aux t in
    dijkstra_aux
end

データ構造の使い分け

計算量変わらんしええやろって感じで,今まではヒープだけでなく最短距離を格納するデータ構造にもMapを使ってたんですが*6, 実際に使ってみると無視できないぐらい遅いことが分かってきました.*7 しかしMapを用いたことで生じた,座圧が不要なばかりか様々な型を頂点としたグラフに対して直接最短距離を求められる利点も捨てがたい…

そこで,配列を用いた実装,ハッシュテーブルを用いた実装,Mapを用いた実装の三種類を用意し,その時々で使い分けられるようにします. 毎回書いてたらダルいのでファンクタで上手くやりますが.

クソ長くなったので変更後の実装は隠しておきます.

module WeightedDirectedGraph
: sig
  (* 配列を用いたダイクストラ法の実装
     単純な速さでは一番だが,インターフェースが不便 *)
  module ByArray : sig
    module Make :
      functor (Weight : sig
        type t
        val inf : t
        val zero : t
        val ( + ) : t -> t -> t
        val compare : t -> t -> int
      end) ->
      sig
        type 'a church_list = { fold : 'b. ('a -> 'b -> 'b) -> 'b -> 'b }

        (* 頂点を[0, n)の自然数に限定したグラフに対してのダイクストラ法
           時間計算量O(E log E)なので,疎なグラフなら速い *)
        val dijkstra :
        (* 頂点数n *)
        int ->
        (* 隣接リスト *)
        (int -> (int * Weight.t) church_list) ->
        (* 始点 *)
        int ->
        (* 始点から辿り着けなければinfを返す関数
           この関数を覚えておけば,呼び出しごとの途中までの計算結果がシェアされる *)
        (int -> Weight.t)
      end
  end

  (* ハッシュテーブルを用いたダイクストラ法の実装
     配列を用いた実装よりは扱いやすいインターフェースを持つ
     ハッシュ関数を上手く選べば配列を用いた実装より1.5倍遅い程度ですむ *)
  module ByHashtbl : sig
    module Make :
      functor
      (* 頂点 *)
      (Vertex : Hashtbl.HashedType)
      (* 辺の重み *)
      (Weight : sig
        type t
        val zero : t
        val ( + ) : t -> t -> t
        val compare : t -> t -> int
      end) ->
      sig
        type 'a church_list = { fold : 'b. ('a -> 'b -> 'b) -> 'b -> 'b }

        val dijkstra :
        (* 頂点数(Hashtbl.tを用いるので目安程度) *)
        int ->
        (* 隣接リスト *)
        (Vertex.t -> (Vertex.t * Weight.t) church_list) ->
        (* 始点 *)
        Vertex.t ->
        (* 始点から辿り着けなければNot_foundを投げる関数
           この関数を覚えておけば,呼び出しごとの途中までの計算結果がシェアされる *)
        (Vertex.t -> Weight.t)
      end
  end

  (* Mapを用いたダイクストラ法の実装
     配列を用いた実装より4倍ぐらい遅いが,
     一番扱いやすいインターフェースを持ち,無限グラフにも対応可能 *)
  module ByMap : sig
    module Make :
      functor
      (* 頂点 *)
      (Vertex : Map.OrderedType)
      (* 辺の重み *)
      (Weight : sig
        type t
        val zero : t
        val ( + ) : t -> t -> t
        val compare : t -> t -> int
      end) ->
      sig
        type 'a church_list = { fold : 'b. ('a -> 'b -> 'b) -> 'b -> 'b }

        val dijkstra :
        (* 隣接リスト *)
        (Vertex.t -> (Vertex.t * Weight.t) church_list) ->
        (* 始点 *)
        Vertex.t ->
        (* 始点から辿り着けなければNot_foundを投げる関数
           この関数を覚えておけば,呼び出しごとの途中までの計算結果がシェアされる *)
        (Vertex.t -> Weight.t)
      end
  end
end
= struct
  module type Weight = sig
    type t
    val zero : t
    val ( + ) : t -> t -> t
    val compare : t -> t -> int
  end

  (* 最短距離を格納するデータ構造を抽象化したダイクストラ法の実装 *)
  module Core
    (W : Weight)
    (* グラフの頂点を添字とした配列 *)
    (VArray : sig
      type t
      type vertex (* グラフの頂点 *)
      val find : t -> vertex -> W.t (* 最短距離が格納されていなければNot_foundを投げる *)
      val update : t -> vertex -> W.t -> unit
    end) =
  struct
    type 'a church_list = { fold : 'b. ('a -> 'b -> 'b) -> 'b -> 'b }

    module WMap = Map.Make (W)

    let dijkstra d e s =
      VArray.update d s W.zero;
      let q = ref (WMap.singleton W.zero [s]) in
      (* 既に最短距離が確定した辺へのクエリを高速化するため,
         ヒープの最小要素をメモしておく *)
      let min_binding_opt = ref (Some (W.zero, [s])) in
      let rec dijkstra_aux t =
        match !min_binding_opt with
        (* もう既に全ての頂点までの距離が分かっている *)
        | None -> VArray.find d t
        | Some (w, us) ->
            match VArray.find d t with
            (* 既に終点までの距離が分かっているので返す *)
            | x when W.compare x w <= 0 -> x
            (* 終点までの距離が分かっていないので,ダイクストラ法を続行 *)
            | _ | exception Not_found ->
                q := WMap.remove w !q;
                Fun.flip List.iter us (fun u ->
                  if 0 <= W.compare (VArray.find d u) w then
                    (* 未だ頂点uを訪れていない *)
                    Fun.flip (e u).fold () @@ fun (v, c) () ->
                      let open W in
                      match VArray.find d v with
                      | d when W.compare d (w + c) <= 0 -> ()
                      | _ | exception Not_found ->
                          VArray.update d v (w + c);
                          q := WMap.update (w + c) (fun vs -> Some (v :: Option.value ~default:[] vs)) !q);
                min_binding_opt := WMap.min_binding_opt !q;
                dijkstra_aux t in
      dijkstra_aux
  end

  module ByArray = struct
    module Make (W : sig include Weight val inf : t end) = struct
      module C = Core (W) (struct
        type t = W.t array
        type vertex = int
        let find = Array.get
        let update = Array.set
      end)
      include C

      let dijkstra n e s = C.dijkstra (Array.make n W.inf) e s
    end
  end

  module ByHashtbl = struct
    module Make (V : Hashtbl.HashedType) (W : Weight) = struct
      module VHash = Hashtbl.Make (V)
      module C = Core (W) (struct
        type t = W.t VHash.t
        type vertex = V.t
        let find = VHash.find
        let update = VHash.replace
      end)
      include C

      let dijkstra n e s = C.dijkstra (VHash.create n) e s
    end
  end

  module ByMap = struct
    module Make (V : Map.OrderedType) (W : Weight) = struct
      module VMap = Map.Make (V)
      module C = Core (W) (struct
        type t = W.t VMap.t ref
        type vertex = V.t
        let find d v = VMap.find v !d
        let update d v w = d := VMap.add v w !d
      end)
      include C

      let dijkstra e s = C.dijkstra (ref VMap.empty) e s
    end
  end
end

まとめ

以前のダイクストラ法の実装に対して定数倍高速化を施したほか,使い勝手を妥協して更に定数倍高速化を推し進めた実装も作り分けました. ここでのダイクストラ法の実装はGitHubにもアップロードされており,実際の使用例なども見ることができます.

*1:正確にはダイクストラ法の実装が入っているファンクタのシグニチャ?

*2:従来はヒープに見立てたMapに要素を追加する際にfindとaddの二つの関数が必要だったが,新しく導入されたupdateを使えば一つの関数で行える.Mapの実装的にも二分木の検索を二回行なっていた部分が一回の検索を行うだけになって良い感じ.

*3:終点tより近い頂点が有限個しかない時

*4:遂に関数プログラミングを諦めてしまった…

*5:こいついっつも手続き型プログラミングしてんな

*6:バカ

*7:配列で実装したものより4倍ぐらい遅い