2. 末尾再帰
著者:梅谷 武
各言語における末尾再帰最適化の実装状況について調べる。
作成:2024-03-12
更新:2024-10-24

C++

 C++コンパイラが末尾再帰最適化を行なうことは以前から知られていたが、ここではGCCとClangでサンプルプログラムのアセンブラ出力でそれを実際に確認する。
 OCamlのループと再帰については:Loops and Recursionsを参照のこと。
 Racketのループと再帰の概要についてはThe Racket Guide>2 Racket Essentials>2.3 Lists, Iteration, and Recursionを参照のこと。
 関数型言語の特徴の一つとして再帰を積極的に利用することがある。命令型言語ではループにするところでも関数型言語では再帰が用いられることが多い。しかし再帰とは関数呼び出しが繰り返されることであり、結果としてスタック領域が大量に消費されることになる。これはループと比べると資源効率の面で不利である。これを解決するために考え出された技法が末尾再帰(tail recursion)である。これについては次節で解説する。
 本節ではまず再帰によりスタックオーバーフローが起こることを実際に確認する。そのために自然数nを与えたとき1からnまでの総和
n

k=1
k
を求める関数sum(n)を再帰で実装する。

C++

 C++の場合はループで実装するのが自然であるが、ここでは再帰で実装する。データ型としては64ビット整数を用いる。
 sum.cpp:
#include <print>
#include <stdexcept>
#include <string>
 
long sum(long n) {
  if (n < 1)
    return 0;
  else
    return (n + sum(n - 1));
}
 
int main(int argc, char* argv[]) {
  try {
    if (argc != 2) {
      throw std::invalid_argument(
        "Usage: sum <non-negative integer>");
    }
    long n = std::stol(argv[1]);
    if (n < 0l) {
      throw std::invalid_argument(
        "The argument must be a non-negative integer.");
    }
    long result = sum(n);
    std::println("sum({}) = {}", n, result);
  } catch (const std::logic_error& e) {
    std::println(stderr, "Logic error: {}", e.what());
  } catch (const std::runtime_error& e) {
    std::println(stderr, "Runtime error: {}", e.what());
  } catch (const std::exception& e) {
    std::println(stderr, "Another std::exception: {}", e.what());
  } catch (...) {
    std::println(stderr, "Unknown exception.");
  }
  return 0;
}


 最適化オプションは付けない。最適化オプション:-O2を付けるとスタックオーバーフローしないがその理由は調べていない。
 makefile:
CPPG = g++
CPPL = clang++
CPPFLAGS = -std=c++23
SRC = sum.cpp
EXES = sum_g sum_l
 
all: $(EXES)
 
sum_g: $(SRC)
	$(CPPG) $(CPPFLAGS) -o $@ $(SRC)
 
sum_l: $(SRC)
	$(CPPL) $(CPPFLAGS) -o $@ $(SRC)
 
clean:
	rm -f $(EXES)


 n = 10から始めて一桁ずつ上げていくとn = 1000000でメモリアクセス違反がでる。C++ではプログラム内部でスタックオーバーフローやメモリアクセス違反を検知するのは困難である。この場合はUbuntuが検知し、シグナルを投げている。
$ ./sum_g 10
sum(10) = 55
$ ./sum_g 100
sum(100) = 5050
$ ./sum_g 1000
sum(1000) = 500500
$ ./sum_g 10000
sum(10000) = 50005000
$ ./sum_g 100000
sum(100000) = 5000050000
$ ./sum_g 1000000
Segmentation fault
 n = 100000とn = 1000000で測定する。
$ measure ./sum_g 100000
sum(100000) = 5000050000
======================================
Process exited with status: 0
total time:  0.004729 [sec]
mem  size:       6960 [KB]
code size:        256 [KB]
$ measure ./sum_g 1000000
======================================
Process terminated by signal: 11
total time:  0.019757 [sec]
mem  size:      11476 [KB]
code size:        256 [KB]
 同じことをClangで。
