The Universe of Discourse
           
Thu, 31 Dec 2009

A monad for probability and provenance
I don't quite remember how I arrived at this, but it occurred to me last week that probability distributions form a monad. This is the first time I've invented a new monad that I hadn't seen before; then I implemented it and it behaved pretty much the way I thought it would. So I feel like I've finally arrived, monadwise.

Suppose a monad value represents all the possible outcomes of an event, each with a probability of occurrence. For concreteness, let's suppose all our probability distributions are discrete. Then we might have:

	data ProbDist p a = ProbDist [(a,p)] deriving (Eq, Show)
	unpd (ProbDist ps) = ps
Each a is an outcome, and each p is the probability of that outcome occurring. For example, biased and unbiased coins:

    unbiasedCoin = ProbDist [ ("heads", 0.5),
                              ("tails", 0.5) ];

    biasedCoin   = ProbDist [ ("heads", 0.6),
                              ("tails", 0.4) ];

Or a couple of simple functions for making dice:

    import Data.Ratio

    d sides = ProbDist [(i, 1 % sides) | i <- [1 .. sides]]
    die = d 6

d n is an n-sided die.

The Functor instance is straightforward:

    instance Functor (ProbDist p) where
      fmap f (ProbDist pas) = ProbDist $ map (\(a,p) -> (f a, p)) pas
The Monad instance requires return and >>=. The return function merely takes an event and turns it into a distribution where that event occurs with probability 1. I find join easier to think about than >>=. The join function takes a nested distribution, where each outcome of the outer distribution specifies an inner distribution for the actual events, and collapses it into a regular, overall distribution. For example, suppose you put a biased coin and an unbiased coin in a bag, then pull one out and flip it:

	  bag :: ProbDist Double (ProbDist Double String)
	  bag = ProbDist [ (biasedCoin, 0.5),
                           (unbiasedCoin, 0.5) ]
The join operator collapses this into a single ProbDist Double String:

	ProbDist [("heads",0.3),
                  ("tails",0.2),
                  ("heads",0.25),
                  ("tails",0.25)]
It would be nice if join could combine the duplicate heads into a single ("heads", 0.55) entry. But that would force an Eq a constraint on the event type, which isn't allowed, because (>>=) must work for all data types, not just for instances of Eq. This is a problem with Haskell, not with the monad itself. It's the same problem that prevents one from making a good set monad in Haskell, even though categorially sets are a perfectly good monad. (The return function constructs singletons, and the join function is simply set union.) Maybe in the next language.

Perhaps someone else will find the >>= operator easier to understand than join? I don't know. Anyway, it's simple enough to derive once you understand join; here's the code:

	instance (Num p) => Monad (ProbDist p) where
	  return a = ProbDist [(a, 1)]
	  (ProbDist pas) >>= f = ProbDist $ do
				   (a, p) <- pas
				   let (ProbDist pbs) = f a
				   (b, q) <- pbs
				   return (b, p*q)
So now we can do some straightforward experiments:

	liftM2 (+) (d 6) (d 6)

	ProbDist [(2,1 % 36),(3,1 % 36),(4,1 % 36),(5,1 % 36),(6,1 %
	36),(7,1 % 36),(3,1 % 36),(4,1 % 36),(5,1 % 36),(6,1 %
	36),(7,1 % 36),(8,1 % 36),(4,1 % 36),(5,1 % 36),(6,1 %
	36),(7,1 % 36),(8,1 % 36),(9,1 % 36),(5,1 % 36),(6,1 %
	36),(7,1 % 36),(8,1 % 36),(9,1 % 36),(10,1 % 36),(6,1 %
	36),(7,1 % 36),(8,1 % 36),(9,1 % 36),(10,1 % 36),(11,1 %
	36),(7,1 % 36),(8,1 % 36),(9,1 % 36),(10,1 % 36),(11,1 %
	36),(12,1 % 36)]
This is nasty-looking; we really need to merge the multiple listings of the same event. Here is a function to do that:

        agglomerate :: (Num p, Eq b) => (a -> b) -> ProbDist p a -> ProbDist p b
        agglomerate f pd = ProbDist $ foldr insert [] (unpd (fmap f pd)) where
          insert (k, p) [] = [(k, p)]
          insert (k, p) ((k', p'):kps) | k == k' = (k, p+p'):kps
                                       | otherwise = (k', p'):(insert (k,p) kps)


        agg :: (Num p, Eq a) => ProbDist p a -> ProbDist p a
        agg = agglomerate id
