元女子高生AI「りんな」などで知られるrinna株式会社が8月25日に公開してくださった、日本語に特化した「GPT-2」と「BERT」の事前学習モデルを使って、簡易大喜利(?)モデルを実装してみました。
日本語CC-100と日本語Wikipediaの計75ギガバイトのデータを、8つのNVIDIA Tesla V100 GPUで、45日間も掛けて学習しないと行けない事前学習モデルが、無料で簡単に使えるのは感動でした!
corp.rinna.co.jp
では、振り返っていきたいと思います。
日本語に特化した「GPT-2」「BERT」事前学習モデルをつかって簡易大喜利を実装してみた!
1.今回実装した簡易大喜利モデル
今回実装した簡易大喜利モデルは、指定した文字列のうち、隠した部分の文字列を予測する機能を使って、面白い回答を期待するというモデルです。
本来であれば、いろいろと実装して、データを集めて、長時間学習させないと行けないところ、公開された事前学習モデルを使うことで、あっという間に実装することができました。
簡易大喜利モデルの仕様
- 穴埋めさせたい文章を入力する。
例:「世界一うるさい人と言ったら、何。」
- 簡易大喜利モデルが「何」の部分を予測して、確率が高い順に出力する。
例:「あなた」
- おもしろ回答を見つけて、こっそり楽しむ。
2.全体像
コード全体は以下の通りで、Google colaboratory上で実行しました。
GitHubにも公開しています。
https://github.com/Oregin-ML/Simple_Ogiri_model
!git clone https://github.com/rinnakk/japanese-pretrained-models
%cd /content/japanese-pretrained-models
!pip install -r requirements.txt
import torch
from transformers import T5Tokenizer, RobertaForMaskedLM
tokenizer = T5Tokenizer.from_pretrained("rinna/japanese-roberta-base")
tokenizer.do_lower_case = True
model = RobertaForMaskedLM.from_pretrained("rinna/japanese-roberta-base")
model = model.eval()
print('#'*40)
print('答えさせたい内容を「〜、何。」と聞いてみてください。')
print('例;世界一うるさい人と言ったら、何。')
print('※最後の「。」を忘れずに。')
print('#'*40)
text = input()
text = "[CLS]" + text
if '何' in text:
tokens = tokenizer.tokenize(text)
masked_idx = -2
tokens[masked_idx] = tokenizer.mask_token
print('#'*40)
print('#以下の[MASK]の部分を考えます。')
print(tokens)
print('#'*40)
token_ids = tokenizer.convert_tokens_to_ids(tokens)
token_tensor = torch.tensor([token_ids])
with torch.no_grad():
outputs = model(token_tensor)
predictions = outputs[0][0, masked_idx].topk(10)
print('#'*40)
print('#思いついた答えは・・・')
for i, index_t in enumerate(predictions.indices):
index = index_t.item()
token = tokenizer.convert_ids_to_tokens([index])[0]
print(i, token)
print('#以上です。')
print('#'*40)
else:
print('何か聞いてくれないとわからないです。')
では、コードを順番に見ていきます。
3.学習済みモデルをインストールする
GitHubより、学習済みモデルをダウンロードして、ダウンロードされたフォルダ内にある、「requirements.txt」を使えば、必要なライブラリが一発でインストールできます。
!git clone https://github.com/rinnakk/japanese-pretrained-models
%cd /content/japanese-pretrained-models
!pip install -r requirements.txt
4.学習済みモデルを定義する
学習済みモデル("rinna/japanese-roberta-base")を、読み込みます。
この数行を記載するだけで、莫大なデータやリソース、時間を必要とする学習を省略することができます。
import torch
from transformers import T5Tokenizer, RobertaForMaskedLM
tokenizer = T5Tokenizer.from_pretrained("rinna/japanese-roberta-base")
tokenizer.do_lower_case = True
model = RobertaForMaskedLM.from_pretrained("rinna/japanese-roberta-base")
model = model.eval()
5.大喜利させたい文章を入力する
大喜利させたい文章を標準入力で受け付けます。
ポイントは、文章の頭に[CLS]を追加するところです。
print('#'*40)
print('答えさせたい内容を「〜、何。」と聞いてみてください。')
print('例;世界一うるさい人と言ったら、何。')
print('※最後の「。」を忘れずに。')
print('#'*40)
text = input()
text = "[CLS]" + text
6.「何」の部分を予測して、出力する。
「何」の文字列が含まれていれば、「masked_idx」で指定した箇所の単語を予測して、確率が高い単語を10個出力します。今回は、文末が「何。」となっていることを前提に、後ろから2つ目の単語を予測することにしています。
赤字下線の部分で、穴埋めの予測処理を実施しています。
if '何' in text:
tokens = tokenizer.tokenize(text)
masked_idx = -2
tokens[masked_idx] = tokenizer.mask_token
print('#'*40)
print('#以下の[MASK]の部分を考えます。')
print(tokens)
print('#'*40)
token_ids = tokenizer.convert_tokens_to_ids(tokens)
token_tensor = torch.tensor([token_ids])
with torch.no_grad():
outputs = model(token_tensor)
predictions = outputs[0][0, masked_idx].topk(10)
print('#'*40)
print('#思いついた答えは・・・')
for i, index_t in enumerate(predictions.indices):
index = index_t.item()
token = tokenizer.convert_ids_to_tokens([index])[0]
print(i, token)
print('#以上です。')
print('#'*40)
else:
print('何か聞いてくれないとわからないです。')
7.実行結果
このコードを実行して、「世界一の天才といったら、何。」と入力した場合の実行結果は以下の通りです。
なかなかいい味を出している回答を返してくれています。
8.感想
今回は、日本語に特化した事前学習モデルを使って、簡易大喜利モデルを作ってみました。このモデル自体が何かの役に立つかというと、そうでもないのですが、こんなに簡単に穴埋め予測モデルが作れてしまうので、この事前学習モデルの有用性を身を持って知ることができました。
今後は、このモデルをどの様に活用できるかの視点でも考えていきたいと思います。
【これまでの道のり】
oregin-ai.hatenablog.com
oregin-ai.hatenablog.com
oregin-ai.hatenablog.com
oregin-ai.hatenablog.com
oregin-ai.hatenablog.com