機械翻訳と夏の日の思い出

初めて機械翻訳に触れたのは中学何年だかの夏の終わりだった。 夏休みの英語の宿題は退屈で仕方がなかった。

僕が入力した日本語から生まれた英文は知らない単語だらけで、それが正しいかどうかはよく分からなかったけれど、回答欄に難しい単語を書くのは少しカッコイイことのような気がしたのでそのまま写して提出した。

今日は機械翻訳システムの話をしようと思う。 宿題の再提出を告げられる生徒が一人でも減る未来を願いながら。


ある英文を日本語に訳すことを考えよう。

英語と日本語は語順が違うので、文頭から順に対応する単語を並べても正しい翻訳にはならない。 真面目に翻訳するためには、前後の単語や前後の文を考慮して英文の意味を捉える必要がある。

さて、英文に目を通し終えたとしよう。 どの文、どの単語から訳し始めるかはセンスの見せ所である。 翻訳文の書き出しを決めるためには、「文頭」という文脈において、英文のどの語に注目すると良いか考えなければならない。 思案の末、翻訳文の一単語目を決めたとする。 次は二単語目である。 初めの単語は決めたので、それに続く語として相応しく、かつ英文の意味を反映したものを二単語目に選ぶ。 以降、このプロセスを繰り返すことで翻訳文は作られる。

Attention-based Neural Machine Translationは、まさにこのようなプロセスで翻訳を解く機械翻訳のモデルである。 このモデルはエンコーダ、デコーダ、アテンションの3つのモジュールで構成されている。 エンコーダは入力文(翻訳元の文)を読みこんで各単語の表現(annotation vector)を得る。 アテンションは出力文(翻訳先の文)の文脈に基づいて入力文の各単語に対する注目度(attention)を計算する機構で、annotation vectorはattentionで重み付けされた後に足し合わされる(context vector)。 デコーダは出力文の文脈とcontext vectorに基づいて翻訳文の続きを出力する。


最も基本的なAttention-based NMTモデルである[Bahdanau+, 2014]を実装する。 実装にはChainerを使った。 その他の開発環境の詳細についてはGithubを参照してほしい。

  • モデル全体の枠組み

Seq2seqクラスは翻訳モデルのインタフェースとして外部から入力を受け取る。 __call__関数は、入力文と出力文のペアを受け取り、入力文から生成した翻訳文と出力文の誤差を計算する。 translate関数は、入力文のみを受け取り、翻訳文を生成する。

class Seq2seq(chainer.Chain):

    def __init__(self, n_source_vocab, n_target_vocab,
                 n_encoder_layers, n_encoder_units, n_encoder_dropout,
                 n_decoder_units, n_attention_units, n_maxout_units):
        super(Seq2seq, self).__init__()
        with self.init_scope():
            self.encoder = Encoder(
                n_source_vocab,
                n_encoder_layers,
                n_encoder_units,
                n_encoder_dropout
            )
            self.decoder = Decoder(
                n_target_vocab,
                n_decoder_units,
                n_attention_units,
                n_encoder_units * 2,  # because of bi-directional lstm
                n_maxout_units,
            )

    def __call__(self, xs, ys):
        """Calculate loss between outputs and ys.
        Args:
            xs: Source sentences.
            ys: Target sentences.
        Returns:
            loss: Cross-entoropy loss between outputs and ys.
        """
        batch_size = len(xs)

        hxs = self.encoder(xs)
        os = self.decoder(ys, hxs)

        concatenated_os = F.concat(os, axis=0)
        concatenated_ys = F.flatten(ys.T)
        n_words = len(self.xp.where(concatenated_ys.data != PAD)[0])

        loss = F.sum(
            F.softmax_cross_entropy(
                concatenated_os, concatenated_ys, reduce='no', ignore_label=PAD
            )
        )
        loss = loss / n_words
        chainer.report({'loss': loss.data}, self)
        perp = self.xp.exp(loss.data * batch_size / n_words)
        chainer.report({'perp': perp}, self)
        return loss

    def translate(self, xs, max_length=100):
        """Generate sentences based on xs.
        Args:
            xs: Source sentences.
        Returns:
            ys: Generated target sentences.
        """
        with chainer.no_backprop_mode(), chainer.using_config('train', False):
            hxs = self.encoder(xs)
            ys = self.decoder.translate(hxs, max_length)
        return ys
  • エンコーダの実装