Then agg $ liftM2 (+) (d 6) (d 6) produces:

        ProbDist [(12,1 % 36),(11,1 % 18),(10,1 % 12),(9,1 % 9),
                  (8,5 % 36),(7,1 % 6),(6,5 % 36),(5,1 % 9),
                  (4,1 % 12),(3,1 % 18),(2,1 % 36)]
Hey, that's correct.

There must be a shorter way to write insert. It really bothers me, because it looks look it should be possible to do it as a fold. But I couldn't make it look any better.

You are not limited to calculating probabilities. The monad actually will count things. For example, let us throw three dice and count how many ways there are to throw various numbers of sixes:

        eq6 n = if n == 6 then 1 else 0
        agg $ liftM3 (\a b c -> eq6 a + eq6 b + eq6 c) die die die

      ProbDist [(3,1),(2,15),(1,75),(0,125)]
There is one way to throw three sixes, 15 ways to throw two sixes, 75 ways to throw one six, and 125 ways to throw no sixes. So ProbDist is a misnomer.

It's easy to convert counts to probabilities:

	probMap :: (p -> q) -> ProbDist p a -> ProbDist q a
	probMap f (ProbDist pds) = ProbDist $ (map (\(a,p) -> (a, f p))) pds

	normalize :: (Fractional p) => ProbDist p a -> ProbDist p a
	normalize pd@(ProbDist pas) = probMap (/ total) pd where
	    total = sum . (map snd) $ pas

        normalize $ agg $ probMap toRational $ 
               liftM3 (\a b c -> eq6 a + eq6 b + eq6 c) die die die

      ProbDist [(3,1 % 216),(2,5 % 72),(1,25 % 72),(0,125 % 216)]
I think this is the first time I've gotten to write die die die in a computer program.

The do notation is very nice. Here we calculate the distribution where we roll four dice and discard the smallest:

        stat = do
                 a <- d 6
                 b <- d 6
                 c <- d 6
                 d <- d 6
                 return (a+b+c+d - minimum [a,b,c,d])

        probMap fromRational $ agg stat

	ProbDist [(18,1.6203703703703703e-2),
                  (17,4.1666666666666664e-2), (16,7.253086419753087e-2),
                  (15,0.10108024691358025),   (14,0.12345679012345678),
                  (13,0.13271604938271606),   (12,0.12885802469135801),
                  (11,0.11419753086419752),   (10,9.41358024691358e-2),
                   (9,7.021604938271606e-2),   (8,4.7839506172839504e-2),
                   (7,2.9320987654320986e-2),  (6,1.6203703703703703e-2),
                   (5,7.716049382716049e-3),   (4,3.0864197530864196e-3),
                   (3,7.716049382716049e-4)]

One thing I was hoping to get didn't work out. I had this idea that I'd be able to calculate the outcome of a game of craps like this:

	dice = liftM2 (+) (d 6) (d 6)

	point n = do
	  roll <- dice
	  case roll of 7 -> return "lose"
                       _ | roll == n  = "win"
                       _ | otherwise  = point n

        craps = do
          roll <- dice
          case roll of 2 -> return "lose"
                       3 -> return "lose"
                       4 -> point 4
                       5 -> point 5
                       6 -> point 6
                       7 -> return "win"
                       8 -> point 8
                       9 -> point 9
                       10 -> point 10
                       11 -> return "win"
                       12 -> return "lose"
This doesn't work at all; point is an infinite loop because the first value of dice, namely 2, causes a recursive call. I might be able to do something about this, but I'll have to think about it more.

It also occurred to me that the use of * in the definition of >>= / join could be generalized. A couple of years back I mentioned a paper of Green, Karvounarakis, and Tannen that discusses "provenance semirings". The idea is that each item in a database is annotated with some "provenance" information about why it is there, and you want to calculate the provenance for items in tables that are computed from table joins. My earlier explanation is here.

One special case of provenance information is that the provenances are probabilities that the database information is correct, and then the probabilities are calculated correctly for the joins, by multiplication and addition of probabilities. But in the general case the provenances are opaque symbols, and the multiplication and addition construct regular expressions over these symbols. One could generalize ProbDist similarly, and the ProbDist monad (even more of a misnomer this time) would calculate the provenance automatically. It occurs to me now that there's probably a natural way to view a database table join as a sort of Kleisli composition, but this article has gone on too long already.

Happy new year, everyone.

[ Addendum 20100103: unsurprisingly, this is not a new idea. Several readers wrote in with references to previous discussion of this monad, and related monads. It turns out that the idea goes back at least to 1981. ]


My thanks to Graham Hunter for his donation.


[Other articles in category /prog/haskell] permanent link