Compare commits

...

1 Commits

Author SHA1 Message Date
cb281aff65 scrollbar attempt 2023-06-15 23:18:09 +02:00
3 changed files with 55 additions and 4 deletions

View File

@ -3,7 +3,10 @@ import sys
import numpy as np
from rank_bm25 import BM25Okapi
from dotenv import load_dotenv
load_dotenv()
RECIPES_INGREDIENTS_PATH = os.getenv('RECIPES_INGREDIENTS_PATH')
def weighted_search(tokenized_query, bm25_recipes, bm25_ingredients,
tok_text_recipes, tok_text_ingredients, weight_recipe=10, weight_ingredient=1):
@ -29,7 +32,7 @@ def search_recipe(query):
tok_text_ingredients = []
tok_text_recipes = []
with open('AMUseBotFront/ai_talks/AMUseBotBackend/utils/tools/ingredients_recipes_merged.csv', 'r') as file:
with open(RECIPES_INGREDIENTS_PATH, 'r') as file:
for line in file:
line = line.split(', ')
ingredients_splitted = [x for x in line[1].split(',')]

View File

@ -3,6 +3,7 @@ from random import randrange
from AMUseBotBackend.src.DP.dp import DP
from AMUseBotBackend.src.DST.dst import DST
import streamlit.components.v1 as components
import graphviz
import streamlit as st
@ -10,6 +11,10 @@ from PIL import Image
from src.utils.conversation import get_user_input, show_chat_buttons, show_conversation
from src.utils.lang import en
# from live_asr import LiveWav2Vec2
# english_model = "facebook/wav2vec2-large-960h-lv60-self"
# asr = LiveWav2Vec2(english_model,device_name="default")
# asr.start()
import os
from dotenv import load_dotenv
@ -86,6 +91,19 @@ def show_graph():
pass
st.graphviz_chart(graph)
# def mermaid(code: str) -> None:
# components.html(
# f"""
# <pre class="mermaid">
# {code}
# </pre>
#
# <script type="module">
# import mermaid from 'https://cdn.jsdelivr.net/npm/mermaid@10/dist/mermaid.esm.min.mjs';
# mermaid.initialize({{ startOnLoad: true }});
# </script>
# """
# )
def main() -> None:
c1, c2 = st.columns(2)
@ -108,7 +126,11 @@ def main() -> None:
get_user_input()
show_chat_buttons()
# mermaid("""graph TD;
# A -->|Dupa|B;
# A --> C;
# B --> D;
# C --> D;""")
show_conversation()
with st.sidebar:
show_graph()

View File

@ -44,8 +44,19 @@ def get_user_input():
def on_send():
st.session_state.past.append(st.session_state.user_text)
def startASR():
try:
while True:
text, sample_length, inference_time = asr.get_last_text()
st.write(f"{sample_length:.3f}s"
+ f"\t{inference_time:.3f}s"
+ f"\t{text}")
except KeyboardInterrupt:
asr.stop()
def show_chat_buttons() -> None:
b0, b1, b2 = st.columns(3)
b0, b1, b2, b3 = st.columns(4)
with b0, b1, b2:
b0.button(label=st.session_state.locale.chat_run_btn, on_click=on_send)
b1.button(label=st.session_state.locale.chat_clear_btn, on_click=clear_chat)
@ -55,6 +66,7 @@ def show_chat_buttons() -> None:
file_name="ai-talks-chat.json",
mime="application/json",
)
b3.button(label="voice", on_click=startASR)
# def show_chat(ai_content: str, user_text: str) -> None:
# first_message = True
@ -77,12 +89,23 @@ def show_chat_buttons() -> None:
# if first_message:
# print('message 3')
# message(st.session_state.generated[-1], key=str(-1), seed=st.session_state.seed)
def scroll_to_bottom():
st.markdown(
"""
<script>
const container = document.querySelector('.stContainer > div')
container.scrollTop = container.scrollHeight
</script>
""",
unsafe_allow_html=True,
)
def show_chat() -> None:
for i in range(len(st.session_state.past)):
message(st.session_state.generated[i], key=str(i), seed=st.session_state.seed)
message(st.session_state.past[i], is_user=True, key=str(i) + "_user", seed=st.session_state.seed)
message(st.session_state.generated[-1], key=str(-1), seed=st.session_state.seed)
scroll_to_bottom()
def show_conversation() -> None:
if st.session_state.messages:
@ -110,6 +133,9 @@ def show_conversation() -> None:
#random_str = ''.join(choices(string.ascii_uppercase + string.digits, k=5))
ai_content = st.session_state.generated[-1]
st.session_state.messages.append({"role": "assistant", "content": ai_content})
container = st.container()
container.write("This is inside the container")
st.write("This is outside the container")
show_chat()
st.divider()
show_audio_player(ai_content)