Add my custom loss function to torch

后端 未结 2 1326
旧时难觅i
旧时难觅i 2021-01-06 03:53

I want to add a loss function to torch that calculates the edit distance between predicted and target values. Is there an easy way to implement this idea? Or do I have to wr

相关标签:
2条回答
  • 2021-01-06 04:15

    If your criterion can be represented as a composition of existing modules and criteria, it's a good idea to simply construct such composition using containers. The only problem is that standard containers are designed to work with modules only, not criteria. The difference is in :forward method signature:

    module:forward(input)
    criterion:forward(input, target)
    

    Luckily, we are free to define our own container which is able work with criteria too. For example, sequential:

    local GeneralizedSequential, _ = torch.class('nn.GeneralizedSequential', 'nn.Sequential')
    
    function GeneralizedSequential:forward(input, target)
        return self:updateOutput(input, target)
    end
    
    function GeneralizedSequential:updateOutput(input, target)
        local currentOutput = input
        for i=1,#self.modules do
            currentOutput = self.modules[i]:updateOutput(currentOutput, target)
        end
        self.output = currentOutput
        return currentOutput
    end
    

    Below is an illustration of how to implement nn.CrossEntropyCriterion having this generalized sequential container:

    function MyCrossEntropyCriterion(weights)
        criterion = nn.GeneralizedSequential()
        criterion:add(nn.LogSoftMax())
        criterion:add(nn.ClassNLLCriterion(weights))
        return criterion
    end
    

    Check whether everything is correct:

    output = torch.rand(3,3)
    target = torch.Tensor({1, 2, 3})
    
    mycrit = MyCrossEntropyCriterion()
    -- print(mycrit)
    print(mycrit:forward(output, target))
    print(mycrit:backward(output, target))
    
    crit = nn.CrossEntropyCriterion()
    -- print(crit)
    print(crit:forward(output, target))
    print(crit:backward(output, target))
    
    0 讨论(0)
  • 2021-01-06 04:27

    Just to add to the accepted answer, you have to be careful that the loss function you define (edit distance in your case) is differentiable with respect to the network parameters.

    0 讨论(0)
提交回复
热议问题