連載第7回の目的

この回では、前回に引き続き手書きの数字を認識するサンプルを紹介します。この回では、コンバータのインストールと、TensorFlow.jsで利用できるモデルへ変換して予測を実行するサンプルを紹介します。

  • 図1 完成サンプルイメージ

    図1 完成サンプルイメージ

完成サンプル
https://github.com/wateryinhare62/mynavi_tensorflowjs/

今回のテーマも画像認識です。前回は、Pythonプラットフォームで手書き文字認識のモデルを構築しました。今回は、そのモデルをTensorFlow.jsすなわちWebプラットフォームで利用できることを理解します。利用する手書き文字データベースのMNISTと動作確認環境については前回を参照してください。

[NOTE]サンプルについて
本記事のサンプルは、Keras Teamのデモプログラム(https://github.com/keras-team/keras/tree/master/examples/demo_mnist_convnet.py)を一部改変して作成、実行しています。

モデルの変換

前回は、Python環境で手書き文字認識のモデルを構築し、ファイルに保存するところまでを行いました。続けて、このモデルをTensorFlow.jsで利用可能な形式に変換する作業からはじめます。Windows環境でターミナルを開き、WSL(Windows Subsystem for Linux)からUbuntuを起動し、作業ディレクトリに移動してください。

> wsl -d Ubuntu-24.04
…略…
~$ cd tfjs  作業ディレクトリへ移動

前回の手順で作業を進めると、Pythonの仮想環境であるvenvディレクトリがあるはずなので、以下のように仮想環境を有効化してください。

tfjs$ source venv/bin/activate
(venv) tfjs$    プロンプトが変化する

変換は、専用のコンバータで行います。このコンバータはtensorflowjsパッケージにてインストール済みなので、直ちに変換作業に入れます。以下のコマンドを実行してください。

(venv) tfjs$ tensorflowjs_converter --input_format=tf_saved_model ./mnist_model ./tfjs

tensorflowjs_converterプログラムは、TensorFlowやKeras形式で保存されたモデルデータを変換します。--input_formatオプションには、入力フォーマットを指定します。この場合は、demo_mnist_convnet.pyのmodel.export()で保存されたTensorFlowのSavedModel形式を指定しています。続く2つのパラメータは、入力ファイルのあるディレクトリ、出力ファイルを書き出すディレクトリで、それぞれ./mnist_modelと./tfjsとしています。

[NOTE]変換用シェルスクリプトconvert.sh
変換のコマンドは長いので、配布サンプルには変換用シェルスクリプトをconvert.shファイルとして含めています。Ubuntu上で実行する場合、実行権限を以下のように付与して使用してください。 (venv) tfjs$ chmod +x convert.sh

変換を開始すると経過や警告が出力されますが、出力ディレクトリ(./tfjs)にファイルが以下のように2つ書き出されていれば成功です。

  • group1-shard1of1.bin:バイナリ形式の重みファイルのコレクション
  • model.json:データフローグラフと重みマニフェスト

モデルの読み込みと確認

ここ以降の作業は、Windows上で行います。先ほど作成した変換後のモデルデータを、Windows側に作業フォルダ(ここではtfjs_win)を用意して、コピーしておきます。ここでUbuntuは不要になるので、exitコマンドで終了しておきます。

(venv) tfjs$ mkdir /mnt/c/Users/nao/Documents/tfjs_win
(venv) tfjs$ cp -r ./tfjs /mnt/c/Users/nao/Documents/tfjs_win
(venv) tfjs$ exit

サーバ機能を用意する

変換後のモデルデータをJavaScriptから読み込むために、サーバ機能を用意します。といっても大げさなものではなく、第2回でも少し触れたVS Code(Visual Studio Code)と拡張機能Live Serverを使います。
VS Codeの準備の詳細は省きますが、[「Download Visual Studio Code - Mac, Linux, Windows」](https://code.visualstudio.com/download)からインストーラをダウンロードしインストール後、拡張機能から「Live Server」を探してインストールしておきます。

予測できるか確認する

変換後のモデルを読み込み、数字に相当する画像を読み込ませて、予測できるか確認してみます。リスト1とリスト2のように、index.htmlファイルとscript.jsファイルを作成します。

リスト1 index.html

<html>
<head>
    <meta charset="utf-8">
    <title>手書き数字認識</title>
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs"></script>   (1)
</head>
<body>
    <h1>手書き数字認識</h1>
    <img id="image" src="number_5.png" width="240" height="240">    (2)
    <script src="script.js"></script>   (3)
</body>
</html>

リスト2 script.js

let model;  // モデルを保持するオブジェクト(tf.GraphModel)
const imageElement = document.getElementById('image');

// 予測関数
function predict() {
    // (1)画像をテンソル化する
    const input = tf.browser
        .fromPixels(imageElement, 1)    // 1チャネル化してテンソル化
        .toFloat()  // 浮動小数点数化
        .resizeNearestNeighbor([28, 28])    // 28×28にリサイズ
        .div(tf.scalar(255))    // 正規化
        .expandDims();  // 次元拡張
    // (2)予測の実行とコンソールへの出力
    const score = model.predict(input).dataSync();
    console.log(score);
}

async function app() {
    // (3)モデルの読み込み
    const FILE = "tfjs/model.json";
    model = await tf.loadGraphModel(FILE);
    // (4)予測関数の呼び出し
    predict();
}
app();

index.htmlファイルをブラウザで表示させると、ただちにモデルの読み込みと予測が実行され、結果がコンソールに出力されます(図2)。結果は10個の要素からなる浮動小数点数の配列(Float32Array)で、その6番目(インデックス=5)の確率が最も高いことがわかります。

  • 図2 モデルの読み込みと予測のテスト

    図2 モデルの読み込みと予測のテスト

[NOTE]Live Serverを使ったHTMLの表示
Live Serverを起動してHTMLファイルを開くには、VS Codeのエクスプローラーでindex.htmlファイルを右クリック、メニューから[Open with Live Server]を選択してください。

最低限、予測を行うだけならコードはこのようにシンプルです。
HTMLでは、(1)のようにTensorFlow.jsをCDN(Contents Delivery Network)から読み込み、(2)に予測対象の画像を配置し(学習画像に合わせて黒地である必要があります)、(3)でJavaScriptを読み込んでモデルの読み込みと予測を実行させています。
JavaScriptのコードの意味は以下の通りです。

・(3)でモデルを読み込む。読み込みの実行はloadGraphModelメソッド。TensorFlowのSavedModelを読み込むときには、このメソッドを使う。戻り値は専用のtf.GraphModelオブジェクトとなる
・(1)で画像を複数のメソッドを用いて、予測のためにテンソル化する。入力画像は4チャネルであるので、1チャネル化してテンソル化し、整数値から浮動小数点数化し、学習画像に合わせて28×28にリサイズする。正規化の後、次元拡張し3Dのテンソルとする。
・(2)で予測を実行する。予測はpredictメソッドで、(2)で生成したテンソルを入力とする。結果は10要素の浮動小数点数の配列として得られる。それぞれの浮動小数点数は確率を意味する

LayersModelとGraphModel

ここで使ったモデルはGraphModelです。GraphModelとは、TensorFlow.jsで利用できるモデルの一つですが、これまでの回で主に使ってきたのはLayersModelというモデル(tf.LayersModelオブジェクト)です。LayersModelは文字通り、層(レイヤ)を持ったモデルで、GraphModelはグラフ構造を持ったモデルです。これらのモデルの違いを表1にまとめておきます。

表1 LayersModelとGraphModelの違い

LayersModel GraphModel
特徴 Keras APIを使用して、層(レイヤ)を積み重ねて構築される。モデルの層構造に関する情報が含まれる TensorFlowのSavedModel形式からインポートされる、より低レベルな表現。操作のデータフローグラフとして表現される
機能 JavaScript環境で追加学習(再訓練)が可能。fitメソッドでモデルを更新できる 予測(推論)のみをサポートする。学習はできない
パフォーマンス GraphModelより速度が遅くなる可能性がある グラフが最適化されているため、LayersModelより推論速度が速い場合がある
モデル変換 GraphModelをLayersModelに変換することはできない LayersModelをSavedModel形式で保存し、GraphModelとして読み込むことが可能

分かりやすいのは、GraphModelは予測のみに利用可能で、最適化によりLayersModelより高速に予測できることがある、ということでしょう。このため、今回のように学習済みモデルをSavedModel形式で保存したものを用意すれば、追加学習こそできませんが、モデルの学習成果を高速かつ簡単に利用できる、ということになります。

手書き文字認識への拡張

固定の数字画像から認識できることを確認したので、手書き文字を認識できるようにサンプルを拡張します。以降で、HTMLとJavaScriptのマークアップおよびスクリプトを解説していきますが、外観を整えるためにCSSを一部適用しています。このCSSについては、サンプルの動作に直接的な関係はないので掲載は割愛します。具体的な内容は配布サンプルを参照してください。

HTMLにキャンバスとボタン、テーブルを追加する

リスト1の<img>タグを無効にして、替わりに<canvas>タグを配置します。予測とキャンバスクリアのためのボタンと、結果(各数字の確率)を表示するためのテーブルも配置します(リスト3)。

リスト3 index.html

…略…
<h1>手書き数字認識</h1>
<!--<img id="image" src="number_5.png" width="240" height="240">--> 無効化
<div class="left">  ここから追加
    <canvas id="image" width="240" height="240"></canvas>
    <div>
        <button id="predict">予測</button>&nbsp;
        <button id="clear">クリア</button>
    </div>
</div>
<div class="right">
    <h3>判定結果</h3>
    <table>
        <tr><th>数字</th><th>確率</th></tr>
        <tr><td>0</td><td><span id="result-probability-0"></span></td></tr>
        <tr><td>1</td><td><span id="result-probability-1"></span></td></tr>
        <tr><td>2</td><td><span id="result-probability-2"></span></td></tr>
        <tr><td>3</td><td><span id="result-probability-3"></span></td></tr>
        <tr><td>4</td><td><span id="result-probability-4"></span></td></tr>
        <tr><td>5</td><td><span id="result-probability-5"></span></td></tr>
        <tr><td>6</td><td><span id="result-probability-6"></span></td></tr>
        <tr><td>7</td><td><span id="result-probability-7"></span></td></tr>
        <tr><td>8</td><td><span id="result-probability-8"></span></td></tr>
        <tr><td>9</td><td><span id="result-probability-9"></span></td></tr>
    </table>
</div>  ここまで
<script src="script.js"></script>
…略…

JavaScriptに数字描画と結果表示のスクリプトを追加する

リスト2に、数字描画のためのスクリプト、結果表示のためのスクリプトを追加します。簡略化のために、数字描画のコードは一筆書きのみに対応しています。

リスト4 script.js

…略…
// (1)マウスの座標とボタン押下状態の保持
let mouse = { x: 0, y: 0, down: false };

// (2)キャンバスへの描画関数。ボタン押下状態で直前の座標から直線を引く
function draw() {
    if (mouse.down) {
        const ctx = imageElement.getContext('2d');
        ctx.lineTo(mouse.x, mouse.y);
        ctx.strokeStyle = 'white';
        ctx.lineWidth = 10;
        ctx.stroke();
    }
}

// 「予測」ボタンで呼び出される予測関数
function predict() {
    …略…
    const score = model.predict(input).dataSync();
    console.log(score);
    // (3)テーブルに確率を設定する。最大の値のセルは赤字とする
    const pos = score.indexOf(Math.max(...score));
    for (let i = 0; i < score.length; i++) {
        const accurance = document.getElementById(`result-probability-${i}`);
        accurance.innerText = score[i].toFixed(4);
        accurance.style.color = (i === pos) ? 'red' : 'black';
    }
}

async function app() {
    // (4)マウス操作のイベントハンドラを設定する
    imageElement.addEventListener('mousedown', (e) => {
        const rect = imageElement.getBoundingClientRect();
        mouse.x = e.pageX - rect.left;
        mouse.y = e.pageY - rect.top;
        mouse.down = true;
    });
    imageElement.addEventListener('mouseup', (e) => {
        mouse.down = false;
    });
    imageElement.addEventListener('mousemove', (e) => {
        const rect = imageElement.getBoundingClientRect();
        mouse.x = e.pageX - rect.left;
        mouse.y = e.pageY - rect.top;
        draw();
    });
    // (5)ボタンクリックのイベントハンドラを設定する
    document.getElementById('predict').addEventListener('click', () => {
        predict();
    });
    document.getElementById('clear').addEventListener('click', () => {
        const ctx = imageElement.getContext('2d');
        ctx.clearRect(0, 0, 240, 240);
        ctx.beginPath();
    });
    const FILE = "tfjs/model.json";
    model = await tf.loadGraphModel(FILE);
}
app();

(1)(2)(4)は、マウス操作でキャンバスに数字を描画するためのスクリプトです。(1)はマウスポインタの座標やボタンの押下状態を保持するための作業オブジェクト、(2)は現在の座標までの描画関数で、マウスポインタの移動によって呼び出されます。(4)からの3つのイベントハンドラ設定は、ボタンプッシュとリリース、移動に対応したものです。
(3)は予測関数に追記されたスクリプトで、予測結果をテーブルに設定します。単純に各値をセルに設定しているだけですが、Math.max関数で最大値を取得し、それに対応するインデックスをindexOfメソッドで求め、そのセルだけ文字色を赤にするという処理を行っています。
(5)からの2つのイベントハンドラ設定は、それぞれ「予測」ボタン、「クリア」ボタンのクリックに対応します。
index.htmlファイルをブラウザで表示させると、黒地のキャンバスと、空のテーブルが表示されます。キャンバス上にマウス操作で適当に数字を描き(一筆書きである必要があることに注意)、「予測」ボタンをクリックすると、結果がテーブルに表示されます(図3)。

  • 図3 手書き数字の認識

    図3 手書き数字の認識

まとめ

今回は、MNISTデータベースを利用した手書き数字の認識の続編として、保存したモデルデータをTensorFlow.jsで読み込み、Canvas上に手書きした数字を認識させるサンプルを作成しました。モデルの利用そのものは、非常にシンプルな手順で行えることをお伝えできたのではないかと思います。
今回を以て、TensorFlow.js活用の本連載は終了です。ブラウザで環境構築いらずで利用できるメリットを活かした、学習モデルの構築や既存モデルの利用の一助になったとすれば幸いです。

WINGSプロジェクト 山内直(著) 山田 祥寛(監修)
有限会社 WINGSプロジェクトが運営する、テクニカル執筆コミュニティ(代表山田祥寛)。主にWeb開発分野の書籍/記事執筆、翻訳、講演等を幅広く手がける。現在も執筆メンバーを募集中。興味のある方は、どしどし応募頂きたい。著書、記事多数。
RSS
X:@WingsPro_info(公式)@WingsPro_info/wings(メンバーリスト)
Facebook

<著者について>
WINGSプロジェクト所属のテクニカルライター。出版社を経てフリーランスとして独立。ライター、エディター、デベロッパー、講師業に従事。屋号は「たまデジ。」。