Tensorflow.jsを使いたくて仕方なくなって試しに動かしてみたところ、意外と動かすのに苦労したのでやった内容をメモ。
チュートリアルを動かしたときには、簡単に動かせる〜〜って思ってたのに自分のModelで動かそうとしたら自分の理解不足で結構苦労してしまいました・・・
もうすこし良い方法があったらぜひともご教授いただきたく
はまったポイント
今回はまって時間を消費したのは以下の点でした。
- KerasのモデルをTensorflow.jsで使える形式に変換
- Tensorflow.jsでモデルに入力するデータの作成
環境
- Typescript
- Tensorflow.js
- tensorflowjs(Python) 1.2.6
KerasでModelを構築
CNNによるクラス分類のモデルを使いました。
昔ソシャゲ画像とイラスト画像を分類するモデル(参考:
CNNを活用しTwitterの画像欄からソシャゲと写真を避けてイラストの抽出を行う - Qiita)
でやった内容をKerasに書き直しました。
やった内容自体は上の記事と同じなのでこの記事では省略。
モデルの保存形式は".h5"形式で保存しました。
Tensorflow.jsで使用できるモデルに変換
この工程でかなりはまりました…
内容としては、以下の記事のチュートリアルを実施して変換を実施しました。(
Importing a Keras model into TensorFlow.js | TensorFlow)
ここのPythonライブラリのバージョンを最新(9.23段階だと1.2.9)にするとAttributeErrorが発生してうまく動かない問題がありました。 調べてみると、1.2.9で発生したバグらしく同じような症状が報告されていました。('EnumTypeWrapper' object has no attribute 'DT_FLOAT' while importing in Python 3 · Issue #2014 · tensorflow/tfjs · GitHub)
以下のように、1.26をインストールするように指定すると問題なく変換することができました。
$ pip install tensorflowjs==1.2.6
Issueのページをみていると、この問題を解決したらしき修正がすでにマージされていたので次のバージョンではこの問題は解決しそうです。(Fix access to proto enum fields (#2040) · tensorflow/tfjs@dd5d9ed · GitHub)
Tensorflow.jsで推論
最後に画像データをモデルで使用可能な形式に整形して、モデルにいれることで推論します。 どういうデータにすればよいのかを整形するのに時間がかかってしまいました。 Tensorflowのデータ形式に触れたのが初めてだったせいもある気がするので、なれた人ならそんなにつまらないのかな?と思います。
画像データ(CanvasやImageDataなど)を入力できるデータに変換するのに tf.browser.fromPixelsを使用しました。
ここで、モデルに入力するためには画像サイズを合わせる必要があり、
tf.image.resizeBilinearという関数を利用してサイズを変更する必要があります。
自分はこの関数の存在に気づかなくてここで時間をくってしまいました。
データの次元数があっていないときには、expandDimsを使って合わせる必要もありました。
一応、動作確認のために書いたコードをはっておきます。ただ単体では動きません。
import * as tf from '@tensorflow/tfjs'; function loadImg(id: string): tf.Tensor<tf.Rank> | undefined { let img = document.getElementById(id) as HTMLImageElement; if( !img ) { return; } const resizeImg = tf.image.resizeBilinear(tf.browser.fromPixels(img,3),[224,224]); return resizeImg.expandDims(0).cast('float32').div(tf.scalar(255)); } tf.loadLayersModel('../model/model.json').then(model => { // htmlにあらかじめimgを追加しておいてから動作させている const batched = loadImg('img'); if( !!batched ) { console.log('result'); (model.predict(batched) as tf.Tensor<tf.Rank>).print(); });
やってみて
最先端のモデルとかをブラウザで動作させたい!!でもどんな値いれるのかはあんまり知らん!って感じで始めたのですが、モデルの変換後にどんな値を入れるのは自分で実装する必要があるのでやっぱりちゃんと理解するのは重要だなぁって思いました。 せっかく動かせるようになったので、何かのアプリに組み込んでみようかなって思ってます。