自然言語処理と機械学習こそが最強の学問

CSの大学院生が真の最強学問である自然言語処理と機械学習の内容を書いていく予定。時々全然違う分野の記事も書くかもしれない。

ニューラルネットワークの入力正規化について

ニューラルネットワークのすばらしさをさらに体感すべく、有名な手書き文字認識のデータセットMNISTを入手して、実験をしてみました。このデータセットには28*28ピクセルの入力データとそれが指し示す0〜9のターゲット数字が入っていて、ニューラルネットワークの訓練データとしてよく使われています。実験前は、先人たちのように簡単に正解率95%叩き出せるだろうと意気込んでいましたが、ものの見事にまったく動きませんでした。

まっさきに考えられる原因としては隠れニューロン数と学習率の2つで、どちらもさまざまな条件のもとで訓練してみました変えてみたが、誤差がいつもほぼ同じところで止まってしまってしまい、どうやら違うところが原因だったようです。

いやいやデバッグをしてみたら、原因の出処はパラメータの正規化でした。私の最初の実装では、入力がすべて0~1に入るようにして、重みの乱数も0~1に入るようにしました。普通サイズの入力ならそれで良かったのですが、このデータセットの入力は784個もあり、重みとの積和を取ると簡単に大きな数になってしまって、この数が隠れ層で非線形関数にかけられると、すべて限りなく1.0になりました。この状態で、バックプロパゲーションのパラメータ更新は簡単に止まり、結果として学習が動きませんでした。

原因を突き止めたら、あとは結構楽で、積和があまり大きくならないように、乱数発生を調整したら、ちゃんと動きました!

正解率は90%程度しか出ていないのは残念ですが、学習率をまた微調整して、もっとよい正解率になるようにがんばります。