fetburner.core

コアダンプ

OCamlでまともなダイクストラ法を実装する

遥か昔にもOCamlダイクストラ法を実装したことはあったんですが, 最も近い頂点を線形探索で求めている手抜き実装なので,O(V2)の計算時間を要する(a.k.a. 疎なグラフに対しては遅い)ものでした. 競プロだと入力が疎なグラフに限定されている場面が多々あって1, それを前提に制限時間が決まっているので結構つらいです. 破壊的代入を濫用しているのも気になってましたし,純粋関数型かつ疎なグラフに対して高速な実装を与えます.

解説のため,以前の実装を再掲します.

(*
 * qは未だ訪れていない頂点の集合
 * eは隣接リストによるグラフの表現
 * dは現時点での最短距離の入った配列
 *)
let rec dijkstra q e d =
  match q with
  | [] -> ()
  | u :: q ->
      (* 最も近い頂点uと,uをqから除いた集合を求める *)
      let u, q =
        List.fold_left (fun (u, q') u' ->
          if d.(u) < d.(u') then (u, u' :: q')
          else (u', u :: q')) (u, []) q in
      (* 最も近い頂点uに隣接する頂点への最短距離を更新 *)
      List.iter (fun (v, c) ->
        if c +. d.(u) < d.(v) then
          d.(v) <- c +. d.(u)) (e u);
      dijkstra q e d

見事に手続き的なコードですね. 最も近い頂点uを求める(加えて,qからuを取り除く)処理と,頂点uに隣接する頂点への最短距離を更新する処理がありますが, 前者はO(V2)の時間計算量2を要するのに対し後者はO(E)なので,明らかに前者がボトルネックになっています. 最も近い頂点uを求める処理とはつまり最小値を求める処理であり,最も近い頂点uをqから取り除く処理とはつまり最小値を取り除く処理ですから, ヒープを使えばボトルネックを解消できそうだと気付くことでしょう.

OCaml標準ライブラリに目を向けると, 一見ヒープの実装は存在しないように思えます. ですがご心配なく.有限集合の実装Setや有限写像の実装Mapには, 最小値や最大値をO(log N)3で求める関数が提供されているため,ヒープとして使えなくもないです4

では,Mapを用いて先ほどの実装を改造してみましょう5

module WeightMap = Map.Make (struct
  type t = float
  let compare = compare
end)


let rec dijkstra e d q =
  (* 最も近い頂点uを求める *)
  match WMap.min_binding q with
  | exception Not_found -> ()
  | (w, us) ->
      dijkstra e d @@
      (* 最も近い頂点は,複数存在しうる *)
      List.fold_left (fun q u ->
        if d.(u) < w
        (* 既に頂点uを訪れていた *)
        then q
        (* uから伸びる辺を見て,最短距離を更新 *)
        else List.fold_left (fun q (v, c) ->
          (* uへの最短距離はw *)
          if d.(v) <= w +. c
          then q
          else begin
            (* 最短距離を更新し,ヒープに突っ込む *)
            d.(v) <- w +. c;
            WMap.add (w +. c) (v :: try WMap.find (w +. c) q with Not_found -> []) q
          end) q (e u)) (WMap.remove w q) us

ここで,Mapは重複した要素を認めない点に注意が必要です. 最短距離が同じ頂点が出てきた時のために,最短距離から頂点のリストへの有限写像でヒープを表現しています.

計算量を評価してみましょう. Mapではヒープで言うところのdecrease-keyに相当する操作を効率的に実装しようがないので, 最短距離が更新された際は何も考えずにヒープに追加しています. なのでヒープの要素数はO(E)となり,最も近い頂点uを求める処理(とuをqから取り除く処理)の時間計算量はO(E log E)ですね. では最短距離を更新する処理はというと,ヒープに追加する処理がO(log E)なのでO(E log E)に増えてしまいます. よってこの実装の計算量はO(E log E)です.

ヒープに追加する処理のせいでlog Eが付いてしまっているので, 実は最短距離を保持するデータ構造を配列からMapに変えても大して計算量が変わらなかったりします. ヒープを取り回す部分で若干関数型っぽくなってますし,この際に配列を削除して純粋関数型な実装にしてしまいましょう.

module VMap = Map.Make (struct
  type t = int
  let compare = compare
end)

module WMap = Map.Make (struct
  type t = float
  let compare = compare
end)

(* d に入っていない頂点への距離は無限大とみなす *)
let rec dijkstra e (d, q) =
  (* 最も近い頂点uを求める *)
  match WMap.min_binding q with
  | exception Not_found -> d
  | (w, us) ->
      dijkstra e @@
      (* 最も近い頂点は,複数存在しうる *)
      List.fold_left (fun (d, q) u ->
        if VMap.find u d < w
        (* 既に頂点uを訪れていた *)
        then (d, q)
        else List.fold_left (fun (d, q) (v, c) ->
          (* uへの最短距離はw *)
          if
            try VMap.find v d <= w +. c
            with Not_found -> false (* d.(v) は無限大 *)
          then (d, q)
          else
            (* 最短距離を更新し,ヒープに突っ込む *)
            VMap.add v (w +. c) d,
            WMap.add (w +. c) (v :: try WMap.find (w +. c) q with Not_found -> []) q)
        (d, q) (e u)) (d, WMap.remove w q) us

計算量を評価してみます. とは言ってもそれほど大きく変化する訳ではなくて,最短距離の読み書きがO(1)からO(log V)になるだけですね. なので最短距離の更新がO(E log V + E log E)になり,全体を通しての計算量はO(E log E + E log V)…まぁ大体O(E log V)で良いんじゃないでしょうか.

説明のために頂点はint型,辺の長さはfloat型に固定していましたが, ファンクタを用いれば汎用的な実装が得られます.

module WeightedDirectedGraph
  (Vertex : sig
    type t
    val compare : t -> t -> int
  end)
  (Weight : sig
    type t
    val zero : t
    val ( + ) : t -> t -> t
    val compare : t -> t -> int
  end) :
sig
  val dijkstra :
    (* 隣接リスト *)
    (Vertex.t -> (Vertex.t * Weight.t) list) ->
    (* 始点 *)
    Vertex.t ->
    (* 始点から辿り着けなければNoneを返す *)
    (Vertex.t -> Weight.t option)
end =
struct
  module WMap = Map.Make (Weight)
  module VMap = Map.Make (Vertex)

  (* ダイクストラ法のメインループ *)
  (* d に入っていない頂点への距離は無限大とみなす *)
  let rec dijkstra_aux e (d, q) =
    (* 最も近い頂点uを求める *)
    match WMap.min_binding q with
    | exception Not_found -> d
    | (w, us) ->
        dijkstra_aux e @@ List.fold_left (fun (d, q) u ->
        (* 最も近い頂点は,複数存在しうる *)
          if Weight.compare (VMap.find u d) w < 0
          (* 既に頂点uを訪れていた *)
          then (d, q)
          else List.fold_left (fun (d, q) (v, c) ->
            let open Weight in
            (* uへの最短距離はw *)
            if
              try Weight.compare (VMap.find v d) (w + c) <= 0
              with Not_found -> false (* d.(u) は無限大 *)
            then (d, q)
            else
              (* 最短距離を更新し,ヒープに突っ込む *)
              VMap.add v (w + c) d,
              WMap.add (w + c) (v :: try WMap.find (w + c) q with Not_found -> []) q)
          (d, q) (e u)) (d, WMap.remove w q) us

  let dijkstra e s =
    let d =
      dijkstra_aux e
        (VMap.singleton s Weight.zero, WMap.singleton Weight.zero [s]) in
    fun t -> try Some (VMap.find t d) with Not_found -> None
end

使い方はこの通り.日本語版Wikipediaと同じ例で最短距離を求めています.

# module G = WeightedDirectedGraph
  (struct
    type t = int
    let compare = compare
  end)
  (struct
    type t = float
    let zero = 0.
    let ( + ) = ( +. )
    let compare = compare
  end);;

module G :
  sig
    val dijkstra : (int -> (int * float) list) -> int -> int -> float option
  end

# Array.init 6 (G.dijkstra [(function
      | 0 -> [ (1, 7.); (2, 9.); (5, 14.) ]
      | 1 -> [ (0, 7.); (2, 10.); (3, 15.) ]
      | 2 -> [ (0, 9.); (1, 10.); (3, 11.); (5, 2.) ]
      | 3 -> [ (1, 15.); (2, 11.); (4, 6.) ]
      | 4 -> [ (3, 6.); (5, 9.) ]
      | 5 -> [ (0, 14.); (2, 2.); (4, 9.) ]) 0);;