Encoderクラスは入力文の各単語のannotation vectorを計算する。 Bi-directional LSTMは入力を順方向だけでなく逆方向からも読み込むモデルで、前後の単語を考慮した表現が獲得されると言われている。 ChainerのNStepBiLSTMは可変長の入力に対応しているので、そのための前処理・後処理も行っている。

class Encoder(chainer.Chain):

    def __init__(self, n_vocab, n_layers, n_units, dropout):
        super(Encoder, self).__init__()
        with self.init_scope():
            self.embed_x = L.EmbedID(n_vocab, n_units, ignore_label=-1)
            self.bilstm = L.NStepBiLSTM(n_layers, n_units, n_units, dropout)

    def __call__(self, xs):
        """Encode source sequences into the representations.
        Args:
            xs: Source sequences.
        Returns:
            hxs: Hidden states for source sequences.
        """
        batch_size, max_length = xs.shape

        exs = self.embed_x(xs)
        exs = F.separate(exs, axis=0)
        masks = self.xp.vsplit(xs != -1, batch_size)
        masked_exs = [ex[mask.reshape((-1, ))] for ex, mask in zip(exs, masks)]

        _, _, hxs = self.bilstm(None, None, masked_exs)
        hxs = F.pad_sequence(hxs, length=max_length, padding=0.0)
        return hxs

Decoderクラスは出力文を生成する。 出力文の生成は、previous_embedding(前回の単語の表現)、h(現在の状態)、context(context vector)に基づいて次の単語を決定することを繰り返して行われる。 ただし、学習時にはys(正しい出力文)が与えられるため、previous_embeddingは前回の自身の出力ではなく、ysに基づいて計算される。これはteacher forcingと呼ばれている。

__call__関数の中で予測を出力するたびに正しい出力単語との誤差を取ることもできるが、 学習時間の短縮のために行列計算の回数をなるべく減らしたいので、全ての予測結果を結合しSeq2seq側で誤差をまとめて計算している。

class Decoder(chainer.Chain):

    def __init__(self, n_vocab, n_units, n_attention_units,
                 n_encoder_output_units, n_maxout_units, n_maxout_pools=2):
        super(Decoder, self).__init__()
        with self.init_scope():
            self.embed_y = L.EmbedID(n_vocab, n_units, ignore_label=-1)
            self.lstm = L.StatelessLSTM(
                n_units + n_encoder_output_units,
                n_units
            )
            self.maxout = L.Maxout(
                n_units + n_encoder_output_units + n_units,
                n_maxout_units,
                n_maxout_pools
            )
            self.w = L.Linear(n_units, n_vocab)
            self.attention = AttentionModule(
                n_encoder_output_units,
                n_attention_units,
                n_units
            )
            self.bos_state = Parameter(
                initializer=self.xp.random.randn(1, n_units).astype('f')
            )
        self.n_units = n_units

    def __call__(self, ys, hxs):
        """Calculate cross-entoropy loss between predictions and ys.
        Args:
            ys: Target sequences.
            hxs: Hidden states for source sequences.
        Returns:
            os: Probability density for output sequences.
        """
        batch_size, max_length, encoder_output_size = hxs.shape

        compute_context = self.attention(hxs)
        # initial cell state
        c = Variable(self.xp.zeros((batch_size, self.n_units), 'f'))
        # initial hidden state
        h = F.broadcast_to(self.bos_state, ((batch_size, self.n_units)))
        # initial character's embedding
        previous_embedding = self.embed_y(
            Variable(self.xp.full((batch_size, ), EOS, 'i'))
        )

        os = []
        for y in self.xp.hsplit(ys, ys.shape[1]):
            y = y.reshape((batch_size, ))
            context = compute_context(h)
            concatenated = F.concat((previous_embedding, context))

            c, h = self.lstm(c, h, concatenated)
            concatenated = F.concat((concatenated, h))
            o = self.w(self.maxout(concatenated))

            os.append(o)
            previous_embedding = self.embed_y(y)
        return os

    def translate(self, hxs, max_length):
        """Generate target sentences given hidden states of source sentences.
        Args:
            hxs: Hidden states for source sequences.
        Returns:
            ys: Generated sequences.
        """
        batch_size, _, _ = hxs.shape
        compute_context = self.attention(hxs)
        c = Variable(self.xp.zeros((batch_size, self.n_units), 'f'))
        h = F.broadcast_to(self.bos_state, ((batch_size, self.n_units)))
        # first character's embedding
        previous_embedding = self.embed_y(
            Variable(self.xp.full((batch_size, ), EOS, 'i'))
        )

        results = []
        for _ in range(max_length):
            context = compute_context(h)
            concatenated = F.concat((previous_embedding, context))

            c, h = self.lstm(c, h, concatenated)
            concatenated = F.concat((concatenated, h))

            logit = self.w(self.maxout(concatenated))
            y = F.reshape(F.argmax(logit, axis=1), (batch_size, ))

            results.append(y)
            previous_embedding = self.embed_y(y)
        else:
            results = F.separate(F.transpose(F.vstack(results)), axis=0)

        ys = []
        for result in results:
            index = np.argwhere(result.data == EOS)
            if len(index) > 0:
                result = result[:index[0, 0] + 1]
            ys.append(result.data)
        return ys
  • アテンションの実装

