webとモバイルアプリの逆転

webアプリとモバイルアプリの開発に関する話や本・ゲームなどの趣味の話を雑多にしていきたい

Kerasで構築したモデルをTensorflow.jsで動かす際にハマったこと

Tensorflow.jsを使いたくて仕方なくなって試しに動かしてみたところ、意外と動かすのに苦労したのでやった内容をメモ。
チュートリアルを動かしたときには、簡単に動かせる〜〜って思ってたのに自分のModelで動かそうとしたら自分の理解不足で結構苦労してしまいました・・・

もうすこし良い方法があったらぜひともご教授いただきたく

はまったポイント

今回はまって時間を消費したのは以下の点でした。

  1. KerasのモデルをTensorflow.jsで使える形式に変換
  2. Tensorflow.jsでモデルに入力するデータの作成

環境

  • Typescript
  • Tensorflow.js
  • tensorflowjs(Python) 1.2.6

KerasでModelを構築

CNNによるクラス分類のモデルを使いました。 昔ソシャゲ画像とイラスト画像を分類するモデル(参考: CNNを活用しTwitterの画像欄からソシャゲと写真を避けてイラストの抽出を行う - Qiita) でやった内容をKerasに書き直しました。
やった内容自体は上の記事と同じなのでこの記事では省略。
モデルの保存形式は".h5"形式で保存しました。

Tensorflow.jsで使用できるモデルに変換

この工程でかなりはまりました…
内容としては、以下の記事のチュートリアルを実施して変換を実施しました。( Importing a Keras model into TensorFlow.js  |  TensorFlow)

ここのPythonライブラリのバージョンを最新(9.23段階だと1.2.9)にするとAttributeErrorが発生してうまく動かない問題がありました。 調べてみると、1.2.9で発生したバグらしく同じような症状が報告されていました。('EnumTypeWrapper' object has no attribute 'DT_FLOAT' while importing in Python 3 · Issue #2014 · tensorflow/tfjs · GitHub

以下のように、1.26をインストールするように指定すると問題なく変換することができました。

$ pip install tensorflowjs==1.2.6

Issueのページをみていると、この問題を解決したらしき修正がすでにマージされていたので次のバージョンではこの問題は解決しそうです。(Fix access to proto enum fields (#2040) · tensorflow/tfjs@dd5d9ed · GitHub)

Tensorflow.jsで推論

最後に画像データをモデルで使用可能な形式に整形して、モデルにいれることで推論します。 どういうデータにすればよいのかを整形するのに時間がかかってしまいました。 Tensorflowのデータ形式に触れたのが初めてだったせいもある気がするので、なれた人ならそんなにつまらないのかな?と思います。

画像データ(CanvasやImageDataなど)を入力できるデータに変換するのに tf.browser.fromPixelsを使用しました。
ここで、モデルに入力するためには画像サイズを合わせる必要があり、 tf.image.resizeBilinearという関数を利用してサイズを変更する必要があります。 自分はこの関数の存在に気づかなくてここで時間をくってしまいました。
データの次元数があっていないときには、expandDimsを使って合わせる必要もありました。

一応、動作確認のために書いたコードをはっておきます。ただ単体では動きません。

import * as tf from '@tensorflow/tfjs';
function loadImg(id: string): tf.Tensor<tf.Rank> | undefined {
    let img = document.getElementById(id) as HTMLImageElement;
    if( !img ) {
        return;
    }
    const resizeImg = tf.image.resizeBilinear(tf.browser.fromPixels(img,3),[224,224]);
    return resizeImg.expandDims(0).cast('float32').div(tf.scalar(255));
}

tf.loadLayersModel('../model/model.json').then(model => {    
// htmlにあらかじめimgを追加しておいてから動作させている
    const batched = loadImg('img');

    if( !!batched ) {
        console.log('result');
        (model.predict(batched) as tf.Tensor<tf.Rank>).print();
});

やってみて

最先端のモデルとかをブラウザで動作させたい!!でもどんな値いれるのかはあんまり知らん!って感じで始めたのですが、モデルの変換後にどんな値を入れるのは自分で実装する必要があるのでやっぱりちゃんと理解するのは重要だなぁって思いました。 せっかく動かせるようになったので、何かのアプリに組み込んでみようかなって思ってます。