[上一篇文章](https://steemit.com/cn-stem/@hongtao/tensorflow-keras-classification-with-keras)我们介绍了如何使用Keras处理分类问题,那Tensorflow可不可以像[处理回归问](https://steemit.com/cn-stem/@hongtao/tensorflow)题一样,直接处理分类问题呢? 答案当然是肯定的。这篇文章我们就用[之前](https://steemit.com/cn-stem/@hongtao/tensorflow-keras-classification-with-keras) 相同的数据,来学习如何用Tensorflow训练一个线性分类器。数据预处理的过程就略过了,可以参考[上一篇文章](https://steemit.com/cn-stem/@hongtao/tensorflow-keras-classification-with-keras)。 同样的,为了方便与读者交流,所有的源代码都放在了这里:https://github.com/zht007/tensorflow-practice/ ## 1. 定义数据shape 在Keras中,我们只需要考虑数据的输入和输出shape,中间的Shape以及参数的Shape,Keras都可以自动帮我们搞定。然而在Tensorflow就必须手动定义参数的Shape了。 首先我们还是要借助keras的工具将标签y转换成one hot 的数据。 ```python import tensorflow as tf from tensorflow.keras.utils import to_categorical y_train_cat = to_categorical(y_train) y_test_cat = to_categorical(y_test) ``` 在定义权重W的Shape的时候,不妨教大家一个技巧, Rows: W.shape[0] = X.shape[1], (输入的feature数); columns: W.shape[1] = Y.shape[1], (输出的classfication数) ```python n_features = X_train.shape[1] n_classes = y_train_cat.shape[1] w_shape = (n_features, n_classes) b_shape = (1, n_classes) ``` ## 2. Variables和Placeholders 参数W和b是Variables,要训练的X和Y是Placeholders ```python W = tf.Variable(initial_value = tf.random.normal(shape = w_shape)) b = tf.Variable(initial_value = tf.random.normal(shape = b_shape)) X = tf.placeholder(tf.float32) y_true = tf.placeholder(tf.float32) ``` ## 3. 计算图谱Graph 与线性回归一样,线性分类器的计算图谱如下 ```python y_hat = tf.matmul(X,W) + b ``` ## 4. 损失函数和Optimizer 损失函数需要选择softmax_cross_entropy,Optimizer与线性回归一样,用梯度下降Optimizer就OK了。 ``` loss = tf.losses.softmax_cross_entropy(y_true, y_hat) optimizer = tf.train.GradientDescentOptimizer(0.05) train = optimizer.minimize(loss) ``` ## 5. Session中训练 初始化Vaiable之后就可以在Session中进行训练啦。为了实现在Keras中存储损失函数的记录的功能,我们手动定义了字典history,用来存储训练组和验证组的损失函数变化过程。 ```python epochs = 50000 history = {'loss':list(),'val_loss':list()} with tf.Session() as sess: sess.run(init) for epoch in range(epochs): sess.run(train,{X:X_train, y_true:y_train_cat}) history['loss'].append(sess.run(loss, {X: X_train, y_true: y_train_cat})) history['val_loss'].append(sess.run(loss, {X: X_test, y_true: y_test_cat})) if epoch % 100 == 0: print("Iteration {}:\tloss={:.6f}:\tval_loss={:.6f}" .format(epoch, history['loss'][epoch], history['val_loss'][epoch])) y_pred = sess.run(y_hat, {X: X_test}) W_final, b_final = sess.run([W, b]) ``` ## 6. 验证结果 最后我们将训练结果可视化,可以看到效果还不错,损失函数的下降曲线非常平滑,而且训练集和测试集的损失函数也相差不大。  --- 同步到我的简书 https://www.jianshu.com/u/bd506afc6fc1
author | hongtao | ||||||
---|---|---|---|---|---|---|---|
permlink | tensorflow-tensorflow-classification-with-tensorflow | ||||||
category | cn-stem | ||||||
json_metadata | {"community":"busy","app":"steemit/0.1","format":"markdown","tags":["cn-stem","tensorflow","team-cn","busy","cn"],"links":["https://steemit.com/cn-stem/@hongtao/tensorflow-keras-classification-with-keras","https://steemit.com/cn-stem/@hongtao/tensorflow","https://github.com/zht007/tensorflow-practice/","https://www.jianshu.com/u/bd506afc6fc1"],"image":["https://ws1.sinaimg.cn/large/006tKfTcgy1g1778h15byj30bu09o3yv.jpg"]} | ||||||
created | 2019-03-18 12:06:33 | ||||||
last_update | 2019-03-20 10:25:51 | ||||||
depth | 0 | ||||||
children | 1 | ||||||
last_payout | 2019-03-25 12:06:33 | ||||||
cashout_time | 1969-12-31 23:59:59 | ||||||
total_payout_value | 0.030 HBD | ||||||
curator_payout_value | 0.006 HBD | ||||||
pending_payout_value | 0.000 HBD | ||||||
promoted | 0.000 HBD | ||||||
body_length | 2,696 | ||||||
author_reputation | 3,241,267,862,629 | ||||||
root_title | "Tensorflow入门——Tensorflow处理分类问题,Classification with Tensorflow" | ||||||
beneficiaries |
| ||||||
max_accepted_payout | 1,000,000.000 HBD | ||||||
percent_hbd | 10,000 | ||||||
post_id | 81,512,875 | ||||||
net_rshares | 66,684,517,296 | ||||||
author_curate_reward | "" |
voter | weight | wgt% | rshares | pct | time |
---|---|---|---|---|---|
justyy | 0 | 46,584,736,993 | 2.06% | ||
busy.org | 0 | 27,407,662 | 0.3% | ||
superbing | 0 | 697,386,202 | 7.39% | ||
dailystats | 0 | 2,102,140,321 | 7.39% | ||
jianan | 0 | 1,305,002,965 | 7.83% | ||
anxin | 0 | 145,772,601 | 8.03% | ||
hongtao | 0 | 212,719,332 | 52% | ||
woolfe19861008 | 0 | 117,127,131 | 8.01% | ||
dailychina | 0 | 1,991,321,354 | 7.4% | ||
dongfengman | 0 | 758,670,343 | 8.01% | ||
ethanlee | 0 | 192,370,164 | 6.65% | ||
lilypang22 | 0 | 166,501,694 | 7.29% | ||
sweet-jenny8 | 0 | 1,485,941,958 | 8.01% | ||
laiyuehta | 0 | 97,913,550 | 5.17% | ||
turtlegraphics | 0 | 660,235,341 | 7.39% | ||
criticizemars | 0 | 437,411,579 | 100% | ||
dauphinegriping | 0 | 445,088,733 | 100% | ||
roadlifted | 0 | 448,374,197 | 100% | ||
zeniththrash | 0 | 448,115,865 | 100% | ||
periodantenna | 0 | 448,372,549 | 100% | ||
securelagan | 0 | 441,987,206 | 100% | ||
witnesstools | 0 | 635,276,025 | 7.39% | ||
ilovecoding | 0 | 628,884,163 | 7.38% | ||
steemfuckeos | 0 | 398,586,652 | 7.39% | ||
hozn4ukhlytriwc | 0 | 129,037,783 | 15% | ||
allofrar | 0 | 501,335,503 | 100% | ||
resubrig | 0 | 501,342,321 | 100% | ||
ofringen | 0 | 501,323,075 | 100% | ||
iteaston | 0 | 501,300,764 | 100% | ||
umilo | 0 | 501,313,812 | 100% | ||
imedyedr | 0 | 501,283,749 | 100% | ||
esoredrof | 0 | 501,289,599 | 100% | ||
dendedou | 0 | 501,277,213 | 100% | ||
uroffu | 0 | 501,220,119 | 100% | ||
terry3t | 0 | 1,166,448,778 | 100% |
Congratulations @hongtao! You have completed the following achievement on the Steem blockchain and have been rewarded with new badge(s) : <table><tr><td>https://steemitimages.com/60x70/http://steemitboard.com/@hongtao/voted.png?201903180426</td><td>You received more than 2000 upvotes. Your next target is to reach 3000 upvotes.</td></tr> </table> <sub>_You can view [your badges on your Steem Board](https://steemitboard.com/@hongtao) and compare to others on the [Steem Ranking](http://steemitboard.com/ranking/index.php?name=hongtao)_</sub> <sub>_If you no longer want to receive notifications, reply to this comment with the word_ `STOP`</sub> **Do not miss the last post from @steemitboard:** <table><tr><td><a href="https://steemit.com/drugwars/@steemitboard/drugwars-early-adopter"><img src="https://steemitimages.com/64x128/https://cdn.steemitimages.com/DQmYGN7R653u4hDFyq1hM7iuhr2bdAP1v2ApACDNtecJAZ5/image.png"></a></td><td><a href="https://steemit.com/drugwars/@steemitboard/drugwars-early-adopter">Are you a DrugWars early adopter? Benvenuto in famiglia!</a></td></tr></table> > You can upvote this notification to help all Steem users. Learn how [here](https://steemit.com/steemitboard/@steemitboard/http-i-cubeupload-com-7ciqeo-png)!
author | steemitboard |
---|---|
permlink | steemitboard-notify-hongtao-20190318t140426000z |
category | cn-stem |
json_metadata | {"image":["https://steemitboard.com/img/notify.png"]} |
created | 2019-03-18 14:04:24 |
last_update | 2019-03-18 14:04:24 |
depth | 1 |
children | 0 |
last_payout | 2019-03-25 14:04:24 |
cashout_time | 1969-12-31 23:59:59 |
total_payout_value | 0.000 HBD |
curator_payout_value | 0.000 HBD |
pending_payout_value | 0.000 HBD |
promoted | 0.000 HBD |
body_length | 1,253 |
author_reputation | 38,975,615,169,260 |
root_title | "Tensorflow入门——Tensorflow处理分类问题,Classification with Tensorflow" |
beneficiaries | [] |
max_accepted_payout | 1,000,000.000 HBD |
percent_hbd | 10,000 |
post_id | 81,517,922 |
net_rshares | 0 |