def show_label_distribution(
sample_labels: Union[List[str], List[List[str]]],
all_labels: Optional[Union[List[str], List[List[str]]]] = None,
):
if sample_labels is not None:
st.header("Label Distribution")
label_counts = _collect_label_counts(sample_labels)
if all_labels is None:
label_chart = (
alt.Chart(label_counts, height=500, width=700)
.mark_bar()
.encode(
alt.X("Label", type="nominal"),
alt.Y("Proportion", type="quantitative"),
)
)
else:
label_counts["Label Set"] = "Sample"
all_label_counts = _collect_label_counts(all_labels)
all_label_counts["Label Set"] = "All Documents"
label_counts = pd.concat([label_counts, all_label_counts])
label_chart = (
alt.Chart(label_counts, width=100)
.mark_bar()
.encode(
alt.X(
"Label Set",
type="nominal",
title=None,
sort=["Sample", "All Documents"],
),
alt.Y("Proportion", type="quantitative"),
alt.Column(
"Label", type="nominal", header=alt.Header(labelAngle=0)
),
alt.Color("Label Set", type="nominal", legend=None),
)
)
st.altair_chart(label_chart)