Jun 18, 2021
新闻文本分类2.0
获取中...
鸽了一个多月,终于更新辣,这篇是[基于CNN+GRU的文本分类实践](基于CNN+GRU的文本分类实践 | Rufus的B滚木 (gitee.io) )的续集,话不多说,直接进入正题!
经过无数次的尝试,最后发现,还是全连接层坠爽,有奇效。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 model.add(Embedding(len (vocab) + 1 , 256 , input_length=300 )) model.add(Dense(512 )) model.add(BatchNormalization()) model.add(LeakyReLU()) model.add(Dropout(0.2 )) model.add(Dense(256 )) model.add(BatchNormalization()) model.add(LeakyReLU()) model.add(Dropout(0.2 )) model.add(Dense(256 )) model.add(BatchNormalization()) model.add(LeakyReLU()) model.add(Dropout(0.2 )) model.add(Flatten()) model.add(Dense(9 , activation='softmax' ))
然后,这次做了一个重大的改动,就是把”其他“类去除了, 因为在搜集数据集的时候,发现其他类的数据集根本无从下手,搜集出来,感觉也只是在模型里面充当噪声,然后经过一波冥思苦想,最后决定这样干,把神经网络最后不太确定的,就归作其他类。
那么应该怎样找出“不确定”的数据呢,我想到了利用标准差来解决这个问题,每次神经网络最后输出的是一个9个float类型数字的数组,现在计算这9个数字的方差,如果是机器比较确定的话,这9个数字势必有一个会特别接近1
其他8个数字特别接近0,如果不确定,则可能数据分布会相对平均,经过多次测验,当这九个数字的标准差小于0.25时,神经网络对判定结果的置信程度比较低。
下面是改进之后的detect.py检测
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 import tensorflow as tffrom tensorflow.keras.preprocessing.sequence import pad_sequencesimport osimport pandas as pdimport jiebaimport pickleimport numpy as npcw = lambda x: list (jieba.cut(x)) os.environ["CUDA_VISIBLE_DEVICES" ] = "-1" fr = open ('tokenizer' , 'rb' ) tokenizer = pickle.load(fr) fr.close() model = tf.keras.models.load_model("model_save/model_checkpoints" ) tags = ['体育' , '军事' , '娱乐' , '房产' , '教育' , '汽车' , '游戏' , '科技' , '财经' ] while True : y1 = input ("------------\n请输入一段新闻:(输入q退出)" ) if y1 == 'q' : break y1 = pd.DataFrame(y1, columns=["a" ], index=['b' ]) y1 = y1['a' ].apply(cw) y1 = str (y1['b' ]) y2 = tokenizer.texts_to_sequences([y1]) y3 = pad_sequences(y2, maxlen=300 ) result = model.predict(y3) print ('置信指数:' + str (np.std(result))) if np.std(result) < 0.25 : print ('该新闻类别为其他' ) else : pred = tf.argmax(result, axis=1 ) c = pred.__int__() print ('该新闻类别为' + tags[c])
简单的UI界面 为了做出一个能让正常人使用的UI效果,单独制作了两个文件
它们分别是用于预测单条新闻和多条新闻的两个函数(好像不用分成两个文件。。。)
另外感谢学长用pyqt写的页面
最后的效果大概是这样
感觉还不错,只是有的时候经常出很多奇奇怪怪的错误,等修复一些BUG,比赛完成之后,我会把源码放上来。
本文由 rufus 创作,采用 知识共享署名 4.0 国际许可协议。
本站文章除注明转载/出处外,均为本站原创或翻译,转载请务必署名。