摘要:經過第一步的處理已經把古詩詞詞語轉換為可以機器學習建模的數字形式,因為我們采用算法進行古詩詞生成,所以還需要構建輸入到輸出的映射處理。
LSTM 介紹
序列化數據即每個樣本和它之前的樣本存在關聯,前一數據和后一個數據有順序關系。深度學習中有一個重要的分支是專門用來處理這樣的數據的——循環神經網絡。循環神經網絡廣泛應用在自然語言處理領域(NLP),今天我們帶你從一個實際的例子出發,介紹循環神經網絡一個重要的改進算法模型-LSTM。本文章不對LSTM的原理進行深入,想詳細了解LSTM可以參考這篇 [譯] 理解 LSTM 網絡。本文重點從古詩詞自動生成的實例出發,一步一步帶你從數據處理到模型搭建,再到訓練出古詩詞生成模型,最后實現從古詩詞自動生成新春祝福詩詞。
數據處理我們使用76748首古詩詞作為數據集,數據集下載鏈接,原始的古詩詞的存儲形式如下:
我們可以看到原始的古詩詞是文本符號的形式,無法直接進行機器學習,所以我們第一步需要把文本信息轉換為數據形式,這種轉換方式就叫詞嵌入(word embedding),我們采用一種常用的詞嵌套(word embedding)算法-Word2vec對古詩詞進行編碼。關于Word2Vec這里不詳細講解,感興趣可以參考 [NLP] 秒懂詞向量Word2vec的本質。在詞嵌套過程中,為了避免最終的分類數過于龐大,可以選擇去掉出現頻率較小的字,比如可以去掉只出現過一次的字。Word2vec算法經過訓練后會產生一個模型文件,我們就可以利用這個模型文件對古詩詞文本進行詞嵌套編碼。
經過第一步的處理已經把古詩詞詞語轉換為可以機器學習建模的數字形式,因為我們采用LSTM算法進行古詩詞生成,所以還需要構建輸入到輸出的映射處理。例如:
“[長河落日圓]”作為train_data,而相應的train_label就是“長河落日圓]]”,也就是
“[”->“長”,“長”->“河”,“河”->“落”,“落”->“日”,“日”->“圓”,“圓”->“]”,“]”->“]”,這樣子先后順序一一對相。這也是循環神經網絡的一個重要的特征。
這里的“[”和“]”是開始符和結束符,用于生成古詩的開始與結束標記。
總結一下數據處理的步驟:
讀取原始的古詩詞文本,統計出所有不同的字,使用 Word2Vec 算法進行對應編碼;
對于每首詩,將每個字、標點都轉換為字典中對應的編號,構成神經網絡的輸入數據 train_data;
將輸入數據左移動構成輸出標簽 train_label;
經過數據處理后我們得到以下數據文件:
poems_edge_split.txt:原始古詩詞文件,按行排列,每行為一首詩詞;
vectors_poem.bin:利用 Word2Vec訓練好的詞向量模型,以開頭,按詞頻排列,去除低頻詞;
poem_ids.txt:按輸入輸出關系映射處理之后的語料庫文件;
rhyme_words.txt: 押韻詞存儲,用于押韻詩的生成;
在提供的源碼中已經提供了以上四個數據文件放在data文件夾下,數據處理代碼見 data_loader.py 文件,源碼鏈接
模型構建及訓練這里我們使用2層的LSTM框架,每層有128個隱藏層節點,我們使用tensorflow.nn模塊庫來定義網絡結構層,其中RNNcell是tensorflow中實現RNN的基本單元,是一個抽象類,在實際應用中多用RNNcell的實現子類BasicRNNCell或者BasicLSTMCell,BasicGRUCell;如果需要構建多層的RNN,在TensorFlow中,可以使用tf.nn.rnn_cell.MultiRNNCell函數對RNNCell進行堆疊。模型網絡的第一層要對輸入數據進行 embedding,可以理解為數據的維度變換,經過兩層LSTM后,接著softMax得到一個在全字典上的輸出概率。
模型網絡結構如下:
定義網絡的類的程序代碼如下:
class CharRNNLM(object): def __init__(self, is_training, batch_size, vocab_size, w2v_model, hidden_size, max_grad_norm, embedding_size, num_layers, learning_rate, cell_type, dropout=0.0, input_dropout=0.0, infer=False): self.batch_size = batch_size self.hidden_size = hidden_size self.vocab_size = vocab_size self.max_grad_norm = max_grad_norm self.num_layers = num_layers self.embedding_size = embedding_size self.cell_type = cell_type self.dropout = dropout self.input_dropout = input_dropout self.w2v_model = w2v_model if embedding_size <= 0: self.input_size = vocab_size self.input_dropout = 0.0 else: self.input_size = embedding_size # 輸入和輸入定義 self.input_data = tf.placeholder(tf.int64, [self.batch_size, self.num_unrollings], name="inputs") self.targets = tf.placeholder(tf.int64, [self.batch_size, self.num_unrollings], name="targets") # 根據定義選擇不同的循環神經網絡內核單元 if self.cell_type == "rnn": cell_fn = tf.nn.rnn_cell.BasicRNNCell elif self.cell_type == "lstm": cell_fn = tf.nn.rnn_cell.LSTMCell elif self.cell_type == "gru": cell_fn = tf.nn.rnn_cell.GRUCell params = dict() if self.cell_type == "lstm": params["forget_bias"] = 1.0 cell = cell_fn(self.hidden_size, **params) cells = [cell] for i in range(self.num_layers-1): higher_layer_cell = cell_fn(self.hidden_size, **params) cells.append(higher_layer_cell) # 訓練時是否進行 Dropout if is_training and self.dropout > 0: cells = [tf.nn.rnn_cell.DropoutWrapper(cell, output_keep_prob=1.0-self.dropout) for cell in cells] # 對lstm層進行堆疊 multi_cell = tf.nn.rnn_cell.MultiRNNCell(cells) # 定義網絡模型初始狀態 with tf.name_scope("initial_state"): self.zero_state = multi_cell.zero_state(self.batch_size, tf.float32) if self.cell_type == "rnn" or self.cell_type == "gru": self.initial_state = tuple( [tf.placeholder(tf.float32, [self.batch_size, multi_cell.state_size[idx]], "initial_state_"+str(idx+1)) for idx in range(self.num_layers)]) elif self.cell_type == "lstm": self.initial_state = tuple( [tf.nn.rnn_cell.LSTMStateTuple( tf.placeholder(tf.float32, [self.batch_size, multi_cell.state_size[idx][0]], "initial_lstm_state_"+str(idx+1)), tf.placeholder(tf.float32, [self.batch_size, multi_cell.state_size[idx][1]], "initial_lstm_state_"+str(idx+1))) for idx in range(self.num_layers)]) # 定義 embedding 層 with tf.name_scope("embedding_layer"): if embedding_size > 0: # self.embedding = tf.get_variable("embedding", [self.vocab_size, self.embedding_size]) self.embedding = tf.get_variable("word_embeddings", initializer=self.w2v_model.vectors.astype(np.float32)) else: self.embedding = tf.constant(np.eye(self.vocab_size), dtype=tf.float32) inputs = tf.nn.embedding_lookup(self.embedding, self.input_data) if is_training and self.input_dropout > 0: inputs = tf.nn.dropout(inputs, 1-self.input_dropout) # 創建每個切分通道網絡層 with tf.name_scope("slice_inputs"): sliced_inputs = [tf.squeeze(input_, [1]) for input_ in tf.split( axis = 1, num_or_size_splits = self.num_unrollings, value = inputs)] outputs, final_state = tf.nn.static_rnn( cell = multi_cell, inputs = sliced_inputs, initial_state=self.initial_state) self.final_state = final_state # 數據變換層,把經過循環神經網絡的數據拉伸降維 with tf.name_scope("flatten_outputs"): flat_outputs = tf.reshape(tf.concat(axis = 1, values = outputs), [-1, hidden_size]) with tf.name_scope("flatten_targets"): flat_targets = tf.reshape(tf.concat(axis = 1, values = self.targets), [-1]) # 定義 softmax 輸出層 with tf.variable_scope("softmax") as sm_vs: softmax_w = tf.get_variable("softmax_w", [hidden_size, vocab_size]) softmax_b = tf.get_variable("softmax_b", [vocab_size]) self.logits = tf.matmul(flat_outputs, softmax_w) + softmax_b self.probs = tf.nn.softmax(self.logits) # 定義 loss 損失函數 with tf.name_scope("loss"): loss = tf.nn.sparse_softmax_cross_entropy_with_logits( logits = self.logits, labels = flat_targets) self.mean_loss = tf.reduce_mean(loss) # tensorBoard 損失函數可視化 with tf.name_scope("loss_montor"): count = tf.Variable(1.0, name="count") sum_mean_loss = tf.Variable(1.0, name="sum_mean_loss") self.reset_loss_monitor = tf.group(sum_mean_loss.assign(0.0), count.assign(0.0), name="reset_loss_monitor") self.update_loss_monitor = tf.group(sum_mean_loss.assign(sum_mean_loss+self.mean_loss), count.assign(count+1), name="update_loss_monitor") with tf.control_dependencies([self.update_loss_monitor]): self.average_loss = sum_mean_loss / count self.ppl = tf.exp(self.average_loss) average_loss_summary = tf.summary.scalar( name = "average loss", tensor = self.average_loss) ppl_summary = tf.summary.scalar( name = "perplexity", tensor = self.ppl) self.summaries = tf.summary.merge( inputs = [average_loss_summary, ppl_summary], name="loss_monitor") self.global_step = tf.get_variable("global_step", [], initializer=tf.constant_initializer(0.0)) self.learning_rate = tf.placeholder(tf.float32, [], name="learning_rate") if is_training: tvars = tf.trainable_variables() grads, _ = tf.clip_by_global_norm(tf.gradients(self.mean_loss, tvars), self.max_grad_norm) optimizer = tf.train.AdamOptimizer(self.learning_rate) self.train_op = optimizer.apply_gradients(zip(grads, tvars), global_step=self.global_step)
訓練時可以定義batch_size的值,是否進行dropout,為了結果的多樣性,訓練時在softmax輸出層每次可以選擇topK概率的字符作為輸出。訓練完成后可以使用tensorboard 對網絡結構和訓練過程可視化展示。這里推薦大家一個在線人工智能建模平臺momodel.cn,帶有完整的Python和機器學習框架運行環境,并且有免費的GPU可以使用,大家可以訓練的時候可以在這個平臺上試一下。訓練部分的代碼和訓練好的模型見鏈接。
詩詞生成調用前面訓練好的模型我們就可以實現一個古詩詞的應用了,我這里利用 Mo平臺 實現了藏頭詩和藏子詩自動生成的功能,運行的效果如下:
新年快到了,趕緊利用算法作詩,給親朋好友送去“最聰明”的祝福吧!
PC端查看完整代碼
參考文章:
https://www.jianshu.com/p/9dc...
https://zhuanlan.zhihu.com/p/...
https://github.com/norybaby/poet
————————————————————————————————————Mo (網址:http://momodel.cn)是一個支持 Python 的人工智能建模平臺,能幫助你快速開發訓練并部署 AI 應用。
文章版權歸作者所有,未經允許請勿轉載,若此文章存在違規行為,您可以聯系管理員刪除。
轉載請注明本文地址:http://m.specialneedsforspecialkids.com/yun/20011.html
摘要:針對區塊鏈技術推出的中文播客第三期更新啦,與宣布正式建立合作關系,共同推進解決方案的發展,團隊去了成都做封閉開發更多動態,都在這里社區動態終于更新啦本期是三位大佬一起暢聊攻擊,畫風幽默,內容全面。 Jan,29,2019showImg(https://segmentfault.com/img/bVbnWvk?w=1080&h=460); 親愛的 Nervos 粉絲們: 中國的農歷新年馬...
摘要:關于節日圣誕節,元旦,看大家情侶在朋友圈里發各種慶祝的或者祝福的話語,甚是感動,然后悄悄拉黑了。預覽效果本地下打開很卡,火狐正常圣誕樹早先的時候是圣誕節的時候,看到各種用字符組成圣誕樹的形式,于是自己就去試了下,還是比較簡單的。 關于節日 圣誕節,元旦,看大家(情侶)在朋友圈里發各種慶祝的或者祝福的話語,甚是感動,然后悄悄拉黑了。作為單身狗,我們也有自己慶祝節日的方式,今天我們就來實現...
摘要:關于節日圣誕節,元旦,看大家情侶在朋友圈里發各種慶祝的或者祝福的話語,甚是感動,然后悄悄拉黑了。預覽效果本地下打開很卡,火狐正常圣誕樹早先的時候是圣誕節的時候,看到各種用字符組成圣誕樹的形式,于是自己就去試了下,還是比較簡單的。 關于節日 圣誕節,元旦,看大家(情侶)在朋友圈里發各種慶祝的或者祝福的話語,甚是感動,然后悄悄拉黑了。作為單身狗,我們也有自己慶祝節日的方式,今天我們就來實現...
摘要:關于節日圣誕節,元旦,看大家情侶在朋友圈里發各種慶祝的或者祝福的話語,甚是感動,然后悄悄拉黑了。預覽效果本地下打開很卡,火狐正常圣誕樹早先的時候是圣誕節的時候,看到各種用字符組成圣誕樹的形式,于是自己就去試了下,還是比較簡單的。 關于節日 圣誕節,元旦,看大家(情侶)在朋友圈里發各種慶祝的或者祝福的話語,甚是感動,然后悄悄拉黑了。作為單身狗,我們也有自己慶祝節日的方式,今天我們就來實現...
閱讀 3609·2021-11-23 09:51
閱讀 1481·2021-11-04 16:08
閱讀 3554·2021-09-02 09:54
閱讀 3620·2019-08-30 15:55
閱讀 2600·2019-08-30 15:54
閱讀 962·2019-08-29 16:30
閱讀 2050·2019-08-29 16:15
閱讀 2321·2019-08-29 14:05