俺人〜OREGIN〜俺、バカだから人工知能に代わりに頑張ってもらうまでのお話

俺って、おバカさんなので、とっても優秀な人工知能を作って代わりに頑張ってもらうことにしました。世界の端っこでおバカな俺が夢の達成に向けてチマチマ頑張る、そんな小さなお話です。現在はG検定、E資格に合格し、KaggleやProbSpaceのコンペに参画しながら、Pythonや機械学習、統計学、Dockerなどの勉強中です。学習したことをブログにアウトプットすることで、自分の身に着けていきたいと思います。まだまだ道半ばですが、お時間がありましたら見て行ってください。

rinna社が公開した、日本語に特化した「GPT-2」「BERT」事前学習モデルをつかって簡易大喜利を実装してみた!

元女子高生AI「りんな」などで知られるrinna株式会社が8月25日に公開してくださった、日本語に特化した「GPT-2」と「BERT」の事前学習モデルを使って、簡易大喜利(?)モデルを実装してみました。

日本語CC-100と日本語Wikipediaの計75ギガバイトのデータを、8つのNVIDIA Tesla V100 GPUで、45日間も掛けて学習しないと行けない事前学習モデルが、無料で簡単に使えるのは感動でした!

corp.rinna.co.jp

 では、振り返っていきたいと思います。

日本語に特化した「GPT-2」「BERT」事前学習モデルをつかって簡易大喜利を実装してみた!

1.今回実装した簡易大喜利モデル

今回実装した簡易大喜利モデルは、指定した文字列のうち、隠した部分の文字列を予測する機能を使って、面白い回答を期待するというモデルです。

本来であれば、いろいろと実装して、データを集めて、長時間学習させないと行けないところ、公開された事前学習モデルを使うことで、あっという間に実装することができました。

簡易大喜利モデルの仕様

  1. 穴埋めさせたい文章を入力する。
    例:「世界一うるさい人と言ったら、。」
  2. 簡易大喜利モデルが「」の部分を予測して、確率が高い順に出力する。
    例:「あなた
  3. おもしろ回答を見つけて、こっそり楽しむ。

2.全体像

コード全体は以下の通りで、Google colaboratory上で実行しました。

GitHubにも公開しています。

https://github.com/Oregin-ML/Simple_Ogiri_model

 

!git clone https://github.com/rinnakk/japanese-pretrained-models

%cd /content/japanese-pretrained-models

!pip install -r requirements.txt
import torch from transformers import T5Tokenizer, RobertaForMaskedLM # load tokenizer tokenizer = T5Tokenizer.from_pretrained("rinna/japanese-roberta-base") tokenizer.do_lower_case = True # due to some bug of tokenizer config loading # load model model = RobertaForMaskedLM.from_pretrained("rinna/japanese-roberta-base") model = model.eval() print('#'*40) print('答えさせたい内容を「〜、何。」と聞いてみてください。') print('例;世界一うるさい人と言ったら、何。') print('※最後の「。」を忘れずに。') print('#'*40) # original text text = input() # prepend [CLS] text = "[CLS]" + text if '何' in text: # tokenize tokens = tokenizer.tokenize(text) # mask a token masked_idx = -2 tokens[masked_idx] = tokenizer.mask_token print('#'*40) print('#以下の[MASK]の部分を考えます。') print(tokens) print('#'*40) # convert to ids token_ids = tokenizer.convert_tokens_to_ids(tokens) # convert to tensor token_tensor = torch.tensor([token_ids]) # get the top 10 predictions of the masked token with torch.no_grad(): outputs = model(token_tensor) predictions = outputs[0][0, masked_idx].topk(10) print('#'*40) print('#思いついた答えは・・・') for i, index_t in enumerate(predictions.indices): index = index_t.item() token = tokenizer.convert_ids_to_tokens([index])[0] print(i, token) print('#以上です。') print('#'*40) else: print('何か聞いてくれないとわからないです。')

では、コードを順番に見ていきます。

3.学習済みモデルをインストールする

GitHubより、学習済みモデルをダウンロードして、ダウンロードされたフォルダ内にある、「requirements.txt」を使えば、必要なライブラリが一発でインストールできます。

 

!git clone https://github.com/rinnakk/japanese-pretrained-models 

%cd /content/japanese-pretrained-models

!pip install -r requirements.txt

4.学習済みモデルを定義する

学習済みモデル("rinna/japanese-roberta-base")を、読み込みます。

この数行を記載するだけで、莫大なデータやリソース、時間を必要とする学習を省略することができます。

import torch
from transformers import T5Tokenizer, RobertaForMaskedLM

# load tokenizer
tokenizer = T5Tokenizer.from_pretrained("rinna/japanese-roberta-base")
tokenizer.do_lower_case = True  # due to some bug of tokenizer config loading

# load model
model = RobertaForMaskedLM.from_pretrained("rinna/japanese-roberta-base")
model = model.eval()

5.大喜利させたい文章を入力する

大喜利させたい文章を標準入力で受け付けます。

ポイントは、文章の頭に[CLS]を追加するところです。

print('#'*40)
print('答えさせたい内容を「〜、何。」と聞いてみてください。')
print('例;世界一うるさい人と言ったら、何。')
print('※最後の「。」を忘れずに。')
print('#'*40)

# original text
text = input()

# prepend [CLS]
text = "[CLS]" + text

6.「何」の部分を予測して、出力する。

「何」の文字列が含まれていれば、「masked_idx」で指定した箇所の単語を予測して、確率が高い単語を10個出力します。今回は、文末が「何。」となっていることを前提に、後ろから2つ目の単語を予測することにしています。

赤字下線の部分で、穴埋めの予測処理を実施しています。

if '何' in text:
  # tokenize
  tokens = tokenizer.tokenize(text)
  # mask a token
  masked_idx = -2
  tokens[masked_idx] = tokenizer.mask_token
  print('#'*40)
  print('#以下の[MASK]の部分を考えます。')
  print(tokens)  
  print('#'*40)
  # convert to ids
  token_ids = tokenizer.convert_tokens_to_ids(tokens)

  # convert to tensor
  token_tensor = torch.tensor([token_ids])

  # get the top 10 predictions of the masked token
  with torch.no_grad():
      outputs = model(token_tensor)
      predictions = outputs[0][0, masked_idx].topk(10)

  print('#'*40)
  print('#思いついた答えは・・・')
  for i, index_t in enumerate(predictions.indices):
      index = index_t.item()
      token = tokenizer.convert_ids_to_tokens([index])[0]
      print(i, token)
  print('#以上です。')
  print('#'*40)

else:
  print('何か聞いてくれないとわからないです。')  

7.実行結果

このコードを実行して、「世界一の天才といったら、何。」と入力した場合の実行結果は以下の通りです。

なかなかいい味を出している回答を返してくれています。

f:id:kanriyou_h004:20210826203630p:plain

 

8.感想

今回は、日本語に特化した事前学習モデルを使って、簡易大喜利モデルを作ってみました。このモデル自体が何かの役に立つかというと、そうでもないのですが、こんなに簡単に穴埋め予測モデルが作れてしまうので、この事前学習モデルの有用性を身を持って知ることができました。

今後は、このモデルをどの様に活用できるかの視点でも考えていきたいと思います。

 

 

【これまでの道のり】

oregin-ai.hatenablog.com

oregin-ai.hatenablog.com

oregin-ai.hatenablog.com

oregin-ai.hatenablog.com

oregin-ai.hatenablog.com

f:id:kanriyou_h004:20210826185619p:plain