読者です 読者をやめる 読者になる 読者になる

AIプログラムとかUnityゲーム開発について

探索や学習などを活用したAI系ゲームを作りたいと思います。

昔作ったバックプロパゲーションのニューラルネットワークでXORの学習

世の中、深層学習なのにどんだけ周回遅れだよって感じですが、
昔、作ったバックプロパゲーションニューラルネットワークC言語プログラムをHDDの片隅から発掘してきた。

XORは非線形

x[1]=rand()%2;
x[2]=rand()%2;

というランダムな入力値(0か1)を与えて、教師は三人

t[0] = (x[1]^x[2]);
t[1] = (x[1]&x[2]);
t[2] = (x[1]|x[2]);

0番はXOR、1番はAND、2番はOR
食わせ物はXORで、これは線形なネットワークでは学習できない。
だから、一度ニューラルネットワークは全否定された。
XORは1^1=0になるので、線形には表せずに、非線形なネットワークしか模倣できない。

XOR
0^0=0
0^1=1
1^0=1
1^1=0

ORみたいだけど、1^1が0になるので、排他的論理和と呼ばれている。


しかし、そこにバックプロパゲーションという学習法が提案されて、隠れ層を入れた非線形ニューラルネットワークを学習できるようになった。
これでやっとXORが学習できるように成った(非線形な演算で近似できた)

シグモイド関数

ニューラルネットワークは、通常 Z=シグモイド関数(X)と表現できる。
シグモイド関数は、f(x)=1/(1+exp(-x))みたいな関数で、グラフで見ると、生物の神経が励起してサチるような図形になっている。
これを使う理由は、実は微分がやりやすい。このf(x)を微分するとf'(x)=f(x)(1-f(x))になる。
微分が元の関数で表せる。
シグモイド関数を微分する - のんびりしているエンジニアの日記

ニューラルネットワークの学習は、二乗誤差をエネルギ関数とみなして最急降下法を用いる。
これは、分散が減るように重みを変化させることで実現させている。
平たく言えば、微分した値を引けば、だんだん底に落ちていく。
例えば、放物線の微分は接線の傾きだけど、放物線のどこにいても、微分を引いていけば、頂点に向かう。
これは線形のニューラルネットワークの話ですが、
出力層は、 h += alpha * 微分関数(x)で学習できる。
でも、これだとANDとORは表現できてもXORは無理

隠れ層を入れて三層構造だと、出力層は上記の単なる微分を使う重み増減で学習できるけど、
問題は隠れ層の学習をどうするか?
中間層の学習に、バックプロパゲーションが使われる。これでやっとXORが表現できる。

バックプロパゲーション

バックプロパゲーションはコスト関数を中間層の重みで微分したものを重みに加える。

 g = f(w・x)
 z = f(h・g) = f(h・f(w・x))
