fetburner.core

コアダンプ

OCamlでまともなダイクストラ法を実装する

遥か昔にもOCamlダイクストラ法を実装したことはあったんですが, 最も近い頂点を線形探索で求めている手抜き実装なので,O(V2)の計算時間を要する(a.k.a. 疎なグラフに対しては遅い)ものでした. 競プロだと入力が疎なグラフに限定されている場面が多々あって1, それを前提に制限時間が決まっているので結構つらいです. 破壊的代入を濫用しているのも気になってましたし,純粋関数型かつ疎なグラフに対して高速な実装を与えます.

解説のため,以前の実装を再掲します.

(*
 * qは未だ訪れていない頂点の集合
 * eは隣接リストによるグラフの表現
 * dは現時点での最短距離の入った配列
 *)
let rec dijkstra q e d =
  match q with
  | [] -> ()
  | u :: q ->
      (* 最も近い頂点uと,uをqから除いた集合を求める *)
      let u, q =
        List.fold_left (fun (u, q') u' ->
          if d.(u) < d.(u') then (u, u' :: q')
          else (u', u :: q')) (u, []) q in
      (* 最も近い頂点uに隣接する頂点への最短距離を更新 *)
      List.iter (fun (v, c) ->
        if c +. d.(u) < d.(v) then
          d.(v) <- c +. d.(u)) (e u);
      dijkstra q e d

見事に手続き的なコードですね. 最も近い頂点uを求める(加えて,qからuを取り除く)処理と,頂点uに隣接する頂点への最短距離を更新する処理がありますが, 前者はO(V2)の時間計算量2を要するのに対し後者はO(E)なので,明らかに前者がボトルネックになっています. 最も近い頂点uを求める処理とはつまり最小値を求める処理であり,最も近い頂点uをqから取り除く処理とはつまり最小値を取り除く処理ですから, ヒープを使えばボトルネックを解消できそうだと気付くことでしょう.

OCaml標準ライブラリに目を向けると, 一見ヒープの実装は存在しないように思えます. ですがご心配なく.有限集合の実装Setや有限写像の実装Mapには, 最小値や最大値をO(log N)3で求める関数が提供されているため,ヒープとして使えなくもないです4

では,Mapを用いて先ほどの実装を改造してみましょう5

module WeightMap = Map.Make (struct
  type t = float
  let compare = compare
end)


let rec dijkstra e d q =
  (* 最も近い頂点uを求める *)
  match WMap.min_binding q with
  | exception Not_found -> ()
  | (w, us) ->
      dijkstra e d @@
      (* 最も近い頂点は,複数存在しうる *)
      List.fold_left (fun q u ->
        if d.(u) < w
        (* 既に頂点uを訪れていた *)
        then q
        (* uから伸びる辺を見て,最短距離を更新 *)
        else List.fold_left (fun q (v, c) ->
          (* uへの最短距離はw *)
          if d.(v) <= w +. c
          then q
          else begin
            (* 最短距離を更新し,ヒープに突っ込む *)
            d.(v) <- w +. c;
            WMap.add (w +. c) (v :: try WMap.find (w +. c) q with Not_found -> []) q
          end) q (e u)) (WMap.remove w q) us

ここで,Mapは重複した要素を認めない点に注意が必要です. 最短距離が同じ頂点が出てきた時のために,最短距離から頂点のリストへの有限写像でヒープを表現しています.

計算量を評価してみましょう. Mapではヒープで言うところのdecrease-keyに相当する操作を効率的に実装しようがないので, 最短距離が更新された際は何も考えずにヒープに追加しています. なのでヒープの要素数はO(E)となり,最も近い頂点uを求める処理(とuをqから取り除く処理)の時間計算量はO(E log E)ですね. では最短距離を更新する処理はというと,ヒープに追加する処理がO(log E)なのでO(E log E)に増えてしまいます. よってこの実装の計算量はO(E log E)です.

ヒープに追加する処理のせいでlog Eが付いてしまっているので, 実は最短距離を保持するデータ構造を配列からMapに変えても大して計算量が変わらなかったりします. ヒープを取り回す部分で若干関数型っぽくなってますし,この際に配列を削除して純粋関数型な実装にしてしまいましょう.

module VMap = Map.Make (struct
  type t = int
  let compare = compare
end)

module WMap = Map.Make (struct
  type t = float
  let compare = compare
end)

(* d に入っていない頂点への距離は無限大とみなす *)
let rec dijkstra e (d, q) =
  (* 最も近い頂点uを求める *)
  match WMap.min_binding q with
  | exception Not_found -> d
  | (w, us) ->
      dijkstra e @@
      (* 最も近い頂点は,複数存在しうる *)
      List.fold_left (fun (d, q) u ->
        if
          (* 既に頂点uを訪れていた *)
          try VMap.find u d < w
          with Not_found -> false
        then (d, q)
        else List.fold_left (fun (d, q) (v, c) ->
          (* uへの最短距離はw *)
          if
            try VMap.find v d <= w +. c
            with Not_found -> false (* d.(v) は無限大 *)
          then (d, q)
          else
            (* 最短距離を更新し,ヒープに突っ込む *)
            VMap.add v (w +. c) d,
            WMap.add (w +. c) (v :: try WMap.find (w +. c) q with Not_found -> []) q)
        (d, q) (e u)) (d, WMap.remove w q) us

計算量を評価してみます. とは言ってもそれほど大きく変化する訳ではなくて,最短距離の読み書きがO(1)からO(log V)になるだけですね. なので最短距離の更新がO(E log V + E log E)になり,全体を通しての計算量はO(E log E + E log V)…まぁ大体O(E log V)で良いんじゃないでしょうか.

説明のために頂点はint型,辺の長さはfloat型に固定していましたが, ファンクタを用いれば汎用的な実装が得られます.

module WeightedDirectedGraph
  (Vertex : sig
    type t
    val compare : t -> t -> int
  end)
  (Weight : sig
    type t
    val zero : t
    val ( + ) : t -> t -> t
    val compare : t -> t -> int
  end) :
sig
  val dijkstra :
    (* 隣接リスト *)
    (Vertex.t -> (Vertex.t * Weight.t) list) ->
    (* 始点 *)
    Vertex.t ->
    (* 始点から辿り着けなければNoneを返す *)
    (Vertex.t -> Weight.t option)
end =
struct
  module WMap = Map.Make (Weight)
  module VMap = Map.Make (Vertex)

  (* ダイクストラ法のメインループ *)
  (* d に入っていない頂点への距離は無限大とみなす *)
  let rec dijkstra_aux e (d, q) =
    (* 最も近い頂点uを求める *)
    match WMap.min_binding q with
    | exception Not_found -> d
    | (w, us) ->
        dijkstra_aux e @@ List.fold_left (fun (d, q) u ->
        (* 最も近い頂点は,複数存在しうる *)
          if
            (* 既に頂点uを訪れていた *)
            try Weight.compare (VMap.find u d) w < 0
            with Not_found -> false
          then (d, q)
          else List.fold_left (fun (d, q) (v, c) ->
            let open Weight in
            (* uへの最短距離はw *)
            if
              try Weight.compare (VMap.find v d) (w + c) <= 0
              with Not_found -> false (* d.(u) は無限大 *)
            then (d, q)
            else
              (* 最短距離を更新し,ヒープに突っ込む *)
              VMap.add v (w + c) d,
              WMap.add (w + c) (v :: try WMap.find (w + c) q with Not_found -> []) q)
          (d, q) (e u)) (d, WMap.remove w q) us

  let dijkstra e s =
    let d =
      dijkstra_aux e
        (VMap.singleton s Weight.zero, WMap.singleton Weight.zero [s]) in
    fun t -> try Some (VMap.find t d) with Not_found -> None
end

使い方はこの通り.日本語版Wikipediaと同じ例で最短距離を求めています.

# module G = WeightedDirectedGraph
  (struct
    type t = int
    let compare = compare
  end)
  (struct
    type t = float
    let zero = 0.
    let ( + ) = ( +. )
    let compare = compare
  end);;

module G :
  sig
    val dijkstra : (int -> (int * float) list) -> int -> int -> float option
  end

# Array.init 6 (G.dijkstra [(function
      | 0 -> [ (1, 7.); (2, 9.); (5, 14.) ]
      | 1 -> [ (0, 7.); (2, 10.); (3, 15.) ]
      | 2 -> [ (0, 9.); (1, 10.); (3, 11.); (5, 2.) ]
      | 3 -> [ (1, 15.); (2, 11.); (4, 6.) ]
      | 4 -> [ (3, 6.); (5, 9.) ]
      | 5 -> [ (0, 14.); (2, 2.); (4, 9.) ]) 0);;

