1from typing import List
2
3import altair as alt
4import numpy as np
5
6from labml import analytics
7from labml.analytics import IndicatorCollection
10def calculate_percentages(means: List[np.ndarray], names: List[List[str]]):
11    normalized = []
12
13    for i in range(len(means)):
14        total = np.zeros_like(means[i])
15        for j, n in enumerate(names):
16            if n[-1][:-1] == names[i][-1][:-1]:
17                total += means[j]
18        normalized.append(means[i] / (total + np.finfo(float).eps))
19
20    return normalized
23def plot_infosets(indicators: IndicatorCollection, *,
24                  is_normalize: bool = True,
25                  width: int = 600,
26                  height: int = 300):
27    data, names = analytics.indicator_data(indicators)
28    step = [d[:, 0] for d in data]
29    means = [d[:, 5] for d in data]
30
31    if is_normalize:
32        normalized = calculate_percentages(means, names)
33    else:
34        normalized = means
35
36    common = names[0][-1]
37    for i, n in enumerate(names):
38        n = n[-1]
39        if len(n) < len(common):
40            common = common[:len(n)]
41        for j in range(len(common)):
42            if common[j] != n[j]:
43                common = common[:j]
44                break
45
46    table = []
47    for i, n in enumerate(names):
48        for j, v in zip(step[i], normalized[i]):
49            table.append({
50                'series': n[-1][len(common):],
51                'step': j,
52                'value': v
53            })
54
55    table = alt.Data(values=table)
56
57    selection = alt.selection_multi(fields=['series'], bind='legend')
58
59    return alt.Chart(table).mark_line().encode(
60        alt.X('step:Q'),
61        alt.Y('value:Q'),
62        alt.Color('series:N', scale=alt.Scale(scheme='tableau20')),
63        opacity=alt.condition(selection, alt.value(1), alt.value(0.0001))
64    ).add_selection(
65        selection
66    ).properties(width=width, height=height)