TadaoYamaokaの日記

山岡忠夫Homeで公開しているプログラムの開発ネタを中心に書いていきます。

C++でディリクレ分布による乱数生成

C++にディリクレ分布で乱数生成する標準関数は用意されていない。

ガンマ分布で乱数生成する標準関数std::gamma_distributionが用意されているので、
Dirichlet distribution - Wikipedia
に書かれている方法を使って、ガンマ分布で乱数y1,...,yKを生成し、
x_i=\frac{y_i}{\sum_{j=1}^K y_j}
で、ディリクレ分布の乱数に変換できる。

これを実装した。

#include <iostream>
#include <random>
#include <vector>
#include <algorithm>

void random_dirichlet(std::mt19937_64 &mt, std::vector<double> &x, const double alpha) {
	std::gamma_distribution<double> gamma(alpha, 1.0);
	
	double sum_y = 0;
	for (int i = 0; i < x.size(); i++) {
		double y = gamma(mt);
		sum_y += y;
		x[i] = y;
	}
	std::for_each(x.begin(), x.end(), [sum_y](double &v) { v /= sum_y; });
}

int main() {
	std::random_device rd;
	std::mt19937_64 mt(rd());
	const int K = 5;
	const double alpha = 0.15;

	for (int i = 0; i < 10; i++) {
		std::vector<double> x(K);
		random_dirichlet(mt, x, alpha);

		for (int j = 0; j < x.size(); j++) {
			std::cout << x[j];
			if (j < x.size() - 1) std::cout << ", ";
		}
		std::cout << std::endl;
	}
}
実行結果

K=5, alpha=0.15の対称ディリクレ分布

0.00820617, 1.6067e-08, 1.82496e-06, 0.991569, 0.000223472
0.100336, 0.104018, 0.027749, 0.0115689, 0.756328
0.965557, 0.00401107, 7.88541e-18, 0.0285327, 0.00189949
0.721145, 0.0584562, 0.00153585, 2.04515e-05, 0.218842
0.834343, 0.00332245, 5.07e-09, 0.00101873, 0.161316
0.00615545, 0.993583, 0.000253227, 8.34935e-06, 3.51258e-07
0.893519, 0.0714379, 0.000327648, 0.00228227, 0.0324336
0.000104421, 0.00124077, 0.99865, 5.90869e-07, 4.54077e-06
0.0158551, 0.00067078, 0.0668421, 0.882654, 0.033978
8.98511e-09, 0.016774, 7.59551e-05, 0.0137781, 0.969372