In this post Ill talk through the pieces of the tensorflow API most relevant to building recurrent neural networks. The tensorflow documentation is great for explaining how to build standard RNNs, but it falls a little flat for building highly customized RNNs.
Ill use the network described in Hierarchical Multiscale Recurrent Neural Networks by Chung et al. as an example of a fairly non-standard RNN. Theres an open-source implementation of that network on github.
In this section Ill provide a quick overview of the tools available to create standard RNNs in tensorflow.
If you need a standard RNN, GRU, or LSTM, tensorflow has you covered. The API contains these pre-written cells, all of which extend the base class RNNCell. Their tutorial on RNNs gives a good overview of how to use these cells, so I wont spend much time here. If you are completely new to RNNs in tensorflow, it may be a good idea to review that tutorial before continuing.
One thing that is worth pointing out about their tutorial is their use of the MultiRNNCell. This is a class that is constructed with a list of objects that extend RNNCell. It is used for creating multi-layered RNNs, where the lowest layer gets fed in the input, then each subsequent layer gets fed the output of the previous layer, and the output of the last layer is the output of the network at a given timestep. If you want to pass information between layers in a different way, youll need a custom multicell. Well come back to this later.
Note that MutliRNNCell itself extends RNNCell as well, so it can be used anywhere any other RNNCell can be used.
I find it easiest to think about RNNs when they are unrolled. In tensorflow, this means that we rebuild an identical computational graph for each timestep, and pass the hidden state(s) from one timestep forward to the next manually, as it were.
This is the approach taken the tensorflow tutorial model. The upside to this approach is that it is easy to think about, and it is flexible. If you want to process the hidden states from one time step in any way before you pass them on, you can easily put more nodes in the graph to do so.
There are two major downsides. First, a graph composed in this way has to be fixed length, which means youll have to rebuild the graph for different length signals, or pad them out with zeros. Neither solution is great.
Second, large graphs take much longer to build and consume much more RAM. Depending on the constraints of your computing environment, this could be prohibitive.
The tf.dynamic_rnn function will transform your RNNCell into a dynamically generated graph that passes the state, whatever that may be, from one time step to the next, and keeps track of the outputs. If you have other needs, tf.scan can serve a similar purpose more flexibly, as we will see later.
In this section, Ill review some available options for creating RNNs with less standard architectures.
If you need a network thats a little different from any of the standard implementations, you can extend RNNCell directly. You can use your own subclass with the MultiRNNCell class described above as well as the DropoutWrapper and other predefined RNNCell wrappers.
Extending RNNCell means overriding at least the
state_size property, the
output_size property, and the
call method. Tensorflows prebuilt cells
internally represent state either as a signal tensor or as a tuple of tensors.
If it is a single tensor, it gets broken down into cell and hidden states (or
whatever the case may be) upon entry into the cell, and then the new states are
stuck back together at the end.
Ive found it simpler to treat cell state as a tuple. In this case, the
state_size property is just a tuple with the lengths of whichever states
youre keeping track of.
output_size is the length of the output of the cell.
call function is where the logic of your cell will live. RNNCells
__call__ method will wrap your
call method and help with scoping and other
In order for your subclass to be a valid RNNCell, the
call method must accept
state, and return a tuple of
output, new_state, where
new_state must have the same form.
Note that if you construct a new RNNCell that you want to use with tensorflow
variables that already exist in your tensorflow session, you can pass a
_reuse=True argument in to the parent constructor within your
method. If the variables already exist but you do not pass
get an error because tensorflow will neither overwrite existing variables or
reuse them without explicit instruction.
For reference, the HMLSTMCell class is an RNNCell used to represent one cell of the Hierarchical and Multiscale RNN mentioned above. Its implementation covers all the main points above.
That code also makes use of an undocumented function in the
_linear, which is used in most of the baked in RNNCell
subclasses. This is a little risky, because its clearly not meant for outside
use, but its a useful little function that handles matrix multiplication and
addition of weights and biases.
If youre building a multi-layered RNN where the layers dont simply pass their output up from layer to layer, youll have to build your own version of a MultiCell. Much like the built in MultiRNNCell, your multicell should extend RNNCell.
In this case the cell state will be a list, where each element is the cell state at the layer corresponding to its index.
Writing your own multicell is useful in two cases. First, in the case where you want to do something to the result of one layer before you pass it into the cell at the next layer, but you dont want to execute that operation for the lowest layer (otherwise you could just build it into the cell).
Second, its useful if theres information from the previous time step that you need to distribute among the different layers, but that doesnt fit neatly into the paradigm of passing along state from one time step to the next.
For example, in the hierarchical multiscall LSTM, each cell expects to receive the hidden state from the layer above it at the previous time step as part of its input. This doesnt neatly fit the standard idea of stacked RNNs, so we cant use the usual MultiRNNCell. For reference, here is the implementation of the MultiHMLSTMCell.
The Hierarchical Multiscale LSTM network calls for the hidden states to be fed into some output network. Weve already seen that this HMLSTM network doesnt neatly fit into the tensorflow RNN paradigm because of how it handles passing information between layers; now weve hit another obstacle. Instead of considering the last output of the last layer the output of the network, we need to pass the hidden states of all layers through another network to get the output we really care about. Not only that, we care about the value of some of the indicator neurons, which are treated as cell state.
For these reasons, we cant use the
tf.dynamic_rnn network, which returns only
the output at each time step and the final state.
tf.scan, instead, takes an arbitrary function and a ordered collection of
elements. It then applies the function to each element in the collection,
keeping track of some accumulator. It returns an ordered collection of the value
of the accumulator at each step in the process.
This is perfect for a more customized RNN. Because you get to define the function, you can manipulate the inputs, outputs, and states however you so choose. Afterwards, you get a full accounting of the state at every time step, rather than just the output.
In the case of the HMLSTM, we use these states to keep track of the boundary detection, and we also map over them to obtain the final predictions.
Heres the code for reference.
In this post, we looked at the standard tools for dealing with RNNs in tensorflow, and explored some more flexible options to use when those tools fall short.