cost = (1/2)||x-z||^2

 ∂cost/∂h = (x-z)(-f'(h・g))・g
 ∂cost/∂w = (x-z)(-f'(h・f(w・x))・h・f'(w・x)・x

gが中間層で、zがgを用いた出力層の時、教師と出力の誤差の自乗をコストと定義する
そのコストを、出力層の重みで微分したのが、 ∂cost/∂h
中間層の重みで微分したのが∂cost/∂w
ちょっとややこしいけど、高校数学3Cの合成関数の微分の知識で微分できる。

あとは、このコストを微分したものを計算して重みに加えれば学習できる。
これで非線形なものを表現できるようになる。重みを用いた非線形演算が、非線形な現象を近似する能力を持つ

しかし、層を増やしていくと、このバックプロパゲーションの伝播が弱くなって、ちっとも入力層の近くのほうが
学習できない(ちょっと強化学習の問題点を思い出すんですけど)
それで多層ニューラルネットワークは「だめじゃん!」と打ち捨てられた。

深層学習

深層学習は、いきなり多層で学習せずに、一層づつ学習させる
それも入力から得られた出力を、出力から戻すと入力に同じ値が戻るような教師なし学習をしてまず慣らす。
それから学習すると多層なのに学習が進んだということらしい。自分の理解としては
層が増えれば増えるほど、非線形模倣能力が向上する。
だから、猫の顔も認識できるし、囲碁の局面まで認識できたらしい。
この1層づつ学習が進んでいく多層構造は、実際の生物の視覚野の構造がヒントになっているらしい


バックプロパゲーションの学習結果

XORがちゃんと学習できている!

86000 in=[0 0] out=[XOR=0.04 AND=0.00 OR=0.03] err=0.0024
88000 in=[0 0] out=[XOR=0.04 AND=0.00 OR=0.03] err=0.0023
90000 in=[0 0] out=[XOR=0.04 AND=0.00 OR=0.03] err=0.0023
92000 in=[1 1] out=[XOR=0.05 AND=0.97 OR=1.00] err=0.0028
94000 in=[1 0] out=[XOR=0.97 AND=0.02 OR=0.98] err=0.0021
96000 in=[1 0] out=[XOR=0.97 AND=0.02 OR=0.98] err=0.0020
98000 in=[1 0] out=[XOR=0.97 AND=0.02 OR=0.98] err=0.0020


以下ソース、けっこう短いよ

#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <time.h>
#include <assert.h>
#include <sys/types.h>
#include <windows.h>
#include <io.h>
#include <fcntl.h>
#include <math.h>
#include <process.h>

int x[3];
int t[3];
double g[3];
double z[3];
double h[30][30];
double w[30][30];

#define sigmoid(a) (1.0/(1.0+exp(-(a))))
#define dsigmoid(a) ((a)*(1.0-(a)))
void main()
{
	srand((unsigned) time(NULL));

	double alpha=0.06;

// 重みの初期化
	for(int j=0;j<=2;j++) for(int i=0;i<=2;i++)
		w[j][i] = (double)((rand()%2000)-1000)/1000;
	for(int j=0;j<=2;j++) for(int i=0;i<=2;i++)
		h[j][i] = (double)((rand()%2000)-1000)/1000;

	for(int count=0;count<100000;count++)
	{
		x[0]=1;
		x[1]=rand()%2;
		x[2]=rand()%2;
		t[0] = (x[1]^x[2]);
		t[1] = (x[1]&x[2]);
		t[2] = (x[1]|x[2]);

// 中間層の出力
		for(int i=0;i<=2;i++) {
			double sum=0;
			for(int j=0;j<=2;j++) sum += w[j][i]*x[j];
			g[i] = sigmoid( sum );
		}
// 出力層の出力
		for(int i=0;i<=2;i++) {
			double sum=0;
			for(int j=0;j<=2;j++) sum += h[j][i]*g[j];
			z[i] = sigmoid( sum );
		}

// 出力層の学習
		alpha=0.06;
		for(int j=0;j<=2;j++) for(int i=0;i<=2;i++)
			h[i][j] += alpha * g[i] * (t[j]-z[j])*z[j]*(1.0-z[j]);
// 中間層の学習
		for(int k=0;k<=2;k++) for(int j=0;j<=2;j++) {
			double dj = (t[k]-z[k])*z[k]*(1.0-z[k]) * h[j][k] * g[j]*(1.0-g[j]);
			for(int i=0;i<=2;i++) w[i][j] += alpha * x[i] * dj;
		}
//誤差の計算
		double error = 0;
		for(int i=0;i<=2;i++) error += (t[i]-z[i])*(t[i]-z[i]);

		if( (count%2000)!=0 ) continue;
		printf( "%6d in=[%d %d] out=[XOR=%.2f AND=%.2f OR=%.2f] err=%.4f \n",count,x[1],x[2],z[0],z[1],z[2],error );
	}
//どんな重みになった?
	printf("w\n");
	for(int i=0;i<=2;i++)
	{
		for(int j=0;j<=2;j++)
			printf( "%+.2f ",w[j][i]);
		printf("\n");
	}
	printf("h\n");
	for(int i=0;i<=2;i++)
	{
		for(int j=0;j<=2;j++)
			printf( "%+.2f ",h[j][i]);
		printf("\n");
	}
}