Compare commits
1 Commits
Author | SHA1 | Date | |
---|---|---|---|
cb281aff65 |
@ -3,7 +3,10 @@ import sys
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from rank_bm25 import BM25Okapi
|
from rank_bm25 import BM25Okapi
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
RECIPES_INGREDIENTS_PATH = os.getenv('RECIPES_INGREDIENTS_PATH')
|
||||||
|
|
||||||
def weighted_search(tokenized_query, bm25_recipes, bm25_ingredients,
|
def weighted_search(tokenized_query, bm25_recipes, bm25_ingredients,
|
||||||
tok_text_recipes, tok_text_ingredients, weight_recipe=10, weight_ingredient=1):
|
tok_text_recipes, tok_text_ingredients, weight_recipe=10, weight_ingredient=1):
|
||||||
@ -29,7 +32,7 @@ def search_recipe(query):
|
|||||||
|
|
||||||
tok_text_ingredients = []
|
tok_text_ingredients = []
|
||||||
tok_text_recipes = []
|
tok_text_recipes = []
|
||||||
with open('AMUseBotFront/ai_talks/AMUseBotBackend/utils/tools/ingredients_recipes_merged.csv', 'r') as file:
|
with open(RECIPES_INGREDIENTS_PATH, 'r') as file:
|
||||||
for line in file:
|
for line in file:
|
||||||
line = line.split(', ')
|
line = line.split(', ')
|
||||||
ingredients_splitted = [x for x in line[1].split(',')]
|
ingredients_splitted = [x for x in line[1].split(',')]
|
||||||
|
@ -3,6 +3,7 @@ from random import randrange
|
|||||||
|
|
||||||
from AMUseBotBackend.src.DP.dp import DP
|
from AMUseBotBackend.src.DP.dp import DP
|
||||||
from AMUseBotBackend.src.DST.dst import DST
|
from AMUseBotBackend.src.DST.dst import DST
|
||||||
|
import streamlit.components.v1 as components
|
||||||
|
|
||||||
import graphviz
|
import graphviz
|
||||||
import streamlit as st
|
import streamlit as st
|
||||||
@ -10,6 +11,10 @@ from PIL import Image
|
|||||||
from src.utils.conversation import get_user_input, show_chat_buttons, show_conversation
|
from src.utils.conversation import get_user_input, show_chat_buttons, show_conversation
|
||||||
from src.utils.lang import en
|
from src.utils.lang import en
|
||||||
|
|
||||||
|
# from live_asr import LiveWav2Vec2
|
||||||
|
# english_model = "facebook/wav2vec2-large-960h-lv60-self"
|
||||||
|
# asr = LiveWav2Vec2(english_model,device_name="default")
|
||||||
|
# asr.start()
|
||||||
import os
|
import os
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
@ -86,6 +91,19 @@ def show_graph():
|
|||||||
pass
|
pass
|
||||||
st.graphviz_chart(graph)
|
st.graphviz_chart(graph)
|
||||||
|
|
||||||
|
# def mermaid(code: str) -> None:
|
||||||
|
# components.html(
|
||||||
|
# f"""
|
||||||
|
# <pre class="mermaid">
|
||||||
|
# {code}
|
||||||
|
# </pre>
|
||||||
|
#
|
||||||
|
# <script type="module">
|
||||||
|
# import mermaid from 'https://cdn.jsdelivr.net/npm/mermaid@10/dist/mermaid.esm.min.mjs';
|
||||||
|
# mermaid.initialize({{ startOnLoad: true }});
|
||||||
|
# </script>
|
||||||
|
# """
|
||||||
|
# )
|
||||||
|
|
||||||
def main() -> None:
|
def main() -> None:
|
||||||
c1, c2 = st.columns(2)
|
c1, c2 = st.columns(2)
|
||||||
@ -108,7 +126,11 @@ def main() -> None:
|
|||||||
|
|
||||||
get_user_input()
|
get_user_input()
|
||||||
show_chat_buttons()
|
show_chat_buttons()
|
||||||
|
# mermaid("""graph TD;
|
||||||
|
# A -->|Dupa|B;
|
||||||
|
# A --> C;
|
||||||
|
# B --> D;
|
||||||
|
# C --> D;""")
|
||||||
show_conversation()
|
show_conversation()
|
||||||
with st.sidebar:
|
with st.sidebar:
|
||||||
show_graph()
|
show_graph()
|
||||||
|
@ -44,8 +44,19 @@ def get_user_input():
|
|||||||
def on_send():
|
def on_send():
|
||||||
st.session_state.past.append(st.session_state.user_text)
|
st.session_state.past.append(st.session_state.user_text)
|
||||||
|
|
||||||
|
def startASR():
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
text, sample_length, inference_time = asr.get_last_text()
|
||||||
|
st.write(f"{sample_length:.3f}s"
|
||||||
|
+ f"\t{inference_time:.3f}s"
|
||||||
|
+ f"\t{text}")
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
asr.stop()
|
||||||
|
|
||||||
def show_chat_buttons() -> None:
|
def show_chat_buttons() -> None:
|
||||||
b0, b1, b2 = st.columns(3)
|
b0, b1, b2, b3 = st.columns(4)
|
||||||
with b0, b1, b2:
|
with b0, b1, b2:
|
||||||
b0.button(label=st.session_state.locale.chat_run_btn, on_click=on_send)
|
b0.button(label=st.session_state.locale.chat_run_btn, on_click=on_send)
|
||||||
b1.button(label=st.session_state.locale.chat_clear_btn, on_click=clear_chat)
|
b1.button(label=st.session_state.locale.chat_clear_btn, on_click=clear_chat)
|
||||||
@ -55,6 +66,7 @@ def show_chat_buttons() -> None:
|
|||||||
file_name="ai-talks-chat.json",
|
file_name="ai-talks-chat.json",
|
||||||
mime="application/json",
|
mime="application/json",
|
||||||
)
|
)
|
||||||
|
b3.button(label="voice", on_click=startASR)
|
||||||
|
|
||||||
# def show_chat(ai_content: str, user_text: str) -> None:
|
# def show_chat(ai_content: str, user_text: str) -> None:
|
||||||
# first_message = True
|
# first_message = True
|
||||||
@ -77,12 +89,23 @@ def show_chat_buttons() -> None:
|
|||||||
# if first_message:
|
# if first_message:
|
||||||
# print('message 3')
|
# print('message 3')
|
||||||
# message(st.session_state.generated[-1], key=str(-1), seed=st.session_state.seed)
|
# message(st.session_state.generated[-1], key=str(-1), seed=st.session_state.seed)
|
||||||
|
def scroll_to_bottom():
|
||||||
|
st.markdown(
|
||||||
|
"""
|
||||||
|
<script>
|
||||||
|
const container = document.querySelector('.stContainer > div')
|
||||||
|
container.scrollTop = container.scrollHeight
|
||||||
|
</script>
|
||||||
|
""",
|
||||||
|
unsafe_allow_html=True,
|
||||||
|
)
|
||||||
|
|
||||||
def show_chat() -> None:
|
def show_chat() -> None:
|
||||||
for i in range(len(st.session_state.past)):
|
for i in range(len(st.session_state.past)):
|
||||||
message(st.session_state.generated[i], key=str(i), seed=st.session_state.seed)
|
message(st.session_state.generated[i], key=str(i), seed=st.session_state.seed)
|
||||||
message(st.session_state.past[i], is_user=True, key=str(i) + "_user", seed=st.session_state.seed)
|
message(st.session_state.past[i], is_user=True, key=str(i) + "_user", seed=st.session_state.seed)
|
||||||
message(st.session_state.generated[-1], key=str(-1), seed=st.session_state.seed)
|
message(st.session_state.generated[-1], key=str(-1), seed=st.session_state.seed)
|
||||||
|
scroll_to_bottom()
|
||||||
|
|
||||||
def show_conversation() -> None:
|
def show_conversation() -> None:
|
||||||
if st.session_state.messages:
|
if st.session_state.messages:
|
||||||
@ -110,6 +133,9 @@ def show_conversation() -> None:
|
|||||||
#random_str = ''.join(choices(string.ascii_uppercase + string.digits, k=5))
|
#random_str = ''.join(choices(string.ascii_uppercase + string.digits, k=5))
|
||||||
ai_content = st.session_state.generated[-1]
|
ai_content = st.session_state.generated[-1]
|
||||||
st.session_state.messages.append({"role": "assistant", "content": ai_content})
|
st.session_state.messages.append({"role": "assistant", "content": ai_content})
|
||||||
|
container = st.container()
|
||||||
|
container.write("This is inside the container")
|
||||||
|
st.write("This is outside the container")
|
||||||
show_chat()
|
show_chat()
|
||||||
st.divider()
|
st.divider()
|
||||||
show_audio_player(ai_content)
|
show_audio_player(ai_content)
|
||||||
|
Loading…
Reference in New Issue
Block a user