Python快乐编程:人工智能深度学习基础
上QQ阅读APP看书,第一时间看更新

2.9.1 scan循环的参数

scan循环的参数有很多,本书仅对以下几个主要参数进行讲解:fnsequences、outputs_info、non_sequences、n_steps、truncate_gradient、strict。

1.fn

该参数是一个函数,通常是一个lambda函数或def函数,fn是scan最核心的组成部分,它定义了每一次循环的处理逻辑,可以返回sequences变量的更新updates。fn对函数参数的定义顺序和函数输出有严格对应的要求,输入变量顺序为sequences、outputs_info、non_sequences。

2.sequences

sequences是一个由Theano变量或字典构成的列表,它们的值将作为参数传递给函数fn。列表中的每一个元素都是一个序列,每次迭代可以传递序列的一个元素或多个元素,具体示例代码如下所示:

上述代码中scan函数的sequences参数包含了以下3个参数:sequence1、sequence2、sequence3,这是3个输入序列。

· sequence1:通常以字典的形式表示,字典中可以包括input(输入序列)和taps(索引)两个key值。上述代码表示在第t次迭代时,sequence1传递给fn的参数有sequence1[t-1]和sequence1[t-2]。

· sequence2:以普通的Theano变量形式传递,该参数等价于下列代码:

    dict(input = squence2,taps = [0])

当忽略taps参数时,Theano会默认taps的值为0,因此,在第t次迭代时,sequence2传递给fn的参数为sequence2[t]。

· sequence3:结合前两个参数的传递过程可以看出,在第t次迭代时,sequence3传递给fn的参数为sequence3[t+3]。

3.outputs_info

与sequences的表达相似,outputs_info也是一个由Theano变量或字典构成的列表,列表中的每个元素都是函数fn的输出结果的初始值,具体示例如下所示:

上述代码的sequences参数包含3个参数:output1、output2、output3。

· output1:以字典的形式进行表示。用字典形式表示outputs_info时,可以包括initial(定义初始值)和taps(索引)两个key值。表示在第t次迭代时,output1传递给fn函数的参数为output1[t-3]和output1[t-5]。

· output2:以普通的Theano变量形式传递,该参数等价于下列代码:

    dict(initial = output2,taps = [-1])

与前面提到的sequence2情况一样,在忽略taps的值时,系统会为taps自动添加默认值,但是需要注意,这里的taps默认值为-1。表示在第t次迭代时,output2传递给fn函数的参数为output2[t-1]。

· output3:结合前两个参数的传递过程可以看出,output3表示在第t次迭代时,传递给fn函数的参数为sequence3[t+3]。

4.non_sequences

该参数是一个不变量或常数值列表,与前两个参数不同,该参数在迭代过程中不可改变。在实际应用中,一般把该参数设置为模型的权重参数列表。

5.n_steps

n_steps用来指定scan的迭代次数。sequences与n_steps两个参数中至少存在一个,否则scan无法知道迭代的步数。

6.truncate_gradient

这是一个专为循环神经网络训练设计的参数。利用scan来实现BPTT算法时,truncate_gradient用于指定向前传播的步长值,当值为-1时,表示采用的是传统的BPTT算法;当值大于0时,表示向前执行步长达到truncate_gradient设定值时,会提前结束并返回。这种截断策略可以用于处理传统的BPTT算法中的梯度消失问题。

7.strict

当该参数的值为True时,必须保证所有用到的Theano共享变量都放置在non_sequences参数中。