Compare commits
2 Commits
61b597adcd
...
56830134d3
Author | SHA1 | Date | |
---|---|---|---|
|
56830134d3 | ||
|
3c52d24af0 |
35
embeddings.py
Normal file
35
embeddings.py
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from transformers import AutoModel, AutoTokenizer
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("Geotrend/distilbert-base-pl-cased")
|
||||||
|
model = AutoModel.from_pretrained("Geotrend/distilbert-base-pl-cased")
|
||||||
|
|
||||||
|
text = """
|
||||||
|
"nazwa": "Tatar wołowy","""
|
||||||
|
# "skladniki": [
|
||||||
|
# "wołowina",
|
||||||
|
# "cebula",
|
||||||
|
# "ogórki kiszone",
|
||||||
|
# "musztarda",
|
||||||
|
# "jajko",
|
||||||
|
# "pieprz",
|
||||||
|
# "sól"
|
||||||
|
# ],
|
||||||
|
# "alergeny": [
|
||||||
|
# "jajko",
|
||||||
|
# "gorczyca"
|
||||||
|
# ]
|
||||||
|
# """
|
||||||
|
encoded_input = tokenizer(text, return_tensors='pt', padding=True)
|
||||||
|
output = model(**encoded_input)
|
||||||
|
prompt = "tatar"
|
||||||
|
encoded_prompt = tokenizer(prompt, return_tensors='pt', padding=True)
|
||||||
|
output_prompt = model(**encoded_prompt)
|
||||||
|
|
||||||
|
text_embedding = output.last_hidden_state[:, 0, :]
|
||||||
|
prompt_embedding = output_prompt.last_hidden_state[:, 0, :]
|
||||||
|
cosine = torch.nn.functional.cosine_similarity(
|
||||||
|
text_embedding, prompt_embedding, dim=1)
|
||||||
|
|
||||||
|
print(cosine.item())
|
Loading…
Reference in New Issue
Block a user