convert Lasagne to Keras code (CNN -gt; LSTM)(将千层面转换为KERAS代码(CNN-GT;LSTM))
问题描述
我想转换此千层面代码:
et = {}
net['input'] = lasagne.layers.InputLayer((100, 1, 24, 113))
net['conv1/5x1'] = lasagne.layers.Conv2DLayer(net['input'], 64, (5, 1))
net['shuff'] = lasagne.layers.DimshuffleLayer(net['conv1/5x1'], (0, 2, 1, 3))
net['lstm1'] = lasagne.layers.LSTMLayer(net['shuff'], 128)
在Kera代码中。目前我想到了这个:
multi_input = Input(shape=(1, 24, 113), name='multi_input')
y = Conv2D(64, (5, 1), activation='relu', data_format='channels_first')(multi_input)
y = LSTM(128)(y)
但我收到错误:Input 0 is incompatible with layer lstm_1: expected ndim=3, found ndim=4
推荐答案
解决方案
from keras.layers import Input, Conv2D, LSTM, Permute, Reshape
multi_input = Input(shape=(1, 24, 113), name='multi_input')
print(multi_input.shape) # (?, 1, 24, 113)
y = Conv2D(64, (5, 1), activation='relu', data_format='channels_first')(multi_input)
print(y.shape) # (?, 64, 20, 113)
y = Permute((2, 1, 3))(y)
print(y.shape) # (?, 20, 64, 113)
# This line is what you missed
# ==================================================================
y = Reshape((int(y.shape[1]), int(y.shape[2]) * int(y.shape[3])))(y)
# ==================================================================
print(y.shape) # (?, 20, 7232)
y = LSTM(128)(y)
print(y.shape) # (?, 128)
说明
我把Lasagne和Kera的文件放在这里,这样您就可以对照了:
Lasagne
重复层的使用方式与前馈层类似,但以下情况除外
输入形状应为(batch_size, sequence_length, num_inputs)
Keras
输入形状
形状为
(batch_size, timesteps, input_dim)
的3D张量。
API基本上是相同的,但是Lasagne可能会为您重塑(稍后我需要检查源代码)。这就是您收到此错误的原因:
Input 0 is incompatible with layer lstm_1: expected ndim=3, found ndim=4
,因为Conv2D
后面的张量形状是ndim=4
的(?, 64, 20, 113)
因此,解决方案是将其重塑为(?, 20, 7232)
。
编辑
与千层面source code确认后,它为您做到了:
num_inputs = np.prod(input_shape[2:])
因此,作为LSTM输入的正确张量形状为(?, 20, 64 * 113)
=(?, 20, 7232)
注意:
Permute
在Kera中是多余的,因为无论如何您都必须重塑。我之所以把它放在这里,是为了有一个从千层面到凯拉斯的"完整翻译",它所做的就是DimshuffleLaye
在千层面中所做的事情。
DimshuffleLaye
由于我在编辑中提到的原因,在Lasagne中需要<2-10]>,Lasagne LSTM创建的新维度来自"最后两个"维度的相乘。
这篇关于将千层面转换为KERAS代码(CNN->;LSTM)的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持编程学习网!
本文标题为:将千层面转换为KERAS代码(CNN->;LSTM)


基础教程推荐
- 求两个直方图的卷积 2022-01-01
- 无法导入 Pytorch [WinError 126] 找不到指定的模块 2022-01-01
- 在同一图形上绘制Bokeh的烛台和音量条 2022-01-01
- 在Python中从Azure BLOB存储中读取文件 2022-01-01
- 使用大型矩阵时禁止 Pycharm 输出中的自动换行符 2022-01-01
- PANDA VALUE_COUNTS包含GROUP BY之前的所有值 2022-01-01
- Plotly:如何设置绘图图形的样式,使其不显示缺失日期的间隙? 2022-01-01
- 修改列表中的数据帧不起作用 2022-01-01
- 包装空间模型 2022-01-01
- PermissionError: pip 从 8.1.1 升级到 8.1.2 2022-01-01