Warning 8: this pattern-matching is not exhaustive.
Here is an example of a value that is not matched:
6
- : float option array =
[|Some 0.; Some 7.; Some 9.; Some 20.; Some 20.; Some 11.|]

Mapを直接返すのではなく,終点を受け取って距離を返す関数を返しているのは, 実装を隠蔽したいみたいな気持ちがあったりします.

辺の長さを工夫してやれば経路とかも手に入ります. 2020年5月19日追記 この方法だと重さが同じ頂点への経路が全て同一視されてしまうので,最短距離が同じ頂点が複数あるとうまく動きません.

# module G = WeightedDirectedGraph
  (struct
    type t = int
    let compare = compare
  end)
  (struct
    type t = float * (string list -> string list)
    let zero = (0., fun xs -> xs)
    let ( + ) (c, f) (d, g) = (c +. d, fun xs -> f (g xs))
    let compare (c, _) (d, _) = compare c d
  end);;

module G :
  sig
    val dijkstra :
      (int -> (int * (float * (string list -> string list))) list) ->
      int -> int -> (float * (string list -> string list)) option
  end

# let e =
  Array.mapi (fun u ->
    List.map (fun (v, c) ->
      let s = Printf.sprintf "%d->%d" u v in
      (v, (c, fun xs -> s :: xs))))
  [|[ (1, 7.); (2, 9.); (5, 14.) ];
    [ (0, 7.); (2, 10.); (3, 15.) ];
    [ (0, 9.); (1, 10.); (3, 11.); (5, 2.) ];
    [ (1, 15.); (2, 11.); (4, 6.) ];
    [ (3, 6.); (5, 9.) ];
    [ (0, 14.); (2, 2.); (4, 9.) ]|];;

