読者です 読者をやめる 読者になる 読者になる

チラシの裏

プログラミングとか色々

Coqで証明付きのマージソートを書く

Coq

この記事はTheorem Prover Advent Calendar 2016の4日目のために書かれました。

少し季節外れの記事になりますが、前期はプロ演A^1の季節でしたね。
僕のTLでもC言語の課題に苦しめられた学部生のツイートが良く回ってきましたが、 とりわけ彼らが苦戦していたのはマージソートを書く課題のようでした。
面白そうなので僕もCoqで実装してみましょう。 もちろん、証明付きで。

実装

とりあえず比較関数等の準備

Require Import Arith Div2 List Orders Sorted Permutation Program.
Require Omega.

Section MergeSort.
  Local Coercion is_true : bool >-> Sortclass.
  Local Hint Constructors Permutation StronglySorted.

  Variable t : Set.
  Variable leb : t -> t -> bool.
  Variable leb_total : forall x y, leb x y \/ leb y x.
  Variable leb_trans : Transitive leb.
  Hint Resolve leb_trans.

まずはマージから実装します。
CPDT風にプログラムと正当性の証明を同時に書いています。

Definition merge :
  forall xs, StronglySorted leb xs ->
  forall ys, StronglySorted leb ys ->
  { zs | StronglySorted leb zs /\ Permutation zs (xs ++ ys) }.
