suzuzusu日記

(´・ω・`)

fork, select, poll, epoll, io_uringのecho server

fork, select, poll, epoll, io_uringなどを使用してそれぞれecho serverを実装したのでそれぞれの仕様などを忘備録としてまとめていきます。エラー処理とか雑な部分があると思いますがご容赦ください。

環境構築

macOS上での実行を前提にします。io_uringを使う場合にはlinux kernel 5.1以上でないと動かないので、multipassというubuntuVMを手軽に作成できるツールを使って実行環境を構築します。

multipassのインストール

brew install --cask multipass

20.10のubuntuを用意して、gccやio_uringのライブラリであるliburingを入れます。

multipass launch 20.10 -n primary
multipass shell # login
sudo apt update -y
sudo apt install gcc liburing -y
uname -a
> Linux primary 5.8.0-43-generic #49-Ubuntu SMP Fri Feb 5 03:01:28 UTC 2021 x86_64 x86_64 x86_64 GNU/Linux

fork

fork(2) - Linux manual page

forkを使用したサーバでは、1クライアントに1プロセスが対応することになり、メモリが枯渇して性能に余裕があるのにもかかわらずレスポンス性能が大きく落ちます。

有名なC10K問題を引き起こします。

C10K問題(英語: C10K problem)とは、Apache HTTP ServerなどのWebサーバソフトウェアとクライアントの通信において、クライアントが約1万台に達すると、Webサーバーのハードウェア性能に余裕があるにも関わらず、レスポンス性能が大きく下がる問題である。

C10K問題 - Wikipedia

#include <arpa/inet.h>
#include <signal.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <sys/socket.h>
#include <unistd.h>

#define BACKLOG_SIZE 5
#define BUF_SIZE 1024
#define INF_TIME -1
#define DISABLE -1

int listen_fd;

void int_handle(int n) {
  close(listen_fd);
  exit(EXIT_SUCCESS);
}

// wirte n byte
ssize_t write_n(int fd, char *ptr, size_t n) {
  ssize_t n_left = n, n_written;

  while (n_left > 0) {
    if ((n_written = write(fd, ptr, n_left)) <= 0) {
      return n_written;
    }
    n_left -= n_written;
    ptr += n_written;
  }

  return EXIT_SUCCESS;
}

int main(int argc, char **argv) {
  // Create listen socket
  if ((listen_fd = socket(AF_INET, SOCK_STREAM, 0)) < 0) {
    fprintf(stderr, "Error: socket\n");
    return EXIT_FAILURE;
  }

  // TCP port number
  int port = 8080;

  // Initialize server socket address
  struct sockaddr_in server_addr;
  memset(&server_addr, 0, sizeof(server_addr));
  server_addr.sin_family = AF_INET;
  server_addr.sin_addr.s_addr = INADDR_ANY;
  server_addr.sin_port = htons(port);

  // Bind socket to an address
  if (bind(listen_fd, (struct sockaddr *)&server_addr, sizeof(server_addr)) <
      0) {
    fprintf(stderr, "Error: bind\n");
    return EXIT_FAILURE;
  }

  // Listen
  if (listen(listen_fd, BACKLOG_SIZE) < 0) {
    fprintf(stderr, "Error: listen\n");
    return EXIT_FAILURE;
  }

  // Set INT signal handler
  signal(SIGINT, int_handle);

  fprintf(stderr, "listen on port %d\n", port);

  while (1) {
    // Check new connection
    struct sockaddr_in client_addr;
    socklen_t len_client = sizeof(client_addr);

    int conn_fd;
    if ((conn_fd = accept(listen_fd, (struct sockaddr *)&client_addr,
                          &len_client)) < 0) {
      fprintf(stderr, "Error: accept\n");
      return EXIT_FAILURE;
    }

    printf("Accept socket %d (%s : %hu)\n", conn_fd,
           inet_ntoa(client_addr.sin_addr), ntohs(client_addr.sin_port));

    pid_t pid = fork();

    if (pid < 0) {
      fprintf(stderr, "Error: fork\n");
      return EXIT_FAILURE;
    }

    if (pid == 0) {
      // child
      char buf[BUF_SIZE];
      close(listen_fd);
      while (1) {
        ssize_t n = read(conn_fd, buf, BUF_SIZE);

        if (n < 0) {
          fprintf(stderr, "Error: read from socket %d\n", conn_fd);
          close(conn_fd);
          exit(-1);
        } else if (n == 0) {  // connection closed by client
          printf("Close socket %d\n", conn_fd);
          close(conn_fd);
          exit(0);
        } else {
          printf("Read %zu bytes from socket %d\n", n, conn_fd);
          write_n(conn_fd, buf, n);
        }
      }
    } else {
      // parent
      close(conn_fd);
    }
  }

  close(listen_fd);
  return EXIT_SUCCESS;
}

I/O多重化

上記のようなマルチプロセスにより起きるC10K問題を解決する方法として、クライアント数にかかわらず1プロセスで処理するイベント駆動型プログラミングが提案されています。

イベント駆動型プログラミング(イベントくどうがたプログラミング、英: event-driven programming)は、コンピュータプログラムが起動すると共にイベントを待機し、発生したイベントに従って受動的に処理を行うプログラミングパラダイムのこと。

イベント駆動型プログラミング - Wikipedia

以降はI/O多重化によるecho serverを実装していきます。I/O多重化とは複数のI/Oデバイスの状態を同時に監視する仕組みです。例えば1プロセスで複数のクライアントを制御するためにこの技術は使用されます。

select

select(2) - Linux manual page

ディスクリプタを線形に探索する必要があるので計算量がO(n)かかります。管理できるディスクリプタ数に上限があるのが特徴です。

#include <arpa/inet.h>
#include <errno.h>
#include <signal.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <sys/select.h>
#include <sys/socket.h>
#include <unistd.h>

#define BACKLOG_SIZE 5
#define BUF_SIZE 1024
#define N_CLIENT 256
#define INF_TIME -1
#define DISABLE -1

int listen_fd;

void int_handle(int n) {
  close(listen_fd);
  exit(EXIT_SUCCESS);
}

// wirte n byte
ssize_t write_n(int fd, char *ptr, size_t n) {
  ssize_t n_left = n, n_written;

  while (n_left > 0) {
    if ((n_written = write(fd, ptr, n_left)) <= 0) {
      return n_written;
    }
    n_left -= n_written;
    ptr += n_written;
  }

  return EXIT_SUCCESS;
}

int main(int argc, char **argv) {
  char buf[BUF_SIZE];
  fd_set fds;
  FD_ZERO(&fds);
  int clients[N_CLIENT];
  for (int i = 0; i < N_CLIENT; i++) {
    clients[i] = DISABLE;
  }
  memset(&fds, 0, sizeof(fds));

  // Create listen socket
  if ((listen_fd = socket(AF_INET, SOCK_STREAM, 0)) < 0) {
    fprintf(stderr, "Error: socket\n");
    return EXIT_FAILURE;
  }

  // Set INT signal handler
  signal(SIGINT, int_handle);

  // TCP port number
  int port = 8080;

  // Initialize server socket address
  struct sockaddr_in server_addr;
  memset(&server_addr, 0, sizeof(server_addr));
  server_addr.sin_family = AF_INET;
  server_addr.sin_addr.s_addr = INADDR_ANY;
  server_addr.sin_port = htons(port);

  // Bind socket to an address
  if (bind(listen_fd, (struct sockaddr *)&server_addr, sizeof(server_addr)) <
      0) {
    fprintf(stderr, "Error: bind\n");
    return EXIT_FAILURE;
  }

  // Listen
  if (listen(listen_fd, BACKLOG_SIZE) < 0) {
    fprintf(stderr, "Error: listen\n");
    return EXIT_FAILURE;
  }

  fprintf(stderr, "listen on port %d\n", port);

  FD_SET(listen_fd, &fds);

  int max_fd = listen_fd;  // max fd
  int max_i = 0;           // max client into clients[] array

  while (1) {
    FD_ZERO(&fds);
    FD_SET(listen_fd, &fds);
    for (int i = 0; i < N_CLIENT; i++) {
      if (clients[i] != DISABLE) {
        FD_SET(clients[i], &fds);
      }
    }

    int res_select = select(max_fd + 1, &fds, NULL, NULL, NULL);

    if (res_select < 0) {
      fprintf(stderr, "Error: select");
      return EXIT_FAILURE;
    }

    // Check new connection
    if (FD_ISSET(listen_fd, &fds)) {
      struct sockaddr_in client_addr;
      socklen_t len_client = sizeof(client_addr);

      int connfd;
      if ((connfd = accept(listen_fd, (struct sockaddr *)&client_addr,
                           &len_client)) < 0) {
        fprintf(stderr, "Error: accept\n");
        return EXIT_FAILURE;
      }

      printf("Accept socket %d (%s : %hu)\n", connfd,
             inet_ntoa(client_addr.sin_addr), ntohs(client_addr.sin_port));

      // Save client socket into clients array
      int i;
      for (i = 0; i < N_CLIENT; i++) {
        if (clients[i] == DISABLE) {
          clients[i] = connfd;
          break;
        }
      }

      // No enough space in clients array
      if (i == N_CLIENT) {
        fprintf(stderr, "Error: too many clients\n");
        close(connfd);
      }

      if (i > max_i) {
        max_i = i;
      }
      if (connfd > max_fd) {
        max_fd = connfd;
      }
    }

    // Check all clients to read data
    for (int i = 0; i <= max_i; i++) {
      int sock_fd;
      if ((sock_fd = clients[i]) == DISABLE) {
        continue;
      }

      // If the client is readable or errors occur
      ssize_t n = read(sock_fd, buf, BUF_SIZE);
      if (n < 0) {
        fprintf(stderr, "Error: read from socket %d\n", sock_fd);
        close(sock_fd);
        clients[i] = DISABLE;
      } else if (n == 0) {  // connection closed by client
        printf("Close socket %d\n", sock_fd);
        close(sock_fd);
        clients[i] = DISABLE;
      } else {
        printf("Read %zu bytes from socket %d\n", n, sock_fd);
        write_n(sock_fd, buf, n);
        write_n(1, buf, n);
      }
    }
  }

  close(listen_fd);
  return EXIT_SUCCESS;
}

poll

poll(2) - Linux manual page

selectとほとんど同じ機能であり、計算量も同じO(n)だけかかります。ただし、管理するディスクリプタの制限がない点で上位互換と捉えることができます。

#include <arpa/inet.h>
#include <errno.h>
#include <poll.h>
#include <signal.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <sys/socket.h>
#include <unistd.h>

#define BACKLOG_SIZE 5
#define BUF_SIZE 1024
#define N_CLIENT 256
#define INF_TIME -1
#define DISABLE -1

int listen_fd;

void int_handle(int n) {
  close(listen_fd);
  exit(EXIT_SUCCESS);
}

// wirte n byte
ssize_t write_n(int fd, char *ptr, size_t n) {
  ssize_t n_left = n, n_written;

  while (n_left > 0) {
    if ((n_written = write(fd, ptr, n_left)) <= 0) {
      return n_written;
    }
    n_left -= n_written;
    ptr += n_written;
  }

  return EXIT_SUCCESS;
}

int main(int argc, char **argv) {
  char buf[BUF_SIZE];
  struct pollfd clients[N_CLIENT];

  // Create listen socket
  if ((listen_fd = socket(AF_INET, SOCK_STREAM, 0)) < 0) {
    fprintf(stderr, "Error: socket\n");
    return EXIT_FAILURE;
  }

  // TCP port number
  int port = 8080;

  // Initialize server socket address
  struct sockaddr_in server_addr;
  memset(&server_addr, 0, sizeof(server_addr));
  server_addr.sin_family = AF_INET;
  server_addr.sin_addr.s_addr = INADDR_ANY;
  server_addr.sin_port = htons(port);

  // Bind socket to an address
  if (bind(listen_fd, (struct sockaddr *)&server_addr, sizeof(server_addr)) <
      0) {
    fprintf(stderr, "Error: bind\n");
    return EXIT_FAILURE;
  }

  // Listen
  if (listen(listen_fd, BACKLOG_SIZE) < 0) {
    fprintf(stderr, "Error: listen\n");
    return EXIT_FAILURE;
  }

  // Set INT signal handler
  signal(SIGINT, int_handle);

  fprintf(stderr, "listen on port %d\n", port);

  clients[0].fd = listen_fd;
  clients[0].events = POLLIN;

  for (int i = 1; i < N_CLIENT; i++) {
    clients[i].fd = DISABLE;
  }
  int max_i = 0;  // max index into clients[] array

  while (1) {
    int n_ready = poll(clients, max_i + 1, INF_TIME);

    // Time out
    if (n_ready == 0) {
      continue;
    }

    // Error poll
    if (n_ready < 0) {
      fprintf(stderr, "Error: poll %d\n", errno);
      return errno;
    }

    // Check new connection
    if (clients[0].revents & POLLIN) {
      struct sockaddr_in client_addr;
      socklen_t len_client = sizeof(client_addr);

      int connfd;
      if ((connfd = accept(listen_fd, (struct sockaddr *)&client_addr,
                           &len_client)) < 0) {
        fprintf(stderr, "Error: accept\n");
        return EXIT_FAILURE;
      }

      printf("Accept socket %d (%s : %hu)\n", connfd,
             inet_ntoa(client_addr.sin_addr), ntohs(client_addr.sin_port));

      // Save client socket into clients array
      int i;
      for (i = 0; i < N_CLIENT; i++) {
        if (clients[i].fd == DISABLE) {
          clients[i].fd = connfd;
          break;
        }
      }

      // No enough space in clients array
      if (i == N_CLIENT) {
        fprintf(stderr, "Error: too many clients\n");
        close(connfd);
      }

      clients[i].events = POLLIN;

      if (i > max_i) {
        max_i = i;
      }
    }

    // Check all clients to read data
    for (int i = 1; i <= max_i; i++) {
      int sock_fd;
      if ((sock_fd = clients[i].fd) == DISABLE) {
        continue;
      }

      // If the client is readable or errors occur
      if (clients[i].revents & (POLLIN | POLLERR)) {
        ssize_t n = read(sock_fd, buf, BUF_SIZE);

        if (n < 0) {
          fprintf(stderr, "Error: read from socket %d\n", sock_fd);
          close(sock_fd);
          clients[i].fd = DISABLE;
        } else if (n == 0) {  // connection closed by client
          printf("Close socket %d\n", sock_fd);
          close(sock_fd);
          clients[i].fd = DISABLE;
        } else {
          printf("Read %zu bytes from socket %d\n", n, sock_fd);
          write_n(sock_fd, buf, n);
          write_n(1, buf, n);
        }
      }
    }
  }

  close(listen_fd);
  return EXIT_SUCCESS;
}

epoll

epoll(7) - Linux manual page

linux kernel 2.6以降であれば使用可能なAPIです。

ディスクリプタの状態をカーネル空間で監視しているため、直接変更のあるディスクリプタを返してくれるので計算量がO(1)となり高速に処理できます。

ただし、ディスクリプタが1の場合はepollにオーバーヘッドが存在して、pollの方が早い場合もあるので注意が必要です。

stackoverflow.com

https://monkey.org/~provos/libevent/libevent-benchmark2.jpg

#include <arpa/inet.h>
#include <errno.h>
#include <signal.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <sys/epoll.h>
#include <sys/socket.h>
#include <unistd.h>

#define BACKLOG_SIZE 5
#define BUF_SIZE 1024
#define N_CLIENT 256
#define INF_TIME -1
#define DISABLE -1

int listen_fd;

void int_handle(int n) {
  close(listen_fd);
  exit(EXIT_SUCCESS);
}

// wirte n byte
ssize_t write_n(int fd, char *ptr, size_t n) {
  ssize_t n_left = n, n_written;

  while (n_left > 0) {
    if ((n_written = write(fd, ptr, n_left)) <= 0) {
      return n_written;
    }
    n_left -= n_written;
    ptr += n_written;
  }

  return EXIT_SUCCESS;
}

int main(int argc, char **argv) {
  char buf[BUF_SIZE];

  // Create listen socket
  if ((listen_fd = socket(AF_INET, SOCK_STREAM, 0)) < 0) {
    fprintf(stderr, "Error: socket\n");
    return EXIT_FAILURE;
  }

  // TCP port number
  int port = 8080;

  // Initialize server socket address
  struct sockaddr_in server_addr;
  memset(&server_addr, 0, sizeof(server_addr));
  server_addr.sin_family = AF_INET;
  server_addr.sin_addr.s_addr = INADDR_ANY;
  server_addr.sin_port = htons(port);

  // Bind socket to an address
  if (bind(listen_fd, (struct sockaddr *)&server_addr, sizeof(server_addr)) <
      0) {
    fprintf(stderr, "Error: bind\n");
    return EXIT_FAILURE;
  }

  // Listen
  if (listen(listen_fd, BACKLOG_SIZE) < 0) {
    fprintf(stderr, "Error: listen\n");
    return EXIT_FAILURE;
  }

  // Set INT signal handler
  signal(SIGINT, int_handle);

  fprintf(stderr, "listen on port %d\n", port);

  // Create epoll
  int epfd = epoll_create1(0);
  if (epfd < 0) {
    fprintf(stderr, "Error: epoll create\n");
    close(listen_fd);
    return EXIT_FAILURE;
  }

  struct epoll_event listen_ev;
  memset(&listen_ev, 0, sizeof(listen_ev));
  listen_ev.events = EPOLLIN;
  listen_ev.data.fd = listen_fd;
  if (epoll_ctl(epfd, EPOLL_CTL_ADD, listen_fd, &listen_ev) < 0) {
    fprintf(stderr, "Error: epoll ctl add listen\n");
    close(listen_fd);
    return EXIT_FAILURE;
  }

  struct epoll_event evs[N_CLIENT];
  while (1) {
    // Wait epoll listener
    int n_fds = epoll_wait(epfd, evs, N_CLIENT, -1);

    // Error epoll
    if (n_fds < 0) {
      fprintf(stderr, "Error: epoll wait\n");
      close(listen_fd);
      return EXIT_FAILURE;
    }

    for (int i = 0; i < n_fds; i++) {
      if (evs[i].data.fd == listen_fd) {  // Add epoll listener
        struct sockaddr_in client_addr;
        socklen_t len_client = sizeof(client_addr);

        int conn_fd;
        if ((conn_fd = accept(listen_fd, (struct sockaddr *)&client_addr,
                              &len_client)) < 0) {
          fprintf(stderr, "Error: accept\n");
          return EXIT_FAILURE;
        }

        printf("Accept socket %d (%s : %hu)\n", conn_fd,
               inet_ntoa(client_addr.sin_addr), ntohs(client_addr.sin_port));

        struct epoll_event conn_ev;
        memset(&conn_ev, 0, sizeof(listen_ev));
        conn_ev.events = EPOLLIN;
        conn_ev.data.fd = conn_fd;
        if (epoll_ctl(epfd, EPOLL_CTL_ADD, conn_fd, &conn_ev) < 0) {
          fprintf(stderr, "Error: epoll ctl add listen\n");
          close(listen_fd);
          return EXIT_FAILURE;
        }
      } else if (evs[i].events & EPOLLIN) {  // Read data from client
        int sock_fd = evs[i].data.fd;
        ssize_t n = read(sock_fd, buf, BUF_SIZE);

        if (n < 0) {
          fprintf(stderr, "Error: read from socket %d\n", sock_fd);
          close(sock_fd);
        } else if (n == 0) {  // connection closed by client
          printf("Close socket %d\n", sock_fd);
          struct epoll_event sock_ev;
          memset(&sock_ev, 0, sizeof(listen_ev));
          sock_ev.events = EPOLLIN;
          sock_ev.data.fd = sock_fd;
          if (epoll_ctl(epfd, EPOLL_CTL_DEL, sock_fd, &sock_ev) < 0) {
            fprintf(stderr, "Error: epoll ctl dell\n");
            close(listen_fd);
            return EXIT_FAILURE;
          }
          close(sock_fd);
        } else {
          printf("Read %zu bytes from socket %d\n", n, sock_fd);
          write_n(sock_fd, buf, n);
          write_n(1, buf, n);
        }
      }
    }
  }

  close(listen_fd);
  return EXIT_SUCCESS;
}

io_uring

https://kernel.dk/io_uring.pdf

linux kernel 5.1以降であれば使用可能な新しい非同期I/O APIです。submission queue (SQ)とcompletion queue (CQ)の二つのring bufferを操作することで処理をします。例えばソケットからメッセージを受け取るrecvmsg を実行するためには IORING_OP_RECVMSG opcodeのsubmission queue entry (SQE)をSQに投入することで処理が非同期で実行されます。またそれらの処理の完了通知はCQからcompletion queue entry (CQE)を受け取ることで実現できます。

今回はliburing というio_uringのシンプルなインターフェイスを提供しているライブラリを使用して実装します。

#include <arpa/inet.h>
#include <errno.h>
#include <liburing.h>
#include <signal.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <sys/socket.h>
#include <unistd.h>

#define BACKLOG_SIZE 5
#define BUF_SIZE 1024
#define N_CLIENT 256
#define N_ENTRY 2048
#define GID 1

int listen_fd;

enum {
  ACCEPT,
  READ,
  WRITE,
};

typedef struct UserData {
  __u32 fd;
  __u16 type;
} UserData;

void int_handle(int n) {
  close(listen_fd);
  exit(EXIT_SUCCESS);
}

// wirte n byte
ssize_t write_n(int fd, char *ptr, size_t n) {
  ssize_t n_left = n, n_written;

  while (n_left > 0) {
    if ((n_written = write(fd, ptr, n_left)) <= 0) {
      return n_written;
    }
    n_left -= n_written;
    ptr += n_written;
  }

  return EXIT_SUCCESS;
}

int main(int argc, char **argv) {
  char buf[BUF_SIZE] = {0};

  // Create listen socket
  if ((listen_fd = socket(AF_INET, SOCK_STREAM, 0)) < 0) {
    fprintf(stderr, "Error: socket\n");
    return EXIT_FAILURE;
  }

  // TCP port number
  int port = 8080;

  // Initialize server socket address
  struct sockaddr_in server_addr, client_addr;
  socklen_t client_len = sizeof(client_addr);
  memset(&server_addr, 0, sizeof(server_addr));
  server_addr.sin_family = AF_INET;
  server_addr.sin_addr.s_addr = INADDR_ANY;
  server_addr.sin_port = htons(port);

  // Bind socket to an address
  if (bind(listen_fd, (struct sockaddr *)&server_addr, sizeof(server_addr)) <
      0) {
    fprintf(stderr, "Error: bind\n");
    return EXIT_FAILURE;
  }

  // Listen
  if (listen(listen_fd, BACKLOG_SIZE) < 0) {
    fprintf(stderr, "Error: listen\n");
    return EXIT_FAILURE;
  }

  // Set INT signal handler
  signal(SIGINT, int_handle);

  fprintf(stderr, "listen on port %d\n", port);

  // Initialize io_uring
  struct io_uring ring;
  struct io_uring_sqe *sqe;
  struct io_uring_cqe *cqe;
  int init_ret = io_uring_queue_init(N_ENTRY, &ring, 0);
  if (init_ret < 0) {
    fprintf(stderr, "Error: init io_uring queue %d\n", init_ret);
    close(listen_fd);
    return EXIT_FAILURE;
  }

  // Setup first accept
  sqe = io_uring_get_sqe(&ring);
  io_uring_prep_accept(sqe, listen_fd, (struct sockaddr *)&client_addr,
                       &client_len, 0);
  io_uring_sqe_set_flags(sqe, 0);
  UserData conn_info = {
      .fd = listen_fd,
      .type = ACCEPT,
  };
  memcpy(&sqe->user_data, &conn_info, sizeof(conn_info));

  while (1) {
    io_uring_submit(&ring);
    io_uring_wait_cqe(&ring, &cqe);
    struct UserData conn_info;
    memcpy(&conn_info, &cqe->user_data, sizeof(conn_info));
    int type = conn_info.type;

    if (cqe->res == -ENOBUFS) {
      fprintf(stderr, "Error: no buffer %d\n", cqe->res);
      close(listen_fd);
      return EXIT_FAILURE;
    } else if (type == ACCEPT) {
      int conn_fd = cqe->res;
      printf("Accept socket %d \n", conn_fd);
      if (conn_fd >= 0) {  // no error
        // Read from client
        sqe = io_uring_get_sqe(&ring);
        io_uring_prep_recv(sqe, conn_fd, buf, BUF_SIZE, 0);

        UserData read_info = {
            .fd = conn_fd,
            .type = READ,
        };
        memcpy(&sqe->user_data, &read_info, sizeof(read_info));
      }

      // Add new client
      sqe = io_uring_get_sqe(&ring);
      io_uring_prep_accept(sqe, listen_fd, (struct sockaddr *)&client_addr,
                           &client_len, 0);
      io_uring_sqe_set_flags(sqe, 0);
      UserData conn_info = {
          .fd = listen_fd,
          .type = ACCEPT,
      };
      memcpy(&sqe->user_data, &conn_info, sizeof(conn_info));
    } else if (type == READ) {
      int n_byte = cqe->res;
      if (cqe->res <= 0) {  // connection closed by client
        printf("Close socket %d\n", conn_info.fd);
        close(conn_info.fd);
      } else {
        // Add Write
        printf("Read %d bytes from socket %d\n", n_byte, conn_info.fd);
        sqe = io_uring_get_sqe(&ring);
        io_uring_prep_send(sqe, conn_info.fd, buf, n_byte, 0);
        write_n(1, buf, n_byte);  // output stdout
        io_uring_sqe_set_flags(sqe, 0);
        UserData write_info = {
            .fd = conn_info.fd,
            .type = WRITE,
        };
        memcpy(&sqe->user_data, &write_info, sizeof(write_info));
      }
    } else if (type == WRITE) {
      // Add read
      sqe = io_uring_get_sqe(&ring);
      io_uring_prep_recv(sqe, conn_info.fd, buf, BUF_SIZE, 0);

      UserData read_info = {
          .fd = conn_info.fd,
          .type = READ,
      };
      memcpy(&sqe->user_data, &read_info, sizeof(read_info));
    }

    io_uring_cqe_seen(&ring, cqe);
  }

  close(listen_fd);
  return EXIT_SUCCESS;
}

repository

実装したecho serverは、以下のrepositoryに置きました

github.com

参考

Navigable Small Worldによる近似最近傍探索

Small World Networkのグラフ特性を利用したNavigable Small World(NSW)というグラフベースの近似最近傍探索をjuliaで実装します

Navigable Small Worldとは?

f:id:suzuzusu:20201214014102p:plain

上記の画像のようにSmall World Networkの特性を持つグラフベースの検索インデックスからqueryに対して近傍のノードを返すアルゴリズムです。 これを階層的に拡張したHierarchical Navigable Small World(HNSW)は非常に性能が良い近似最近傍探索アルゴリズムとして知られています。

実装

まず、n次元のデータ data と隣接ノード friend を持つ Node 構造体を作ります。

using Random
using LinearAlgebra
using DataStructures
using Base

mutable struct Node
    data
    friend::Set{Node}
end

function show(io::IO, n::Node)
    friend_str = join(map((x) -> string(x.data), collect(n.friend)), ", ")
    println(io, "data: {", n.data, "}, friend: {", friend_str, "}")
end

グラフ上のノードにおいて、queryに対してk近傍のデータを検索する knn_search を実装します。

function knn_search(nodes::Vector{Node}, q, m::Int, k::Int)
    visited_set = Set()
    canditates = PriorityQueue()
    result = PriorityQueue()
    for _ in 1:m
        if length(visited_set) == length(nodes)
            break
        end
        tmp_result = Vector{Node}()
        while true
            ri = rand(1:length(nodes), 1)[1]
            node = nodes[ri]
            if !in(node, visited_set)
                push!(visited_set, node)
                enqueue!(canditates, node=>norm(node.data .- q))
                push!(tmp_result, node)
                break
            end
        end
        while true
            if length(canditates) == 0
                break
            end
            c = dequeue!(canditates)
            c_d = norm(c.data .- q)
            result_collect = collect(result)
            if length(result_collect) >= k
                n = result_collect[k]
                if c_d >= n[2]
                    break
                end
            end
            for node in c.friend
                if !in(node, visited_set)
                    push!(visited_set, node)
                    enqueue!(canditates, node=>norm(node.data .- q))
                    push!(tmp_result, node)
                end
            end
        end
        for node in tmp_result
            enqueue!(result, node=>norm(node.data .- q))
        end
    end
    result_collect = collect(result)[1:k]
    result_v = Vector()
    for r in result_collect
        push!(result_v, r[1])
    end
    return result_v
end

新しいノードの追加は knn_search を使ってk近傍のノードを接続します。複数のノード群からnswを構築する nsw_build も作成しておきます。 初期の knn_search によって長距離のリンクが作られることがポイントです。

function nearest_neighbor_insert(nodes, new_node, f, w)
    neighbors = knn_search(nodes, new_node.data, w, f)
    for node in neighbors
        push!(node.friend, new_node)
        push!(new_node.friend, node)
    end
end

function nsw_build(nodes, f, w)
    for new_node in nodes
        nearest_neighbor_insert(nodes, new_node, f, w)
    end
end

queryに対する最近傍探索は隣接ノードから最も近いノードをgreedyに探索していきます。

function greedy_searh(nodes, q)
    ri = rand(1:length(nodes), 1)[1]
    near_node = nodes[ri]
    min_d = norm(near_node.data .- q)
    while true
        break_flg = true
        for node in near_node.friend 
            d = norm(node.data - q)
            if d < min_d
                min_d = d
                near_node = node
                break_flg = false
            end
        end
        if break_flg
            break
        end
    end
    return near_node
end

以上で実装は終わりです。 試しに以下のように2次元の範囲[0, 1)の乱数データ1000個に対してテストしてみます。

Random.seed!(1234)
n = 1000 # number of node
dim = 2 # dimension

nodes = [Node(rand(dim), Set()) for i in 1:1000]
nsw_build(nodes, 10, 10)

q = rand(dim)
println("query: ", q)
res = greedy_searh(nodes, q)
println("result: ", res.data)
l2 = norm(res.data .- q)
println("l2 distance: ", l2)
query: [0.10217918147914129, 0.6093148458481783]
result: [0.09006444691972337, 0.6142456375838046]
l2 distance: 0.013079736258246004

queryに対して近傍のデータが検索できていることが分かります。

gist

こちらにコードをおいておきます

gist.github.com

参考

スチューデントのt分布の最尤推定

スチューデントのt分布は外れ値のあるデータに対してもロバストな推定を可能とする分布として知られている。忘備録としてこの記事では、そのt分布の最尤推定をする。

スチューデントのt分布

一次元のt分布は、パラメータ \mu, \lambda, \nuを用いて次のような確率密度関数で表すことができる。

f:id:suzuzusu:20200803061703p:plain

EMアルゴリズム

最尤推定には、EMアルゴリズムを用いる。 Xをデータ、 Zを潜在変数、 \Thetaをパラメータとする。 次式で尤度関数と対数尤度関数を示す。

f:id:suzuzusu:20200803062219p:plain

f:id:suzuzusu:20200803063006p:plain

Eステップ

 Zの事後分布を計算する。

f:id:suzuzusu:20200804072030p:plain

 Zの事後分布による \eta, \log \etaの期待値は以下のようになる。

f:id:suzuzusu:20200804073114p:plain

f:id:suzuzusu:20200804073050p:plain

Mステップ

対数尤度関数 Q(\Theta, \Theta_0)の期待値を最大化するように、それぞれのパラメータを更新する。

f:id:suzuzusu:20200804073748p:plain

f:id:suzuzusu:20200804074107p:plain

f:id:suzuzusu:20200804074343p:plain

f:id:suzuzusu:20200804074804p:plain

ただし、パラメータ \nuに関しては、上の式のように解析解が得られないのでニュートン法による数値解析によってパラメータを更新する。

f:id:suzuzusu:20200804075200p:plain

f:id:suzuzusu:20200804075429p:plain

このようにEステップ、Mステップを繰り返すことによって尤度関数を最大化させて最尤推定をする。

実装

juliaを用いて実装をする。外れ値による影響を正規分布と比較するために、 \mathcal{N}(0.0, 1.0)から10のデータのサンプルに \mathcal{N}(100.0, 1.0)から生起されるデータを1つ加えたデータセットでそれぞれを最尤推定して比較する。

using Plots
using StatsPlots
using Random
using Distributions
using SpecialFunctions
using Statistics
using StatsBase
using Optim
using ReverseDiff: gradient

Random.seed!(1234)

function e_eta(x, nu, lambda_, mu)
    return (nu + 1)/(nu + lambda_ * (x - mu)^2)
end

function e_etas(X, nu, lambda_, mu)
    return [ e_eta(x, nu, lambda_, mu) for x in X]
end

function e_log_eta(x, nu, lambda_, mu)
    return digamma((nu + 1.0)/2.0) - log((nu + lambda_ * (x - mu)^2)/2.0)
end

function e_log_etas(X, nu, lambda_, mu)
    return [ e_log_eta(x, nu, lambda_, mu) for x in X]
end

function fit_t(X)
    n = length(X)
    
    # init
    mu = median(X)
    lambda_ = iqr(X)/2.0
    nu = 1.0
   
    for i in 1:1000
        # E step
        e_etas_ = e_etas(X, nu, lambda_, mu)
        e_log_etas_ = e_log_etas(X, nu[1], lambda_, mu)
        
        
        # M step
        # mu
        mu = (X' * e_etas_) / sum(e_etas_)
        # lambda_
        lambda_ = 1.0 / ( (((X .- mu).^2)' * e_etas_) / n)
        # nu
        function f(nu)
            return digamma(nu[1]/2.0) - log(nu[1]/2.0) - (1.0 + sum(e_log_etas_)/n - sum(e_etas_)/n)
        end
        for j in 1:1000
            tmp = nu - f([nu])/gradient(f, [nu])[1]
            if !isnan(tmp)
                nu = max(tmp, 1e-6)
                if abs(f([nu])) < 1e-6
                    break
                end
            else
                break
            end
        end
    end
    
    return mu, lambda_, nu
end

d = Normal(0.0, 1.0)
noise_d = Normal(100.0, 1.0)
data = rand(d, 10)
outlier = rand(noise_d, 1)
X = vcat(data, outlier)
mu, lambda_, nu = fit_t(X)

scatter(data, zeros(length(data)), label="data")
scatter!(outlier, zeros(length(outlier)), label="outlier")
plot!(fit_mle(Normal, X), label="Normal")

Xs = Array(range(-100, 120, step=0.1))
function t_pdf(x, mu, lamda_, nu)
    return gamma((nu+1)/2)./gamma(nu/2).*sqrt(lambda_/(pi*nu)).*(1 .+ lambda_*(x .- mu).^2/nu).^(-(nu+1)/2)
end
Ys = t_pdf(Xs, mu, lambda_, nu)
plot!(Xs, Ys.*0.035, label="Student t Distribution", title="MLE of the Student t distribution using EM Algorithm")

f:id:suzuzusu:20200804220738p:plain

上記の結果のように正規分布は、外れ値に引っ張られた推定をしていることが分かる。一方、t分布は外れ値に対してロバストに推定可能であることが分かる。

notebook

Jupyter Notebook Viewer

参考

TerraformでFizzBuzz

ネタです。

Terraformとは

f:id:suzuzusu:20200610061323p:plain

AWSGCP、Azureなどのクラウドやサービスのプロビジョニングと管理を行うInfrastructure as Code(IaC)ツールです。 今回はこのTerraformを使ってFizzBuzzを書いていきます。

Terraformの環境構築

0.13.0-beta10.12.1 の2つの環境を用意します。理由は後述します。

wget https://releases.hashicorp.com/terraform/0.13.0-beta1/terraform_0.13.0-beta1_linux_amd64.zip
unzip terraform_0.13.0-beta1_linux_amd64.zip
mv terraform terraform_0.13

wget https://releases.hashicorp.com/terraform/0.12.1/terraform_0.12.1_linux_amd64.zip   
unzip terraform_0.12.1_linux_amd64.zip   
mv terraform terraform_0.12.1

rangeを使った方法(IaC感がない)

TerraformにはいくつかのBuilt-in Functionsがあり、中にはCollectionの関数でrangeがあります。 さらには三項演算子もあり、for文もあるので以下のように簡単に書くことができます。

main.tf

output "out" {
  value = [
    for i in range(1, 101): i % 15 == 0 ? "FizzBuzz" : i % 5 == 0 ? "Buzz" : i % 3 == 0 ? "Fizz" : i
  ]
}

これを 0.13.0-beta1 で実行すると以下のようにFuzzBuzzが実装されていることが分かると思います。

❯ ./terraform_0.13 apply -auto-approve

Apply complete! Resources: 0 added, 0 changed, 0 destroyed.

Outputs:

out = [
  "1",
  "2",
  "Fizz",
  "4",
  "Buzz",
  "Fizz",
  "7",
  "8",
  "Fizz",
  "Buzz",
  "11",
  "Fizz",
  "13",
  "14",
  "FizzBuzz",
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  "Buzz",
  "Fizz",
  "97",
  "98",
  "Fizz",
  "Buzz",
]

しかし、このコードは 0.12.1 では実行できません。

❯ ./terraform_0.12.1 apply -auto-approve

Error: Error locking state: Error acquiring the state lock: state snapshot was created by Terraform v0.13.0, which is newer than current v0.12.1; upgrade to Terraform v0.13.0 or greater to work with this state

Terraform acquires a state lock to protect the state from being written
by multiple users at the same time. Please resolve the issue above and try
again. For most commands, you can disable locking with the "-lock=false"
flag, but this is not recommended.

なぜなら range 関数は 0.12.2 から追加された関数だからです。 何よりインフラを構築するツールなのに何も構築していないのでIaC感がないのが気になります。

Release v0.12.2 · hashicorp/terraform · GitHub

countを使った方法

そこでcountというパラメータを使用します。これは同じリソースを複数作成する際に用いるパラメータです。例えば100個同じリソースを作りたい場合はcount を100に設定します。 また、count.index を使うことでそれぞれのリソースのindexを取得できるので、1から100までの連番の生成が可能となります。以上のものを使って0.12.2 でも動くようにするのとIaC感があるコードになりそうですね!!

main.tf

今回はproviderにdockerを使ってdockerのリソースを複数作成して実装していきます。 dockerのコンテナを複数作成するように実装していきます。

terraform {
  required_providers {
    docker = {
      source = "terraform-providers/docker"
    }
  }
  required_version = ">= 0.13"
}

provider "docker" {
  host = "unix:///var/run/docker.sock"
}

resource "docker_container" "hello" {
  count = 15
  image = docker_image.hello.latest
  name  = "name-${count.index}"
  labels {
    label = "label-${count.index}"
    value = (count.index + 1) % 15 == 0 ? "FizzBuzz" : (count.index + 1) % 5 == 0 ? "Buzz" : (count.index + 1) % 3 == 0 ? "Fizz" : (count.index + 1)
  }
}

resource "docker_image" "hello" {
  name = "hello-world:latest"
}

output "out" {
  value = [for o in docker_container.hello : o.labels.*.value]
}

とりあえず1から15までを実行した結果が以下のようになります。

❯ terraform apply -auto-approve
❯ terraform output
out = [
  [
    "1",
  ],
  [
    "2",
  ],
  [
    "Fizz",
  ],
  [
    "4",
  ],
  [
    "Buzz",
  ],
  [
    "Fizz",
  ],
  [
    "7",
  ],
  [
    "8",
  ],
  [
    "Fizz",
  ],
  [
    "Buzz",
  ],
  [
    "11",
  ],
  [
    "Fizz",
  ],
  [
    "13",
  ],
  [
    "14",
  ],
  [
    "FizzBuzz",
  ],
]

一見成功してそうに見えますが、上記の count を100にして1から100で実行すると以下のようにコンテナが建てられず、Errorが出てしまいFizzBuzzが実装できないです。 running状態にならなかった場合にはnullが入ってしまいます。

❯ terraform apply -auto-approve
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Error: Container 33111f01371fcff7a1de675ffb77523006fda97f869665e61d3e9dc5c3fc19d0 failed to be in running state

  on main.tf line 14, in resource "docker_container" "hello":
  14: resource "docker_container" "hello" {
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
❯ terraform output
out = [
  null,
  null,
  null,
  null,
  null,
  [
    "Fizz",
  ],
  null,
  null,
  null,
  null,
  [
    "11",
  ],
  null,
  null,
  null,
  null,
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  null,
  null,
  [
    "92",
  ],
  null,
  null,
  null,
  null,
  [
    "97",
  ],
]

main.tf

回避する策としてdockerのvolumeを複数作成するように変更します。

terraform {
  required_providers {
    docker = {
      source = "terraform-providers/docker"
    }
  }
  required_version = ">= 0.13"
}

provider "docker" {
  host = "unix:///var/run/docker.sock"
}

resource "docker_volume" "volume" {
  count = 100
  name  = "name-${count.index}"
  labels {
    label = "label-${count.index}"
    value = (count.index + 1) % 15 == 0 ? "FizzBuzz" : (count.index + 1) % 5 == 0 ? "Buzz" : (count.index + 1) % 3 == 0 ? "Fizz" : (count.index + 1)
  }
}

output "out" {
  value = [for o in docker_volume.volume : o.labels.*.value]
}
❯ terraform apply -auto-approve
❯ terraform output
out = [
  [
    "1",
  ],
  [
    "2",
  ],
  [
    "Fizz",
  ],
  [
    "4",
  ],
  [
    "Buzz",
  ],
  [
    "Fizz",
  ],
  [
    "7",
  ],
  [
    "8",
  ],
  [
    "Fizz",
  ],
  [
    "Buzz",
  ],
  [
    "11",
  ],
  [
    "Fizz",
  ],
  [
    "13",
  ],
  [
    "14",
  ],
  [
    "FizzBuzz",
  ],
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  [
    "91",
  ],
  [
    "92",
  ],
  [
    "Fizz",
  ],
  [
    "94",
  ],
  [
    "Buzz",
  ],
  [
    "Fizz",
  ],
  [
    "97",
  ],
  [
    "98",
  ],
  [
    "Fizz",
  ],
  [
    "Buzz",
  ],
]

うまくいきましたね!!これで古い環境でもFizzBuzzが実装できるようになりました!!

このコードをクラウドの高いGPUインスタンスで実行してみたいので誰か僕に寄付してください(やりません)。

Repository

github.com

効率的フロンティアの解析解

現代ポートフォリオ理論の効率的フロンティアの解析解について調べたことをまとめました。

卵は一つのカゴに盛るな

ファイナンスの世界では、先人たちの経験をもとにした「卵は一つのカゴに盛るな」という格言があります。 これは一つのカゴに全ての卵を盛ると落としたときに全ての卵が割れてしまいますが、複数のカゴに分けて盛ると全ての卵が割れる事態を回避できることを示したものになります。 このようなリスクを最小化する分散投資が重要であることは昔から経験的に知られていました。

現代ポートフォリオ理論とは?

現代ポートフォリオ理論とは、リスクのある投資する場合のポートフォリオの配分をどのように合理的に決定すべきかを示した理論です。 この理論では、投資家は同じリターンが見込まれるのであればリスクを回避するという仮定をおいています。 これはポートフォリオの投資収益率の平均・分散にのみを考慮して、投資収益率の分散(リスク)を最小化することを意味します。 分散(リスク)を最小化するためには「卵は一つのカゴに盛るな」という先人の知恵を活かします。

卵を複数のカゴに分けるようにそれぞれのポートフォリオの配分を x_iとします。 それぞれ投資収益率の平均 E ・分散  \sigma^{2} は次のように示されます。 ポートフォリオ数は m ポートフォリオの共分散行列の要素を \sigma_{ij} とします

f:id:suzuzusu:20200415000146p:plain

この式からポートフォリオの配分 x_i を変えることによって、投資収益率の分散(リスク)をコントロールすることが可能であることが分かります。

効率的フロンティアの解析解

分散を最小化した場合の、リスク・リターン平面における曲線は最小分散フロンティアと呼ばれており、その中でリターンが高いものを効率的フロンティアと言います。 その解析解を求めます。 以下のように、等式制約がある最適化問題として定式化します。

f:id:suzuzusu:20200415000257p:plain

この問題は、ラグランジュの未定乗数法を使用することで解析解を求めることが可能となります。 ラグランジュ関数は以下のようになります。

f:id:suzuzusu:20200415000453p:plain

これをそれぞれ偏微分して、極値を求めます。

f:id:suzuzusu:20200415000811p:plain

f:id:suzuzusu:20200415000940p:plain

 (1a)  x に対して線形なので共分散行列の逆行列の要素 v_{ij} を使って以下のように書き換えることが可能です。

f:id:suzuzusu:20200415001119p:plain

 (2)  E_k をかけて k について合計すると以下のようになります。

f:id:suzuzusu:20200415001226p:plain

さらに、式 (2)  k について合計すると以下のようになります。

f:id:suzuzusu:20200415001234p:plain

それぞれ A, B, C を以下のように定義すると式  (1b), (1c), (3), (4) から  \gamma_1, \gamma_2 についての線形式が導出されます。

f:id:suzuzusu:20200415001402p:plain

f:id:suzuzusu:20200415001350p:plain

 D = BC - A^{2} とすると、 \gamma_1, \gamma_2 は以下のように解けます。

f:id:suzuzusu:20200415001323p:plain

次に式  (1a) x_i をかけて i について合計したものが次のようになります。

f:id:suzuzusu:20200415001436p:plain

式(6)が最小分散フロンティアです。 このようにして解析解を得ることができました。 次に、テストデータを使って効率的フロンティアをプロットしていきます。

実装

この ポートフォリオのデータセットを使用します。以下例

f:id:suzuzusu:20200408235815p:plain

以下は、Juliaで実装したコードです。

using Plots
using Statistics
using DataFrames
using CSV
using LinearAlgebra

open("./49_Industry_Portfolios.CSV", "r") do f
    global df
    csv = join(readlines(f)[2268:2361], '\n')
    df = CSV.read(IOBuffer(csv))
end

# preprocessing
inds = describe(df).min .> -99 # remove missing value
inds[1] = 0 # remove year
inds = inds .* range(1; stop=length(inds))
filter!(x -> x > 0, inds)
df_sub = df[:, inds]

mat = convert(Matrix, df_sub)
E = mean(mat; dims=1)
v = inv(cov(mat))
n = size(mat)[2]

A = 0.0
for i in 1:n
    for j in 1:n
        A += v[i,j]E[j]
    end
end

B = 0.0
for i in 1:n
    for j in 1:n
        B += v[i,j]E[i]E[j]
    end
end

C = 0.0
for i in 1:n
    for j in 1:n
        C += v[i,j]
    end
end

D = B*C - A^2

function sigma(E)
    sqrt((C * E^2 - 2A * E + B)/D)
end


# portfolio
scatter(sqrt.(diag(cov(mat))), E', label="portforio")

# global minimum variance portfolio
y = A/C
x = sigma(y)
scatter!([x], [y], label="global minimum variance portfolio")

# effirencent frontier
E_min = A/C
E_max = max(E...)
ys = range(E_min; stop=E_max, length=1000)
xs = [sigma(y) for y in ys]
plot!(xs, ys, label="efficient frontier")

# minimum variance frontier
E_min = 2 * A/C - max(E...)
E_max= A/C
ys = range(E_min; stop=E_max, length=1000)
xs = [sigma(y) for y in ys]
plot!(xs, ys, linestyle=:dash, xlabel="risk", ylabel="return", label="", title="Modern portfolio theory")

結果

f:id:suzuzusu:20200409000102p:plain

それぞれのポートフォリオが青い点、効率的フロンティア(efficient frontier)は緑の線、分散が最も小さくなる点である最小分散フロンティア(global minimum variance portfolio)はオレンジ色で示されています。 既存のポートフォリオよりリスクが少なくなるような曲線が描かれていることが分かります。 このようにポートフォリオの配分によってリスクを最小化してより良い資産運用をすることが可能となります。

notebook

https://nbviewer.jupyter.org/github/suzusuzu/blog/blob/master/finance/EfficientFrontier.ipynb

参考

  • Markowitz, Harry. “Portfolio Selection.” The Journal of Finance, vol. 7, no. 1, 1952, pp. 77–91., doi:10.2307/2975974. Accessed 14 2020.
  • Merton, Robert C. "An analytic derivation of the efficient portfolio frontier." Journal of financial and quantitative analysis 7.4 (1972): 1851-1872.
  • ウォール街のランダムウォーカー
  • Kenneth R. French - Data Library

AV1で採用されているChroma from Luma Prediction (CfL)を使ってイントラ予測

Chroma from Luma Prediction (CfL) はAlliance for Open Mediaが開発したオープンかつロイヤリティフリーな動画圧縮コーデックAV1などで採用されているイントラ予測手法です。イントラ予測というのは動画圧縮で他のフレームを参照しないフレーム符号化のことを言います。今回は、Juliaを用いてその効果を検証していきます。

Chroma(彩度)とLuma(輝度)の関係

CfLはLumaからChromaを推定する手法です。これにより保持すべきデータ容量が削減できます。ここでは、レナの画像を使ってChromaとLumaの関係を可視化します。

まず必要なパッケージを使ってレナの画像を準備します。

using Images, ImageView, TestImages
using Plots

lena = testimage("lena_color_256")
plot(lena)

f:id:suzuzusu:20200304055448p:plain

次にYCbCr色空間に変換します。

lena = YCbCr.(lena)
lena_arr = channelview(lena)

p_y = heatmap(reverse(lena_arr[1,:,:], dims=1), title="Y(Luma)", color=:grays)
p_cb = heatmap(reverse(lena_arr[2,:,:], dims=1), title="Cb", color=:grays)
p_cr = heatmap(reverse(lena_arr[3,:,:], dims=1), title="Cr", color=:grays)
plot(p_y, p_cb, p_cr, layout=(1, 3), size=(1200, 300), fmt=:png)

f:id:suzuzusu:20200304055512p:plain

左上の32x32のタイル上における、ビットごとのそれぞれChromaとLumaの関係を見ます。横軸にLuma、縦軸にChromaの散布図を作成します。

bs = 32 # block size
s = 1
e = s + bs - 1

plot(lena[s:e,s:e], fmt=:png)



x_min = minimum(lena_arr[1,s:e,s:e])
x_max = maximum(lena_arr[1,s:e,s:e])

# Cb
x = vec(lena_arr[1,s:e,s:e])
y = vec(lena_arr[2,s:e,s:e])
scatter(x, y, label="Cb", xlabel="Luma", ylabel="Chroma")

cb_a = (bs*bs*sum(x.*y) - sum(x)*sum(y)) / (sum(bs*bs*(x.^2)) - sum(x)^2)
cb_b = (sum(y) - cb_a * sum(x)) / (bs*bs)
x = range(x_min, x_max, length=1000)
y = cb_a .* x .+ cb_b
plot!(x, y, label="Cb prediction")

# Cr
x = vec(lena_arr[1,s:e,s:e])
y = vec(lena_arr[3,s:e,s:e])
scatter!(x, y, label="Cr")

cr_a = (bs*bs*sum(x.*y) - sum(x)*sum(y)) / (sum(bs*bs*(x.^2)) - sum(x)^2)
cr_b = (sum(y) - cr_a * sum(x)) / (bs*bs)
x = range(x_min, x_max, length=1000)
y = cr_a .* x .+ cr_b
plot!(x, y, label="Cr prediction", fmt=:png)

f:id:suzuzusu:20200304055625p:plain

f:id:suzuzusu:20200304055638p:plain

上記の結果から、ChromaのそれぞれのCr, CbはLumaに対してのある程度線形な関係にあることが分かりました。つまりCr, Cbはそれぞれ傾きと切片が分かっているならLumaからある程度推定可能ということです。この関係を用いてCfLを実装します。

Chroma from Luma Prediction (CfL)

先程の例でChromaとLumaの関係が分かっているのであとは定式化してみましょう。 Predicting Chroma from Luma in AV1 から式を引用するとCfLは以下のようなパラメータ \alpha, \betaの線形モデルを使って推定します。  C^{p}, L^{r} はChroma, Lumaです。

f:id:suzuzusu:20200304061247p:plain

 \alpha, \beta パラメータは次のように最小二乗法で求めます。

f:id:suzuzusu:20200304061319p:plain

最後に、8, 16, 32のブロックサイズごとにCfLをした画像とオリジナル画像を比較してみましょう。

function cfl(img_arr, bs=8)
    nc, h, w = size(img_arr)
    img_arr_est = zeros((nc, h, w))
    
    for i in 1:Int(h/bs)
        si = (i-1)*bs + 1
        ei = i*bs
        for j in 1:Int(w/bs)
            sj = (j-1)*bs + 1
            ej = j*bs
            img_arr_tmp = img_arr[:, si:ei, sj:ej]

            # Cb
            x = vec(img_arr_tmp[1, :, :])
            y = vec(img_arr_tmp[2, :, :])
            cb_a = (bs*bs*sum(x.*y) - sum(x)*sum(y)) / (sum(bs*bs*(x.^2)) - sum(x)^2)
            cb_b = (sum(y) - cb_a * sum(x)) / (bs*bs)

            # Cr
            x = vec(img_arr_tmp[1, :, :])
            y = vec(img_arr_tmp[3, :, :])
            cr_a = (bs*bs*sum(x.*y) - sum(x)*sum(y)) / (sum(bs*bs*(x.^2)) - sum(x)^2)
            cr_b = (sum(y) - cr_a * sum(x)) / (bs*bs)

            x = img_arr_tmp[1, :, :]
            cb_est = cb_a .* x .+ cb_b
            cr_est = cr_a .* x .+ cr_b
            img_arr_est_tmp = zeros((3, bs, bs))
            img_arr_est_tmp[1,:,:] .= x
            img_arr_est_tmp[2,:,:] .= cb_est
            img_arr_est_tmp[3,:,:] .= cr_est

            img_arr_est[:, si:ei, sj:ej] .= img_arr_est_tmp
        end
    end
    img_arr_est
end

p_org = plot(colorview(YCbCr, lena_arr), title="origin")
p8 = plot(colorview(YCbCr, cfl(lena_arr, 8)), title="CfL(block size 8)")
p16 = plot(colorview(YCbCr, cfl(lena_arr, 16)), title="CfL(block size 16)")
p32 = plot(colorview(YCbCr, cfl(lena_arr, 32)), title="CfL(block size 32)")
plot(p_org, p8, p16, p32, layout = (1, 4), size = (1200, 300), fmt=:png)

f:id:suzuzusu:20200304055719p:plain

上記の結果からオリジナル画像に対して、CfLでかなりきれいに推定できることが分かりました。 AV1では、上記のような基本的なCfLが用いられているわけではないのですが原理は同じです。 さらに今後の取り組みとして非線形モデルを使って推定するなども考えられているらしいです。

f:id:suzuzusu:20200304062355p:plain

notebook

nbviewer.jupyter.org

参考

RustのHashMapの再現性を確保する

Rustの標準のHashMapのハッシュアルゴリズムでは再現性を確保することができないので,別のハッシュアルゴリズムを使って再現性を確保する方法をメモしておく.これはHashSetの場合も同様に再現性を確保することができます.

HashMapのランダム性

By default, HashMap uses a hashing algorithm selected to provide resistance against HashDoS attacks. The algorithm is randomly seeded, and a reasonable best-effort is made to generate this seed from a high quality, secure source of randomness provided by the host without blocking the program.

doc.rust-lang.org

上記の引用のように,標準のHashMapのハッシュアルゴリズムではHashDoS攻撃を防ぐためにランダム性があり,同じコードでも再現不可能となる場合があります.

再現不可能なコード

例えば以下のようなコードは実行ごとに結果が異なります.

use std::collections::HashMap;

fn main() {
    let mut map = HashMap::new();
    for i in 0..10 {
        map.insert(i, i);
    }
    let _ = map.iter().map(|x| print!("{:?},", x)).collect::<Vec<_>>();
    println!("");
}

実行1

(1, 1),(7, 7),(8, 8),(6, 6),(3, 3),(4, 4),(0, 0),(9, 9),(2, 2),(5, 5),

実行2

(0, 0),(3, 3),(5, 5),(6, 6),(2, 2),(4, 4),(7, 7),(8, 8),(9, 9),(1, 1),

rust-fnv

fnvはkeyのサイズが小さい場合に効率的に動作するハッシュアルゴリズムです. このライブラリには,標準のHashMapをfnvハッシュアルゴリズムに代替してエイリアスされたものがあります.今回はこれを使用して再現性を確保します.

再現可能なコード

use fnv::FnvHashMap;

fn main() {
    let mut map = FnvHashMap::default();
    for i in 0..10 {
        map.insert(i, i);
    }
    let _ = map.iter().map(|x| print!("{:?},", x)).collect::<Vec<_>>();
    println!("");
}

実行1

(5, 5),(4, 4),(7, 7),(6, 6),(1, 1),(0, 0),(3, 3),(2, 2),(9, 9),(8, 8),

実行2

(5, 5),(4, 4),(7, 7),(6, 6),(1, 1),(0, 0),(3, 3),(2, 2),(9, 9),(8, 8),

参考