您好, 欢迎来到 !    登录 | 注册 | | 设为首页 | 收藏本站

Python sklearn在训练期间显示损失值

5b51 2022/1/14 8:21:48 python 字数 2759 阅读 546 来源 www.jb51.cc/python

我想在训练期间检查我的损失值,这样我就可以观察每次迭代时的损失.到目前为止,我还没有找到一个简单的方法让scikit学会给我一个损失值的历史,我也没有找到一个功能已经在scikit中为我绘制损失.如果无法绘制这个,那么如果我可以简单地在classifier.fit的末尾获取最终的损失值,那就太好了.注意:我知道某些解决方案是封闭的形式.我正在使用几个没有分析

概述

我想在训练期间检查我的损失值,这样我就可以观察每次迭代时的损失.到目前为止,我还没有找到一个简单的方法让scikit学会给我一个损失值的历史,我也没有找到一个功能已经在scikit中为我绘制损失.

如果无法绘制这个,那么如果我可以简单地在classifier.fit的末尾获取最终的损失值,那就太好了.

注意:我知道某些解决方案是封闭的形式.我正在使用几个没有分析解决方案的分类器,例如逻辑回归和svm.

有没有人有什么建议?

old_stdout = sys.stdout
sys.stdout = mystdout = StringIO()
clf = SGDClassifier(**kwargs,verbose=1)
clf.fit(X_tr,y_tr)
sys.stdout = old_stdout
loss_history = mystdout.getvalue()
loss_list = []
for line in loss_history.split('\n'):
    if(len(line.split("loss: ")) == 1):
        continue
    loss_list.append(float(line.split("loss: ")[-1]))
plt.figure()
plt.plot(np.arange(len(loss_list)),loss_list)
plt.savefig("warmstart_plots/pure_SGD:"+str(kwargs)+".png")
plt.xlabel("Time in epochs")
plt.ylabel("Loss")
plt.close()

代码将采用普通的SGDClassifier(几乎任何线性分类器),并拦截verbose = 1标志,然后将进行拆分以从详细打印中获取损失.显然这是较慢但会给我们带来损失并打印出来.


如果您也喜欢它,动动您的小指点个赞吧

除非注明,文章均由 laddyq.com 整理发布,欢迎转载。

转载请注明:
链接:http://laddyq.com
来源:laddyq.com
著作权归作者所有。商业转载请联系作者获得授权,非商业转载请注明出处。


联系我
置顶