|
@@ -64,6 +64,11 @@ for i in range(n_classes):
|
|
|
|
|
|
data[f"truth_vocs.{i}"] = samples["truth_vocs"][:,0,i]
|
|
|
|
|
|
+ data["alpha_dev"] = samples["alpha_dev"]
|
|
|
+ data["sigma_dev"] = samples["sigma_dev"]
|
|
|
+ data["beta_dev"] = samples["beta_dev"]
|
|
|
+ data["child_dev_age"] = samples["child_dev_age"][:,0]
|
|
|
+
|
|
|
if "mus" in samples:
|
|
|
for j in range(n_classes):
|
|
|
data[f"mus.{i}.{j}"] = samples["mus"][:,i,j]
|
|
@@ -71,6 +76,8 @@ for i in range(n_classes):
|
|
|
|
|
|
data = pd.DataFrame(data)
|
|
|
|
|
|
+pair_plot(data[["alpha_dev", "sigma_dev", "beta_dev", "child_dev_age"]], f"dev")
|
|
|
+
|
|
|
for i in range (1,n_classes):
|
|
|
cols = [
|
|
|
f"alpha_corpus_level.{i}", f"alpha_child_level.{i}",
|