Thu, 31 Dec 2009
A monad for probability and provenance
Suppose a monad value represents all the possible outcomes of an event, each with a probability of occurrence. For concreteness, let's suppose all our probability distributions are discrete. Then we might have:
data ProbDist p a = ProbDist [(a,p)] deriving (Eq, Show) unpd (ProbDist ps) = psEach a is an outcome, and each p is the probability of that outcome occurring. For example, biased and unbiased coins:
unbiasedCoin = ProbDist [ ("heads", 0.5), ("tails", 0.5) ]; biasedCoin = ProbDist [ ("heads", 0.6), ("tails", 0.4) ];
Or a couple of simple functions for making dice:
import Data.Ratio d sides = ProbDist [(i, 1 % sides) | i <- [1 .. sides]] die = d 6
d n is an n-sided die.
The Functor instance is straightforward:
instance Functor (ProbDist p) where fmap f (ProbDist pas) = ProbDist $ map (\(a,p) -> (f a, p)) pasThe Monad instance requires return and >>=. The return function merely takes an event and turns it into a distribution where that event occurs with probability 1. I find join easier to think about than >>=. The join function takes a nested distribution, where each outcome of the outer distribution specifies an inner distribution for the actual events, and collapses it into a regular, overall distribution. For example, suppose you put a biased coin and an unbiased coin in a bag, then pull one out and flip it:
bag :: ProbDist Double (ProbDist Double String) bag = ProbDist [ (biasedCoin, 0.5), (unbiasedCoin, 0.5) ]The join operator collapses this into a single ProbDist Double String:
ProbDist [("heads",0.3), ("tails",0.2), ("heads",0.25), ("tails",0.25)]It would be nice if join could combine the duplicate heads into a single ("heads", 0.55) entry. But that would force an Eq a constraint on the event type, which isn't allowed, because (>>=) must work for all data types, not just for instances of Eq. This is a problem with Haskell, not with the monad itself. It's the same problem that prevents one from making a good set monad in Haskell, even though categorially sets are a perfectly good monad. (The return function constructs singletons, and the join function is simply set union.) Maybe in the next language.
Perhaps someone else will find the >>= operator easier to understand than join? I don't know. Anyway, it's simple enough to derive once you understand join; here's the code:
instance (Num p) => Monad (ProbDist p) where return a = ProbDist [(a, 1)] (ProbDist pas) >>= f = ProbDist $ do (a, p) <- pas let (ProbDist pbs) = f a (b, q) <- pbs return (b, p*q)So now we can do some straightforward experiments:
liftM2 (+) (d 6) (d 6) ProbDist [(2,1 % 36),(3,1 % 36),(4,1 % 36),(5,1 % 36),(6,1 % 36),(7,1 % 36),(3,1 % 36),(4,1 % 36),(5,1 % 36),(6,1 % 36),(7,1 % 36),(8,1 % 36),(4,1 % 36),(5,1 % 36),(6,1 % 36),(7,1 % 36),(8,1 % 36),(9,1 % 36),(5,1 % 36),(6,1 % 36),(7,1 % 36),(8,1 % 36),(9,1 % 36),(10,1 % 36),(6,1 % 36),(7,1 % 36),(8,1 % 36),(9,1 % 36),(10,1 % 36),(11,1 % 36),(7,1 % 36),(8,1 % 36),(9,1 % 36),(10,1 % 36),(11,1 % 36),(12,1 % 36)]This is nasty-looking; we really need to merge the multiple listings of the same event. Here is a function to do that:
agglomerate :: (Num p, Eq b) => (a -> b) -> ProbDist p a -> ProbDist p b agglomerate f pd = ProbDist $ foldr insert  (unpd (fmap f pd)) where insert (k, p)  = [(k, p)] insert (k, p) ((k', p'):kps) | k == k' = (k, p+p'):kps | otherwise = (k', p'):(insert (k,p) kps) agg :: (Num p, Eq a) => ProbDist p a -> ProbDist p a agg = agglomerate idThen agg $ liftM2 (+) (d 6) (d 6) produces:
ProbDist [(12,1 % 36),(11,1 % 18),(10,1 % 12),(9,1 % 9), (8,5 % 36),(7,1 % 6),(6,5 % 36),(5,1 % 9), (4,1 % 12),(3,1 % 18),(2,1 % 36)]Hey, that's correct.
There must be a shorter way to write insert. It really bothers me, because it looks look it should be possible to do it as a fold. But I couldn't make it look any better.
You are not limited to calculating probabilities. The monad actually will count things. For example, let us throw three dice and count how many ways there are to throw various numbers of sixes:
eq6 n = if n == 6 then 1 else 0 agg $ liftM3 (\a b c -> eq6 a + eq6 b + eq6 c) die die die ProbDist [(3,1),(2,15),(1,75),(0,125)]There is one way to throw three sixes, 15 ways to throw two sixes, 75 ways to throw one six, and 125 ways to throw no sixes. So ProbDist is a misnomer.
It's easy to convert counts to probabilities:
probMap :: (p -> q) -> ProbDist p a -> ProbDist q a probMap f (ProbDist pds) = ProbDist $ (map (\(a,p) -> (a, f p))) pds normalize :: (Fractional p) => ProbDist p a -> ProbDist p a normalize pd@(ProbDist pas) = probMap (/ total) pd where total = sum . (map snd) $ pas normalize $ agg $ probMap toRational $ liftM3 (\a b c -> eq6 a + eq6 b + eq6 c) die die die ProbDist [(3,1 % 216),(2,5 % 72),(1,25 % 72),(0,125 % 216)]I think this is the first time I've gotten to write die die die in a computer program.
The do notation is very nice. Here we calculate the distribution where we roll four dice and discard the smallest:
stat = do a <- d 6 b <- d 6 c <- d 6 d <- d 6 return (a+b+c+d - minimum [a,b,c,d]) probMap fromRational $ agg stat ProbDist [(18,1.6203703703703703e-2), (17,4.1666666666666664e-2), (16,7.253086419753087e-2), (15,0.10108024691358025), (14,0.12345679012345678), (13,0.13271604938271606), (12,0.12885802469135801), (11,0.11419753086419752), (10,9.41358024691358e-2), (9,7.021604938271606e-2), (8,4.7839506172839504e-2), (7,2.9320987654320986e-2), (6,1.6203703703703703e-2), (5,7.716049382716049e-3), (4,3.0864197530864196e-3), (3,7.716049382716049e-4)]One thing I was hoping to get didn't work out. I had this idea that I'd be able to calculate the outcome of a game of craps like this:
dice = liftM2 (+) (d 6) (d 6) point n = do roll <- dice case roll of 7 -> return "lose" _ | roll == n = "win" _ | otherwise = point n craps = do roll <- dice case roll of 2 -> return "lose" 3 -> return "lose" 4 -> point 4 5 -> point 5 6 -> point 6 7 -> return "win" 8 -> point 8 9 -> point 9 10 -> point 10 11 -> return "win" 12 -> return "lose"This doesn't work at all; point is an infinite loop because the first value of dice, namely 2, causes a recursive call. I might be able to do something about this, but I'll have to think about it more.
It also occurred to me that the use of * in the definition of >>= / join could be generalized. A couple of years back I mentioned a paper of Green, Karvounarakis, and Tannen that discusses "provenance semirings". The idea is that each item in a database is annotated with some "provenance" information about why it is there, and you want to calculate the provenance for items in tables that are computed from table joins. My earlier explanation is here.
One special case of provenance information is that the provenances are probabilities that the database information is correct, and then the probabilities are calculated correctly for the joins, by multiplication and addition of probabilities. But in the general case the provenances are opaque symbols, and the multiplication and addition construct regular expressions over these symbols. One could generalize ProbDist similarly, and the ProbDist monad (even more of a misnomer this time) would calculate the provenance automatically. It occurs to me now that there's probably a natural way to view a database table join as a sort of Kleisli composition, but this article has gone on too long already.
Happy new year, everyone.
[ Addendum 20100103: unsurprisingly, this is not a new idea. Several readers wrote in with references to previous discussion of this monad, and related monads. It turns out that the idea goes back at least to 1981. ]
My thanks to Graham Hunter for his donation.