Weighted mse custom loss function in keras

巧了我就是萌 提交于 2019-12-20 10:11:11

问题


I'm working with time series data, outputting 60 predicted days ahead.

I'm currently using mean squared error as my loss function and the results are bad

I want to implement a weighted mean squared error such that the early outputs are much more important than later ones.

Weighted Mean Square Root formula:

So I need some way to iterate over a tensor's elements, with an index (since I need to iterate over both the predicted and the true values at the same time, then write the results to a tensor with only one element. They're both (?,60) but really (1,60) lists.

And nothing I'm trying is working. Here's the code for the broken version

def weighted_mse(y_true,y_pred):
    wmse = K.cast(0.0,'float')

    size = K.shape(y_true)[0]
    for i in range(0,K.eval(size)):
        wmse += 1/(i+1)*K.square((y_true[i]-y_pred)[i])

    wmse /= K.eval(size)
    return wmse

I am currently getting this error as a result:

InvalidArgumentError (see above for traceback): You must feed a value for placeholder tensor 'dense_2_target' with dtype float
 [[Node: dense_2_target = Placeholder[dtype=DT_FLOAT, shape=[], _device="/job:localhost/replica:0/task:0/cpu:0"]()]]

Having read the replies to similar posts, I don't think a mask can accomplish the task, and looping over elements in one tensor would also not work since I'd not be able to access the corresponding element in the other tensor.

Any suggestions would be appreciated


回答1:


You can use this approach:

def weighted_mse(yTrue,yPred):

    ones = K.ones_like(yTrue[0,:]) #a simple vector with ones shaped as (60,)
    idx = K.cumsum(ones) #similar to a 'range(1,61)'


    return K.mean((1/idx)*K.square(yTrue-yPred))

The use of ones_like with cumsum allows you to use this loss function to any kind of (samples,classes) outputs.


Hint: always use backend functions when working with tensors. You can use slices, but avoid iterating.



来源:https://stackoverflow.com/questions/46242187/weighted-mse-custom-loss-function-in-keras

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!