diff --git a/ai_talks/AMUseBotBackend/src/tools/search.py b/ai_talks/AMUseBotBackend/src/tools/search.py index 5f3164e..d45048a 100644 --- a/ai_talks/AMUseBotBackend/src/tools/search.py +++ b/ai_talks/AMUseBotBackend/src/tools/search.py @@ -3,7 +3,10 @@ import sys import numpy as np from rank_bm25 import BM25Okapi +from dotenv import load_dotenv +load_dotenv() +RECIPES_INGREDIENTS_PATH = os.getenv('RECIPES_INGREDIENTS_PATH') def weighted_search(tokenized_query, bm25_recipes, bm25_ingredients, tok_text_recipes, tok_text_ingredients, weight_recipe=10, weight_ingredient=1): @@ -29,7 +32,7 @@ def search_recipe(query): tok_text_ingredients = [] tok_text_recipes = [] - with open('AMUseBotFront/ai_talks/AMUseBotBackend/utils/tools/ingredients_recipes_merged.csv', 'r') as file: + with open(RECIPES_INGREDIENTS_PATH, 'r') as file: for line in file: line = line.split(', ') ingredients_splitted = [x for x in line[1].split(',')] diff --git a/ai_talks/chat.py b/ai_talks/chat.py index e38fd4a..87d2835 100644 --- a/ai_talks/chat.py +++ b/ai_talks/chat.py @@ -3,6 +3,7 @@ from random import randrange from AMUseBotBackend.src.DP.dp import DP from AMUseBotBackend.src.DST.dst import DST +import streamlit.components.v1 as components import graphviz import streamlit as st @@ -10,6 +11,10 @@ from PIL import Image from src.utils.conversation import get_user_input, show_chat_buttons, show_conversation from src.utils.lang import en +# from live_asr import LiveWav2Vec2 +# english_model = "facebook/wav2vec2-large-960h-lv60-self" +# asr = LiveWav2Vec2(english_model,device_name="default") +# asr.start() import os from dotenv import load_dotenv @@ -86,6 +91,19 @@ def show_graph(): pass st.graphviz_chart(graph) +# def mermaid(code: str) -> None: +# components.html( +# f""" +#
+# {code} +#+# +# +# """ +# ) def main() -> None: c1, c2 = st.columns(2) @@ -108,7 +126,11 @@ def main() -> None: get_user_input() show_chat_buttons() - + # mermaid("""graph TD; + # A -->|Dupa|B; + # A --> C; + # B --> D; + # C --> D;""") show_conversation() with st.sidebar: show_graph() diff --git a/ai_talks/src/utils/conversation.py b/ai_talks/src/utils/conversation.py index 67c9340..6a6e5be 100644 --- a/ai_talks/src/utils/conversation.py +++ b/ai_talks/src/utils/conversation.py @@ -44,8 +44,19 @@ def get_user_input(): def on_send(): st.session_state.past.append(st.session_state.user_text) +def startASR(): + try: + while True: + text, sample_length, inference_time = asr.get_last_text() + st.write(f"{sample_length:.3f}s" + + f"\t{inference_time:.3f}s" + + f"\t{text}") + + except KeyboardInterrupt: + asr.stop() + def show_chat_buttons() -> None: - b0, b1, b2 = st.columns(3) + b0, b1, b2, b3 = st.columns(4) with b0, b1, b2: b0.button(label=st.session_state.locale.chat_run_btn, on_click=on_send) b1.button(label=st.session_state.locale.chat_clear_btn, on_click=clear_chat) @@ -55,6 +66,7 @@ def show_chat_buttons() -> None: file_name="ai-talks-chat.json", mime="application/json", ) + b3.button(label="voice", on_click=startASR) # def show_chat(ai_content: str, user_text: str) -> None: # first_message = True @@ -77,12 +89,23 @@ def show_chat_buttons() -> None: # if first_message: # print('message 3') # message(st.session_state.generated[-1], key=str(-1), seed=st.session_state.seed) - +def scroll_to_bottom(): + st.markdown( + """ + + """, + unsafe_allow_html=True, + ) + def show_chat() -> None: for i in range(len(st.session_state.past)): message(st.session_state.generated[i], key=str(i), seed=st.session_state.seed) message(st.session_state.past[i], is_user=True, key=str(i) + "_user", seed=st.session_state.seed) message(st.session_state.generated[-1], key=str(-1), seed=st.session_state.seed) + scroll_to_bottom() def show_conversation() -> None: if st.session_state.messages: @@ -110,6 +133,9 @@ def show_conversation() -> None: #random_str = ''.join(choices(string.ascii_uppercase + string.digits, k=5)) ai_content = st.session_state.generated[-1] st.session_state.messages.append({"role": "assistant", "content": ai_content}) + container = st.container() + container.write("This is inside the container") + st.write("This is outside the container") show_chat() st.divider() show_audio_player(ai_content)