class AttentionGRU(GRU):

  def __init__(self, atten_states, states_len, L2Strength, **kwargs):
    '''
    :param atten_states: previous states for attention
    :param states_len: length of state
    :param L2Strength: for regularization
    :param kwargs: for GRU
    '''
    self.p_states = atten_states
    self.states_len = states_len
    self.size = kwargs['units']
    self.L2Strength = L2Strength
    super(AttentionGRU, self).__init__(**kwargs)

  def build(self, input_shape):
    input_dim = input_shape[-1]
    input_length = input_shape[1]
    self.W1 = self.add_weight(shape = (self.units + input_dim, 1),
                              initializer = 'random_uniform',
                              regularizer=l2(self.L2Strength),
                              trainable = True)
    self.b1 = self.add_weight(shape=(1,),
                              initializer = 'zero',
                              regularizer=l2(self.L2Strength),
                              trainable= True)
    self.W2 = self.add_weight(shape=(self.units + input_dim, self.units),
                              initializer='random_uniform',
                              regularizer=l2(self.L2Strength),
                              trainable=True)
    self.b2 = self.add_weight(shape=(self.units,),
                              initializer='zero',
                              regularizer=l2(self.L2Strength),
                              trainable=True)

    super(AttentionGRU, self).build(input_shape)

  def step(self, inputs, states):
    h, _ = super(AttentionGRU, self).step(inputs, states)

    alfa = K.repeat(h, self.states_len) # alfa = [batch_size, states_len, units]
    alfa = K.concatenate([self.p_states, alfa], axis = 2) # alfa = [batch_size, states_len, 2*units]
    scores = K.tanh(K.dot(alfa, self.W1) + self.b1) # scores = [batch_size, states_len, 1]
    scores = K.softmax(scores) 
    scores = K.reshape(scores, (-1, 1, self.states_len)) # scores = [batch_size, 1, states_len]
    attn = K.batch_dot(scores, self.p_states) # attn = [batch_size, 1, units]
    attn = K.reshape(attn, (-1, self.units))  # attn = [batch_size, units]

    h = keras.layers.concatenate([h, attn]) # h = [batch_size, 2*units]
    h = K.dot(h, self.W2) + self.b2 # h = [batch_size, units] 
    return h, [h]

  def compute_output_shape(self, input_shape):
    return input_shape[0], self.units

需要把encoder的states传给参数atten_states,然后就当Keras里标准的GRU用就好了。因为是GRU不是LSTM,所以step里计算方式和论文里有点不一样。units是hidden size,这里假设encoder和decoder的hidden size一样。