3. 高階関数
著者:梅谷 武
各言語において高階関数の基本動作を比較する。
作成:2024-03-29
更新:2024-10-24
 プログラミング言語における高階関数(higher-order function)とは関数を引数に取るかもしくは関数を値として返す関数のことである。Racketを使って簡単な例を示す。
> (map add1 '(1 2 3))
'(2 3 4)
> (map + '(1 2 3) '(1 2 3))
'(2 4 6)
> (map + '(1 2 3) '(1 2 3) '(1 2 3))
'(3 6 9)
> (apply + '(1 2 3))
6
> (filter even? '(1 2 3))
'(2)
 よく用いられる畳み込み(fold)については注意が必要である。二変数関数:f(x,y)が与えられたとき、Racketでは
  • (foldr f 0 '(1 2 3)) := f(1, f(2, f(3, 0)))
  • (foldl f 0 '(1 2 3)) := f(3, f(2, f(1, 0)))
と定義される。これにより予想外の結果となることがある。
> (foldr + 0 '(1 2 3))
6
> (foldl + 0 '(1 2 3))
6
> (foldr cons '() '(1 2 3))
'(1 2 3)
> (foldl cons '() '(1 2 3))
'(3 2 1)
 同じことをOCamlでやってみる。OCamlの標準ライブラリにはapplyに対応するものはなく、map、map2は引数の個数が固定される。
# List.map (fun n -> n + 1) [1; 2; 3];;
- : int list = [2; 3; 4]
# List.map2 (+) [1; 2; 3] [1; 2; 3];;
- : int list = [2; 4; 6]
# List.filter (fun n -> (n mod 2 = 0)) [1; 2; 3];;
- : int list = [2]
 畳み込み(fold)の定義はRacketとは異なる。二変数関数:f(x,y)が与えられたとき、
  • List.fold_right f [1; 2; 3] 0 := f(1, f(2, f(3, 0)))
  • List.fold_left f 0 [1; 2; 3] := f(f(f(0, 1), 2), 3)
# List.fold_right (+) [1; 2; 3] 0;;
- : int = 6
# List.fold_left (+) 0 [1; 2; 3];;
- : int = 6
# List.fold_right (fun x acc -> x :: acc) [1; 2; 3] [];;
- : int list = [1; 2; 3]
# List.fold_left (fun acc x -> x :: acc) [] [1; 2; 3];;
- : int list = [3; 2; 1]

C++

 C++11からラムダ式が導入された。これはalgorithmの引数として用いられる。
 OCamlのリストについてはオンラインマニュアル:Listsを参照のこと。
 リストについての概説はThe Racket Guide>3 Built-in Datatypes>3.8 Pairs and Lists、詳細仕様はThe Racket Reference>4 Datatypes>4.10 Pairs and Listsにある。
 加法関数:x → x + xに対して、Ⅲ-2 整列で作成した基準データ:uniform1m.txtから作った整数リスト:lstをmapし、結果をファイル出力する。

C++

std::transformを用いる。入力データはstd::dequeに読み込むがmapした結果はstd::vectorに書き込み、それを出力ファイルに落とす。
 map.cpp:
#include <algorithm>
#include <chrono>
#include <deque>
#include <fstream>
#include <iostream>
#include <print>
#include <sstream>
#include <stdexcept>
#include <vector>
 
bool test(const std::string &infile, const std::string &outfile) {
  std::deque<long> lst;
  // read file
  auto start = std::chrono::high_resolution_clock::now();
  std::ifstream fin(infile);
  if (!fin) {
    throw std::runtime_error("Could not open file for reading.");
    return false;
  }
  int num;
  while (fin >> num)
    lst.push_back(num);
  fin.close();
  auto end = std::chrono::high_resolution_clock::now();
  std::chrono::duration<double> duration = end - start;
  std::println("length = {}", lst.size());
  std::println("read time:  {:9.6f} [sec]", duration.count());
  // map
  start = std::chrono::high_resolution_clock::now();
  std::vector<long> lst2(lst.size());
  std::transform(lst.begin(), lst.end(), lst2.begin(),
                 [](int x) { return x + x; });
  end = std::chrono::high_resolution_clock::now();
  duration = end - start;
  std::println("map time:   {:9.6f} [sec]", duration.count());
  // write file
  start = std::chrono::high_resolution_clock::now();
  std::ofstream fout(outfile);
  if (!fout) {
    throw std::runtime_error("Could not open file for writing.");
    return false;
  }
  std::ostringstream buf;
  for (long num : lst2)
    buf << num << "\n";
  fout << buf.str();
  fout.close();
  end = std::chrono::high_resolution_clock::now();
  duration = end - start;
  std::println("write time: {:9.6f} [sec]", duration.count());
  return true;
}
 
int main(int argc, char *argv[]) {
  try {
    if (argc != 3) {
      throw std::invalid_argument(
          "Usage: map <infile> <outfile>");
    }
    std::string infile(argv[1]);
    std::string outfile(argv[2]);
    if (!test(infile, outfile))
      std::println("An error has occurred.");
  } catch (const std::logic_error &e) {
    std::println(stderr, "Logic error: {}", e.what());
  } catch (const std::runtime_error &e) {
    std::println(stderr, "Runtime error: {}", e.what());
  } catch (const std::exception &e) {
    std::println(stderr, "Another std::exception: {}", e.what());
  } catch (...) {
    std::println(stderr, "Unknown exception.");
  }
  return 0;
}


 makefile:
CPPG = g++
CPPL = clang++
CPPFLAGS = -std=c++23 -O2
SRC = map.cpp
EXES = map_g map_l
 
all: $(EXES)
 
map_g: $(SRC)
	$(CPPG) $(CPPFLAGS) -o $@ $(SRC)
 
map_l: $(SRC)
	$(CPPL) $(CPPFLAGS) -o $@ $(SRC)
 
clean:
	rm -f $(EXES)


 基準データ:uniform1m.txtで測定する。
$ measure ./map_g uniform1m.txt doubled_g.txt
length = 1000000
read time:   0.094015 [sec]
map time:    0.005328 [sec]
write time:  0.104806 [sec]
======================================
Process exited with status: 0
total time:  0.207683 [sec]
mem  size:      39748 [KB]
code size:        134 [KB]
$ measure ./map_l uniform1m.txt doubled_l.txt
length = 1000000
read time:   0.102906 [sec]
map time:    0.006330 [sec]
write time:  0.120802 [sec]
======================================
Process exited with status: 0
total time:  0.234538 [sec]
mem  size:      39028 [KB]
code size:        107 [KB]
 このプログラムで
let doubled = map (fun x -> x + x) lst in
は標準の
let doubled = List.map (fun x -> x + x) lst in
を使いたいところであるが、List.mapは末尾再帰になっていないので1000000個の整数リストを処理するとスタックオーバーフローする。そこで末尾再帰最適化されるmapを作成した。
 map.ml:
open Printf
 
let map f lst =
  let rec aux acc = function
    | [] -> List.rev acc
    | x :: xs -> aux (f x :: acc) xs in
  aux [] lst
 
let read_data filename =
  let ichan = open_in filename in
  let rec read_lines acc =
    match input_line ichan with
    | line ->
      let num = int_of_string line in
      read_lines (num :: acc)
    | exception End_of_file ->
      close_in ichan;
      List.rev acc in
  read_lines []
 
let write_data filename lst =
  let ochan = open_out filename in
  List.iter (fun x -> fprintf ochan "%d\n" x) lst;
  close_out ochan
 
let test infile outfile =
  (* read file *)
  let start = Unix.gettimeofday () in
  let lst = read_data infile in
  let end_ = Unix.gettimeofday () in
  printf "length = %d\n" (List.length lst);
  printf "read time:  %9.6f [sec]\n" (end_ -. start);
  (* map *)
  let start = Unix.gettimeofday () in
  let lst2 = map (fun x -> x + x) lst in
  let end_ = Unix.gettimeofday () in
  printf "map time:   %9.6f [sec]\n" (end_ -. start);
  (* write file *)
  let start = Unix.gettimeofday () in
  write_data outfile lst2;
  let end_ = Unix.gettimeofday () in
  printf "write time: %9.6f [sec]\n" (end_ -. start)
 
let () =
  try
    if Array.length Sys.argv <> 3 then
      raise (Invalid_argument
        "Usage: map <infile> <outfile>")
    else
      let infile =  Sys.argv.(1) in
      let outfile =  Sys.argv.(2) in
      test infile outfile
  with
    | Invalid_argument msg ->
        eprintf "Logic error: %s\n" msg
    | Failure msg ->
        eprintf "Runtime error: %s\n" msg
    | exn ->
        eprintf "Another exception: %s\n" (Printexc.to_string exn)


 makefile:
OCAMLOPT = ocamlopt
OCAMLFLAGS = -O2
LIBPATH = -I +unix
LIBS = unix.cmxa
SRC = map.ml
EXES = map
 
all: $(EXES)
 
map: $(SRC)
	$(OCAMLOPT) $(OCAMLFLAGS) $(LIBPATH) -o $@ $(LIBS) $(SRC)
 
clean:
	rm -f $(EXES) *.o *.cmx *.cmi


 基準データ:uniform1m.txtで測定する。
$ measure ./map uniform1m.txt doubled.txt
length = 1000000
read time:   0.280672 [sec]
map time:    0.100565 [sec]
write time:  0.258753 [sec]
======================================
Process exited with status: 0
total time:  0.650027 [sec]
mem  size:      67740 [KB]
code size:       3289 [KB]
 自作mapの性能は思わしくない。
 標準のmapを用いる。
 map.rkt:
#lang racket
(require iso-printf)
 
(define (read-data filename)
  (call-with-input-file filename
    (lambda (port)
      (let loop ([lst '()])
        (let ([line (read-line port)])
          (if (eof-object? line)
              (reverse lst)
              (loop (cons (string->number line) lst))))))))
 
(define (write-data filename lst)
  (call-with-output-file filename #:exists 'replace
    (lambda (port)
      (for-each (lambda (num) (displayln num port))
                lst))))
 
(define (test infile outfile)
  ; read file
  (define start (current-inexact-milliseconds))
  (define lst (read-data infile))
  (define end (current-inexact-milliseconds))
  (printf "length = %d\n" (length lst))
  (printf "read time:  %9.6f [sec]\n" (/ (- end start) 1000.0))
  ; map
  (set! start (current-inexact-milliseconds))
  (define lst2 (map (λ (i) (+ i i)) lst))
  (set! end (current-inexact-milliseconds))
  (printf "map time:   %9.6f [sec]\n" (/ (- end start) 1000.0))
   ; write file
  (set! start (current-inexact-milliseconds))
  (write-data outfile lst2)
  (set! end (current-inexact-milliseconds))
  (printf "write time: %9.6f [sec]\n" (/ (- end start) 1000.0)))
 
(define (main args)
  (with-handlers
    ([exn:fail? (λ (e) (eprintf "%s\n" (exn-message e)))]
     [exn? (λ (e) (eprintf "Unexpected: %s\n" (exn-message e)))])
  (cond
    [(not (= (vector-length args) 2))
      (error "Usage: map <infile> <outfile>")]
    [else
      (let ([infile (vector-ref args 0)]
            [outfile (vector-ref args 1)])
        (test infile outfile))])))
 
(main (current-command-line-arguments))


 基準データ:uniform1m.txtで測定する。
$ raco exe map.rkt
$ measure ./map uniform1m.txt doubled.txt
length = 1000000
read time:   0.835495 [sec]
map time:    0.038531 [sec]
write time:  0.971377 [sec]
======================================
Process exited with status: 0
total time:  2.394380 [sec]
mem  size:     181380 [KB]
code size:      12446 [KB]
 基準データ:uniform1m.txtから偶数をフィルタリングする。

C++

std::copy_ifを用いる。入力データはstd::dequeに読み込むがmapした結果はstd::vectorにpush_backし、それを出力ファイルに落とす。
 filter.cpp:
#include <algorithm>
#include <chrono>
#include <deque>
#include <fstream>
#include <iostream>
#include <print>
#include <sstream>
#include <stdexcept>
#include <vector>
 
bool test(const std::string &infile, const std::string &outfile) {
  std::deque<long> lst;
  // read file
  auto start = std::chrono::high_resolution_clock::now();
  std::ifstream fin(infile);
  if (!fin) {
    throw std::runtime_error("Could not open file for reading.");
    return false;
  }
  int num;
  while (fin >> num)
    lst.push_back(num);
  fin.close();
  auto end = std::chrono::high_resolution_clock::now();
  std::chrono::duration duration = end - start;
  std::println("length = {}", lst.size());
  std::println("read time:  {:9.6f} [sec]", duration.count());
  // filter
  start = std::chrono::high_resolution_clock::now();
  std::vector<long> lst2;
  std::copy_if(lst.begin(), lst.end(), std::back_inserter(lst2),
               [](int x) { return (x & 1) == 0; });
  end = std::chrono::high_resolution_clock::now();
  duration = end - start;
  std::println("filt time:  {:9.6f} [sec]", duration.count());
  // write file
  start = std::chrono::high_resolution_clock::now();
  std::ofstream fout(outfile);
  if (!fout) {
    throw std::runtime_error("Could not open file for writing.");
    return false;
  }
  std::ostringstream buf;
  for (long num : lst2)
    buf << num << "\n";
  fout << buf.str();
  fout.close();
  end = std::chrono::high_resolution_clock::now();
  duration = end - start;
  std::println("write time: {:9.6f} [sec]", duration.count());
  return true;
}
 
int main(int argc, char *argv[]) {
  try {
    if (argc != 3) {
      throw std::invalid_argument(
          "Usage: filter <infile> <outfile>");
    }
    std::string infile(argv[1]);
    std::string outfile(argv[2]);
    if (!test(infile, outfile))
      std::println("An error has occurred.");
  } catch (const std::logic_error &e) {
    std::println(stderr, "Logic error: {}", e.what());
  } catch (const std::runtime_error &e) {
    std::println(stderr, "Runtime error: {}", e.what());
  } catch (const std::exception &e) {
    std::println(stderr, "Another std::exception: {}", e.what());
  } catch (...) {
    std::println(stderr, "Unknown exception.");
  }
  return 0;
}


 makefile:
CPPG = g++
CPPL = clang++
CPPFLAGS = -std=c++23 -O2
SRC = filter.cpp
EXES = filter_g filter_l
 
all: $(EXES)
 
filter_g: $(SRC)
	$(CPPG) $(CPPFLAGS) -o $@ $(SRC)
 
filter_l: $(SRC)
	$(CPPL) $(CPPFLAGS) -o $@ $(SRC)
 
clean:
	rm -f $(EXES)


 基準データ:uniform1m.txtで測定する。
$ measure ./filter_g uniform1m.txt even_g.txt
length = 1000000
read time:   0.094720 [sec]
filt time:   0.011476 [sec]
write time:  0.046589 [sec]
======================================
Process exited with status: 0
total time:  0.151554 [sec]
mem  size:      25508 [KB]
code size:        134 [KB]
$ measure ./filter_l uniform1m.txt even_l.txt
length = 1000000
read time:   0.093494 [sec]
filt time:   0.011769 [sec]
write time:  0.051909 [sec]
======================================
Process exited with status: 0
total time:  0.157288 [sec]
mem  size:      26008 [KB]
code size:        107 [KB]
 List.filterを用いる。
 filter.ml:
open Printf
 
let read_data filename =
  let ichan = open_in filename in
  let rec read_lines acc =
    match input_line ichan with
    | line ->
      let num = int_of_string line in
      read_lines (num :: acc)
    | exception End_of_file ->
      close_in ichan;
      List.rev acc in
  read_lines []
 
let write_data filename lst =
  let ochan = open_out filename in
  List.iter (fun x -> fprintf ochan "%d\n" x) lst;
  close_out ochan
 
let test infile outfile =
  (* read file *)
  let start = Unix.gettimeofday () in
  let lst = read_data infile in
  let end_ = Unix.gettimeofday () in
  printf "length = %d\n" (List.length lst);
  printf "read time:  %9.6f [sec]\n" (end_ -. start);
  (* filter *)
  let start = Unix.gettimeofday () in
  let lst2 = List.filter (fun x -> (x land 1) = 0) lst in
  let end_ = Unix.gettimeofday () in
  printf "filt time:  %9.6f [sec]\n" (end_ -. start);
  (* write file *)
  let start = Unix.gettimeofday () in
  write_data outfile lst2;
  let end_ = Unix.gettimeofday () in
  printf "write time: %9.6f [sec]\n" (end_ -. start)
 
let () =
  try
    if Array.length Sys.argv <> 3 then
      raise (Invalid_argument
        "Usage: filter <infile> <outfile>")
    else
      let infile =  Sys.argv.(1) in
      let outfile =  Sys.argv.(2) in
      test infile outfile
  with
    | Invalid_argument msg ->
        eprintf "Logic error: %s\n" msg
    | Failure msg ->
        eprintf "Runtime error: %s\n" msg
    | exn ->
        eprintf "Another exception: %s\n" (Printexc.to_string exn)


 makefile:
OCAMLOPT = ocamlopt
OCAMLFLAGS = -O2
LIBPATH = -I +unix
LIBS = unix.cmxa
SRC = filter.ml
EXES = filter
 
all: $(EXES)
 
filter: $(SRC)
	$(OCAMLOPT) $(OCAMLFLAGS) $(LIBPATH) -o $@ $(LIBS) $(SRC)
 
clean:
	rm -f $(EXES) *.o *.cmx *.cmi


 基準データ:uniform1m.txtで測定する。
$ measure ./filter uniform1m.txt even.txt
length = 1000000
read time:   0.306795 [sec]
filt time:   0.033516 [sec]
write time:  0.124537 [sec]
======================================
Process exited with status: 0
total time:  0.450744 [sec]
mem  size:      54616 [KB]
code size:       3289 [KB]
 標準のfilterを用いる。
 filter.rkt:
#lang racket
(require iso-printf)
 
(define (read-data filename)
  (call-with-input-file filename
    (lambda (port)
      (let loop ([lst '()])
        (let ([line (read-line port)])
          (if (eof-object? line)
              (reverse lst)
              (loop (cons (string->number line) lst))))))))
 
(define (write-data filename lst)
  (call-with-output-file filename #:exists 'replace
    (lambda (port)
      (for-each (lambda (num) (displayln num port))
                lst))))
 
(define (test infile outfile)
  ; read file
  (define start (current-inexact-milliseconds))
  (define lst (read-data infile))
  (define end (current-inexact-milliseconds))
  (printf "length = %d\n" (length lst))
  (printf "read time:  %9.6f [sec]\n" (/ (- end start) 1000.0))
  ; filter
  (set! start (current-inexact-milliseconds))
  (define lst2 (filter even? lst))
  (set! end (current-inexact-milliseconds))
  (printf "filt time:  %9.6f [sec]\n" (/ (- end start) 1000.0))
   ; write file
  (set! start (current-inexact-milliseconds))
  (write-data outfile lst2)
  (set! end (current-inexact-milliseconds))
  (printf "write time: %9.6f [sec]\n" (/ (- end start) 1000.0)))
 
(define (main args)
  (with-handlers
    ([exn:fail? (λ (e) (eprintf "%s\n" (exn-message e)))]
     [exn? (λ (e) (eprintf "Unexpected: %s\n" (exn-message e)))])
  (cond
    [(not (= (vector-length args) 2))
      (error "Usage: filter <infile> <outfile>")]
    [else
      (let ([infile (vector-ref args 0)]
            [outfile (vector-ref args 1)])
        (test infile outfile))])))
 
(main (current-command-line-arguments))


 基準データ:uniform1m.txtで測定する。
$ raco exe filter.rkt
$ measure ./filter uniform1m.txt even.txt
length = 1000000
read time:   0.762564 [sec]
filt time:   0.029869 [sec]
write time:  0.382791 [sec]
======================================
Process exited with status: 0
total time:  1.658260 [sec]
mem  size:     148312 [KB]
code size:      12445 [KB]
 基準データ:uniform1m.txtを加法により左畳み込み(foldl)と右畳み込み(foldr)をして総和を求めながら各々実行時間を測定し、結果が一致することを確認する。

C++

std::accumulateを用いる。
 fold.cpp:
#include <algorithm>
#include <chrono>
#include <deque>
#include <fstream>
#include <iostream>
#include <numeric>
#include <print>
#include <stdexcept>
 
bool test(const std::string &infile) {
  std::deque<long> lst;
  // read file
  auto start = std::chrono::high_resolution_clock::now();
  std::ifstream fin(infile);
  if (!fin) {
    throw std::runtime_error("Could not open file for reading.");
    return false;
  }
  int num;
  while (fin >> num)
    lst.push_back(num);
  fin.close();
  auto end = std::chrono::high_resolution_clock::now();
  std::chrono::duration duration = end - start;
  std::println("length = {}", lst.size());
  std::println("read time:  {:9.6f} [sec]", duration.count());
  // foldl
  start = std::chrono::high_resolution_clock::now();
  long suml = std::accumulate(
      lst.begin(), lst.end(), 0l,
      [](long x, long acc) { return (x + acc); });
  end = std::chrono::high_resolution_clock::now();
  duration = end - start;
  std::println("foldl time: {:9.6f} [sec]", duration.count());
  // foldr
  start = std::chrono::high_resolution_clock::now();
  long sumr = std::accumulate(
      lst.rbegin(), lst.rend(), 0l,
      [](long x, long acc) { return (x + acc); });
  end = std::chrono::high_resolution_clock::now();
  duration = end - start;
  std::println("foldr time: {:9.6f} [sec]", duration.count());
  if (suml == sumr)
    std::println("sum = {}", suml);
  else
    std::println("suml ≠ sumr", suml);
  return true;
}
 
int main(int argc, char *argv[]) {
  try {
    if (argc != 2) {
      throw std::invalid_argument(
          "Usage: fold <infile>");
    }
    std::string infile(argv[1]);
    if (!test(infile))
      std::println("An error has occurred.");
  } catch (const std::logic_error &e) {
    std::println(stderr, "Logic error: {}", e.what());
  } catch (const std::runtime_error &e) {
    std::println(stderr, "Runtime error: {}", e.what());
  } catch (const std::exception &e) {
    std::println(stderr, "Another std::exception: {}", e.what());
  } catch (...) {
    std::println(stderr, "Unknown exception.");
  }
  return 0;
}


 makefile:
CPPG = g++
CPPL = clang++
CPPFLAGS = -std=c++23 -O2
SRC = fold.cpp
EXES = fold_g fold_l
 
all: $(EXES)
 
fold_g: $(SRC)
	$(CPPG) $(CPPFLAGS) -o $@ $(SRC)
 
fold_l: $(SRC)
	$(CPPL) $(CPPFLAGS) -o $@ $(SRC)
 
clean:
	rm -f $(EXES)


 基準データ:uniform1m.txtで測定する。
r$ measure ./fold_g uniform1m.txt
length = 1000000
read time:   0.092722 [sec]
foldl time:  0.001860 [sec]
foldr time:  0.001603 [sec]
sum = 50012261692650
======================================
Process exited with status: 0
total time:  0.099696 [sec]
mem  size:      11904 [KB]
code size:        133 [KB]
$ measure ./fold_l uniform1m.txt
length = 1000000
read time:   0.092388 [sec]
foldl time:  0.001379 [sec]
foldr time:  0.002116 [sec]
sum = 50012261692650
======================================
Process exited with status: 0
total time:  0.098455 [sec]
mem  size:      11320 [KB]
code size:        106 [KB]
 右畳み込みは標準のList.fold_rightを使って
let sumr = List.fold_right (fun x acc -> x + acc) lst 0 in
と書きたいところであるが、List.fold_rightは末尾再帰になっていないのでデータが1000000個のときスタックオーバーフローする。そこで末尾再帰になっているList.fold_leftを用い、リストを反転して左畳み込みすることにより右畳み込みを実装する。
let sumr = List.fold_left (fun acc x -> acc + x) 0 (List.rev lst) in
 fold.ml:
open Printf
 
let read_data filename =
  let ichan = open_in filename in
  let rec read_lines acc =
    match input_line ichan with
    | line ->
      let num = int_of_string line in
      read_lines (num :: acc)
    | exception End_of_file ->
      close_in ichan;
      List.rev acc in
  read_lines []
 
let test infile =
  (* read file *)
  let start = Unix.gettimeofday () in
  let lst = read_data infile in
  let end_ = Unix.gettimeofday () in
  printf "length = %d\n" (List.length lst);
  printf "read time:  %9.6f [sec]\n" (end_ -. start);
  (* foldl *)
  let start = Unix.gettimeofday () in
  let suml = List.fold_left (fun acc x -> acc + x) 0 lst in
  let end_ = Unix.gettimeofday () in
  printf "foldl time: %9.6f [sec]\n" (end_ -. start);
  (* foldr *)
  let start = Unix.gettimeofday () in
  let sumr = List.fold_left (fun acc x -> acc + x) 0 (List.rev lst) in
  let end_ = Unix.gettimeofday () in
  printf "foldr time: %9.6f [sec]\n" (end_ -. start);
  if (suml = sumr) then
    printf "sum = %d\n" suml
  else
    printf "suml ≠ sumr\n"
 
let () =
  try
    if Array.length Sys.argv <> 2 then
      raise (Invalid_argument
        "Usage: fold <infile>")
    else
      let infile =  Sys.argv.(1) in
      test infile
  with
    | Invalid_argument msg ->
        eprintf "Logic error: %s\n" msg
    | Failure msg ->
        eprintf "Runtime error: %s\n" msg
    | exn ->
        eprintf "Another exception: %s\n" (Printexc.to_string exn)


 makefile:
OCAMLOPT = ocamlopt
OCAMLFLAGS = -O2
LIBPATH = -I +unix
LIBS = unix.cmxa
SRC = fold.ml
EXES = fold
 
all: $(EXES)
 
fold: $(SRC)
	$(OCAMLOPT) $(OCAMLFLAGS) $(LIBPATH) -o $@ $(LIBS) $(SRC)
 
clean:
	rm -f $(EXES) *.o *.cmx *.cmi


 基準データ:uniform1m.txtで測定する。
$ measure ./fold uniform1m.txt
length = 1000000
read time:   0.269756 [sec]
foldl time:  0.004815 [sec]
foldr time:  0.058270 [sec]
sum = 50012261692650
======================================
Process exited with status: 0
total time:  0.342521 [sec]
mem  size:      65332 [KB]
code size:       3289 [KB]
 右畳み込みには反転と左畳み込みの時間が含まれている。
 fold.rkt:
#lang racket
(require iso-printf)
 
(define (read-data filename)
  (call-with-input-file filename
    (lambda (port)
      (let loop ([lst '()])
        (let ([line (read-line port)])
          (if (eof-object? line)
              (reverse lst)
              (loop (cons (string->number line) lst))))))))
 
(define (test infile)
  ; read file
  (define start (current-inexact-milliseconds))
  (define lst (read-data infile))
  (define end (current-inexact-milliseconds))
  (printf "length = %d\n" (length lst))
  (printf "read time:  %9.6f [sec]\n" (/ (- end start) 1000.0))
  ; foldl
  (set! start (current-inexact-milliseconds))
  (define suml (foldl + 0 lst))
  (set! end (current-inexact-milliseconds))
  (printf "foldl time: %9.6f [sec]\n" (/ (- end start) 1000.0))
  ; foldr
  (set! start (current-inexact-milliseconds))
  (define sumr (foldr + 0 lst))
  (set! end (current-inexact-milliseconds))
  (printf "foldr time: %9.6f [sec]\n" (/ (- end start) 1000.0))
  (if (= suml sumr)
    (printf "sum = %d\n" suml)
    (printf "suml ≠ sumr\n")))
 
(define (main args)
  (with-handlers
    ([exn:fail? (λ (e) (eprintf "%s\n" (exn-message e)))]
     [exn? (λ (e) (eprintf "Unexpected: %s\n" (exn-message e)))])
  (cond
    [(not (= (vector-length args) 1))
      (error "Usage: fold <infile>")]
    [else
      (let ([infile (vector-ref args 0)])
        (test infile))])))
 
(main (current-command-line-arguments))


 基準データ:uniform1m.txtで測定する。
$ raco exe list8.rkt
$ measure ./fold uniform1m.txt
length = 1000000
read time:   0.776808 [sec]
foldl time:  0.008317 [sec]
foldr time:  0.086100 [sec]
sum = 50012261692650
======================================
Process exited with status: 0
total time:  1.372039 [sec]
mem  size:     182280 [KB]
code size:      12444 [KB]
 測定結果を表にまとめる。単位は秒[sec]。
測定項目 GCC Clang OCaml Racket
map 0.005 0.006 0.101 0.039
filter 0.011 0.012 0.034 0.030
foldl 0.002 0.001 0.005 0.008
foldr 0.002 0.002 0.058 0.086
This document is licensed under the MIT License.
Copyright (C) 2024 Takeshi Umetani