$ ./sum_l 10
sum(10) = 55
$ ./sum_l 100
sum(100) = 5050
$ ./sum_l 1000
sum(1000) = 500500
$ ./sum_l 10000
sum(10000) = 50005000
$ ./sum_l 100000
sum(100000) = 5000050000
$ ./sum_l 1000000
Segmentation fault
$ measure ./sum_l 100000
sum(100000) = 5000050000
======================================
Process exited with status: 0
total time:  0.005790 [sec]
mem  size:       8516 [KB]
code size:        230 [KB]
$ measure ./sum_l 1000000
======================================
Process terminated by signal: 11
total time:  0.017040 [sec]
mem  size:      11376 [KB]
code size:        230 [KB]
 sum.ml:
open Printf
 
let rec sum n =
  if n < 1 then
    0
  else
    n + sum (n - 1)
 
let () =
  try
    if Array.length Sys.argv <> 2 then
      raise (Invalid_argument
        "Usage: sum <non-negative integer>")
    else
      let n = int_of_string Sys.argv.(1) in
      if n < 0 then
        raise (Invalid_argument
          "The argument must be a non-negative integer.")
      else
        let result = sum n in
        printf "sum(%d) = %d\n" n result
  with
    | Invalid_argument msg ->
        eprintf "Logic error: %s\n" msg
    | Failure msg ->
        eprintf "Runtime error: %s\n" msg
    | exn ->
        eprintf "Another exception: %s\n" (Printexc.to_string exn)


 最適化オプションを付ける。
 makefile:
OCAMLOPT = ocamlopt
OCAMLFLAGS = -O2
LIBS = 
SRC = sum.ml
EXES = sum
 
all: $(EXES)
 
sum: $(SRC)
	$(OCAMLOPT) $(OCAMLFLAGS) -o $@ $(LIBS) $(SRC)
 
clean:
	rm -f $(EXES) *.o *.cmx *.cmi


 n = 10から始めて一桁ずつ上げていくとn = 1000000でスタックオーバーフローし、OCamlから例外が投げられる。
$ ./sum 10
sum(10) = 55
$ ./sum 100
sum(100) = 5050
$ ./sum 1000
sum(1000) = 500500
$ ./sum 10000
sum(10000) = 50005000
$ ./sum 100000
sum(100000) = 5000050000
$ ./sum 1000000
Another exception: Stack overflow
 n = 100000とn = 1000000で測定する。
$ measure ./sum 100000
sum(100000) = 5000050000
======================================
Process exited with status: 0
total time:  0.002213 [sec]
mem  size:       4416 [KB]
code size:       1111 [KB]
$ measure ./sum 1000000
Another exception: Stack overflow
======================================
Process exited with status: 0
total time:  0.007845 [sec]
mem  size:      10884 [KB]
code size:       1111 [KB]
 OCamlはプログラム内部でスタックオーバーフローを検知し、例外を投げ、プロセスは正常に終了する。
 プログラム:sum.rkt
#lang racket
(require iso-printf)
 
(define (sum n)
  (if (< n 1)
    0
    (+ n (sum (- n 1)))))
 
(define (main args)
  (with-handlers
    ([exn:fail? (λ (e) (eprintf "%s\n" (exn-message e)))]
     [exn? (λ (e) (eprintf "Unexpected: %s\n" (exn-message e)))])
  (cond
    [(not (= (vector-length args) 1))
      (error "Usage: sum <non-negative integer>")]
    [else
      (let ([n (string->number (vector-ref args 0))])
        (cond
          [(not (integer? n))
            (error "The argument must be an integer.")]
          [(< n 0)
            (error "The argument must be a non-negative integer.")]
          [else
            (let ([result (sum n)])
              (printf "sum(~a) = ~a\n" n result))]))])))
 
(main (current-command-line-arguments))


 n = 10から始めて一桁ずつ上げていくとn = 100000000でもスタックオーバーフローしない。しかし桁数が増えると計算終了までかなり時間がかかる。
$ ./sum 10
sum 10 = 55
$ ./sum 100
sum 100 = 5050
$ ./sum 1000
sum 1000 = 500500
$ ./sum 10000
sum 10000 = 50005000
$ ./sum 100000
sum 100000 = 5000050000
$ ./sum 1000000
sum 1000000 = 500000500000
$ ./sum 10000000
sum 10000000 = 50000005000000
$ ./sum 100000000
sum 100000000 = 5000000050000000
 n = 100000で測定する。