Attentionクラスはannotation vectorデコーダの状態に基づいてcontext vectorを計算する。 __call__関数は、デコーダの状態(hxs)に基づいて各単語に対するattention(attention)を計算し、その出力に応じてcontext vectorcontext)を計算する関数を返している。 broadcast_toreshapeを連発しているが、これらも行列計算の回数をなるべく減らすための工夫である。

class AttentionModule(chainer.Chain):

    def __init__(self, n_encoder_output_units,
                 n_attention_units, n_decoder_units):
        super(AttentionModule, self).__init__()
        with self.init_scope():
            self.h = L.Linear(n_encoder_output_units, n_attention_units)
            self.s = L.Linear(n_decoder_units, n_attention_units)
            self.o = L.Linear(n_attention_units, 1)
        self.n_encoder_output_units = n_encoder_output_units
        self.n_attention_units = n_attention_units

    def __call__(self, hxs):
        """Returns a function that calculates context given decoder's state.
        Args:
            hxs: Encoder's hidden states.
        Returns:
            compute_context: A function to calculate attention.
        """
        batch_size, max_length, encoder_output_size = hxs.shape

        encoder_factor = F.reshape(
            self.h(
                F.reshape(
                    hxs,
                    (batch_size * max_length, self.n_encoder_output_units)
                )
            ),
            (batch_size, max_length, self.n_attention_units)
        )

        def compute_context(previous_state):
            decoder_factor = F.broadcast_to(
                F.reshape(
                    self.s(previous_state),
                    (batch_size, 1, self.n_attention_units)
                ),
                (batch_size, max_length, self.n_attention_units)
            )

            attention = F.softmax(
                F.reshape(
                    self.o(
                        F.reshape(
                            F.tanh(encoder_factor + decoder_factor),
                            (batch_size * max_length, self.n_attention_units)
                        )
                    ),
                    (batch_size, max_length)
                )
            )

            context = F.reshape(
                F.batch_matmul(attention, hxs, transa=True),
                (batch_size, encoder_output_size)
            )
            return context

現在使われている機械翻訳システムの多くはAttention-based NMTがベースであり、いくつかの言語対では人間と同等の翻訳性能に迫っている。

その一方で、重要な課題も残っている。

一つは学習に非常に長い時間がかかることである。 計算時間のボトルネックになっているのはRecurrent Neural Network(RNN)による特徴抽出である。 RNNは入力を時系列順に受け付けるため、GPUによる並列計算の恩恵を受けづらい。 最近はConvolutional Neural NetworkやFeed-Forward Neural Networkによって系列特徴を高速に得られる手法が提案されている。

また、モデルが何を学んでいるかが不明瞭なことも課題である。 機械翻訳を含む近年の深層学習の枠組みでは、正しい入力と出力のペアを用意し、それらの写像を直接学習する。 重要なのは、入力のどのような特徴に着目し、何を学んで写像を解くかは数理的な最適化に委ねている点である。 例えば、ハヤシライスとカレーライスを識別するモデルを学習するとき、我々は当然ハヤシライスらしさとカレーライスらしさが学習されることを期待する。 しかし、もしハヤシライスとカレーライスで盛り付けの皿が違ったなら、皿に注目して識別するよう学習してしまう可能性もある。 この例はあまりに極端だが、複雑な問題が一見上手く解けているように見えるとき、それが我々の期待する方法で解決されているのか確認することは非常に難しい。 近年はこうした深層学習のブラックボックスに迫る研究も盛んに行われている。


参考文献

Dzmitry Bahdanau, Kyunghyun Cho, and Yoshua Bengio. 2014. NEURAL MACHINE TRANSLATION BY JOINTLY LEARNING TO ALIGN AND TRANSLATE. CoRR abs/1409.0473. http://arxiv.org/abs/1409.0473 .

その他

実装: https://github.com/kiyomaro927/chainer-attention-nmt