How do I implement the Fn trait for one struct for different types of arguments?

只谈情不闲聊 提交于 2019-12-07 05:15:16

问题


I have a simple classifier:

struct Clf {
    x: f64,
}

The classifier returns 0 if the observed value is smaller than x and 1 if bigger than x.

I want to implement the call operator for this classifier. However, the function should be able to take either a float or a vector as arguments. In case of a vector, the output is a vector of 0 or 1 which has the same size as the input vector:

let c = Clf { x: 0 };
let v = vec![-1, 0.5, 1];
println!("{}", c(0.5));   // prints 1
println!("{}", c(v));     // prints [0, 1, 1]

How can I write implementation of Fn in this case?

impl Fn for Clf {
    extern "rust-call" fn call(/*...*/) {
        // ...
    }
}

回答1:


The short answer is: You can't. At least it won't work the way you want. I think the best way to show that is to walk through and see what happens, but the general idea is that Rust doesn't support function overloading.

For this example, we will be implementing FnOnce, because Fn requires FnMut which requires FnOnce. So, if we were to get this all sorted, we could do it for the other function traits.

First, this is unstable, so we need some feature flags

#![feature(unboxed_closures, fn_traits)]

Then, let's do the impl for taking an f64:

impl FnOnce<(f64,)> for Clf {
    type Output = i32;
    extern "rust-call" fn call_once(self, args: (f64,)) -> i32 {
        if args.0 > self.x {
            1
        } else {
            0
        }
    }
}

The arguments to the Fn family of traits are supplied via a tuple, so that's the (f64,) syntax; it's a tuple with just one element.

This is all well and good, and we can now do c(0.5), although it will consume c until we implement the other traits.

Now let's do the same thing for Vecs:

impl FnOnce<(Vec<f64>,)> for Clf {
    type Output = Vec<i32>;
    extern "rust-call" fn call_once(self, args: (Vec<f64>,)) -> Vec<i32> {
        args.0
            .iter()
            .map(|&f| if f > self.x { 1 } else { 0 })
            .collect()
    }
}

Before Rust 1.33 nightly, you cannot directly call c(v) or even c(0.5) (which worked before); we'd get an error about the type of the function not being known. Basically, these versions of Rust didn't support function overloading. But we can still call the functions using fully qualified syntax, where c(0.5) becomes FnOnce::call_once(c, (0.5,)).


Not knowing your bigger picture, I would want to solve this simply by giving Clf two functions like so:

impl Clf {
    fn classify(&self, val: f64) -> u32 {
        if val > self.x {
            1
        } else {
            0
        }
    }

    fn classify_vec(&self, vals: Vec<f64>) -> Vec<u32> {
        vals.into_iter().map(|v| self.classify(v)).collect()
    }
}

Then your usage example becomes

let c = Clf { x: 0 };
let v = vec![-1, 0.5, 1];
println!("{}", c.classify(0.5));   // prints 1
println!("{}", c.classify_vec(v)); // prints [0, 1, 1]

I would actually want to make the second function classify_slice and take &[f64] to be a bit more general, then you could still use it with Vecs by referencing them: c.classify_slice(&v).




回答2:


This is indeed possible, but you need a new trait and a ton of mess.

If you start with the abstraction

enum VecOrScalar<T> {
    Scalar(T),
    Vector(Vec<T>),
}

use VecOrScalar::*;

You want a way to use the type transformations

T      (hidden) -> VecOrScalar<T> -> T      (known)
Vec<T> (hidden) -> VecOrScalar<T> -> Vec<T> (known)

because then you can take a "hidden" type T, wrap it in a VecOrScalar and extract the real type T with a match.

You also want

T      (known) -> bool      = T::Output
Vec<T> (known) -> Vec<bool> = Vec<T>::Output

but without higher-kinded-types, this is a bit tricky. Instead, you can do

T      (known) -> VecOrScalar<T> -> T::Output
Vec<T> (known) -> VecOrScalar<T> -> Vec<T>::Output

if you allow for a branch that can panic.

The trait will thus be

trait FromVecOrScalar<T> {
    type Output;

    fn put(self) -> VecOrScalar<T>;

    fn get(out: VecOrScalar<bool>) -> Self::Output;
}

with implementations

impl<T> FromVecOrScalar<T> for T {
    type Output = bool;

    fn put(self) -> VecOrScalar<T> {
        Scalar(self)
    }

    fn get(out: VecOrScalar<bool>) -> Self::Output {
        match out {
            Scalar(val) => val,
            Vector(_) => panic!("Wrong output type!"),
        }
    }
}
impl<T> FromVecOrScalar<T> for Vec<T> {
    type Output = Vec<bool>;

    fn put(self) -> VecOrScalar<T> {
        Vector(self)
    }

    fn get(out: VecOrScalar<bool>) -> Self::Output {
        match out {
            Vector(val) => val,
            Scalar(_) => panic!("Wrong output type!"),
        }
    }
}

Your type

#[derive(Copy, Clone)]
struct Clf {
    x: f64,
}

will first implement the two branches:

impl Clf {
    fn calc_scalar(self, f: f64) -> bool {
        f > self.x
    }

    fn calc_vector(self, v: Vec<f64>) -> Vec<bool> {
        v.into_iter().map(|x| self.calc_scalar(x)).collect()
    }
}

Then it will dispatch by implementing FnOnce for T: FromVecOrScalar<f64>

impl<T> FnOnce<(T,)> for Clf
where
    T: FromVecOrScalar<f64>,
{

with types

    type Output = T::Output;
    extern "rust-call" fn call_once(self, (arg,): (T,)) -> T::Output {

The dispatch first boxes the private type up, so you can extract it with the enum, and then T::gets the result, to hide it again.

        match arg.put() {
            Scalar(scalar) => T::get(Scalar(self.calc_scalar(scalar))),
            Vector(vector) => T::get(Vector(self.calc_vector(vector))),
        }
    }
}

Then, success:

fn main() {
    let c = Clf { x: 0.0 };
    let v = vec![-1.0, 0.5, 1.0];
    println!("{}", c(0.5f64));
    println!("{:?}", c(v));
}

Since the compiler can see through all of this malarky, it actually compiles away to basically the same assembly as a direct call to the calc_ methods.

That's not to say it's nice to write. Overloading like this is a pain, fragile and most certainly A Bad Idea™. Don't do it, though it's fine to know that you can.




回答3:


You can't (but read until the end of the answer).

First of all, implementing the Fn* family of traits explicitly is unstable and subject to change at any time, so it'd be a bad idea to depend on that.

Secondly, and more importantly, the Rust compiler before Rust 1.33 nightly just will not let you call a value that has Fn* implementations for different argument types. It just can't work out what you want it to do, since there's normally no way for it to happen. The only way around that is fully specifying the trait you wanted to call, but at that point, you've lost any possible ergonomic benefit of this approach.

Just define and implement your own trait instead of trying to use the Fn* traits. I took some liberties with the question to avoid/fix questionable aspects.

struct Clf {
    x: f64,
}

trait ClfExt<T: ?Sized> {
    type Result;
    fn classify(&self, arg: &T) -> Self::Result;
}

impl ClfExt<f64> for Clf {
    type Result = bool;
    fn classify(&self, arg: &f64) -> Self::Result {
        *arg > self.x
    }
}

impl ClfExt<[f64]> for Clf {
    type Result = Vec<bool>;
    fn classify(&self, arg: &[f64]) -> Self::Result {
        arg.iter().map(|v| self.classify(v)).collect()
    }
}

fn main() {
    let c = Clf { x: 0.0 };
    let v = vec![-1.0, 0.5, 1.0];
    println!("{}", c.classify(&0.5f64));
    println!("{:?}", c.classify(&v[..]));
}

How to use the Fn* traits

I've included this for the sake of completeness; do not actually do this. Not only is it unsupported, it's damn ugly.

#![feature(fn_traits, unboxed_closures)]

#[derive(Copy, Clone)]
struct Clf {
    x: f64,
}

impl FnOnce<(f64,)> for Clf {
    type Output = bool;
    extern "rust-call" fn call_once(self, args: (f64,)) -> Self::Output {
        args.0 > self.x
    }
}

impl<'a> FnOnce<(&'a [f64],)> for Clf {
    type Output = Vec<bool>;
    extern "rust-call" fn call_once(self, args: (&'a [f64],)) -> Self::Output {
        args.0
            .iter()
            .cloned()
            .map(|v| FnOnce::call_once(self, (v,)))
            .collect()
    }
}

fn main() {
    let c = Clf { x: 0.0 };
    let v = vec![-1.0, 0.5, 1.0];

    // Before 1.33 nightly
    println!("{}", FnOnce::call_once(c, (0.5f64,)));
    println!("{:?}", FnOnce::call_once(c, (&v[..],)));

    // After
    println!("{}", c(0.5f64));
    println!("{:?}", c(&v[..]));
}



回答4:


You can do using nightly and unstable features:

#![feature(fn_traits, unboxed_closures)]
struct Clf {
    x: f64,
}

impl FnOnce<(f64,)> for Clf {
    type Output = i32;
    extern "rust-call" fn call_once(self, args: (f64,)) -> i32 {
        if args.0 > self.x {
            1
        } else {
            0
        }
    }
}

impl FnOnce<(Vec<f64>,)> for Clf {
    type Output = Vec<i32>;
    extern "rust-call" fn call_once(self, args: (Vec<f64>,)) -> Vec<i32> {
        args.0
            .iter()
            .map(|&f| if f > self.x { 1 } else { 0 })
            .collect()
    }
}

fn main() {
    let c = Clf { x: 0.0 };
    let v = vec![-1.0, 0.5, 1.0];
    println!("{:?}", c(0.5));

    let c = Clf { x: 0.0 };
    println!("{:?}", c(v));
}


来源:https://stackoverflow.com/questions/38672235/how-do-i-implement-the-fn-trait-for-one-struct-for-different-types-of-arguments

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!