$ measure ./sum 100000
sum(100000) = 5000050000
======================================
Process exited with status: 0
total time:  0.470441 [sec]
mem  size:     133644 [KB]
code size:      12442 [KB]
 末尾再帰とは自分自身を末尾で呼び出す再帰関数のことである。通常の再帰においては関数呼び出しが行なわれるので呼び出す毎に新しいスタックフレームが確保される。しかし、末尾再帰においては実行中のスタックフレームにある局所変数は末尾で再帰関数を呼び出す時点以後に使われることがないのでこれを再利用することができる。この発想が公にされたのは1977年にGuy L. Steeleにより発表された論文[1]が最初とされている。
 関数型言語のコンパイラは末尾再帰を検知すると局所変数を再利用するような最適化コードを生成する。GCCやClangにおいてもコンパイルオプション:-O2で末尾再帰最適化が行なわれる。本節では実験によりそれを確認する。

C++

 前節のsum(n)は最後の行で再帰関数を呼び出した後にn + sum(n - 1)という演算を行なっているので末尾再帰ではない。これを末尾再帰にするためには補助関数:sum_auxを使って次のように書き直す。
 プログラム:sum2.cpp
#include <print>
#include <stdexcept>
#include <string>
 
long sum_aux(long n, long acc) {
  if (n == 0)
    return acc;
  else
    return sum_aux(n - 1, acc + n);
}
 
int main(int argc, char* argv[]) {
  try {
    if (argc != 2) {
      throw std::invalid_argument(
        "Usage: sum2 <non-negative integer>");
    }
    long n = std::stol(argv[1]);
    if (n < 0l) {
      throw std::invalid_argument(
        "The argument must be a non-negative integer.");
    }
    long result = sum_aux(n, 0);
    std::println("sum({}) = {}", n, result);
  } catch (const std::logic_error& e) {
    std::println(stderr, "Logic error: {}", e.what());
  } catch (const std::runtime_error& e) {
    std::println(stderr, "Runtime error: {}", e.what());
  } catch (const std::exception& e) {
    std::println(stderr, "Another std::exception: {}", e.what());
  } catch (...) {
    std::println(stderr, "Unknown exception.");
  }
  return 0;
}


 このプログラムを普通にコンパイルすると末尾再帰最適化は行なわれない。
$ g++ -std=c++23 -o sum2_g sum2.cpp
 どのようなコードが生成されたのかをsum_aux関数だけのアセンブラソースで見てみる。C++23とは無縁なので-std=c++23は必要ない。
$ g++ -S -masm=intel sum_aux.cpp
 アセンブラソース:sum_aux.s
_Z7sum_auxll:
.LFB0:
	.cfi_startproc
	endbr64
	push	rbp
	.cfi_def_cfa_offset 16
	.cfi_offset 6, -16
	mov	rbp, rsp
	.cfi_def_cfa_register 6
	sub	rsp, 16
	mov	QWORD PTR -8[rbp], rdi
	mov	QWORD PTR -16[rbp], rsi
	cmp	QWORD PTR -8[rbp], 0
	jne	.L2
	mov	rax, QWORD PTR -16[rbp]
	jmp	.L3
.L2:
	mov	rdx, QWORD PTR -16[rbp]
	mov	rax, QWORD PTR -8[rbp]
	add	rdx, rax
	mov	rax, QWORD PTR -8[rbp]
	sub	rax, 1
	mov	rsi, rdx
	mov	rdi, rax
	call	_Z7sum_auxll
	nop
.L3:
	leave
	.cfi_def_cfa 7, 8
	ret
 次の行で自分自身を呼び出しているのがわかる。
	call	_Z7sum_auxll
 このプログラムを-O2オプションでコンパイルすると末尾再帰最適化が行なわれる。
$ g++ -std=c++23 -O2 -o sum2_g sum2.cpp
 どのようなコードが生成されたのかをsum_aux関数だけのアセンブラソースで見てみる。
$ g++ -S -O2 -masm=intel sum_aux.cpp
 アセンブラソース:sum_aux.s
_Z7sum_auxll:
.LFB0:
	.cfi_startproc
	endbr64
	mov	rax, rsi
	test	rdi, rdi
	je	.L5
	lea	rdx, -1[rdi]
	test	dil, 1
	je	.L2
	add	rax, rdi
	mov	rdi, rdx
	test	rdx, rdx
	je	.L17
	.p2align 4,,10
	.p2align 3
