2. 末尾再帰
著者:梅谷 武
各言語における末尾再帰最適化の実装状況について調べる。
作成:2024-03-12
更新:2024-10-24
更新:2024-10-24
C++コンパイラが末尾再帰最適化を行なうことは以前から知られていたが、ここではGCCとClangでサンプルプログラムのアセンブラ出力でそれを実際に確認する。
OCamlのループと再帰については:Loops and Recursionsを参照のこと。
Racketのループと再帰の概要についてはThe Racket Guide>2 Racket Essentials>2.3 Lists, Iteration, and Recursionを参照のこと。
関数型言語の特徴の一つとして再帰を積極的に利用することがある。命令型言語ではループにするところでも関数型言語では再帰が用いられることが多い。しかし再帰とは関数呼び出しが繰り返されることであり、結果としてスタック領域が大量に消費されることになる。これはループと比べると資源効率の面で不利である。これを解決するために考え出された技法が末尾再帰(tail recursion)である。これについては次節で解説する。
本節ではまず再帰によりスタックオーバーフローが起こることを実際に確認する。そのために自然数nを与えたとき1からnまでの総和
を求める関数sum(n)を再帰で実装する。
|
C++の場合はループで実装するのが自然であるが、ここでは再帰で実装する。データ型としては64ビット整数を用いる。
sum.cpp:
#include <print>
#include <stdexcept>
#include <string>
long sum(long n) {
if (n < 1)
return 0;
else
return (n + sum(n - 1));
}
int main(int argc, char* argv[]) {
try {
if (argc != 2) {
throw std::invalid_argument(
"Usage: sum <non-negative integer>");
}
long n = std::stol(argv[1]);
if (n < 0l) {
throw std::invalid_argument(
"The argument must be a non-negative integer.");
}
long result = sum(n);
std::println("sum({}) = {}", n, result);
} catch (const std::logic_error& e) {
std::println(stderr, "Logic error: {}", e.what());
} catch (const std::runtime_error& e) {
std::println(stderr, "Runtime error: {}", e.what());
} catch (const std::exception& e) {
std::println(stderr, "Another std::exception: {}", e.what());
} catch (...) {
std::println(stderr, "Unknown exception.");
}
return 0;
}
最適化オプションは付けない。最適化オプション:-O2を付けるとスタックオーバーフローしないがその理由は調べていない。
makefile:
CPPG = g++
CPPL = clang++
CPPFLAGS = -std=c++23
SRC = sum.cpp
EXES = sum_g sum_l
all: $(EXES)
sum_g: $(SRC)
$(CPPG) $(CPPFLAGS) -o $@ $(SRC)
sum_l: $(SRC)
$(CPPL) $(CPPFLAGS) -o $@ $(SRC)
clean:
rm -f $(EXES)
n = 10から始めて一桁ずつ上げていくとn = 1000000でメモリアクセス違反がでる。C++ではプログラム内部でスタックオーバーフローやメモリアクセス違反を検知するのは困難である。この場合はUbuntuが検知し、シグナルを投げている。
$ ./sum_g 10
sum(10) = 55
$ ./sum_g 100
sum(100) = 5050
$ ./sum_g 1000
sum(1000) = 500500
$ ./sum_g 10000
sum(10000) = 50005000
$ ./sum_g 100000
sum(100000) = 5000050000
$ ./sum_g 1000000
Segmentation fault
n = 100000とn = 1000000で測定する。
$ measure ./sum_g 100000
sum(100000) = 5000050000
======================================
Process exited with status: 0
total time: 0.004729 [sec]
mem size: 6960 [KB]
code size: 256 [KB]
$ measure ./sum_g 1000000
======================================
Process terminated by signal: 11
total time: 0.019757 [sec]
mem size: 11476 [KB]
code size: 256 [KB]
同じことをClangで。
$ ./sum_l 10
sum(10) = 55
$ ./sum_l 100
sum(100) = 5050
$ ./sum_l 1000
sum(1000) = 500500
$ ./sum_l 10000
sum(10000) = 50005000
$ ./sum_l 100000
sum(100000) = 5000050000
$ ./sum_l 1000000
Segmentation fault
$ measure ./sum_l 100000
sum(100000) = 5000050000
======================================
Process exited with status: 0
total time: 0.005790 [sec]
mem size: 8516 [KB]
code size: 230 [KB]
$ measure ./sum_l 1000000
======================================
Process terminated by signal: 11
total time: 0.017040 [sec]
mem size: 11376 [KB]
code size: 230 [KB]
sum.ml:
open Printf
let rec sum n =
if n < 1 then
0
else
n + sum (n - 1)
let () =
try
if Array.length Sys.argv <> 2 then
raise (Invalid_argument
"Usage: sum <non-negative integer>")
else
let n = int_of_string Sys.argv.(1) in
if n < 0 then
raise (Invalid_argument
"The argument must be a non-negative integer.")
else
let result = sum n in
printf "sum(%d) = %d\n" n result
with
| Invalid_argument msg ->
eprintf "Logic error: %s\n" msg
| Failure msg ->
eprintf "Runtime error: %s\n" msg
| exn ->
eprintf "Another exception: %s\n" (Printexc.to_string exn)
最適化オプションを付ける。
makefile:
OCAMLOPT = ocamlopt
OCAMLFLAGS = -O2
LIBS =
SRC = sum.ml
EXES = sum
all: $(EXES)
sum: $(SRC)
$(OCAMLOPT) $(OCAMLFLAGS) -o $@ $(LIBS) $(SRC)
clean:
rm -f $(EXES) *.o *.cmx *.cmi
n = 10から始めて一桁ずつ上げていくとn = 1000000でスタックオーバーフローし、OCamlから例外が投げられる。
$ ./sum 10
sum(10) = 55
$ ./sum 100
sum(100) = 5050
$ ./sum 1000
sum(1000) = 500500
$ ./sum 10000
sum(10000) = 50005000
$ ./sum 100000
sum(100000) = 5000050000
$ ./sum 1000000
Another exception: Stack overflow
n = 100000とn = 1000000で測定する。
$ measure ./sum 100000
sum(100000) = 5000050000
======================================
Process exited with status: 0
total time: 0.002213 [sec]
mem size: 4416 [KB]
code size: 1111 [KB]
$ measure ./sum 1000000
Another exception: Stack overflow
======================================
Process exited with status: 0
total time: 0.007845 [sec]
mem size: 10884 [KB]
code size: 1111 [KB]
OCamlはプログラム内部でスタックオーバーフローを検知し、例外を投げ、プロセスは正常に終了する。
プログラム:sum.rkt
#lang racket
(require iso-printf)
(define (sum n)
(if (< n 1)
0
(+ n (sum (- n 1)))))
(define (main args)
(with-handlers
([exn:fail? (λ (e) (eprintf "%s\n" (exn-message e)))]
[exn? (λ (e) (eprintf "Unexpected: %s\n" (exn-message e)))])
(cond
[(not (= (vector-length args) 1))
(error "Usage: sum <non-negative integer>")]
[else
(let ([n (string->number (vector-ref args 0))])
(cond
[(not (integer? n))
(error "The argument must be an integer.")]
[(< n 0)
(error "The argument must be a non-negative integer.")]
[else
(let ([result (sum n)])
(printf "sum(~a) = ~a\n" n result))]))])))
(main (current-command-line-arguments))
n = 10から始めて一桁ずつ上げていくとn = 100000000でもスタックオーバーフローしない。しかし桁数が増えると計算終了までかなり時間がかかる。
$ ./sum 10
sum 10 = 55
$ ./sum 100
sum 100 = 5050
$ ./sum 1000
sum 1000 = 500500
$ ./sum 10000
sum 10000 = 50005000
$ ./sum 100000
sum 100000 = 5000050000
$ ./sum 1000000
sum 1000000 = 500000500000
$ ./sum 10000000
sum 10000000 = 50000005000000
$ ./sum 100000000
sum 100000000 = 5000000050000000
n = 100000で測定する。
$ measure ./sum 100000
sum(100000) = 5000050000
======================================
Process exited with status: 0
total time: 0.470441 [sec]
mem size: 133644 [KB]
code size: 12442 [KB]
末尾再帰とは自分自身を末尾で呼び出す再帰関数のことである。通常の再帰においては関数呼び出しが行なわれるので呼び出す毎に新しいスタックフレームが確保される。しかし、末尾再帰においては実行中のスタックフレームにある局所変数は末尾で再帰関数を呼び出す時点以後に使われることがないのでこれを再利用することができる。この発想が公にされたのは1977年にGuy L. Steeleにより発表された論文[1]が最初とされている。
関数型言語のコンパイラは末尾再帰を検知すると局所変数を再利用するような最適化コードを生成する。GCCやClangにおいてもコンパイルオプション:-O2で末尾再帰最適化が行なわれる。本節では実験によりそれを確認する。
前節のsum(n)は最後の行で再帰関数を呼び出した後にn + sum(n - 1)という演算を行なっているので末尾再帰ではない。これを末尾再帰にするためには補助関数:sum_auxを使って次のように書き直す。
プログラム:sum2.cpp
#include <print>
#include <stdexcept>
#include <string>
long sum_aux(long n, long acc) {
if (n == 0)
return acc;
else
return sum_aux(n - 1, acc + n);
}
int main(int argc, char* argv[]) {
try {
if (argc != 2) {
throw std::invalid_argument(
"Usage: sum2 <non-negative integer>");
}
long n = std::stol(argv[1]);
if (n < 0l) {
throw std::invalid_argument(
"The argument must be a non-negative integer.");
}
long result = sum_aux(n, 0);
std::println("sum({}) = {}", n, result);
} catch (const std::logic_error& e) {
std::println(stderr, "Logic error: {}", e.what());
} catch (const std::runtime_error& e) {
std::println(stderr, "Runtime error: {}", e.what());
} catch (const std::exception& e) {
std::println(stderr, "Another std::exception: {}", e.what());
} catch (...) {
std::println(stderr, "Unknown exception.");
}
return 0;
}
このプログラムを普通にコンパイルすると末尾再帰最適化は行なわれない。
$ g++ -std=c++23 -o sum2_g sum2.cpp
どのようなコードが生成されたのかをsum_aux関数だけのアセンブラソースで見てみる。C++23とは無縁なので-std=c++23は必要ない。
$ g++ -S -masm=intel sum_aux.cpp
アセンブラソース:sum_aux.s
_Z7sum_auxll:
.LFB0:
.cfi_startproc
endbr64
push rbp
.cfi_def_cfa_offset 16
.cfi_offset 6, -16
mov rbp, rsp
.cfi_def_cfa_register 6
sub rsp, 16
mov QWORD PTR -8[rbp], rdi
mov QWORD PTR -16[rbp], rsi
cmp QWORD PTR -8[rbp], 0
jne .L2
mov rax, QWORD PTR -16[rbp]
jmp .L3
.L2:
mov rdx, QWORD PTR -16[rbp]
mov rax, QWORD PTR -8[rbp]
add rdx, rax
mov rax, QWORD PTR -8[rbp]
sub rax, 1
mov rsi, rdx
mov rdi, rax
call _Z7sum_auxll
nop
.L3:
leave
.cfi_def_cfa 7, 8
ret
次の行で自分自身を呼び出しているのがわかる。
call _Z7sum_auxll
このプログラムを-O2オプションでコンパイルすると末尾再帰最適化が行なわれる。
$ g++ -std=c++23 -O2 -o sum2_g sum2.cpp
どのようなコードが生成されたのかをsum_aux関数だけのアセンブラソースで見てみる。
$ g++ -S -O2 -masm=intel sum_aux.cpp
アセンブラソース:sum_aux.s
_Z7sum_auxll:
.LFB0:
.cfi_startproc
endbr64
mov rax, rsi
test rdi, rdi
je .L5
lea rdx, -1[rdi]
test dil, 1
je .L2
add rax, rdi
mov rdi, rdx
test rdx, rdx
je .L17
.p2align 4,,10
.p2align 3
.L2:
lea rax, -1[rax+rdi*2]
sub rdi, 2
jne .L2
.L5:
ret
.L17:
ret
関数呼び出しがループに代わっていることがわかる。
最適化オプション:-O2を付けてメイクする。
makefile:
CPPG = g++
CPPL = clang++
CPPFLAGS = -std=c++23 -O2
SRC = sum2.cpp
EXES = sum2_g sum2_l
all: $(EXES)
sum2_g: $(SRC)
$(CPPG) $(CPPFLAGS) -o $@ $(SRC)
sum2_l: $(SRC)
$(CPPL) $(CPPFLAGS) -o $@ $(SRC)
clean:
rm -f $(EXES)
n = 100000000で測定する。
$ measure ./sum2_g 100000000
sum(100000000) = 5000000050000000
======================================
Process exited with status: 0
total time: 0.064444 [sec]
mem size: 3692 [KB]
code size: 119 [KB]
$ measure ./sum2_l 100000000
sum(100000000) = 5000000050000000
======================================
Process exited with status: 0
total time: 0.000000 [sec]
mem size: 3680 [KB]
code size: 92 [KB]
スタックオーバーフローせず、メモリ消費量も少ない。
プログラム:sum2.ml
open Printf
let rec sum_aux n acc =
if n < 1 then
acc
else
sum_aux (n - 1) (acc + n)
let () =
try
if Array.length Sys.argv <> 2 then
raise (Invalid_argument
"Usage: sum2 <non-negative integer>")
else
let n = int_of_string Sys.argv.(1) in
if n < 0 then
raise (Invalid_argument
"The argument must be a non-negative integer.")
else
let result = sum_aux n 0 in
printf "sum(%d) = %d\n" n result
with
| Invalid_argument msg ->
eprintf "Logic error: %s\n" msg
| Failure msg ->
eprintf "Runtime error: %s\n" msg
| exn ->
eprintf "Another exception: %s\n" (Printexc.to_string exn)
n = 100000000で測定する。
$ ocamlopt -O2 -o sum2 sum2.ml
$ measure ./sum2 100000000
sum(100000000) = 5000000050000000
======================================
Process exited with status: 0
total time: 0.126208 [sec]
mem size: 2748 [KB]
code size: 1111 [KB]
スタックオーバーフローせず、メモリ消費量も少ない。
プログラム:sum2.rkt
#lang racket
(require iso-printf)
(define (sum_aux n acc)
(if (= n 0)
acc
(sum_aux (- n 1) (+ acc n))))
(define (main args)
(with-handlers
([exn:fail? (λ (e) (eprintf "%s\n" (exn-message e)))]
[exn? (λ (e) (eprintf "Unexpected: %s\n" (exn-message e)))])
(cond
[(not (= (vector-length args) 1))
(error "Usage: sum2 <non-negative integer>")]
[else
(let ([n (string->number (vector-ref args 0))])
(cond
[(not (integer? n))
(error "The argument must be an integer.")]
[(< n 0)
(error "The argument must be a non-negative integer.")]
[else
(let ([result (sum_aux n 0)])
(printf "sum(%d) = %d\n" n result))]))])))
(main (current-command-line-arguments))
n = 100000000で測定する。
$ raco exe sum2.rkt
$ measure ./sum2 100000000
sum(100000000) = 5000000050000000
======================================
Process exited with status: 0
total time: 0.679221 [sec]
mem size: 135140 [KB]
code size: 12442 [KB]
メモリ消費量が減り、実行速度も速くなっている。Racketはスタックオーバーフローが起こりにくい設計になっているが、メモリ消費量と実行速度で末尾再帰の優位性が確認できた。
[1] Steele, Guy Lewis, Debunking the "expensive procedure call" myth or, procedure call implementations considered harmful or, LAMBDA: The Ultimate GOTO, Proceedings of the 1977 annual conference on - ACM '77, 153–162, 1977
This document is licensed under the MIT License.
Copyright (C) 2024 Takeshi Umetani