val e : (int * (float * (string list -> string list))) list array =
  [|[(1, (7., <fun>)); (2, (9., <fun>)); (5, (14., <fun>))];
    [(0, (7., <fun>)); (2, (10., <fun>)); (3, (15., <fun>))];
    [(0, (9., <fun>)); (1, (10., <fun>)); (3, (11., <fun>));
     (5, (2., <fun>))];
    [(1, (15., <fun>)); (2, (11., <fun>)); (4, (6., <fun>))];
    [(3, (6., <fun>)); (5, (9., <fun>))];
    [(0, (14., <fun>)); (2, (2., <fun>)); (4, (9., <fun>))]|]

# Array.map (fun (Some (c, f)) -> (c, f [])) @@ Array.init 6 (G.dijkstra (fun u -> e.(u)) 0)

Warning 8: this pattern-matching is not exhaustive.
Here is an example of a value that is not matched:
None
- : (float * string list) array =
[|(0., []); (7., ["0->1"]); (9., ["0->2"]); (20., ["0->2"; "2->3"]);
  (20., ["0->2"; "2->5"; "5->4"]); (11., ["0->2"; "2->5"])|]

計算量を落とすために差分リストを使っていて読みづらいですが…

ちなみに,ダイクストラ法と同じ要領でプリム法も純粋関数型に書けます.

module WeightedGraph
  (Vertex : sig
    type t
    val compare : t -> t -> int
  end)
  (Weight : sig
    type t
    val compare : t -> t -> int
  end) :
sig
  val prim :
    (* 隣接リスト *)
    (Vertex.t -> (Vertex.t * Weight.t) list) ->
    (* 始点 *)
    Vertex.t ->
    (* 最小全域木に含まれる辺のリスト *)
    (Vertex.t * Vertex.t * Weight.t) list
end =
struct
  module VSet = Set.Make (Vertex)
  module WMap = Map.Make (Weight)

  (*
   * プリム法のメインループ
   * es : 隣接リスト
   * vs : 訪れた頂点の集合
   * q : 訪れた頂点から伸びる辺が重み順に入ったヒープ
   * acc : 最小全域木に使うのが確定した辺を入れるやつ
   *)
  let rec prim_aux es acc vs q =
    match WMap.min_binding q with
    | exception Not_found -> acc
    | (w, []) -> prim_aux es acc vs (WMap.remove w q)
    | (w, (u, v) :: rest) ->
        if VSet.mem v vs then
          (* vは既に訪れていた *)
          prim_aux es acc vs (WMap.add w rest q)
        else
          (* vはまだ訪れていなかった *)
          prim_aux es ((u, v, w) :: acc) (VSet.add v vs) @@
            (* vから伸びる辺をキューに追加 *)
            List.fold_left (fun q (u, w) ->
              (* 現時点で既に訪れている頂点への辺は追加しない *)
              if VSet.mem u vs then q
              else WMap.add w ((v, u) :: try WMap.find w q with Not_found -> []) q) (WMap.add w rest q) (es v)

  let prim es s =
    prim_aux es [] (VSet.singleton s) @@
      (* 始点から伸びる辺をキューに入れておく *)
      List.fold_left (fun q (v, w) ->
        WMap.add w ((s, v) :: try WMap.find w q with Not_found -> []) q) WMap.empty (es s)
end

まぁ,副作用使いまくってクラスカル法を実装した方が速いですが…


  1. e.g. https://beta.atcoder.jp/contests/abc035/submissions/676838

  2. 頂点1つあたり(つまり,再帰一回あたり)の計算量ではなく,頂点全てを処理した際の(つまり,dijkstra再帰全体を通しての)計算量です

  3. AVL木の変種で実装されているため

  4. 真面目なヒープだと最小値はO(1)で求められると思いますが,まぁ最小値の削除にO(log N)を要するので大して変わらないでしょう

  5. 最短距離を求めるだけでなく,どの頂点への距離なのかも知りたいのでMapを使っています