Warning 8: this pattern-matching is not exhaustive.
Here is an example of a value that is not matched:
6
- : float option array =
[|Some 0.; Some 7.; Some 9.; Some 20.; Some 20.; Some 11.|]

Mapを直接返すのではなく,終点を受け取って距離を返す関数を返しているのは, 実装を隠蔽したいみたいな気持ちがあったりします.

辺の長さを工夫してやれば経路とかも手に入ります.

# module G = WeightedDirectedGraph
  (struct
    type t = int
    let compare = compare
  end)
  (struct
    type t = float * (string list -> string list)
    let zero = (0., fun xs -> xs)
    let ( + ) (c, f) (d, g) = (c +. d, fun xs -> f (g xs))
    let compare (c, _) (d, _) = compare c d
  end);;

module G :
  sig
    val dijkstra :
      (int -> (int * (float * (string list -> string list))) list) ->
      int -> int -> (float * (string list -> string list)) option
  end

# let e =
  Array.mapi (fun u ->
    List.map (fun (v, c) ->
      let s = Printf.sprintf "%d->%d" u v in
      (v, (c, fun xs -> s :: xs))))
  [|[ (1, 7.); (2, 9.); (5, 14.) ];
    [ (0, 7.); (2, 10.); (3, 15.) ];
    [ (0, 9.); (1, 10.); (3, 11.); (5, 2.) ];
    [ (1, 15.); (2, 11.); (4, 6.) ];
    [ (3, 6.); (5, 9.) ];
    [ (0, 14.); (2, 2.); (4, 9.) ]|];;

