What is the idiomatic way to get the index of a maximum or minimum floating point value in a slice or Vec in Rust?

痞子三分冷 提交于 2019-12-04 03:33:25

问题


Assumption -- The Vec<f32> does not have any NaN values or exhibit any NaN behavior.

Take the following sample set:

0.28  
0.3102
0.9856
0.3679
0.3697
0.46  
0.4311
0.9781
0.9891
0.5052
0.9173
0.932 
0.8365
0.5822
0.9981
0.9977

What is the neatest and most stable way to get the index of the highest value in the above list (values can be negative)?

My initial attempts were along the following lines:

let _tmp = *nets.iter().max_by(|i, j| i.partial_cmp(j).unwrap()).unwrap();    
let _i = nets.iter().position(|&element| element == _tmp).unwrap();

Where nets is a &Vec<f32>. Which to me seems blatantly incorrect.

The Python equivalent of this that works (taking into consideration the above assumption):

_i = nets.index(max(nets))

回答1:


I will probably do something like this:

fn main() -> Result<(), Box<std::error::Error>> {
    let samples = vec![
        0.28, 0.3102, 0.9856, 0.3679, 0.3697, 0.46, 0.4311, 0.9781, 0.9891, 0.5052, 0.9173, 0.932,
        0.8365, 0.5822, 0.9981, 0.9977,
    ];

    // Use enumerate to get the index
    let mut iter = samples.iter().enumerate();
    // we get the first entry
    let init = iter.next().ok_or("Need at least one input")?;
    // we process the rest
    let result = iter.try_fold(init, |acc, x| {
        // return None if x is NaN
        let cmp = x.1.partial_cmp(acc.1)?;
        // if x is greater the acc
        let max = if let std::cmp::Ordering::Greater = cmp {
            x
        } else {
            acc
        };
        Some(max)
    });
    println!("{:?}", result);

    Ok(())
}

This could be implemented by adding a trait on Iterator with for example the function try_max_by.




回答2:


Is there a reason why this wouldn't work?

use std::cmp::Ordering;

fn example(nets: &Vec<f32>) {
    let index_of_max: Option<usize> = nets
        .iter()
        .enumerate()
        .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(Ordering::Equal))
        .map(|(index, _)| index);
}



回答3:


The reason why this is tricky is because f32 does not implement Ord. That is because NaN values prevent floating point numbers from forming a total order, which violates the contract of Ord.

There are 3rd party crates that work around this by defining a numeric type wrapper which is not allowed to contain a NaN. One example is ordered-float. If you use this crate to first prepare the collection to contain NotNan values, then you can write code very close to your original idea:

use ordered_float::NotNan;

let non_nan_floats: Vec<_> = nets.iter()
    .cloned()
    .map(NotNan::new)       // Attempt to convert each f32 to a NotNan
    .filter_map(Result::ok) // Unwrap the `NotNan`s and filter out the `NaN` values 
    .collect();

let max = non_nan_floats.iter().max().unwrap();
let index = non_nan_floats.iter().position(|element| element == max).unwrap();

Add this to Cargo.toml:

[dependencies]
ordered-float = "1.0.1"

Bonus material: The type conversion can be made truly zero-cost (assuming you are really sure that there are no NaN values!), by taking advantage of the fact that NotNan has a transparent representation:

let non_nan_floats: Vec<NotNan<f32>> = unsafe { mem::transmute(nets) };



回答4:


You can find the maximum value with the following:

let mut max_value = my_vec.iter().fold(0.0f32, |mut max, &val| {
    if val > max {
        max = val;
    }
    max
});

After finding max_value you can track its position in the vector itself:

let index = my_vec.iter().position(|&r| r == max_value).unwrap();

To get this result you need to iterate twice over the same vector. To improve the performance, you can return the index value with the max value as tuple in the fold iteration.

Playground



来源:https://stackoverflow.com/questions/53903318/what-is-the-idiomatic-way-to-get-the-index-of-a-maximum-or-minimum-floating-poin

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!