import numpy as np
import seaborn as sns
import random
import pandas as pd
import matplotlib.pyplot as plt
import cvxpy as cp
from scipy.stats import binom, randint, rv_discrete
max_n = 20
state_space = np.arange(max_n+1)
state_space
array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20])
mu = randint.freeze(low=0, high=max_n+1)
nu1 = binom.freeze(n=max_n, p=0.7)
nu2 = binom.freeze(n=max_n, p=0.2)
# mixture of two binomials
nu = rv_discrete(values=(state_space, nu1.pmf(state_space)*0.5 + 0.5 * nu2.pmf(state_space)))
plt.bar(state_space, mu.pmf(state_space),
label='mu', alpha=0.6)
plt.bar(state_space, nu.pmf(state_space),
label='nu', alpha=0.6)
plt.legend()
<matplotlib.legend.Legend at 0x7fe02ab12550>
pairs = [{'x': i, 'y': j} for i in state_space
for j in state_space]
def jplot(data):
data = list(data)
sns.jointplot(data=pd.DataFrame(data), x='x', y='y', kind='hist',
marginal_kws={'discrete': True},
joint_kws={'discrete': True}
)
plt.show()
def q_independent(x,y):
return mu.pmf(x) * nu.pmf(y)
data = list(np.random.choice(pairs, 10000,
p=[q_independent(**pair) for pair in pairs]))
jplot(data)
state_space = [0,1,2,3] mu = lambda x: [1/4, 1/4, 1/4, 1/4][x] nu = lambda y: [1/8, 1/8, 0, 3/4][y]
cp.Problem
Q = cp.Variable((len(state_space), len(state_space)))
proba_constraints = [cp.sum(Q) == 1]\
+[Q >= 0]
marginals = [cp.sum(Q[:,j]) == nu.pmf(j) for j in state_space] \
+[cp.sum(Q[i,:]) == mu.pmf(i) for i in state_space]
problem = cp.Problem(
cp.Maximize(cp.sum(cp.diag(Q))),
constraints=proba_constraints + marginals
)
problem.solve()
# np.round(Q.value, 3)
0.6641521252440812
Q_value = Q.value.clip(0, None)
Q_value = Q_value/np.sum(Q_value)
np.sum(Q_value)
0.9999999999999999
def q_optimal(x,y):
return Q_value[x,y]
data = list(np.random.choice(pairs, 10000,
p=[q_optimal(**pair) for pair in pairs]))
jplot(data)
data = list([{'x': mu.ppf(U), 'y': nu.ppf(U)}
for U in np.random.uniform(size=10000)])
jplot(data)
data = list([{'x': mu.ppf(U), 'y': nu.ppf(1-U)}
for U in np.random.uniform(size=10000)])
jplot(data)
cp.Problem
Q = cp.Variable((len(state_space), len(state_space)))
proba_constraints = [cp.sum(Q) == 1]\
+[Q >= 0]
marginals = [cp.sum(Q[:,j]) == nu.pmf(j) for j in state_space] \
+[cp.sum(Q[i,:]) == mu.pmf(i) for i in state_space]
problem = cp.Problem(
cp.Minimize(cp.sum([Q[i,j]*abs(i-j)
for i in state_space
for j in state_space])),
constraints=proba_constraints + marginals
)
problem.solve()
1.3248694788553057
Q_value = Q.value.clip(0, None)
Q_value = Q_value/np.sum(Q_value)
np.sum(Q_value)
1.0
def q_wasserstein(x,y):
return Q_value[x,y]
data = list(np.random.choice(pairs, 10000,
p=[q_wasserstein(**pair) for pair in pairs]))
jplot(data)
min_array = np.minimum(mu.pmf(state_space),
nu.pmf(state_space))
p = min_array.sum()
pi_pmf = (1/p) * min_array
pi = rv_discrete(values=(state_space, pi_pmf))
mu_tilde_pmf = (mu.pmf(state_space)-min_array)/(1-p)
mu_tilde = rv_discrete(values=(state_space, mu_tilde_pmf))
nu_tilde_pmf = (nu.pmf(state_space)-min_array)/(1-p)
nu_tilde = rv_discrete(values=(state_space, nu_tilde_pmf))
plt.plot(state_space, pi.pmf(state_space))
plt.plot(state_space, mu.pmf(state_space))
plt.plot(state_space, mu_tilde.pmf(state_space))
plt.plot(state_space, nu.pmf(state_space))
plt.plot(state_space, nu_tilde.pmf(state_space))
[<matplotlib.lines.Line2D at 0x7fe01e734d50>]
plt.bar(state_space, mu.pmf(state_space), label='mu', alpha=0.5)
plt.bar(state_space, nu.pmf(state_space), label='nu', alpha=0.5)
plt.bar(state_space, min_array, label='minimum', alpha=0.7)
plt.legend()
<matplotlib.legend.Legend at 0x7fe01e68fb90>
plt.bar(state_space, mu.pmf(state_space), label='mu', alpha=0.2)
plt.bar(state_space, nu.pmf(state_space), label='nu', alpha=0.2)
plt.bar(state_space, min_array, label='minimum', alpha=1.0)
plt.bar(state_space, pi.pmf(state_space), label='pi', alpha=0.5)
plt.legend()
plt.show()
plt.plot(state_space, mu.pmf(state_space), label='mu')
plt.plot(state_space, nu.pmf(state_space), label='nu')
plt.plot(state_space, min_array, label='minimum')
plt.plot(state_space, pi.pmf(state_space), label='pi')
plt.legend()
plt.show()
plt.bar(state_space, mu.pmf(state_space), label='mu', alpha=0.8)
plt.bar(state_space, nu.pmf(state_space), label='nu', alpha=0.4)
plt.bar(state_space, mu_tilde.pmf(state_space), label='mu tilde', alpha=0.3)
plt.legend()
<matplotlib.legend.Legend at 0x7fe01de941d0>
Verification: curves should be indistinguishable if marginal for mu is correct
plt.plot(state_space, mu.pmf(state_space))
plt.plot(state_space,
p*pi.pmf(state_space) \
+ (1-p)*mu_tilde.pmf(state_space))
[<matplotlib.lines.Line2D at 0x7fe01e95d890>]
Verification: curves should be indistinguishable if marginal for nu is correct
plt.plot(state_space, nu.pmf(state_space))
plt.plot(state_space,
p*pi.pmf(state_space) \
+ (1-p)*nu_tilde.pmf(state_space))
[<matplotlib.lines.Line2D at 0x7fe01e0e3d50>]
data = []
for _ in range(10000):
B = np.random.binomial(1, p)
U0 = np.random.uniform()
U1 = np.random.uniform()
U2 = np.random.uniform()
Z = pi.ppf(U0)
X_tilde = mu_tilde.ppf(U1)
Y_tilde = nu_tilde.ppf(U2)
data.append({'x': B*Z + (1-B)*X_tilde,
'y': B*Z + (1-B)*Y_tilde})
jplot(data)
data = []
for _ in range(10000):
B = np.random.binomial(1, p)
U0 = np.random.uniform()
U1 = np.random.uniform()
Z = pi.ppf(U0)
X_tilde = mu_tilde.ppf(U1)
Y_tilde = nu_tilde.ppf(U1) # same as for X_tilde
data.append({'x': B*Z + (1-B)*X_tilde,
'y': B*Z + (1-B)*Y_tilde})
jplot(data)
data = []
for _ in range(10000):
B = np.random.binomial(1, p)
U0 = np.random.uniform()
U1 = np.random.uniform()
Z = pi.ppf(U0)
X_tilde = mu_tilde.ppf(U1)
Y_tilde = nu_tilde.ppf(1-U1) # reflection using U1
data.append({'x': B*Z + (1-B)*X_tilde,
'y': B*Z + (1-B)*Y_tilde})
jplot(data)