機械学習のトレーニングって時間もお金もすごくかかっちゃいますよね...
データ容量がそれほど必要でない学習ならGoogle Colaboratoryでなんとかしている人が多いと思います。私もよくお世話になってます。ですが、高性能なモデルを学習させようとすると必然的にデータ量が跳ね上がります...
そこで、無料で高性能な学習済みモデルを手に入れて実行できる、もしくは転移学習やファインチューニングに流用できる「TensorFlow Hub」の使い方を紹介します!
「TensorFlow Hub」って?
Googleが開発した機械学習ライブラリ「TensorFlow」を用いてトレーニングされた学習済みモデルの配布サイトです。下記サイトが日本語の解説サイトです。
https://www.tensorflow.org/hub?hl=ja
自然言語処理で圧倒的なプレゼンスを誇る「BERT」、昨年バズったアプリ「AI画伯」のような画像のスタイル変換ができる「画風変換(magenta)」など、面白そうなモデルがいっぱいあります。ちなみに、上記画像にある画風変換を検証してみた動画をニコニコ動画にアップロードしているので、よかったらご覧ください。
どうやって使うの?
まずは下準備。こちらを使うにはtensortlowとtensorflow-hubの2つのライブラリが少なくとも必要です。pipでインストールしてしまいましょう。
pip install tensorflow
pip install tensorflow-hub
TensorFlow Hubの日本語サイト
https://www.tensorflow.org/hub?hl=ja
TensorFlow Hubのモデル検索サイト(英語)
https://tfhub.dev/
上記サイトから、使ってみたいモデルを探して選択します。今回は「画風変換」を例に解説します。検索サイトのメニューの「Image」を選択して「magenta/arbitrary-image-stylization-v1-256」を探すか、もしくは日本語サイトから「画風変換」をクリックします。
https://tfhub.dev/google/magenta/arbitrary-image-stylization-v1-256/2
モデルのページで下にスクロールしていくと、モデル概要とサンプルコードがあるので、それを参考に実装していきます。
基本的にはTensorFlow Hubの使い方はこの流れで、
- 使いたいモデルを検索する
- サンプルコードを読んで、インプットやアウトプットの形式を確認し、実装する
上記2ステップで学習済みモデルを動かしていきます。
サンプルコード
動画にもしている画風変換のコードを記載します。このモデルはTensorFlowが公式チュートリアルを公開しているので、そちらを参考に(ほとんどコピペですが…)しました。詳しい解説や論文が気になる方は公式の方を参照してください。
import matplotlib.pylab as plt
import numpy as np
import tensorflow as tf
import tensorflow_hub as hub
# param
content_image_url = './source.jpg'
style_image_url = './texture.jpg'
output_image_size = 256
output_image_url = './output.jpg'
def crop_center(image):
"""Returns a cropped square image."""
shape = image.shape
new_shape = min(shape[1], shape[2])
offset_y = max(shape[1] - shape[2], 0) // 2
offset_x = max(shape[2] - shape[1], 0) // 2
image = tf.image.crop_to_bounding_box(
image, offset_y, offset_x, new_shape, new_shape)
return image
def load_image(image_url, image_size=(256, 256), preserve_aspect_ratio=True):
"""Loads and preprocesses images."""
# Load and convert to float32 numpy array,
# add batch dimension, and normalize to range [0, 1].
img = plt.imread(image_url).astype(np.float32)[np.newaxis, ...]
if img.max() > 1.0:
img = img / 255.
if len(img.shape) == 3:
img = tf.stack([img, img, img], axis=-1)
img = crop_center(img)
img = tf.image.resize(img, image_size, preserve_aspect_ratio=True)
return img
# The content image size can be arbitrary.
content_img_size = (output_image_size, output_image_size)
# The style prediction model was trained with image size 256 and it's the
# recommended image size for the style image (though, other sizes work as
# well but will lead to different results).
style_img_size = (256, 256) # Recommended to keep it at 256.
content_image = load_image(content_image_url, content_img_size)
style_image = load_image(style_image_url, style_img_size)
style_image = tf.nn.avg_pool(style_image, ksize=[3,3], strides=[1,1], padding='SAME')
# Load TF-Hub module.
hub_handle = 'https://tfhub.dev/google/magenta/arbitrary-image-stylization-v1-256/2'
hub_module = hub.load(hub_handle)
# Stylize content image with given style image.
# This is pretty fast within a few milliseconds on a GPU.
outputs = hub_module(tf.constant(content_image), tf.constant(style_image))
stylized_image = outputs[0][0]
# Visualize input images and the generated stylized image.
result = tf.keras.preprocessing.image.array_to_img(stylized_image)
result.save(output_image_url)
TensorFlow Hubでのモデルをロードする作法は、ネットワーク上のアドレスを指定して「hub.load()」メソッドで呼び出してロードするのが基本です。一度ロードしたモデルは自動的にキャッシュされて次回以降はダウンロード済みのものを呼び出すように「tensorflow-hub」ライブラリがバックグラウンドで処理してくれるようです。
【参考】https://www.tensorflow.org/hub/caching?hl=ja
しかし、キャッシュが勝手にクリーンアップされることや、キャッシュが壊れている場合はキャッシュを手動で削除する必要があったりと面倒もあります。ロードに失敗する場合、Windows環境では「(ユーザー名)\AppData\Local\Temp\tfhub_modules」のパスのキャッシュを削除すると上手くいったりします。(環境によってキャッシュの場所は違うかもしれません)
そのため、アプリに組み込む場合などにはローカルにダウンロードしたモデルを呼び出すほうがスマートかもしれません。その場合の手順は以下の通り。
- モデルのダウンロード
- モデルをロードする部分のソースコードの修正
モデルのページの「Download」をクリックしてモデルのダウンロード、圧縮フォルダを解凍する。
ソースコードを以下のように修正
# 修正前
# Load TF-Hub module.
hub_handle = 'https://tfhub.dev/google/magenta/arbitrary-image-stylization-v1-256/2'
hub_module = hub.load(hub_handle)
# 修正後
# 「saved_model.pb」があるフォルダを指定
hub_handle = './magenta_arbitrary-image-stylization-v1-256_2'
hub_module = hub.load(hub_handle)
機械学習モデルが利用できると、アプリ作成の幅が広がります!みなさんもぜひトライしてみてくださいねー