Proof.
  refine (fix merge xs :=
    match xs as xs0 return StronglySorted leb xs0 ->
      forall ys, StronglySorted leb ys ->
      { zs | StronglySorted leb zs
          /\ Permutation zs (xs0 ++ ys) } with
    | [] => fun _ ys _ => exist _ ys _
    | x :: xs' => fun _ =>
        fix merge' ys :=
          match ys with
          | [] => fun _ => exist _ (x :: xs') _
          | y :: ys' => fun _ =>
              if Sumbool.sumbool_of_bool (leb x y) then
                let (zs, _) := merge xs' _ (y :: ys') _ in
                exist _ (x :: zs) _
              else
                let (zs, _) := merge' ys' _ in
                exist _ (y :: zs) _
          end
    end);
    simpl in *;
    try clear merge';
    repeat (rewrite app_nil_r in *
      || match goal with
         | H : StronglySorted _ (_ :: _) |- _ =>
             inversion H; clear H; subst
         | H : _ /\ _ |- _ => destruct H
         end
      || split
      || apply SSorted_cons);
    eauto;
    repeat match goal with
    | H : Forall ?P ?l |- _ =>
        assert (forall x, In x l -> P x)
          by (apply Forall_forall; eauto);
        clear H
    | |- Forall _ _ =>
        apply Forall_forall; intros
    end.
  - apply (Permutation_in _ H0) in H2.
    apply in_app_iff in H2.
    destruct H2 as [ | [ ] ]; subst; eauto.
  - destruct (leb_total x y); [ congruence | ].
    apply (Permutation_in _ H0) in H2.
    apply in_app_iff with (l := x :: xs') in H2.
    destruct H2 as [ [ ] | ]; subst; eauto.
  - apply Permutation_cons_app with (l1 := x :: xs').
    eauto.
Defined.

マージソートといえば再帰的に分解していくものが有名ですが、 今回は最初にソート済のリストに分解した後にボトムアップにマージしていく実装とします。
こうする事である程度ソートされた入力に対しては高速に処理できるようになります。

ボトムアップにマージしていく操作、つまりソート済のリストのリストを受け取って隣り合ったリスト同士をマージして新しいソート済のリストのリストを作る操作は次のように実装できます。

Definition meld : forall xss, Forall (StronglySorted leb) xss ->
  { xss' | length xss' = div2 (S (length xss))
      /\ Forall (StronglySorted leb) xss'
      /\ Permutation
           (fold_right (@app _) [] xss')
           (fold_right (@app _) [] xss) }.
Proof.
  refine (fix meld xss :=
    match xss with
    | [] => fun _ => exist _ [] _
    | [xs] => fun _ => exist _ [xs] _
    | xs :: xs' :: xss' => fun _ =>
        let (ys, _) := merge xs _ xs' _ in
        let (yss, _) := meld xss' _ in
        exist _ (ys :: yss) _
    end);
    simpl in *; clear meld;
    repeat (rewrite app_nil_r in *
      || match goal with
         | H : Forall _ (_ :: _) |- _ =>
             inversion H; clear H; subst
         | H : _ /\ _ |- _ => destruct H
         end
      || split);
    eauto.
  rewrite app_assoc.
  apply Permutation_app; eauto.
Defined.

ソート済のリストが一つになるまでこの操作を繰り返せばソートができます。
再帰呼び出しの度にソート済のリストの数が半分程度になるため停止性が保証されますが、停止性の自動判定に失敗するため面倒な記述が必要になります。

最初にソート済のリストに分解する操作も必要ですが、今は手抜きの実装にしておきましょう。

Definition merge_sort xs :
  { xs' | StronglySorted leb xs' /\ Permutation xs' xs }.
Proof.
  refine (let (xs', _) := Fix
      (well_founded_ltof _ (@length _))
      (fun xss =>
        Forall (StronglySorted leb) xss ->
        Permutation (fold_right (@app _) [] xss) xs ->
        { xs' | StronglySorted leb xs' /\ Permutation xs' xs })
      (fun xss =>
        match xss with
        | [] => fun _ _ _ => exist _ [] _
        | [xs] => fun _ _ _ => exist _ xs _
        | xs :: ys :: xss => fun iter_meld _ _ =>
            let (xss', _) := meld (xs :: ys :: xss) _ in
            iter_meld xss' _ _ _
        end)
      (map (fun x => [x]) xs) _ _ in
    exist _ xs' _);
    simpl in *;
    repeat (rewrite app_nil_r in *
      || match goal with
         | H : Forall _ (_ :: _) |- _ =>
             inversion H; clear H; subst
         | H : _ /\ _ |- _ => destruct H
         end
      || split);
    eauto.
  - unfold ltof.
    simpl in *.
    destruct (length xss0).
    + omega.
    + rewrite H.
      apply le_n_S.
      apply lt_div2 with (n := S (S n)).
      omega.
  - apply Forall_forall.
    intros ? HIn.
    apply in_map_iff in HIn.
    destruct HIn as [? [ ]]; subst; eauto.
  - induction xs; simpl; eauto.
Defined.

ベンチマーク

速度に気を遣ってアルゴリズムを実装したら実際の速度も気になるのが人情。 OCamlにextractしてOCaml標準ライブラリのList.sort及びCoq標準ライブラリのMergesort.NatSort.sortと比較してみます。

ベンチマークを取るための準備として、まずはOCamlのプログラムをextractします。

Require Import ExtrOcamlBasic ExtrOcamlNatInt Mergesort.
Extract Constant map => "List.map".
Extract Constant Sumbool.sumbool_of_bool => "(fun b -> b)".
Extract Constant NatOrder.leb => "( <= )".
Extraction "merge_sort.ml" merge_sort NatSort.

次に実行時間を計測するOCamlの関数を雑に用意します。

# #use "./merge_sort.ml";;
# let measure f =
    let start = Sys.time () in
    f ();
    Sys.time () -. start;;
val measure : (unit -> 'a) -> float = <fun>

ランダムな入力

準備ができたのでランダムな入力に対してのソートの実行時間を測っていきましょう。
とりあえずOCamlのList.sortから。

# Random.self_init ();;
- : unit = ()
# let l = Array.to_list (Array.init 114514 (fun _ -> Random.bits ()));;
val l : int list = ...
# measure (fun _ -> ignore (List.sort compare l));;
- : float = 0.122534999999999977

List.sortで約10万個のランダムな要素をソートするのに要した時間は0.1秒ほどでした。

次はCoq標準ライブラリのマージソートを試してみます。

# measure (fun _ -> ignore (NatSort.sort l));;
- : float = 0.240472

こちらは約10万個のランダムな要素をソートするのに0.24秒ほど要しました。
やっぱList.sort速いっすね…

では今回の実装はと言うと、

# measure (fun _ -> ignore (merge_sort ( <= ) l));;
- : float = 0.305214

0.3秒程度。ダメだー

ランダムな大きい入力

また、今回の実装とCoq標準ライブラリのマージソートには問題があって、

# let l = Array.to_list (Array.init 1919810 (fun _ -> Random.bits ()));;
val l : int list = ...
# measure (fun _ -> ignore (NatSort.sort l));;
Stack overflow during evaluation (looping recursion?).
# measure (fun _ -> ignore (merge_sort ( <= ) l));;
Stack overflow during evaluation (looping recursion?).

ある程度大きな入力ではスタックオーバーフローしてしまいます。僕の実装はともかく、Coq標準ライブラリのまで末尾再帰になってないんですね…

実装の改良

とりあえずアキュムレーターを導入してマージを末尾再帰に直します

Definition rev_merge :
  forall xs, StronglySorted leb xs ->
  forall ys, StronglySorted leb ys ->
  forall acc,
  { ws | exists zs, ws = rev zs ++ acc /\ StronglySorted leb zs /\ Permutation zs (xs ++ ys) }.
Proof.
  refine (fix rev_merge xs :=
    match xs as xs0 return StronglySorted leb xs0 ->
      forall ys, StronglySorted leb ys ->
      forall acc,
      { ws | exists zs, ws = rev zs ++ acc
          /\ StronglySorted leb zs
          /\ Permutation zs (xs0 ++ ys) } with
    | [] => fun _ ys _ acc => exist _ (rev_append ys acc) (ex_intro _ ys _)
    | x :: xs' => fun _ =>
        fix rev_merge' ys :=
          match ys as ys0 return StronglySorted leb ys0 ->
            forall acc,
            { ws | exists zs, ws = rev zs ++ acc
                /\ StronglySorted leb zs
                /\ Permutation zs (x :: xs' ++ ys0) }
          with
          | [] => fun _ acc => exist _ (rev_append xs' (x :: acc)) (ex_intro _ (x :: xs') _)
          | y :: ys' => fun _ acc => 
              if Sumbool.sumbool_of_bool (leb x y) then
                let (zs, H) := rev_merge xs' _ (y :: ys') _ (x :: acc) in
                exist _ zs (let (x0, _) := H in ex_intro _ (x :: x0) _)
              else
                let (zs, H) := rev_merge' ys' _ (y :: acc) in
                exist _ zs (let (x0, _) := H in ex_intro _ (y :: x0) _)
          end
    end); try clear rev_merge';
    repeat (simpl in *;
        ( rewrite app_nil_r in *
       || rewrite rev_append_rev in *
       || rewrite <- app_assoc
       || match goal with
          | H : StronglySorted _ (_ :: _) |- _ => inversion H; clear H; subst
          | H : _ /\ _ |- _ => destruct H
          | H : leb ?x ?y = false |- _ => destruct (leb_total x y); [ congruence | clear H ]
          end
       || split
       || apply SSorted_cons); subst);
    eauto;
    repeat match goal with
    | H : Forall ?P ?l |- _ =>
        assert (forall x, In x l -> P x)
          by (apply Forall_forall; eauto);
        clear H
    | |- Forall _ _ =>
        apply Forall_forall; intros
    end.
  - apply (Permutation_in _ H1) in H3.
    apply in_app_iff in H3.
    destruct H3 as [ | [ ] ]; subst; eauto.
  - apply (Permutation_in _ H1) in H3.
    apply in_app_iff with (l := x :: xs') in H3.
    destruct H3 as [ [ ] | ]; subst; eauto.
  - apply Permutation_cons_app with (l1 := x :: xs').
    eauto.
Defined.

ボトムアップにマージしていく操作も末尾再帰に直します。

Definition meld : forall xss,
  Forall (StronglySorted leb) xss ->
  forall acc,
  { zss | exists xss', zss = rev (map (@rev t) xss') ++ acc /\ length xss' = div2 (S (length xss))
      /\ Forall (StronglySorted leb) xss'
      /\ Permutation (concat _ xss') (concat _ xss) }.
Proof.
  refine (fix meld xss :=
    match xss as xss0 return
      Forall (StronglySorted leb) xss0 ->
      forall acc,
      { zss | exists xss', zss = rev (map (@rev t) xss') ++ acc /\ length xss' = div2 (S (length xss0))
          /\ Forall (StronglySorted leb) xss'
          /\ Permutation (concat _ xss') (concat _ xss0) }
    with
    | [] => fun _ acc => exist _ acc (ex_intro _ [] _)
    | [xs] => fun _ acc => exist _ (rev xs :: acc) (ex_intro _ [xs] _)
    | xs :: xs' :: xss' => fun _ acc =>
        let (ys, H1) := rev_merge xs _ xs' _ [] in
        let (yss, H2) := meld xss' _ (ys :: acc) in
        exist _ yss (let (zs, _) := H1 in let (zss, _) := H2 in ex_intro _ (zs :: zss) _)
    end);
    repeat (simpl in *;
       ( rewrite app_nil_r in *
      || rewrite rev_involutive in *
      || rewrite <- app_assoc
      || match goal with
         | H : Forall _ (_ :: _) |- _ => inversion H; clear H
         | H : _ /\ _ |- _ => destruct H
         | H : exists _, _ |- _ => destruct H
         end
      || split); subst); eauto.
  rewrite app_assoc.
  apply Permutation_app; eauto.
Defined.

前の実装ではリストをソート済の部分リストのリストに直す部分は手抜きしてましたが、既にソートされている部分についてはそれを使う事にしましょう。

Definition splitting : forall xs acc prev,
  Forall (StronglySorted leb) acc ->
  { xss | Permutation (concat _ xss) (rev (concat _ acc) ++ prev :: xs)
       /\ Forall (StronglySorted leb) xss }.
Proof.
  refine (fix neutral xs :=
    match xs as xs0 return
      forall acc prev,
      Forall (StronglySorted leb) acc ->
      { xss | Permutation (concat _ xss) (rev (concat _ acc) ++ prev :: xs0)
           /\ Forall (StronglySorted leb) xss }
    with 
    | [] => fun acc prev _ => exist _ ([prev] :: acc) _
    | x :: xs => fun acc prev _ =>
        if Sumbool.sumbool_of_bool (leb prev x) then _
        else _
    end
  with incr xs :=
    match xs as xs0 return
      forall acc curr prev,
      Forall (StronglySorted leb) acc ->
      StronglySorted (fun x y => leb y x) (prev :: curr) ->
      { xss | Permutation (concat _ xss) (rev (concat _ acc) ++ rev curr ++ prev :: xs0)
           /\ Forall (StronglySorted leb) xss }
    with
    | [] => fun acc curr prev _ _ => exist _ (rev_append curr [prev] :: acc) _
    | x :: xs => fun acc curr prev _ _ =>
        if Sumbool.sumbool_of_bool (leb prev x) then _
        else _
    end
  with decr xs :=
    match xs as xs0 return
      forall acc curr prev,
      Forall (StronglySorted leb) acc ->
      StronglySorted leb (prev :: curr) ->
      { xss | Permutation (concat _ xss) (rev (concat _ acc) ++ prev :: curr ++ xs0)
           /\ Forall (StronglySorted leb) xss }
    with
    | [] => fun acc curr prev _ _ => exist _ ((prev :: curr) :: acc) _
    | x :: xs => fun acc curr prev _ _ =>
        if Sumbool.sumbool_of_bool (leb x prev) then _
        else _
    end
  for neutral);
  [
  | refine (let (xs, _) := incr xs0 acc [prev] x _ _ in exist _ xs _)
  | refine (let (xs, _) := decr xs0 acc [prev] x _ _ in exist _ xs _)
  | 
  | refine (let (xs, _) := incr xs0 acc (prev :: curr) x _ _ in exist _ xs _)
  | refine (let (xs, _) := neutral xs0 (rev_append curr [prev] :: acc) x _ in exist _ xs _)
  | 
  | refine (let (xs, _) := decr xs0 acc (prev :: curr) x _ _ in exist _ xs _)
  | refine (let (xs, _) := neutral xs0 ((prev :: curr) :: acc) x _ in exist _ xs _) ];
  repeat
    (( rewrite app_nil_r in *
    || rewrite <- app_assoc in *
    || rewrite rev_app_distr in *
    || rewrite rev_involutive in *
    || rewrite rev_append_rev in *
    || apply Forall_rev
    || apply StronglySorted_app
    || apply StronglySorted_rev
    || match goal with
       | H : _ /\ _ |- _ => destruct H
       | H : exists _, _ |- _ => destruct H
       | H : leb ?x ?y = false |- _ => destruct (leb_total x y); [ congruence | clear H ]
       | H : Forall _ ?l |- Forall _ ?l => eapply Forall_impl; [| apply H ]; intros
       | H : StronglySorted _ (_ :: _) |- _ => inversion H; subst; clear H
       | |- Forall _ (_ :: _) => constructor
       | |- StronglySorted _ (_ :: _) => constructor
       | _ => split
       end
    || split); subst; simpl in *);
  eauto.
- apply Permutation_rev with (l := prev :: concat t acc).
- etransitivity.
  + eassumption.
  + eapply Permutation_app_head.
    eauto.
- etransitivity.
  + eapply Permutation_app_comm.
  + simpl.
    etransitivity.
    * apply Permutation_cons_append.
    * rewrite <- app_assoc.
      apply Permutation_app_tail.
      apply Permutation_rev.
- etransitivity.
  + eassumption.
  + apply Permutation_app_head.
    etransitivity.
    * apply Permutation_middle.
    * apply Permutation_app_tail.
      apply Permutation_rev.
- etransitivity.
  + apply Permutation_app_comm with (l := prev :: curr).
  + apply Permutation_app_tail.
    apply Permutation_rev.
- etransitivity.
  + eassumption.
  + apply Permutation_app_head.
    apply Permutation_middle with (l1 := prev :: curr).
- etransitivity.
  + eassumption.
  + apply Permutation_app_head.
    symmetry.
    etransitivity.
    * apply Permutation_middle.
    * apply Permutation_app_tail.
      apply Permutation_rev.
Defined.

ソート済のリストが1つなるまでマージを繰り返す関数は次のように末尾再帰にできます。 ここで、マージを行う関数を末尾再帰に直したために要素が逆順になってしまいますが、 リストの反転を二回行えば元に戻る事を利用して計算時間を抑えましょう。

Definition merge_sort (t : Set) (leb : t -> t -> bool)
  (leb_total : forall x y, leb x y \/ leb y x)
  (leb_trans : Transitive leb) xs :
  { xs' | StronglySorted leb xs' /\ Permutation xs' xs }.
Proof.
  refine (match xs as xs0 return { xs' | StronglySorted leb xs' /\ Permutation xs' xs0 } with
    | [] => exist _ [] _
    | x :: xs =>
        let (xss, _) := splitting _ leb _ _ xs [] x _ in
        (fix iter_meld xss (Hwf : Acc (ltof _ (@length _)) xss) { struct Hwf } :=
          match xss as xss0 return
            Acc (ltof _ (@length _)) xss0 ->
            Forall (StronglySorted leb) xss0 ->
            Permutation (concat _ xss0) (x :: xs) ->
            { xs' | StronglySorted leb xs' /\ Permutation xs' (x :: xs) }
          with
          | [] => fun _ _ _ => exist _ [] _
          | [xs] => fun _ _ _ => exist _ xs _
          | xs :: xs' :: xss => fun Hwf _ _ =>
              let (xss', _) := meld _ leb _ _ (xs :: xs' :: xss) _ [] in
              match Hwf with
              | Acc_intro Hwf' => _
              end
          end Hwf
        with iter_meld_rev xss (Hwf : Acc (ltof _ (@length _)) xss) { struct Hwf } :=
          match xss as xss0 return
            Acc (ltof _ (@length _)) xss0 ->
            Forall (StronglySorted (fun x y => leb y x)) xss0 ->
            Permutation (concat _ xss0) (x :: xs) ->
            { xs' | StronglySorted leb xs' /\ Permutation xs' (x :: xs) }
          with
          | [] => fun _ _ _ => exist _ [] _
          | [xs] => fun _ _ _ => exist _ (rev xs) _
          | xs :: xs' :: xss => fun Hwf _ _ =>
              let (xss', _) := meld _ (fun x y => leb y x) _ _ (xs :: xs' :: xss) _ [] in
              match Hwf with
              | Acc_intro Hwf' => _
              end
          end Hwf
        for iter_meld) xss (well_founded_ltof _ (@length _) _) _ _
  end);
  [ | | | | | | | refine (iter_meld_rev xss' (Hwf' _ _) _ _) | | | | | | refine (iter_meld xss' (Hwf' _ _) _ _) | | ];
    repeat (simpl in *;
       ( rewrite app_nil_r in *
      || rewrite rev_length in *
      || rewrite map_length in *
      || rewrite concat_rev in *
      || match goal with
         | H : Forall _ (_ ++ _) |- _ => apply Forall_app_iff in H
         | H : Forall _ (_ :: _) |- _ => inversion H; clear H; subst
         | H : _ /\ _ |- _ => destruct H
         | H : exists _, _ |- _ => destruct H
         | H : Forall _ ?l |- Forall _ ?l => eapply Forall_impl; [| apply H ]
         | |- Permutation (rev ?xs) _ =>
             apply Permutation_trans with (l' := rev (rev xs));
             [ apply Permutation_rev
             | rewrite rev_involutive ]
         end
      || apply Forall_rev
      || apply Forall_map
      || apply StronglySorted_rev
      || split); subst);
    eauto.
  - destruct (length xss1).
    + omega.
    + rewrite H2.
      apply le_n_S.
      apply lt_div2 with (n := S (S n)).
      omega.
  - destruct (length xss1).
    + omega.
    + rewrite H2.
      apply le_n_S.
      apply lt_div2 with (n := S (S n)).
      omega.
Defined.

ベンチマーク再び

ひたすら末尾再帰に直したお陰でスタックオーバーフローは発生しなくなりました。

# let l = Array.to_list (Array.init 1919810 (fun _ -> Random.bits ()));;
val l : int list = ...
# measure (fun _ -> ignore (merge_sort ( <= ) l));;
- : float = 4.59001299999999901

では、これらの改良を施した事で実行時間はどのように変化したかと言うと、

# let l = Array.to_list (Array.init 114514 (fun _ -> Random.bits ()));;
val l : int list = ...
# measure (fun _ -> ignore (merge_sort ( <= ) l));;
- : float = 0.206737000000000393

1.5倍程度の処理速度向上が見られました! しかし依然としてList.sortの半分程度の速度しか出ていないのも事実。どうしたものか…
まぁ今回はCoq標準ライブラリより優秀そうなマージソートの実装が得られたので良しとしましょう。

結び

やっぱりList.sortには勝てなかったよ… 比較的高速な証明付きのマージソートを実装できました。 末尾再帰に直した後のCoqのソースコードこちらに置いてあります。

Coqを用いる事による利点は正しさが保証される事に留まりません。 今までバグを恐れて行えなかった積極的な最適化を行う勇気も我々に与えてくれるのです。