Menu

【鈴鹿詩子さん】AIにドラムを演奏させて叩いてみた動画自動で作れないか検証してみた(2)【機械学習】

2021年8月28日に公開(2021年8月28日に更新)

シリーズ第二回目となってまいりました。前回こちらの投稿で概要についてお伝えしましたが、今回は本格的にディープラーニングを使った実装方式について解説しようと思います。まずは下記の動画をご覧ください。

「バスドラム」「スネアドラム」「ハイハット」の3種類については認識精度は出ていますが、やはりデータ数の不足する「クラッシュ」「ライド」「タム」については誤検知しかない印象です。AIの耳コピ力はこの段階ではまだまだです。

本当はgoogleが主導している「Magenta」の機能を使った方が精度が高いとは思うんですが、将来的にはドラム音検出部分を組み込んだソフトを色々作って行きたいと思ってるので、車輪の再発明みたいなところはありますが頑張ってます(Magentaのソースコードは膨大な量なので、ドラム音認識部分だけをうまく分離するの大変そうだったので挫折しました)

以下、この機械学習モデルの解説と今後の課題について解説していきます。

どんな機械学習モデルを使ってるの?

CNN(畳み込みニューラルネットワーク)の実装版で参考にしたのは下記のページ。

https://qiita.com/cvusk/items/61cdbce80785eaf28349

CNNは基本的に画像の場合に威力を発揮するのですが、音声に対しても使用している例が多いです。kaggleで音声扱う時は基本これっぽいですね。

テストデータは下記のものを使いました。

プログラムの全体構成は?

画像をCNNで学習させる場合と比較して、圧倒的にデータの前処理が大変です。下図を見ていただけると一目瞭然ですが、多くの処理が必要となります。

音声を扱うと学習データの容量が結構大きくなるので、グラフィックボードのメモリ簡単にオーバーしてしまうんですよね…トレーニングにも時間かかるし大変
(ちなみに、私はRTX3080搭載マシンで機械学習やってます。心優しいだれかがRTX3090くれたりしないかな…?いつかグラボ2枚刺しするのが夢です)

具体的な処理は?

・MIDIデータについて

まずはMIDIファイルの仕様について解説します。MIDIっていうと結構身近な感じがしますが、意外と皆さんどんなデータの持ち方をしているかご存じないですよね。(かくいう私も今回初めて知りました…)
pythonライブラリの「pretty_midi」でファイルを読み込むと、下記の表みたいなデータの持ち方をしていました。

No.start (sec)end (sec)pitchvelocity
00.002717393750.104166760416666675560
10.0063405854166666670.106884154166666685155
20.0090579791666666670.108695750000000013664
30.191123360416666670.29166692916666673657
40.215579904166666670.315217675000000035131
50.40398587083333330.50543523754477
60.423007627083333340.52355119583333335154
・・・・・・・・・・・・・・・

上記のように、データは

  • start:音を鳴らし始める時間
  • ent:音を止める時間
  • pitch:音程(ドラムの場合は楽器ごとにコードが割り当てられる)
  • velocity:音を鳴らす強さ

といった要素を持ってます。

・クオンタイズ

上記の表を見てみると、おそらく同じ時間に音を鳴らしているはずなのに、微妙にタイミングがずれているのがわかると思います。察しの良い方はもうお気づきだと思いますが、このデータは人間が電子ドラムを叩いて作成したデータなんですよね…
なので、この微妙な時間のずれを補正してあげる必要があります。それがクオンタイズです。

クオンタイズする時間は、大体下記の想定にしてみました。

(60(秒) / 290(BPM)) / 4(16分音符) = 約0.05秒

理論上、BPM290で16分音符までは対応させてます。

・ドラム音のマッピング

pitchの値が叩いている楽器のコードになるので、元データの仕様をみて変換していきます。
今回使ったデータセットは下記の音源をつかっており、楽器数が多すぎるので少なく変換します。

https://magenta.tensorflow.org/datasets/groove

だいたいこんな感じでマッピングしました。

# [pitch⇒ドラム]マッピング
inst_map = {
    36: 0, # Kick
    38: 1, # Snare
    48: 2, # Tom
    46: 3, # HH
    49: 4, # Crash
    51: 5 # Ride
}

marge_map = {
    36: 36, # Kick
    38: 38, # Snare (Head)
    40: 38, # Snare (Rim)
    37: 38, # Snare X-Stick
    48: 48, # Tom 1
    50: 48, # Tom 1 (Rim)
    45: 48, # Tom 2
    47: 48, # Tom 2 (Rim)
    43: 48, # Tom 3 (Head)
    58: 48, # Tom 3 (Rim)
    46: 46, # HH Open (Bow)
    26: 46, # HH Open (Edge)
    42: 46, # HH Closed (Bow)
    22: 46, # HH Closed (Edge)
    44: 46, # HH Pedal
    49: 49, # Crash 1 (Bow)
    55: 49, # Crash 1 (Edge)
    57: 49, # Crash 2 (Bow)
    52: 49, # Crash 2 (Edge)
    51: 51, # Ride (Bow)
    59: 51, # Ride (Edge)
    53: 51 # Ride (Bell)
}

・オーギュメンテーション

基本的にはこちらの記事に具体的な手法が全部書いてます。
ピッチの変更、コンプレッサー、イコライザ、リバーブの4種類の処理をランダムでかけてます。
1つの音源に対して全体に効果を適用してから音源をカットする方式にしていますが、本当は順序を逆にしたほうが多様性が出るのでいいと思います。ただ、カットされた音源1つ1つに効果を適用するのは処理が重すぎて諦めました。

・メルスペクトログラム

音源はサンプリングレートは44100Hzで読み込み、そのうち切り出すフレーム数は3072(約0.07秒)としています。そして、切り出したWAVデータをメルスペクトログラムに変換します。
変換のソースコードは参考にしたQiitaの記事と同じです。

def calculate_melsp(x, n_fft=1024, hop_length=128):
    stft = np.abs(librosa.stft(x, n_fft=n_fft, hop_length=hop_length))**2
    log_stft = librosa.power_to_db(stft)
    melsp = librosa.feature.melspectrogram(S=log_stft,n_mels=128)
    return melsp

次回の内容

CNNだけでなく色んなモデルを試していて、精度を向上させるために下記の変遷をたどってます。

CNN ⇒ CNN×Frequency LSTMs ⇒ Time-Frequency LSTM(今ここ) ⇒ Transformer(予定)

細かい試行錯誤も含めると、今開発しているのは第8世代ぐらいなんですよね…めちゃくちゃトライ&エラー繰り返してます。
次の記事では各種モデルの実装について触れていけたらなと思ってます。

それでは、次回もお楽しみに!

©2023 Maya Hanada