Spaces:
Running
Running
| """Web app page for showing codes for different examples in the dataset.""" | |
| import streamlit as st | |
| from streamlit_extras.switch_page_button import switch_page | |
| import code_search_utils | |
| import webapp_utils | |
| webapp_utils.load_widget_state() | |
| if "cb_acts" not in st.session_state: | |
| switch_page("Code_Browser") | |
| total_examples = 2000 | |
| prec_threshold = 0.01 | |
| model_name = st.session_state["model_name_id"] | |
| seq_len = st.session_state["seq_len"] | |
| tokens_text = st.session_state["tokens_text"] | |
| tokens_str = st.session_state["tokens_str"] | |
| cb_acts = st.session_state["cb_acts"] | |
| act_count_ft_tkns = st.session_state["act_count_ft_tkns"] | |
| gcb = st.session_state["gcb"] | |
| def get_example_topic_codes(example_id): | |
| """Get topic codes for the given example id.""" | |
| token_pos_ids = [(example_id, i) for i in range(seq_len)] | |
| all_codes = [] | |
| for cb_name, cb in cb_acts.items(): | |
| base_cb_name = code_search_utils.convert_to_base_name(cb_name, gcb=gcb) | |
| codes, prec, rec, code_acts = code_search_utils.get_code_precision_and_recall( | |
| token_pos_ids, | |
| cb, | |
| act_count_ft_tkns[base_cb_name], | |
| ) | |
| prec_sat_idx = prec >= prec_threshold | |
| codes, prec, rec, code_acts = ( | |
| codes[prec_sat_idx], | |
| prec[prec_sat_idx], | |
| rec[prec_sat_idx], | |
| code_acts[prec_sat_idx], | |
| ) | |
| rec_sat_idx = rec >= recall_threshold | |
| codes, prec, rec, code_acts = ( | |
| codes[rec_sat_idx], | |
| prec[rec_sat_idx], | |
| rec[rec_sat_idx], | |
| code_acts[rec_sat_idx], | |
| ) | |
| codes_pr = list(zip(codes, prec, rec, code_acts)) | |
| all_codes.append((cb_name, codes_pr)) | |
| return all_codes | |
| def find_next_example(example_id): | |
| """Find the example after `example_id` that has topic codes.""" | |
| initial_example_id = example_id | |
| example_id += 1 | |
| while example_id != initial_example_id: | |
| all_codes = get_example_topic_codes(example_id) | |
| codes_found = sum([len(code_pr_infos) for _, code_pr_infos in all_codes]) | |
| if codes_found > 0: | |
| st.session_state["example_id"] = example_id | |
| return | |
| example_id = (example_id + 1) % total_examples | |
| st.error( | |
| f"No examples found at the specified recall threshold: {recall_threshold}.", | |
| icon="🚨", | |
| ) | |
| def redirect_to_main_with_code(code, layer, head): | |
| """Redirect to main page with the given code.""" | |
| st.session_state["ct_act_code"] = code | |
| st.session_state["ct_act_layer"] = layer | |
| if st.session_state["is_attn"]: | |
| st.session_state["ct_act_head"] = head | |
| switch_page("Code Browser") | |
| def show_examples_for_topic_code(code, layer, head, code_act_ratio=0.3): | |
| """Show examples that the code activates on.""" | |
| ex_acts, _ = webapp_utils.get_code_acts( | |
| model_name, | |
| tokens_str, | |
| code, | |
| layer, | |
| head, | |
| ctx_size=5, | |
| return_example_list=True, | |
| ) | |
| filt_ex_acts = [] | |
| for act_str, num_acts in ex_acts: | |
| if num_acts > seq_len * code_act_ratio: | |
| filt_ex_acts.append(act_str) | |
| st.markdown("#### Examples for Code") | |
| st.markdown( | |
| webapp_utils.escape_markdown("".join(filt_ex_acts)), unsafe_allow_html=True | |
| ) | |
| is_attn = st.session_state["is_attn"] | |
| st.markdown("## Topic Code") | |
| topic_code_description = ( | |
| "Topic codes are codes that activate many different times on passages that describe a particular" | |
| " topic or concept (e.g. “fire”). This interface provides a way to search for such codes by looking" | |
| " at different examples in the dataset (ExampleID) and finding codes that activate on some fraction" | |
| " of the tokens in that example (Recall Threshold). Decrease the Recall Threshold to view more possible" | |
| " topic codes and increase it to see fewer. Click “Find Next Example” to find the next example with at" | |
| " least one code firing on that example above the Recall Threshold.\n\n" | |
| "Topic codes are displayed for the codebook model selected on the Code Browser page. To view topic codes" | |
| " for a different model, go to the Code Browser page and select a different model." | |
| ) | |
| st.write(topic_code_description) | |
| ex_col, r_col, trunc_col, sort_col = st.columns([1, 1, 1, 1]) | |
| example_id = ex_col.number_input( | |
| "Example ID", | |
| 0, | |
| total_examples - 1, | |
| 0, | |
| key="example_id", | |
| ) | |
| recall_threshold = r_col.slider( | |
| "Recall Threshold", | |
| 0.0, | |
| 1.0, | |
| 0.2, | |
| key="recall", | |
| help="Recall Threshold is the minimum fraction of tokens in the example that the code must activate on.", | |
| ) | |
| example_truncation = trunc_col.number_input( | |
| "Max Output Chars", 0, 102400, 1024, key="max_chars" | |
| ) | |
| sort_by_options = ["Precision", "Recall", "Num Acts"] | |
| sort_by_name = sort_col.radio( | |
| "Sort By", | |
| sort_by_options, | |
| index=1, | |
| horizontal=True, | |
| help="Sorts the codes by the selected metric.", | |
| ) | |
| sort_by = sort_by_options.index(sort_by_name) | |
| button = st.button( | |
| "Find Next Example", | |
| key="find_next_example", | |
| on_click=find_next_example, | |
| args=(example_id,), | |
| help="Find an example which has codes above the recall threshold.", | |
| ) | |
| st.markdown("### Example Text") | |
| trunc_suffix = "..." if example_truncation < len(tokens_text[example_id]) else "" | |
| st.write(tokens_text[example_id][:example_truncation] + trunc_suffix) | |
| cols = st.columns(7 if is_attn else 6) | |
| cols[0].markdown("Search", help="Button to see token activations for the code.") | |
| cols[1].write("Layer") | |
| if is_attn: | |
| cols[2].write("Head") | |
| cols[-4].write("Code") | |
| cols[-3].write("Precision") | |
| cols[-2].write("Recall") | |
| cols[-1].markdown( | |
| "Num Acts", | |
| help="Number of tokens that the code activates on in the acts dataset.", | |
| ) | |
| all_codes = get_example_topic_codes(example_id) | |
| all_codes = [ | |
| (cb_name, code_pr_info) | |
| for cb_name, code_pr_infos in all_codes | |
| for code_pr_info in code_pr_infos | |
| ] | |
| all_codes = sorted(all_codes, key=lambda x: x[1][1 + sort_by], reverse=True) | |
| for cb_name, (code, p, r, acts) in all_codes: | |
| cols = st.columns(7 if is_attn else 6) | |
| code_button = cols[0].button( | |
| "🔍", | |
| key=f"ex-code-{code}-{cb_name}", | |
| ) | |
| layer, head = code_search_utils.get_layer_head_from_adv_name(cb_name) | |
| cols[1].write(str(layer)) | |
| if is_attn: | |
| cols[2].write(str(head)) | |
| cols[-4].write(code) | |
| cols[-3].write(f"{p*100:.2f}%") | |
| cols[-2].write(f"{r*100:.2f}%") | |
| cols[-1].write(str(acts)) | |
| if code_button: | |
| show_examples_for_topic_code( | |
| code, | |
| layer, | |
| head, | |
| code_act_ratio=recall_threshold, | |
| ) | |
| if len(all_codes) == 0: | |
| st.markdown( | |
| f"<div style='text-align:center'>No codes found at recall threshold = {recall_threshold}." | |
| " Consider decreasing the recall threshold.</div>", | |
| unsafe_allow_html=True, | |
| ) | |