整合Netty|tensorflow實現 聊天AI--PigPig養成記(2)
整合Netty
通過上一節的學習我們已經可以訓練得到一隻傲嬌的聊天AI_PigPig了。
本章將介紹專案關於Netty的整合問題,將其我們的AI_PigPig可以通過web應用與大家日常互撩。
由於只是一個小測試,所以不考慮效能方面的問題,在下一章我們將重點處理效率難關,整合Redis。
關於Netty的學習大家可以看我的另一篇文章,本節中關於Netty部分的程式碼改編自該文章中的 netty聊天小練習 ,文章中會有詳細的講解。
Python程式碼改動
首先對測試訓練結果的程式碼進行改動,將輸入輸出流重定向自作為中間媒介的測試檔案中。
with tf.Session() as sess:#開啟作為一次會話 # 恢復前一次訓練 ckpt = tf.train.get_checkpoint_state('.')#從檢查點檔案中返回一個狀態(ckpt) #如果ckpt存在,輸出模型路徑 if ckpt != None: print(ckpt.model_checkpoint_path) model.saver.restore(sess, ckpt.model_checkpoint_path)#儲存模型引數 else: print("沒找到模型") #測試該模型的能力 while True: #從檔案中進行讀取 #input_string = input('me > ') #測試檔案輸入格式為"[內容]:[名字]" #eg.你好:AI【表示AI的回覆】 #你好:user【表示使用者的輸入】 with open('./temp.txt','r+',encoding='ANSI') as myf: #從檔案中讀取使用者的輸入 line=myf.read() list1=line.split(':') #長度為一,表明不符合輸入格式,設定為"no",則不進行測試處理 if len(list1)==1: input_string='no' else: #符合輸入格式,證明是使用者輸入的 #input_string為使用者輸入的內容 input_string=list1[0] myf.seek(0) #清空檔案 myf.truncate() #寫入"no",若讀到"no",則不進行測試處理 myf.write('no') # 退出 if input_string == 'quit': exit() #若讀到"no",則不進行測試處理 if input_string != 'no': input_string_vec = []#輸入字串向量化 for words in input_string.strip(): input_string_vec.append(vocab_en.get(words, UNK_ID))#get()函式:如果words在詞表中,返回索引號;否則,返回UNK_ID bucket_id = min([b for b in range(len(buckets)) if buckets[b][0] > len(input_string_vec)])#保留最小的大於輸入的bucket的id encoder_inputs, decoder_inputs, target_weights = model.get_batch({bucket_id: [(input_string_vec, [])]}, bucket_id) #get_batch(A,B):兩個引數,A為大小為len(buckets)的元組,返回了指定bucket_id的encoder_inputs,decoder_inputs,target_weights _, _, output_logits = model.step(sess, encoder_inputs, decoder_inputs, target_weights, bucket_id, True) #得到其輸出 outputs = [int(np.argmax(logit, axis=1)) for logit in output_logits]#求得最大的預測範圍列表 if EOS_ID in outputs:#如果EOS_ID在輸出內部,則輸出列表為[,,,,:End] outputs = outputs[:outputs.index(EOS_ID)] response = "".join([tf.compat.as_str(vocab_de[output]) for output in outputs])#轉為解碼詞彙分別新增到回覆中 print('AI-PigPig > ' + response)#輸出回覆 #將AI的回覆以要求的格式進行寫入,方便Netty程式讀取 with open('./temp1.txt','w',encoding='ANSI') as myf1: myf1.write(response+':AI')
Netty程式
完整程式碼參見 連結 netty包下。
在原本的ChatHandler類中添加了從檔案中讀取資料的方法readFromFile,以及向檔案中覆蓋地寫入資料的方法writeToFile。
//從檔案中讀取資料 private static String readFromFile(String filePath) { File file=new File(filePath); String line=null; String name=null; String content=null; try { //以content:name的形式寫入 BufferedReader br=new BufferedReader(new FileReader(file)); line=br.readLine(); String [] arr=line.split(":"); if(arr.length==1) { name=null; content=null; }else { content=arr[0]; name=arr[1]; } br.close(); } catch (FileNotFoundException e) { e.printStackTrace(); } catch (IOException e) { e.printStackTrace(); } return content; } //向檔案中覆蓋地寫入 private static void writeToFile(String filePath,String content) { File file =new File(filePath); try { FileWriter fileWriter=new FileWriter(file); fileWriter.write(""); fileWriter.flush(); fileWriter.write(content); fileWriter.close(); } catch (IOException e) { e.printStackTrace(); } }
對原來的channelRead0方法進行修改,將輸入輸出流重定向到臨時檔案中。
@Override protected void channelRead0(ChannelHandlerContext ctx, TextWebSocketFrame msg) throws Exception { System.out.println("channelRead0"); //得到使用者輸入的訊息,需要寫入檔案/快取中,讓AI進行讀取 String content=msg.text(); if(content==null||content=="") { System.out.println("content 為null"); return ; } System.out.println("接收到的訊息:"+content); //寫入 writeToFile(writeFilePath, content+":user"); //給AI回覆與寫入的時間,後期會增對效能方面進行改進 Thread.sleep(1000); //讀取AI返回的內容 String AIsay=readFromFile(readFilePath); //讀取後馬上寫入 writeToFile(readFilePath,"no"); //沒有說,或者還沒說 if(AIsay==null||AIsay==""||AIsay=="no") { System.out.println("AIsay為空或no"); return; } System.out.println("AI說:"+AIsay); clients.writeAndFlush( new TextWebSocketFrame( "AI_PigPig在"+LocalDateTime.now() +"說:"+AIsay)); }
客戶端程式碼
<!DOCTYPE html> <html> <head> <meta charset="utf-8" /> <title></title> </head> <body> <div>傳送訊息:</div> <input type="text" id="msgContent"/> <input type="button" value="點我傳送" onclick="CHAT.chat()"/> <div>接受訊息:</div> <div id="receiveMsg" style="background-color: gainsboro;"></div> <script type="application/javascript"> window.CHAT = { socket: null, init: function() { if (window.WebSocket) { CHAT.socket = new WebSocket("ws://192.168.0.104:8088/ws"); CHAT.socket.onopen = function() { console.log("連線建立成功..."); }, CHAT.socket.onclose = function() { console.log("連線關閉..."); }, CHAT.socket.onerror = function() { console.log("發生錯誤..."); }, CHAT.socket.onmessage = function(e) { console.log("接受到訊息:" + e.data); var receiveMsg = document.getElementById("receiveMsg"); var html = receiveMsg.innerHTML; receiveMsg.innerHTML = html + "<br/>" + e.data; } } else { alert("瀏覽器不支援websocket協議..."); } }, chat: function() { var msg = document.getElementById("msgContent"); CHAT.socket.send(msg.value); } }; CHAT.init(); </script> </body> </html>
測試結果
客戶端傳送訊息
使用者與AI日常互撩