.L2:
	lea	rax, -1[rax+rdi*2]
	sub	rdi, 2
	jne	.L2
.L5:
	ret
.L17:
	ret
 関数呼び出しがループに代わっていることがわかる。
 最適化オプション:-O2を付けてメイクする。
 makefile:
CPPG = g++
CPPL = clang++
CPPFLAGS = -std=c++23 -O2
SRC = sum2.cpp
EXES = sum2_g sum2_l
 
all: $(EXES)
 
sum2_g: $(SRC)
	$(CPPG) $(CPPFLAGS) -o $@ $(SRC)
 
sum2_l: $(SRC)
	$(CPPL) $(CPPFLAGS) -o $@ $(SRC)
 
clean:
	rm -f $(EXES)


 n = 100000000で測定する。
$ measure ./sum2_g 100000000
sum(100000000) = 5000000050000000
======================================
Process exited with status: 0
total time:  0.064444 [sec]
mem  size:       3692 [KB]
code size:        119 [KB]
$ measure ./sum2_l 100000000
sum(100000000) = 5000000050000000
======================================
Process exited with status: 0
total time:  0.000000 [sec]
mem  size:       3680 [KB]
code size:         92 [KB]
 スタックオーバーフローせず、メモリ消費量も少ない。
 プログラム:sum2.ml
open Printf
 
let rec sum_aux n acc =
  if n < 1 then
    acc
  else
    sum_aux (n - 1) (acc + n)
 
let () =
  try
    if Array.length Sys.argv <> 2 then
      raise (Invalid_argument
        "Usage: sum2 <non-negative integer>")
    else
      let n = int_of_string Sys.argv.(1) in
      if n < 0 then
        raise (Invalid_argument
          "The argument must be a non-negative integer.")
      else
        let result = sum_aux n 0 in
        printf "sum(%d) = %d\n" n result
  with
    | Invalid_argument msg ->
        eprintf "Logic error: %s\n" msg
    | Failure msg ->
        eprintf "Runtime error: %s\n" msg
    | exn ->
        eprintf "Another exception: %s\n" (Printexc.to_string exn)


 n = 100000000で測定する。
$ ocamlopt -O2 -o sum2 sum2.ml
$ measure ./sum2 100000000
sum(100000000) = 5000000050000000
======================================
Process exited with status: 0
total time:  0.126208 [sec]
mem  size:       2748 [KB]
code size:       1111 [KB]
 スタックオーバーフローせず、メモリ消費量も少ない。
 プログラム:sum2.rkt
#lang racket
(require iso-printf)
 
(define (sum_aux n acc)
  (if (= n 0)
    acc
    (sum_aux (- n 1) (+ acc n))))
 
(define (main args)
  (with-handlers
    ([exn:fail? (λ (e) (eprintf "%s\n" (exn-message e)))]
     [exn? (λ (e) (eprintf "Unexpected: %s\n" (exn-message e)))])
  (cond
    [(not (= (vector-length args) 1))
      (error "Usage: sum2 <non-negative integer>")]
    [else
      (let ([n (string->number (vector-ref args 0))])
        (cond
          [(not (integer? n))
            (error "The argument must be an integer.")]
          [(< n 0)
            (error "The argument must be a non-negative integer.")]
          [else
            (let ([result (sum_aux n 0)])
              (printf "sum(%d) = %d\n" n result))]))])))
 
(main (current-command-line-arguments))


 n = 100000000で測定する。
$ raco exe sum2.rkt
$ measure ./sum2 100000000
sum(100000000) = 5000000050000000
======================================
Process exited with status: 0
total time:  0.679221 [sec]
mem  size:     135140 [KB]
code size:      12442 [KB]
 メモリ消費量が減り、実行速度も速くなっている。Racketはスタックオーバーフローが起こりにくい設計になっているが、メモリ消費量と実行速度で末尾再帰の優位性が確認できた。
[1] Steele, Guy Lewis, Debunking the "expensive procedure call" myth or, procedure call implementations considered harmful or, LAMBDA: The Ultimate GOTO, Proceedings of the 1977 annual conference on - ACM '77, 153–162, 1977
This document is licensed under the MIT License.
Copyright (C) 2024 Takeshi Umetani