OCamlでのダイクストラ法の実装を改良する
OCamlでAtCoderの問題を解くのに昔書いたダイクストラ法の実装を使っていたんですが, どうも定数倍でTLEすることが多く,もどかしい気持ちになることもしばしばでした. 本記事では行儀の良いスタイルに囚われず以前の実装に定数倍高速化を施し,実際のコンテストでの使用に堪える実装を得ることを目標とします.
隣接リストの表現を一般化
module WeightedDirectedGraph (Vertex : sig type t val compare : t -> t -> int end) (Weight : sig type t val zero : t val ( + ) : t -> t -> t val compare : t -> t -> int end) : sig val dijkstra : (* 隣接リスト *) (Vertex.t -> (Vertex.t * Weight.t) list) -> (* 始点 *) Vertex.t -> (* 始点から辿り着けなければNoneを返す *) (Vertex.t -> Weight.t option) end
ここでグラフの隣接リストが,始点を受け取って終点と辺の重みの二つ組のリストを返す関数として表現されていることに注目して下さい.
この表現は大変直感的で分かりやすいのですが,実際の競プロの問題を解く際には逆辺を張る必要があったり, 頂点を拡張してグラフを作り直す必要があったりと,入力された辺の情報をそのまま突っ込めないことが多いです. 我々はただ,ある頂点から伸びる辺のリストに対して畳み込みがしたいだけなのに,いちいち辺のリストを作り直してから畳み込みを行うのは非効率的ではないでしょうか?(融合変換したいですよね?)
そこで以下のように,隣接リストを渡すのではなく隣接リストに対しての畳み込みを渡すようにシグニチャを変更します.
module WeightedDirectedGraph (Vertex : sig type t val compare : t -> t -> int end) (Weight : sig type t val zero : t val ( + ) : t -> t -> t val compare : t -> t -> int end) : sig type 'a church_list = { fold : 'b. ('a -> 'b -> 'b) -> 'b -> 'b } val dijkstra : (* 隣接リスト *) (Vertex.t -> (Vertex.t * Weight.t) church_list) -> (* 始点 *) Vertex.t -> (* 始点から辿り着けなければNoneを返す *) (Vertex.t -> Weight.t option) end
要するにリストをチャーチエンコーディングする訳ですね. ちなみにランク2多相が必要になるので,OCamlだと一旦レコードに包まなくてはなりません.
シグニチャ変更後の実装は以下の通りです.
= struct module WMap = Map.Make (Weight) module VMap = Map.Make (Vertex) type 'a church_list = { fold : 'b. ('a -> 'b -> 'b) -> 'b -> 'b } let rec dijkstra_aux e (d, q) = match WMap.min_binding q with | exception Not_found -> d | (w, us) -> dijkstra_aux e @@ List.fold_left (fun (d, q) u -> if Weight.compare (VMap.find u d) w < 0 then (d, q) (* !!! uから伸びる辺のリストを畳み込む部分の実装が変わっている !!! *) else (e u).fold (fun (v, c) (d, q) -> let open Weight in if try Weight.compare (VMap.find v d) (w + c) <= 0 with Not_found -> false then (d, q) else VMap.add v (w + c) d, WMap.add (w + c) (v :: try WMap.find (w + c) q with Not_found -> []) q) (d, q)) (d, WMap.remove w q) us let dijkstra e s = let d = dijkstra_aux e (VMap.singleton s Weight.zero, WMap.singleton Weight.zero [s]) in fun t -> try Some (VMap.find t d) with Not_found -> None end
新しめのバージョンのOCamlで導入された関数の利用
最近AtCoderで使えるOCaml処理系が4.02.3から4.10.0に新しくなったので, 新しめのバージョンで導入された標準ライブラリの関数と言語機能を使って書き換えます.
高速化というよりリファクタリングですが,一部定数倍高速化に効果のある変更もあります.*2
= struct module WMap = Map.Make (Weight) module VMap = Map.Make (Vertex) type 'a church_list = { fold : 'b. ('a -> 'b -> 'b) -> 'b -> 'b } let rec dijkstra_aux e (d, q) = (* !!! 例外を使うよりはoptionを使う方が行儀が良い !!! *) match WMap.min_binding_opt q with | None -> d | Some (w, us) -> dijkstra_aux e @@ List.fold_left (fun (d, q) u -> if Weight.compare (VMap.find u d) w < 0 then (d, q) else (e u).fold (fun (v, c) (d, q) -> let open Weight in (* !!! 最近のOCamlはexceptionパターンの位置に融通が効く !!! *) match VMap.find v d with | x when Weight.compare x (w + c) <= 0 -> (d, q) | _ | exception Not_found -> VMap.add v (w + c) d, (* !!! findとaddを組み合わせて書いていた部分が,update一つで置き換えられる !!! *) WMap.update (w + c) (fun vs -> Some (v :: Option.value ~default:[] vs)) q) (d, q)) (d, WMap.remove w q) us let dijkstra e s = let d = dijkstra_aux e (VMap.singleton s Weight.zero, WMap.singleton Weight.zero [s]) in fun t -> try Some (VMap.find t d) with Not_found -> None end
2頂点間最短経路問題に対しての高速化
この節で話すのは,既にML Day#2で話したのと同じ内容です.
競プロでダイクストラ法を使う際,全ての頂点への最短距離が欲しい場合よりも,むしろ終点までの最短距離だけ欲しい場合の方が多いのではないでしょうか?ダイクストラ法では,ヒープから頂点を取り出した段階でまでの最短距離が確定するため,ヒープから終点を取り出したタイミングで計算を打ち切ってやれば速くなりそうです.
このアイデアに基づいて,以下のようにダイクストラ法の実装を変更します.
= struct module WMap = Map.Make (Weight) module VMap = Map.Make (Vertex) type 'a church_list = { fold : 'b. ('a -> 'b -> 'b) -> 'b -> 'b } let rec dijkstra_aux e t (d, q) = match WMap.min_binding_opt q with | None -> None | Some (w, us) -> (* !!! 終点までの距離が分かったので計算を切り上げる !!! *) if List.exists (( = ) t) us then Some w else dijkstra_aux e t @@ List.fold_left (fun (d, q) u -> if Weight.compare (VMap.find u d) w < 0 then (d, q) else (e u).fold (fun (v, c) (d, q) -> let open Weight in match VMap.find v d with | x when Weight.compare x (w + c) <= 0 -> (d, q) | _ | exception Not_found -> VMap.add v (w + c) d, WMap.update (w + c) (fun vs -> Some (v :: Option.value ~default:[] vs)) q) (d, q)) (d, WMap.remove w q) us let dijkstra e s t = dijkstra_aux e t (VMap.singleton s Weight.zero, WMap.singleton Weight.zero [s]) end
ちなみにこの変更を施すと,調子の良い時*3は無限グラフでも最短距離を求められるようになります. 関数型言語のオタクはこういう性質を見ると嬉しくなる.
もっとも,この変更は諸刃の剣でもあります. ある終点までの距離が分かる度に途中までの計算結果を捨てているので, この実装を使って全ての頂点までの最短経路を求めようとするとかかってしまうのです. 実装を使いわけろと言うのは簡単ですが,終点が一つの場合に定数倍速く, なおかつ全頂点までの最短距離が欲しい場合でも計算量のオーダーを落とさないことは可能でしょうか?
結論から言えば可能で,与えられた終点までの最短距離が分かった時点で結果を返し, 次の終点が与えられた時は前回最短距離を返したところから実行を再開できれば良さそうです. コルーチンとか限定継続があれば綺麗に書けそうですがOCaml標準ライブラリにそんなものはないので, 計算を再開するのに必要な状態を参照として外に出してやることで実装します.*4
= struct module WMap = Map.Make (Weight) module VMap = Map.Make (Vertex) type 'a church_list = { fold : 'b. ('a -> 'b -> 'b) -> 'b -> 'b } let rec dijkstra e s = (* !!! 計算の途中状態を参照として外に出す !!! *) let d = ref (VMap.singleton s Weight.zero) in let q = ref (WMap.singleton Weight.zero [s]) in let rec dijkstra_aux t = match WMap.min_binding_opt !q with (* !!! 既にダイクストラ法の実行が終わっていた場合は,dに最短距離が入っている !!! *) | None -> VMap.find_opt t !d | Some (w, us) -> match VMap.find t !d with (* !!! 新しく距離が確定した頂点より終点の方が近ければ, 終点までの距離は確定している !!! *) | x when Weight.compare x w <= 0 -> Some x | _ | exception Not_found -> q := WMap.remove w !q; List.iter (fun u -> if 0 <= Weight.compare (VMap.find u !d) w then (e u).fold (fun (v, c) () -> let open Weight in match VMap.find v !d with | d when Weight.compare d (w + c) <= 0 -> () | _ | exception Not_found -> d := VMap.add v (w + c) !d; q := WMap.update (w + c) (fun vs -> Some (v :: Option.value ~default:[] vs)) !q) ()) us; dijkstra_aux t in dijkstra_aux end
相変わらず無限グラフでも動いてくれるので嬉しいですね.
ここで,既に最短距離が確定している頂点へのクエリばかり行われた場合,!q
が変化していないのに毎回min_binding_opt
が呼び出されて効率が悪いことに気付きます.こういうのはメモ化してやりましょう.*5
= struct module WMap = Map.Make (Weight) module VMap = Map.Make (Vertex) type 'a church_list = { fold : 'b. ('a -> 'b -> 'b) -> 'b -> 'b } let rec dijkstra e s = let d = ref (VMap.singleton s Weight.zero) in let q = ref (WMap.singleton Weight.zero [s]) in (* !!! WMap.min_binding !qの結果をメモしておく !!! *) let min_binding_opt = ref (Some (Weight.zero, [s])) in let rec dijkstra_aux t = match !min_binding_opt with | None -> VMap.find_opt t !d | Some (w, us) -> match VMap.find t !d with | x when Weight.compare x w <= 0 -> Some x | _ | exception Not_found -> q := WMap.remove w !q; List.iter (fun u -> if 0 <= Weight.compare (VMap.find u !d) w then (e u).fold (fun (v, c) () -> let open Weight in match VMap.find v !d with | d when Weight.compare d (w + c) <= 0 -> () | _ | exception Not_found -> d := VMap.add v (w + c) !d; q := WMap.update (w + c) (fun vs -> Some (v :: Option.value ~default:[] vs)) !q) ()) us; (* !!! qが変更されたので,min_binding_optを更新 !!! *) min_binding_opt := WMap.min_binding_opt !q; dijkstra_aux t in dijkstra_aux end
データ構造の使い分け
計算量変わらんしええやろって感じで,今まではヒープだけでなく最短距離を格納するデータ構造にもMapを使ってたんですが*6, 実際に使ってみると無視できないぐらい遅いことが分かってきました.*7 しかしMapを用いたことで生じた,座圧が不要なばかりか様々な型を頂点としたグラフに対して直接最短距離を求められる利点も捨てがたい…
そこで,配列を用いた実装,ハッシュテーブルを用いた実装,Mapを用いた実装の三種類を用意し,その時々で使い分けられるようにします.
毎回書いてたらダルいのでファンクタで上手くやりますが.
クソ長くなったので変更後の実装は隠しておきます.
module WeightedDirectedGraph : sig (* 配列を用いたダイクストラ法の実装 単純な速さでは一番だが,インターフェースが不便 *) module ByArray : sig module Make : functor (Weight : sig type t val inf : t val zero : t val ( + ) : t -> t -> t val compare : t -> t -> int end) -> sig type 'a church_list = { fold : 'b. ('a -> 'b -> 'b) -> 'b -> 'b } (* 頂点を[0, n)の自然数に限定したグラフに対してのダイクストラ法 時間計算量O(E log E)なので,疎なグラフなら速い *) val dijkstra : (* 頂点数n *) int -> (* 隣接リスト *) (int -> (int * Weight.t) church_list) -> (* 始点 *) int -> (* 始点から辿り着けなければinfを返す関数 この関数を覚えておけば,呼び出しごとの途中までの計算結果がシェアされる *) (int -> Weight.t) end end (* ハッシュテーブルを用いたダイクストラ法の実装 配列を用いた実装よりは扱いやすいインターフェースを持つ ハッシュ関数を上手く選べば配列を用いた実装より1.5倍遅い程度ですむ *) module ByHashtbl : sig module Make : functor (* 頂点 *) (Vertex : Hashtbl.HashedType) (* 辺の重み *) (Weight : sig type t val zero : t val ( + ) : t -> t -> t val compare : t -> t -> int end) -> sig type 'a church_list = { fold : 'b. ('a -> 'b -> 'b) -> 'b -> 'b } val dijkstra : (* 頂点数(Hashtbl.tを用いるので目安程度) *) int -> (* 隣接リスト *) (Vertex.t -> (Vertex.t * Weight.t) church_list) -> (* 始点 *) Vertex.t -> (* 始点から辿り着けなければNot_foundを投げる関数 この関数を覚えておけば,呼び出しごとの途中までの計算結果がシェアされる *) (Vertex.t -> Weight.t) end end (* Mapを用いたダイクストラ法の実装 配列を用いた実装より4倍ぐらい遅いが, 一番扱いやすいインターフェースを持ち,無限グラフにも対応可能 *) module ByMap : sig module Make : functor (* 頂点 *) (Vertex : Map.OrderedType) (* 辺の重み *) (Weight : sig type t val zero : t val ( + ) : t -> t -> t val compare : t -> t -> int end) -> sig type 'a church_list = { fold : 'b. ('a -> 'b -> 'b) -> 'b -> 'b } val dijkstra : (* 隣接リスト *) (Vertex.t -> (Vertex.t * Weight.t) church_list) -> (* 始点 *) Vertex.t -> (* 始点から辿り着けなければNot_foundを投げる関数 この関数を覚えておけば,呼び出しごとの途中までの計算結果がシェアされる *) (Vertex.t -> Weight.t) end end end = struct module type Weight = sig type t val zero : t val ( + ) : t -> t -> t val compare : t -> t -> int end (* 最短距離を格納するデータ構造を抽象化したダイクストラ法の実装 *) module Core (W : Weight) (* グラフの頂点を添字とした配列 *) (VArray : sig type t type vertex (* グラフの頂点 *) val find : t -> vertex -> W.t (* 最短距離が格納されていなければNot_foundを投げる *) val update : t -> vertex -> W.t -> unit end) = struct type 'a church_list = { fold : 'b. ('a -> 'b -> 'b) -> 'b -> 'b } module WMap = Map.Make (W) let dijkstra d e s = VArray.update d s W.zero; let q = ref (WMap.singleton W.zero [s]) in (* 既に最短距離が確定した辺へのクエリを高速化するため, ヒープの最小要素をメモしておく *) let min_binding_opt = ref (Some (W.zero, [s])) in let rec dijkstra_aux t = match !min_binding_opt with (* もう既に全ての頂点までの距離が分かっている *) | None -> VArray.find d t | Some (w, us) -> match VArray.find d t with (* 既に終点までの距離が分かっているので返す *) | x when W.compare x w <= 0 -> x (* 終点までの距離が分かっていないので,ダイクストラ法を続行 *) | _ | exception Not_found -> q := WMap.remove w !q; Fun.flip List.iter us (fun u -> if 0 <= W.compare (VArray.find d u) w then (* 未だ頂点uを訪れていない *) Fun.flip (e u).fold () @@ fun (v, c) () -> let open W in match VArray.find d v with | d when W.compare d (w + c) <= 0 -> () | _ | exception Not_found -> VArray.update d v (w + c); q := WMap.update (w + c) (fun vs -> Some (v :: Option.value ~default:[] vs)) !q); min_binding_opt := WMap.min_binding_opt !q; dijkstra_aux t in dijkstra_aux end module ByArray = struct module Make (W : sig include Weight val inf : t end) = struct module C = Core (W) (struct type t = W.t array type vertex = int let find = Array.get let update = Array.set end) include C let dijkstra n e s = C.dijkstra (Array.make n W.inf) e s end end module ByHashtbl = struct module Make (V : Hashtbl.HashedType) (W : Weight) = struct module VHash = Hashtbl.Make (V) module C = Core (W) (struct type t = W.t VHash.t type vertex = V.t let find = VHash.find let update = VHash.replace end) include C let dijkstra n e s = C.dijkstra (VHash.create n) e s end end module ByMap = struct module Make (V : Map.OrderedType) (W : Weight) = struct module VMap = Map.Make (V) module C = Core (W) (struct type t = W.t VMap.t ref type vertex = V.t let find d v = VMap.find v !d let update d v w = d := VMap.add v w !d end) include C let dijkstra e s = C.dijkstra (ref VMap.empty) e s end end end
まとめ
以前のダイクストラ法の実装に対して定数倍高速化を施したほか,使い勝手を妥協して更に定数倍高速化を推し進めた実装も作り分けました. ここでのダイクストラ法の実装はGitHubにもアップロードされており,実際の使用例なども見ることができます.