博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
关于google深度学习框架中PTB数据的batch方法中参数的理解
阅读量:3786 次
发布时间:2019-05-22

本文共 1709 字,大约阅读时间需要 5 分钟。

简介

在《实战google深度学习框架》中的9.2.2节中,介绍了如何对文本数据进行batching的方法,主要包含两种,一种是填充式(padding),另一种则是batching方法。为了方便查看代码中的参数的含义,这里做一下简单的记录。

内容

代码中的各参数如下图所示:

说明:在图中,假设数据为data_size=100,batch_size的大小为4,每个batch中截取的片段包含的字符数为num_step=5,则每个batch中包含的单词数量为batch_size*num_step=4*5=20,总的batch的数量为num_batch = data_size/(batch_size*num_step)=100/(4*5)=5。这里面主要需要理解num_step、batch的大小、batch_size、num_batches的含义。num_step可以理解为每个序列截取片段中包含的字符数,batch的大小可以理解为,神经网络每训练一次,可以同时处理的字符数量,num_batches可以理解为整个训练过程总共需要经过多少次迭代,batches_size可以理解为将整个训练数据分为多少个子序列。

以上仅为个人理解,共勉。

《实战google深度学习框架》中的9.2.2节中涉及的代码如下

import numpy as npimport tensorflow as tfTRAIN_DATA = 'ptb.train'TRAIN_BATCH_SIZE = 20TRAIN_NUM_STEP = 35#从文件中获取数据,并返回包含单词编号的数组def read_data(file_path):    with open(TRAIN_DATA,"r") as fin:        id_string = ' '.join([line.strip() for line in fin.readlines()])    id_list = [int(w) for w in id_string.split()]    return id_listdef make_batches(id_list, batch_size, num_step):    num_batches = (len(id_list)-1)//(batch_size*num_step)    #整理数据为一个维度为[batch_size, num_batches*num_step]的二维数组    data = np.array(id_list[:num_batches*batch_size*num_step])    data = np.reshape(data, [batch_size, num_batches*num_step])    #沿着第二个维度,将数据切分成num_batches个batch, 存入数组    data_batches = np.split(data, num_batches, axis=1)    #重复上述操作,但是每个位置向右移动一位,这里得到的是RNN每一步输出所需要预测的下一个单词    label = np.array(id_list[1:num_batches*batch_size*num_step+1])    label = np.reshap(label, [batch_size, num_batches*num_step])    label_batches = np.split(label, num_batches, axis=1)    return list(zip(data_batches,label_batches))def main():    train_batchws = make_batches(read_data(TRAIN_DATA), TRAIN_BATCH_SIZE, TRAIN_NUM_STEP)    #在这里插入训练代码,训练代码将在下一节中介绍if __name__ == '__main__':    main()

 

转载地址:http://kdktn.baihongyu.com/

你可能感兴趣的文章
类加载器
查看>>
数据库设计
查看>>
Java虚拟机的内存分配和运行机制(粗谈)
查看>>
web开发之BaseServlet的使用
查看>>
初识Maven
查看>>
Maven分模块构建项目
查看>>
MyBatis初识
查看>>
MyBatis【进阶详解】
查看>>
面试题集锦(七)
查看>>
注解开发——Spring整合dao/service/web
查看>>
架构的演进
查看>>
Elastic-Job的基础使用
查看>>
策略过滤器的灵活性分析
查看>>
POI的使用
查看>>
Anaconda和PyCharm的下载、安装和配置
查看>>
Mockito单元测试简述
查看>>
GUAVA的常用方法汇总
查看>>
装饰器和门面设计模式介绍
查看>>
创建型模式——克隆模式
查看>>
JVM关闭和Hook钩子
查看>>