Compare commits
1 Commits
Author | SHA1 | Date | |
---|---|---|---|
cb281aff65 |
@ -3,7 +3,10 @@ import sys
|
||||
|
||||
import numpy as np
|
||||
from rank_bm25 import BM25Okapi
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
RECIPES_INGREDIENTS_PATH = os.getenv('RECIPES_INGREDIENTS_PATH')
|
||||
|
||||
def weighted_search(tokenized_query, bm25_recipes, bm25_ingredients,
|
||||
tok_text_recipes, tok_text_ingredients, weight_recipe=10, weight_ingredient=1):
|
||||
@ -29,7 +32,7 @@ def search_recipe(query):
|
||||
|
||||
tok_text_ingredients = []
|
||||
tok_text_recipes = []
|
||||
with open('AMUseBotFront/ai_talks/AMUseBotBackend/utils/tools/ingredients_recipes_merged.csv', 'r') as file:
|
||||
with open(RECIPES_INGREDIENTS_PATH, 'r') as file:
|
||||
for line in file:
|
||||
line = line.split(', ')
|
||||
ingredients_splitted = [x for x in line[1].split(',')]
|
||||
|
@ -3,6 +3,7 @@ from random import randrange
|
||||
|
||||
from AMUseBotBackend.src.DP.dp import DP
|
||||
from AMUseBotBackend.src.DST.dst import DST
|
||||
import streamlit.components.v1 as components
|
||||
|
||||
import graphviz
|
||||
import streamlit as st
|
||||
@ -10,6 +11,10 @@ from PIL import Image
|
||||
from src.utils.conversation import get_user_input, show_chat_buttons, show_conversation
|
||||
from src.utils.lang import en
|
||||
|
||||
# from live_asr import LiveWav2Vec2
|
||||
# english_model = "facebook/wav2vec2-large-960h-lv60-self"
|
||||
# asr = LiveWav2Vec2(english_model,device_name="default")
|
||||
# asr.start()
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
|
||||
@ -86,6 +91,19 @@ def show_graph():
|
||||
pass
|
||||
st.graphviz_chart(graph)
|
||||
|
||||
# def mermaid(code: str) -> None:
|
||||
# components.html(
|
||||
# f"""
|
||||
# <pre class="mermaid">
|
||||
# {code}
|
||||
# </pre>
|
||||
#
|
||||
# <script type="module">
|
||||
# import mermaid from 'https://cdn.jsdelivr.net/npm/mermaid@10/dist/mermaid.esm.min.mjs';
|
||||
# mermaid.initialize({{ startOnLoad: true }});
|
||||
# </script>
|
||||
# """
|
||||
# )
|
||||
|
||||
def main() -> None:
|
||||
c1, c2 = st.columns(2)
|
||||
@ -108,7 +126,11 @@ def main() -> None:
|
||||
|
||||
get_user_input()
|
||||
show_chat_buttons()
|
||||
|
||||
# mermaid("""graph TD;
|
||||
# A -->|Dupa|B;
|
||||
# A --> C;
|
||||
# B --> D;
|
||||
# C --> D;""")
|
||||
show_conversation()
|
||||
with st.sidebar:
|
||||
show_graph()
|
||||
|
@ -44,8 +44,19 @@ def get_user_input():
|
||||
def on_send():
|
||||
st.session_state.past.append(st.session_state.user_text)
|
||||
|
||||
def startASR():
|
||||
try:
|
||||
while True:
|
||||
text, sample_length, inference_time = asr.get_last_text()
|
||||
st.write(f"{sample_length:.3f}s"
|
||||
+ f"\t{inference_time:.3f}s"
|
||||
+ f"\t{text}")
|
||||
|
||||
except KeyboardInterrupt:
|
||||
asr.stop()
|
||||
|
||||
def show_chat_buttons() -> None:
|
||||
b0, b1, b2 = st.columns(3)
|
||||
b0, b1, b2, b3 = st.columns(4)
|
||||
with b0, b1, b2:
|
||||
b0.button(label=st.session_state.locale.chat_run_btn, on_click=on_send)
|
||||
b1.button(label=st.session_state.locale.chat_clear_btn, on_click=clear_chat)
|
||||
@ -55,6 +66,7 @@ def show_chat_buttons() -> None:
|
||||
file_name="ai-talks-chat.json",
|
||||
mime="application/json",
|
||||
)
|
||||
b3.button(label="voice", on_click=startASR)
|
||||
|
||||
# def show_chat(ai_content: str, user_text: str) -> None:
|
||||
# first_message = True
|
||||
@ -77,12 +89,23 @@ def show_chat_buttons() -> None:
|
||||
# if first_message:
|
||||
# print('message 3')
|
||||
# message(st.session_state.generated[-1], key=str(-1), seed=st.session_state.seed)
|
||||
|
||||
def scroll_to_bottom():
|
||||
st.markdown(
|
||||
"""
|
||||
<script>
|
||||
const container = document.querySelector('.stContainer > div')
|
||||
container.scrollTop = container.scrollHeight
|
||||
</script>
|
||||
""",
|
||||
unsafe_allow_html=True,
|
||||
)
|
||||
|
||||
def show_chat() -> None:
|
||||
for i in range(len(st.session_state.past)):
|
||||
message(st.session_state.generated[i], key=str(i), seed=st.session_state.seed)
|
||||
message(st.session_state.past[i], is_user=True, key=str(i) + "_user", seed=st.session_state.seed)
|
||||
message(st.session_state.generated[-1], key=str(-1), seed=st.session_state.seed)
|
||||
scroll_to_bottom()
|
||||
|
||||
def show_conversation() -> None:
|
||||
if st.session_state.messages:
|
||||
@ -110,6 +133,9 @@ def show_conversation() -> None:
|
||||
#random_str = ''.join(choices(string.ascii_uppercase + string.digits, k=5))
|
||||
ai_content = st.session_state.generated[-1]
|
||||
st.session_state.messages.append({"role": "assistant", "content": ai_content})
|
||||
container = st.container()
|
||||
container.write("This is inside the container")
|
||||
st.write("This is outside the container")
|
||||
show_chat()
|
||||
st.divider()
|
||||
show_audio_player(ai_content)
|
||||
|
Loading…
Reference in New Issue
Block a user