Scala and State Monad

巧了我就是萌 提交于 2019-12-04 03:12:06

To understand the "second run" let's analyse it "backwards".

The signature def flatMap[B](f: A => State[S, B]): State[S, B] suggests that we need to run a function f and return its result.

To execute function f we need to give it an A. Where do we get one?
Well, we have run that can give us A out of S, so we need an S.

Because of that we do: s => val (a, t) = run(s) .... We read it as "given an S execute the run function which produces us A and a new S. And this is our "first" run.

Now we have an A and we can execute f. That's what we wanted and f(a) gives us a new State[S, B]. If we do that then we have a function which takes S and returns Stats[S, B]:

(s: S) => 
   val (a, t) = run(s)
   f(a) //State[S, B]

But function S => State[S, B] isn't what we want to return! We want to return just State[S, B].

How do we do that? We can wrap this function into State:

State(s => ... f(a))

But it doesn't work because State takes S => (B, S), not S => State[B, S]. So we need to get (B, S) out of State[B, S].
We do it by just calling its run method and providing it with the state we just produced on the previous step! And it is our "second" run.

So as a result we have the following transformation performed by a flatMap:

s =>                   // when a state is provided
  val (a, t) = run(s)  // produce an `A` and a new state value
  val resState = f(a)  // produce a new `State[S, B]`
  resState.run(t)      // return `(S, B)`

This gives us S => (S, B) and we just wrap it with the State constructor.

Another way of looking at these "two runs" is:
first - we transform the state ourselves with "our" run function
second - we pass that transformed state to the function f and let it do its own transformation.

So we kind of "chaining" state transformations one after another. And that's exactly what monads do: they provide us with the ability to schedule computation sequentially.

The state monad boils down to this function from one state to another state (plus A):

type StatefulComputation[S, +A] = S => (A, S)

The implementation mentioned by Tony in that blog post "capture" that function into run of the case class:

case class State[S, A](run: S => (A, S))

The flatmap implementation to bind a state to another state is calling 2 different runs:

    // the `run` on the actual `state`
    val (a: A, nextState: S) = run(s)

    // the `run` on the bound `state`
    f(a).run(nextState)

EDIT Example of flatmap between 2 State

Considering a function that simply call .head to a List to get A, and .tail for the next state S

// stateful computation: `S => (A, S)` where `S` is `List[A]`
def head[A](xs: List[A]): (A, List[A]) = (xs.head, xs.tail)

A simple binding of 2 State(head[Int]):

// flatmap example
val result = for {
  a <- State(head[Int])
  b <- State(head[Int])
} yield Map('a' -> a,
            'b' -> b)

The expect behaviour of the for-comprehension is to "extract" the first element of a list into a and the second one in b. The resulting state S would be the remaining tail of the run list:

scala> result.run(List(1, 2, 3, 4, 5))
(Map(a -> 1, b -> 2),List(3, 4, 5))

How? Calling the "stateful computation" head[Int] that is in run on some state s:

s => run(s)

That gives the head (A) and the tail (B) of the list. Now we need to pass the tail to the next State(head[Int])

f(a).run(t)

Where f is in the flatmap signature:

def flatMap[B](f: A => State[S, B]): State[S, B]

Maybe to better understand what is f in this example, we should de-sugar the for-comprehension to:

val result = State(head[Int]).flatMap {
  a => State(head[Int]).map {
    b => Map('a' -> a, 'b' -> b)
  }
}

With f(a) we pass a into the function and with run(t) we pass the modified state.

I have accepted @AlexyRaga's answer to my question. I think @Filippo's answer was very good as well and, in fact, gave me some additional food for thought. Thanks to both of you.

I think the conceptual difficulty I was having was really mostly to do with 'what does the run method 'mean'. That is, what is its purpose and result. I was looking at it as a 'transition' function (from one state to the next). And, after a fashion, that is what it does. However, it doesn't transition from a given (this) state to the next state. Instead, it takes an initial State and returns the (this) state's value and a new 'current' state (not the next state in the state-transition sequence).

That is why the flatMap method is implemented the way it is. When you generate a new State then you need the current value/state pair from it based on the passed-in initial state which can then be wrapped in a new State object as a function. You are not really transitioning to a new state. Just re-wrapping the generated state in a new State object.

I was too steeped in traditional state machines to see what was going on here.

Thank, again, everyone.

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!