博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
Keras实现Hierarchical Attention Network时的一些坑
阅读量:5161 次
发布时间:2019-06-13

本文共 3013 字,大约阅读时间需要 10 分钟。

Reshape

对于的张量x,x.shape=(a, b, c, d)的情况

若调用keras.layer.Reshape(target_shape=(-1, c, d)),

处理后的张量形状为(?, ?, c, d)

若调用tf.reshape(x, shape=[-1, c, d])

处理后的张量形状为(a*b, c, d)

为了在keras代码中实现tf.reshape的效果,用lambda层做,

调用Lambda(lambda x: tf.reshape(x, shape=[-1, c, d]))(x)

nice and cool.

输出Attention的打分

这里,我们希望attention层能够输出attention的score,而不只是计算weighted sum。

在使用时
score = Attention()(x)
weighted_sum = MyMerge()([score, x])

class Attention(Layer):    def __init__(self, **kwargs):        super(Attention, self).__init__(**kwargs)    def build(self, input_shape):        assert len(input_shape) == 3        self.w = self.add_weight(name="attention_weight",                                   shape=(input_shape[-1],                                          input_shape[-1]),                                   initializer='uniform',                                   trainable=True                                   )        self.b = self.add_weight(name="attention_bias",                                   shape=(input_shape[-1],),                                   initializer='uniform',                                   trainable=True                                   )        self.v = self.add_weight(name="attention_v",                                 shape=(input_shape[-1], 1),                                 initializer='uniform',                                 trainable=True                                 )        super(Attention, self).build(input_shape)    def call(self, inputs):        x = inputs        att = K.tanh(K.dot(x, self.w) + self.b)        att = K.softmax(K.dot(att, self.v))        print(att.shape)        return att    def compute_output_shape(self, input_shape):        return input_shape[0], input_shape[1], 1class MyMerge(Layer):    def __init__(self, **kwargs):        super(MyMerge, self).__init__(**kwargs)    def call(self, inputs):        att = inputs[0]        x = inputs[1]        att = tf.tile(att, [1, 1, x.shape[-1]])        outputs = tf.multiply(att, x)        outputs = K.sum(outputs, axis=1)        return outputs    def compute_output_shape(self, input_shape):        return input_shape[1][0], input_shape[1][2]

keras中Model的嵌套

这边是转载自https://github.com/uhauha2929/examples/blob/master/Hierarchical%20Attention%20Networks%20.ipynb

可以看到,sentEncoder是Model类型,在后面的时候通过TimeDistributed(sentEncoder),当成一个层那样被调用。

embedding_layer = Embedding(len(word_index) + 1,                            EMBEDDING_DIM,                            input_length=MAX_SENT_LENGTH)sentence_input = Input(shape=(MAX_SENT_LENGTH,), dtype='int32')embedded_sequences = embedding_layer(sentence_input)l_lstm = Bidirectional(LSTM(100))(embedded_sequences)sentEncoder = Model(sentence_input, l_lstm)review_input = Input(shape=(MAX_SENTS,MAX_SENT_LENGTH), dtype='int32')review_encoder = TimeDistributed(sentEncoder)(review_input)l_lstm_sent = Bidirectional(LSTM(100))(review_encoder)preds = Dense(2, activation='softmax')(l_lstm_sent)model = Model(review_input, preds)

转载于:https://www.cnblogs.com/bellz/p/11153691.html

你可能感兴趣的文章
(转)Tomcat 8 安装和配置、优化
查看>>
(转)Linxu磁盘体系知识介绍及磁盘介绍
查看>>
tkinter布局
查看>>
命令ord
查看>>
Sharepoint 2013搜索服务配置总结(实战)
查看>>
博客盈利请先考虑这七点
查看>>
使用 XMLBeans 进行编程
查看>>
写接口请求类型为get或post的时,参数定义的几种方式,如何用注解(原创)--雷锋...
查看>>
【OpenJ_Bailian - 2287】Tian Ji -- The Horse Racing (贪心)
查看>>
Java网络编程--socket服务器端与客户端讲解
查看>>
List_统计输入数值的各种值
查看>>
学习笔记-KMP算法
查看>>
Timer-triggered memory-to-memory DMA transfer demonstrator
查看>>
跨域问题整理
查看>>
[Linux]文件浏览
查看>>
64位主机64位oracle下装32位客户端ODAC(NFPACS版)
查看>>
获取国内随机IP的函数
查看>>
今天第一次写博客
查看>>
江城子·己亥年戊辰月丁丑日话凄凉
查看>>
IP V4 和 IP V6 初识
查看>>