您好, 欢迎来到 !    登录 | 注册 | | 设为首页 | 收藏本站

基于最小二乘拟合SIR模型

基于最小二乘拟合SIR模型

我正在将我的评论转换为完整的答案。

问题是由于模型设置不正确而引起的。为了简化微分方程,我将提到dS(t)/dt并且dI(t)/dt作为SI分别。

# incorrect
S = -S * I * beta
I = S * I * beta - I * gamma

# correct
S = -S * I * beta / N
I = S * I * beta / N - I * gamma

通过错误地设置微分方程,变化率即从y(t)到y(t + dt)的变化是错误的。因此,不仅会得到不正确积分的I(t),而且还会将其除以N(或k,正如您所称的),这甚至会导致错误

我们知道这些特定方程的耦合系统要求S(t)+ I(t)+ R(t)= N,其中N是总体常数。从声明初始条件的方式来看,我们推断N为1。请注意,这也与max(ydata)小于1一致。

# IO + SO + R0 is always 1 regardless of "value"
I0 = value
S0 = 1 - I0
R0 = 0

此外,您的处理k方式确实令人怀疑。您的数据似乎已经被标准化,但是您将其乘以0.1。如您所见,k = 1./sum(ydata)与人口常数无关。通过对I0 = ydata[0] * kI(t)进行除以k,您可以有效地按比例缩小数据,仅稍后再按比例放大。无论总体常数是多少,这几乎将I(t)限制在0-1范围内。

您只需设置所有初始条件和未知参数并查看产生的结果,即可验证模型是否错误odeint()。您会注意到S(0),I(0)和R(0)可能与您给它们提供的值不对应,这表明您做错了事k。但是要发现有缺陷的动力学演变,您只需简单地查看模型即可。

在此处输入图片说明

一个棘手的解决方案是设置k = 1.0。一切顺利,因为乘法和除法无效,即使您在技术上仍然做错了计算。但是,如果您的人口常数曾经假设不等于1,那么一切都会破裂。所以为了彻底

手动将其设置k为总体常数,除非您也试图适合S0,I0和/或R0,否则无论如何都应该知道。

写变化模型的正确率SI

摆脱任何np.divide(array, k)计算,并且

kfitFunc()参数中删除,不要将其附加到param_init列表中。尽管此最后一个动作是可选的,不会影响结果,但从技术上来说还是正确的。这是因为k,即使您没有最终在任何地方使用它来影响您的计算,优化求解器也会尝试通过它寻找最佳值。

如果要进行最小二乘拟合,可以使用curve_fit(),在内部调用最小二乘方法。您仍然需要为配件创建包装函数,该包装函数必须通过数字方式集成系统以获取各种beta和gamma值,但是您不必手动进行任何SSE计算。

curve_fit()还将返回协方差矩阵,您可以用它来估计拟合变量的置信区间。从协方差矩阵计算置信区间的进一步相关讨论可以在这里找到。

import numpy as np
import matplotlib.pyplot as plt
from scipy import integrate, optimize

ydata = ['1e-06', '1.49920166169172e-06', '2.24595472686361e-06', '3.36377954575331e-06', '5.03793663882291e-06', '7.54533628058909e-06', '1.13006564683911e-05', '1.69249500601052e-05', '2.53483161761933e-05', '3.79636391699325e-05', '5.68567547875179e-05', '8.51509649182741e-05', '0.000127522555808945', '0.000189928392105942', '0.000283447055673738', '0.000423064043409294', '0.000631295993246634', '0.000941024110897193', '0.00140281896645859', '0.00209085569326554', '0.00311449589149717', '0.00463557784224762', '0.00689146863803467', '0.010227347567051', '0.0151380084180746', '0.0223233100045688', '0.0327384810150231', '0.0476330618585758', '0.0685260046667727', '0.0970432959143974', '0.134525888779423', '0.181363340075877', '0.236189247803334', '0.295374180276257', '0.353377036130714', '0.404138746080267', '0.442876028839178', '0.467273954573897', '0.477529937494976', '0.475582401936257', '0.464137179474659', '0.445930281787152', '0.423331710456602', '0.39821360956389', '0.371967226561944', '0.345577884704341', '0.319716449520481', '0.294819942458255', '0.271156813453547', '0.24887641905719', '0.228045466022105', '0.208674420183194', '0.190736203926912', '0.174179448652951', '0.158937806544529', '0.144936441326754', '0.132096533873646', '0.120338367115739', '0.10958340819268', '0.099755679236243', '0.0907826241267504', '0.0825956203546979', '0.0751302384111894', '0.0683263295744258', '0.0621279977639921', '0.0564834809370572', '0.0513449852139111', '0.0466684871328814', '0.042413516167789', '0.0385429293775096', '0.035022685071934', '0.0318216204865132', '0.0289112368382048', '0.0262654939162707', '0.0238606155312519', '0.021674906523588', '0.0196885815912485', '0.0178836058829335', '0.0162435470852779', '0.0147534385851646', '0.0133996531928511', '0.0121697868544064', '0.0110525517526551', '0.0100376781867076', '0.00911582462544914', '0.00827849534575178', '0.00751796508841916', '0.00682721019158058', '0.00619984569061827', '0.00563006790443123', '0.00511260205894446', '0.00464265452957236', '0.00421586931435123', '0.00382828837833139', '0.00347631553734708', '0.00315668357532714', '0.00286642431380459', '0.00260284137520731', '0.00236348540287827', '0.00214613152062159', '0.00194875883295343']
xdata = ['1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12', '13', '14', '15', '16', '17', '18', '19', '20', '21', '22', '23', '24', '25', '26', '27', '28', '29', '30', '31', '32', '33', '34', '35', '36', '37', '38', '39', '40', '41', '42', '43', '44', '45', '46', '47', '48', '49', '50', '51', '52', '53', '54', '55', '56', '57', '58', '59', '60', '61', '62', '63', '64', '65', '66', '67', '68', '69', '70', '71', '72', '73', '74', '75', '76', '77', '78', '79', '80', '81', '82', '83', '84', '85', '86', '87', '88', '89', '90', '91', '92', '93', '94', '95', '96', '97', '98', '99', '100', '101']

ydata = np.array(ydata, dtype=float)
xdata = np.array(xdata, dtype=float)

def sir_model(y, x, beta, gamma):
    S = -beta * y[0] * y[1] / N
    R = gamma * y[1]
    I = -(S + R)
    return S, I, R

def fit_odeint(x, beta, gamma):
    return integrate.odeint(sir_model, (S0, I0, R0), x, args=(beta, gamma))[:,1]

N = 1.0
I0 = ydata[0]
S0 = N - I0
R0 = 0.0

popt, pcov = optimize.curve_fit(fit_odeint, xdata, ydata)
fitted = fit_odeint(xdata, *popt)

plt.plot(xdata, ydata, 'o')
plt.plot(xdata, fitted)
plt.show()

您可能会注意到一些运行时警告,但是这些警告主要是由于最初搜索了最小化求解器(Levenburg- Marquardt),它会尝试一些值,beta并且gamma会在积分过程中引起数值溢出。但是,应该尽快将其定为更合理的值。如果您尝试使用的其他求解器minimize(),则会注意到类似的警告。

其他 2022/1/1 18:28:20 有431人围观

撰写回答


你尚未登录,登录后可以

和开发者交流问题的细节

关注并接收问题和回答的更新提醒

参与内容的编辑和改进,让解决方法与时俱进

请先登录

推荐问题


联系我
置顶