scrollbar attempt

This commit is contained in:
Adrian Charkiewicz 2023-06-15 23:18:09 +02:00
parent 5633019411
commit cb281aff65
3 changed files with 55 additions and 4 deletions

View File

@ -3,7 +3,10 @@ import sys
import numpy as np import numpy as np
from rank_bm25 import BM25Okapi from rank_bm25 import BM25Okapi
from dotenv import load_dotenv
load_dotenv()
RECIPES_INGREDIENTS_PATH = os.getenv('RECIPES_INGREDIENTS_PATH')
def weighted_search(tokenized_query, bm25_recipes, bm25_ingredients, def weighted_search(tokenized_query, bm25_recipes, bm25_ingredients,
tok_text_recipes, tok_text_ingredients, weight_recipe=10, weight_ingredient=1): tok_text_recipes, tok_text_ingredients, weight_recipe=10, weight_ingredient=1):
@ -29,7 +32,7 @@ def search_recipe(query):
tok_text_ingredients = [] tok_text_ingredients = []
tok_text_recipes = [] tok_text_recipes = []
with open('AMUseBotFront/ai_talks/AMUseBotBackend/utils/tools/ingredients_recipes_merged.csv', 'r') as file: with open(RECIPES_INGREDIENTS_PATH, 'r') as file:
for line in file: for line in file:
line = line.split(', ') line = line.split(', ')
ingredients_splitted = [x for x in line[1].split(',')] ingredients_splitted = [x for x in line[1].split(',')]

View File

@ -3,6 +3,7 @@ from random import randrange
from AMUseBotBackend.src.DP.dp import DP from AMUseBotBackend.src.DP.dp import DP
from AMUseBotBackend.src.DST.dst import DST from AMUseBotBackend.src.DST.dst import DST
import streamlit.components.v1 as components
import graphviz import graphviz
import streamlit as st import streamlit as st
@ -10,6 +11,10 @@ from PIL import Image
from src.utils.conversation import get_user_input, show_chat_buttons, show_conversation from src.utils.conversation import get_user_input, show_chat_buttons, show_conversation
from src.utils.lang import en from src.utils.lang import en
# from live_asr import LiveWav2Vec2
# english_model = "facebook/wav2vec2-large-960h-lv60-self"
# asr = LiveWav2Vec2(english_model,device_name="default")
# asr.start()
import os import os
from dotenv import load_dotenv from dotenv import load_dotenv
@ -86,6 +91,19 @@ def show_graph():
pass pass
st.graphviz_chart(graph) st.graphviz_chart(graph)
# def mermaid(code: str) -> None:
# components.html(
# f"""
# <pre class="mermaid">
# {code}
# </pre>
#
# <script type="module">
# import mermaid from 'https://cdn.jsdelivr.net/npm/mermaid@10/dist/mermaid.esm.min.mjs';
# mermaid.initialize({{ startOnLoad: true }});
# </script>
# """
# )
def main() -> None: def main() -> None:
c1, c2 = st.columns(2) c1, c2 = st.columns(2)
@ -108,7 +126,11 @@ def main() -> None:
get_user_input() get_user_input()
show_chat_buttons() show_chat_buttons()
# mermaid("""graph TD;
# A -->|Dupa|B;
# A --> C;
# B --> D;
# C --> D;""")
show_conversation() show_conversation()
with st.sidebar: with st.sidebar:
show_graph() show_graph()

View File

@ -44,8 +44,19 @@ def get_user_input():
def on_send(): def on_send():
st.session_state.past.append(st.session_state.user_text) st.session_state.past.append(st.session_state.user_text)
def startASR():
try:
while True:
text, sample_length, inference_time = asr.get_last_text()
st.write(f"{sample_length:.3f}s"
+ f"\t{inference_time:.3f}s"
+ f"\t{text}")
except KeyboardInterrupt:
asr.stop()
def show_chat_buttons() -> None: def show_chat_buttons() -> None:
b0, b1, b2 = st.columns(3) b0, b1, b2, b3 = st.columns(4)
with b0, b1, b2: with b0, b1, b2:
b0.button(label=st.session_state.locale.chat_run_btn, on_click=on_send) b0.button(label=st.session_state.locale.chat_run_btn, on_click=on_send)
b1.button(label=st.session_state.locale.chat_clear_btn, on_click=clear_chat) b1.button(label=st.session_state.locale.chat_clear_btn, on_click=clear_chat)
@ -55,6 +66,7 @@ def show_chat_buttons() -> None:
file_name="ai-talks-chat.json", file_name="ai-talks-chat.json",
mime="application/json", mime="application/json",
) )
b3.button(label="voice", on_click=startASR)
# def show_chat(ai_content: str, user_text: str) -> None: # def show_chat(ai_content: str, user_text: str) -> None:
# first_message = True # first_message = True
@ -77,12 +89,23 @@ def show_chat_buttons() -> None:
# if first_message: # if first_message:
# print('message 3') # print('message 3')
# message(st.session_state.generated[-1], key=str(-1), seed=st.session_state.seed) # message(st.session_state.generated[-1], key=str(-1), seed=st.session_state.seed)
def scroll_to_bottom():
st.markdown(
"""
<script>
const container = document.querySelector('.stContainer > div')
container.scrollTop = container.scrollHeight
</script>
""",
unsafe_allow_html=True,
)
def show_chat() -> None: def show_chat() -> None:
for i in range(len(st.session_state.past)): for i in range(len(st.session_state.past)):
message(st.session_state.generated[i], key=str(i), seed=st.session_state.seed) message(st.session_state.generated[i], key=str(i), seed=st.session_state.seed)
message(st.session_state.past[i], is_user=True, key=str(i) + "_user", seed=st.session_state.seed) message(st.session_state.past[i], is_user=True, key=str(i) + "_user", seed=st.session_state.seed)
message(st.session_state.generated[-1], key=str(-1), seed=st.session_state.seed) message(st.session_state.generated[-1], key=str(-1), seed=st.session_state.seed)
scroll_to_bottom()
def show_conversation() -> None: def show_conversation() -> None:
if st.session_state.messages: if st.session_state.messages:
@ -110,6 +133,9 @@ def show_conversation() -> None:
#random_str = ''.join(choices(string.ascii_uppercase + string.digits, k=5)) #random_str = ''.join(choices(string.ascii_uppercase + string.digits, k=5))
ai_content = st.session_state.generated[-1] ai_content = st.session_state.generated[-1]
st.session_state.messages.append({"role": "assistant", "content": ai_content}) st.session_state.messages.append({"role": "assistant", "content": ai_content})
container = st.container()
container.write("This is inside the container")
st.write("This is outside the container")
show_chat() show_chat()
st.divider() st.divider()
show_audio_player(ai_content) show_audio_player(ai_content)