val e : (int * (float * (string list -> string list))) list array =
  [|[(1, (7., <fun>)); (2, (9., <fun>)); (5, (14., <fun>))];
    [(0, (7., <fun>)); (2, (10., <fun>)); (3, (15., <fun>))];
    [(0, (9., <fun>)); (1, (10., <fun>)); (3, (11., <fun>));
     (5, (2., <fun>))];
    [(1, (15., <fun>)); (2, (11., <fun>)); (4, (6., <fun>))];
    [(3, (6., <fun>)); (5, (9., <fun>))];
    [(0, (14., <fun>)); (2, (2., <fun>)); (4, (9., <fun>))]|]

# Array.map (fun (Some (c, f)) -> (c, f [])) @@ Array.init 6 (G.dijkstra (fun u -> e.(u)) 0)

Warning 8: this pattern-matching is not exhaustive.
Here is an example of a value that is not matched:
None
- : (float * string list) array =
[|(0., []); (7., ["0->1"]); (9., ["0->2"]); (20., ["0->2"; "2->3"]);
  (20., ["0->2"; "2->5"; "5->4"]); (11., ["0->2"; "2->5"])|]

計算量を落とすために差分リストを使っていて読みづらいですが…

ちなみに,ダイクストラ法と同じ要領でプリム法も純粋関数型に書けます.

module WeightedGraph
  (Vertex : sig
    type t
    val compare : t -> t -> int
  end)
  (Weight : sig
    type t
    val compare : t -> t -> int
  end) :
sig
  val prim :
    (* 隣接リスト *)
    (Vertex.t -> (Vertex.t * Weight.t) list) ->
    (* 始点 *)
    Vertex.t ->
    (* 最小全域木に含まれる辺のリスト *)
    (Vertex.t * Vertex.t * Weight.t) list
end =
struct
  module VSet = Set.Make (Vertex)
  module WMap = Map.Make (Weight)

  (*
   * プリム法のメインループ
   * es : 隣接リスト
   * vs : 訪れた頂点の集合
   * q : 訪れた頂点から伸びる辺が重み順に入ったヒープ
   * acc : 最小全域木に使うのが確定した辺を入れるやつ
   *)
  let rec prim_aux es acc vs q =
    match WMap.min_binding q with
    | exception Not_found -> acc
    | (w, []) -> prim_aux es acc vs (WMap.remove w q)
    | (w, (u, v) :: rest) ->
        if VSet.mem v vs then
          (* vは既に訪れていた *)
          prim_aux es acc vs (WMap.add w rest q)
        else
          (* vはまだ訪れていなかった *)
          prim_aux es ((u, v, w) :: acc) (VSet.add v vs) @@
            (* vから伸びる辺をキューに追加 *)
            List.fold_left (fun q (u, w) ->
              (* 現時点で既に訪れている頂点への辺は追加しない *)
              if VSet.mem u vs then q
              else WMap.add w ((v, u) :: try WMap.find w q with Not_found -> []) q) (WMap.add w rest q) (es v)

  let prim es s =
    prim_aux es [] (VSet.singleton s) @@
      (* 始点から伸びる辺をキューに入れておく *)
      List.fold_left (fun q (v, w) ->
        WMap.add w ((s, v) :: try WMap.find w q with Not_found -> []) q) WMap.empty (es s)
end

まぁ,副作用使いまくってクラスカル法を実装した方が速いですが…


  1. e.g. https://beta.atcoder.jp/contests/abc035/submissions/676838

  2. 頂点1つあたり(つまり,再帰一回あたり)の計算量ではなく,頂点全てを処理した際の(つまり,dijkstra再帰全体を通しての)計算量です

  3. AVL木の変種で実装されているため

  4. 真面目なヒープだと最小値はO(1)で求められると思いますが,まぁ最小値の削除にO(log N)を要するので大して変わらないでしょう

  5. 最短距離を求めるだけでなく,どの頂点への距離なのかも知りたいのでMapを使っています