Merge branch 'arc53:main' into main
@ -0,0 +1,23 @@
|
||||
repo:
|
||||
- '*'
|
||||
|
||||
github:
|
||||
- .github/**/*
|
||||
|
||||
application:
|
||||
- application/**/*
|
||||
|
||||
docs:
|
||||
- docs/**/*
|
||||
|
||||
extensions:
|
||||
- extensions/**/*
|
||||
|
||||
frontend:
|
||||
- frontend/**/*
|
||||
|
||||
scripts:
|
||||
- scripts/**/*
|
||||
|
||||
tests:
|
||||
- tests/**/*
|
@ -0,0 +1,15 @@
|
||||
# https://github.com/actions/labeler
|
||||
name: Pull Request Labeler
|
||||
on:
|
||||
- pull_request_target
|
||||
jobs:
|
||||
triage:
|
||||
permissions:
|
||||
contents: read
|
||||
pull-requests: write
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/labeler@v4
|
||||
with:
|
||||
repo-token: "${{ secrets.GITHUB_TOKEN }}"
|
||||
sync-labels: true
|
@ -1,27 +1,139 @@
|
||||
from application.llm.base import BaseLLM
|
||||
from application.core.settings import settings
|
||||
import requests
|
||||
import json
|
||||
import io
|
||||
|
||||
|
||||
|
||||
class LineIterator:
|
||||
"""
|
||||
A helper class for parsing the byte stream input.
|
||||
|
||||
The output of the model will be in the following format:
|
||||
```
|
||||
b'{"outputs": [" a"]}\n'
|
||||
b'{"outputs": [" challenging"]}\n'
|
||||
b'{"outputs": [" problem"]}\n'
|
||||
...
|
||||
```
|
||||
|
||||
While usually each PayloadPart event from the event stream will contain a byte array
|
||||
with a full json, this is not guaranteed and some of the json objects may be split across
|
||||
PayloadPart events. For example:
|
||||
```
|
||||
{'PayloadPart': {'Bytes': b'{"outputs": '}}
|
||||
{'PayloadPart': {'Bytes': b'[" problem"]}\n'}}
|
||||
```
|
||||
|
||||
This class accounts for this by concatenating bytes written via the 'write' function
|
||||
and then exposing a method which will return lines (ending with a '\n' character) within
|
||||
the buffer via the 'scan_lines' function. It maintains the position of the last read
|
||||
position to ensure that previous bytes are not exposed again.
|
||||
"""
|
||||
|
||||
def __init__(self, stream):
|
||||
self.byte_iterator = iter(stream)
|
||||
self.buffer = io.BytesIO()
|
||||
self.read_pos = 0
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
while True:
|
||||
self.buffer.seek(self.read_pos)
|
||||
line = self.buffer.readline()
|
||||
if line and line[-1] == ord('\n'):
|
||||
self.read_pos += len(line)
|
||||
return line[:-1]
|
||||
try:
|
||||
chunk = next(self.byte_iterator)
|
||||
except StopIteration:
|
||||
if self.read_pos < self.buffer.getbuffer().nbytes:
|
||||
continue
|
||||
raise
|
||||
if 'PayloadPart' not in chunk:
|
||||
print('Unknown event type:' + chunk)
|
||||
continue
|
||||
self.buffer.seek(0, io.SEEK_END)
|
||||
self.buffer.write(chunk['PayloadPart']['Bytes'])
|
||||
|
||||
class SagemakerAPILLM(BaseLLM):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.url = settings.SAGEMAKER_API_URL
|
||||
import boto3
|
||||
runtime = boto3.client(
|
||||
'runtime.sagemaker',
|
||||
aws_access_key_id='xxx',
|
||||
aws_secret_access_key='xxx',
|
||||
region_name='us-west-2'
|
||||
)
|
||||
|
||||
|
||||
self.endpoint = settings.SAGEMAKER_ENDPOINT
|
||||
self.runtime = runtime
|
||||
|
||||
|
||||
def gen(self, model, engine, messages, stream=False, **kwargs):
|
||||
context = messages[0]['content']
|
||||
user_question = messages[-1]['content']
|
||||
prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n"
|
||||
|
||||
response = requests.post(
|
||||
url=self.url,
|
||||
headers={
|
||||
"Content-Type": "application/json; charset=utf-8",
|
||||
},
|
||||
data=json.dumps({"input": prompt})
|
||||
)
|
||||
|
||||
return response.json()['answer']
|
||||
# Construct payload for endpoint
|
||||
payload = {
|
||||
"inputs": prompt,
|
||||
"stream": False,
|
||||
"parameters": {
|
||||
"do_sample": True,
|
||||
"temperature": 0.1,
|
||||
"max_new_tokens": 30,
|
||||
"repetition_penalty": 1.03,
|
||||
"stop": ["</s>", "###"]
|
||||
}
|
||||
}
|
||||
body_bytes = json.dumps(payload).encode('utf-8')
|
||||
|
||||
# Invoke the endpoint
|
||||
response = self.runtime.invoke_endpoint(EndpointName=self.endpoint,
|
||||
ContentType='application/json',
|
||||
Body=body_bytes)
|
||||
result = json.loads(response['Body'].read().decode())
|
||||
import sys
|
||||
print(result[0]['generated_text'], file=sys.stderr)
|
||||
return result[0]['generated_text'][len(prompt):]
|
||||
|
||||
def gen_stream(self, model, engine, messages, stream=True, **kwargs):
|
||||
raise NotImplementedError("Sagemaker does not support streaming")
|
||||
context = messages[0]['content']
|
||||
user_question = messages[-1]['content']
|
||||
prompt = f"### Instruction \n {user_question} \n ### Context \n {context} \n ### Answer \n"
|
||||
|
||||
|
||||
# Construct payload for endpoint
|
||||
payload = {
|
||||
"inputs": prompt,
|
||||
"stream": True,
|
||||
"parameters": {
|
||||
"do_sample": True,
|
||||
"temperature": 0.1,
|
||||
"max_new_tokens": 512,
|
||||
"repetition_penalty": 1.03,
|
||||
"stop": ["</s>", "###"]
|
||||
}
|
||||
}
|
||||
body_bytes = json.dumps(payload).encode('utf-8')
|
||||
|
||||
# Invoke the endpoint
|
||||
response = self.runtime.invoke_endpoint_with_response_stream(EndpointName=self.endpoint,
|
||||
ContentType='application/json',
|
||||
Body=body_bytes)
|
||||
#result = json.loads(response['Body'].read().decode())
|
||||
event_stream = response['Body']
|
||||
start_json = b'{'
|
||||
for line in LineIterator(event_stream):
|
||||
if line != b'' and start_json in line:
|
||||
#print(line)
|
||||
data = json.loads(line[line.find(start_json):].decode('utf-8'))
|
||||
if data['token']['text'] not in ["</s>", "###"]:
|
||||
print(data['token']['text'],end='')
|
||||
yield data['token']['text']
|
@ -1,4 +1,4 @@
|
||||
## To customize a main prompt navigate to `/application/prompt/combine_prompt.txt`
|
||||
## To customize a main prompt, navigate to `/application/prompt/combine_prompt.txt`
|
||||
|
||||
You can try editing it to see how the model responses.
|
||||
|
||||
|
@ -1,3 +1,3 @@
|
||||
# Please put appropriate value
|
||||
VITE_API_HOST=http://localhost:7091
|
||||
VITE_API_HOST=http://0.0.0.0:7091
|
||||
VITE_API_STREAMING=true
|
@ -0,0 +1,3 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" width="14" height="11" viewBox="0 0 14 11" fill="none">
|
||||
<path d="M4.95919 10.1906C4.84318 10.1902 4.72847 10.166 4.62222 10.1194C4.51596 10.0729 4.42041 10.0049 4.34152 9.91985L0.229353 5.54538C0.0756344 5.38157 -0.00671208 5.1634 0.000428491 4.93886C0.00756906 4.71433 0.103612 4.50183 0.267428 4.34812C0.431245 4.1944 0.649417 4.11205 0.873948 4.11919C1.09848 4.12633 1.31098 4.22238 1.4647 4.38619L4.95073 8.10068L12.0666 0.316329C12.1389 0.226405 12.2287 0.152193 12.3306 0.0982513C12.4326 0.0443098 12.5445 0.0117775 12.6594 0.00265255C12.7744 -0.00647237 12.89 0.00800286 12.9992 0.045189C13.1084 0.082375 13.2088 0.141487 13.2943 0.218894C13.3798 0.296301 13.4485 0.390369 13.4964 0.49532C13.5442 0.600272 13.57 0.713891 13.5723 0.829198C13.5746 0.944506 13.5534 1.05907 13.5098 1.16585C13.4662 1.27263 13.4012 1.36937 13.3189 1.45014L5.58533 9.91139C5.50718 9.998 5.41197 10.0675 5.30567 10.1156C5.19938 10.1636 5.0843 10.1892 4.96766 10.1906H4.95919Z" fill="#747474"/>
|
||||
</svg>
|
After Width: | Height: | Size: 1.0 KiB |
@ -0,0 +1,3 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" width="14" height="11" viewBox="0 0 14 11" fill="none">
|
||||
<path d="M4.95919 10.1906C4.84318 10.1902 4.72847 10.166 4.62222 10.1194C4.51596 10.0729 4.42041 10.0049 4.34152 9.91985L0.229353 5.54538C0.0756344 5.38157 -0.00671208 5.1634 0.000428491 4.93886C0.00756906 4.71433 0.103612 4.50183 0.267428 4.34812C0.431245 4.1944 0.649417 4.11205 0.873948 4.11919C1.09848 4.12633 1.31098 4.22238 1.4647 4.38619L4.95073 8.10068L12.0666 0.316329C12.1389 0.226405 12.2287 0.152193 12.3306 0.0982513C12.4326 0.0443098 12.5445 0.0117775 12.6594 0.00265255C12.7744 -0.00647237 12.89 0.00800286 12.9992 0.045189C13.1084 0.082375 13.2088 0.141487 13.2943 0.218894C13.3798 0.296301 13.4485 0.390369 13.4964 0.49532C13.5442 0.600272 13.57 0.713891 13.5723 0.829198C13.5746 0.944506 13.5534 1.05907 13.5098 1.16585C13.4662 1.27263 13.4012 1.36937 13.3189 1.45014L5.58533 9.91139C5.50718 9.998 5.41197 10.0675 5.30567 10.1156C5.19938 10.1636 5.0843 10.1892 4.96766 10.1906H4.95919Z" fill="#747474"/>
|
||||
</svg>
|
After Width: | Height: | Size: 1.0 KiB |
@ -0,0 +1,3 @@
|
||||
<svg width="14" height="17" stroke-width="1.15" viewBox="0 0 14 17" >
|
||||
<path d="M13.8013 5.01282L8.80645 0.191795C8.67953 0.0691399 8.50734 0.000152609 8.32774 0H6.09677C5.43801 0 4.80623 0.252586 4.34041 0.702193C3.8746 1.1518 3.6129 1.7616 3.6129 2.39744V3.48718H2.48387C1.82511 3.48718 1.19332 3.73977 0.727509 4.18937C0.261693 4.63898 0 5.24878 0 5.88462V14.6026C0 15.2384 0.261693 15.8482 0.727509 16.2978C1.19332 16.7474 1.82511 17 2.48387 17H8.80645C9.46521 17 10.097 16.7474 10.5628 16.2978C11.0286 15.8482 11.2903 15.2384 11.2903 14.6026V13.5128H11.5161C12.1749 13.5128 12.8067 13.2602 13.2725 12.8106C13.7383 12.361 14 11.7512 14 11.1154V5.44872C13.9929 5.28447 13.9219 5.12884 13.8013 5.01282ZM9.03226 2.23179L11.6877 4.79487H9.03226V2.23179ZM9.93548 14.6026C9.93548 14.8916 9.81653 15.1688 9.6048 15.3731C9.39306 15.5775 9.10589 15.6923 8.80645 15.6923H2.48387C2.18443 15.6923 1.89726 15.5775 1.68552 15.3731C1.47379 15.1688 1.35484 14.8916 1.35484 14.6026V5.88462C1.35484 5.5956 1.47379 5.31842 1.68552 5.11405C1.89726 4.90968 2.18443 4.79487 2.48387 4.79487H3.6129V11.1154C3.6129 11.7512 3.8746 12.361 4.34041 12.8106C4.80623 13.2602 5.43801 13.5128 6.09677 13.5128H9.93548V14.6026ZM11.5161 12.2051H6.09677C5.79734 12.2051 5.51016 12.0903 5.29843 11.886C5.08669 11.6816 4.96774 11.4044 4.96774 11.1154V2.39744C4.96774 2.10842 5.08669 1.83124 5.29843 1.62687C5.51016 1.4225 5.79734 1.30769 6.09677 1.30769H7.67742V5.44872C7.67976 5.62143 7.75188 5.78643 7.87842 5.90856C8.00496 6.03069 8.1759 6.10031 8.35484 6.10256H12.6452V11.1154C12.6452 11.4044 12.5262 11.6816 12.3145 11.886C12.1027 12.0903 11.8156 12.2051 11.5161 12.2051Z" fill="#949494"/>
|
||||
</svg>
|
After Width: | Height: | Size: 1.6 KiB |
@ -1 +1 @@
|
||||
<svg width="256px" height="256px" viewBox="0 -28.5 256 256" version="1.1" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" preserveAspectRatio="xMidYMid" fill="#000000"><g id="SVGRepo_bgCarrier" stroke-width="0"></g><g id="SVGRepo_tracerCarrier" stroke-linecap="round" stroke-linejoin="round"></g><g id="SVGRepo_iconCarrier"> <g> <path d="M216.856339,16.5966031 C200.285002,8.84328665 182.566144,3.2084988 164.041564,0 C161.766523,4.11318106 159.108624,9.64549908 157.276099,14.0464379 C137.583995,11.0849896 118.072967,11.0849896 98.7430163,14.0464379 C96.9108417,9.64549908 94.1925838,4.11318106 91.8971895,0 C73.3526068,3.2084988 55.6133949,8.86399117 39.0420583,16.6376612 C5.61752293,67.146514 -3.4433191,116.400813 1.08711069,164.955721 C23.2560196,181.510915 44.7403634,191.567697 65.8621325,198.148576 C71.0772151,190.971126 75.7283628,183.341335 79.7352139,175.300261 C72.104019,172.400575 64.7949724,168.822202 57.8887866,164.667963 C59.7209612,163.310589 61.5131304,161.891452 63.2445898,160.431257 C105.36741,180.133187 151.134928,180.133187 192.754523,160.431257 C194.506336,161.891452 196.298154,163.310589 198.110326,164.667963 C191.183787,168.842556 183.854737,172.420929 176.223542,175.320965 C180.230393,183.341335 184.861538,190.991831 190.096624,198.16893 C211.238746,191.588051 232.743023,181.531619 254.911949,164.955721 C260.227747,108.668201 245.831087,59.8662432 216.856339,16.5966031 Z M85.4738752,135.09489 C72.8290281,135.09489 62.4592217,123.290155 62.4592217,108.914901 C62.4592217,94.5396472 72.607595,82.7145587 85.4738752,82.7145587 C98.3405064,82.7145587 108.709962,94.5189427 108.488529,108.914901 C108.508531,123.290155 98.3405064,135.09489 85.4738752,135.09489 Z M170.525237,135.09489 C157.88039,135.09489 147.510584,123.290155 147.510584,108.914901 C147.510584,94.5396472 157.658606,82.7145587 170.525237,82.7145587 C183.391518,82.7145587 193.761324,94.5189427 193.539891,108.914901 C193.539891,123.290155 183.391518,135.09489 170.525237,135.09489 Z" fill="#61626b" fill-rule="nonzero"> </path> </g> </g></svg>
|
||||
<svg fill="none" width="20" height="20" viewBox="0 0 22 20" xmlns="http://www.w3.org/2000/svg" xml:space="preserve"><path d="M18.942 5.556a16.299 16.299 0 0 0-4.126-1.297c-.178.321-.385.754-.529 1.097a15.175 15.175 0 0 0-4.573 0 11.583 11.583 0 0 0-.535-1.097 16.274 16.274 0 0 0-4.129 1.3c-2.611 3.946-3.319 7.794-2.965 11.587a16.494 16.494 0 0 0 5.061 2.593 12.65 12.65 0 0 0 1.084-1.785 10.689 10.689 0 0 1-1.707-.831c.143-.106.283-.217.418-.331 3.291 1.539 6.866 1.539 10.118 0 .137.114.277.225.418.331-.541.326-1.114.606-1.71.832a12.52 12.52 0 0 0 1.084 1.785 16.46 16.46 0 0 0 5.064-2.595c.415-4.396-.709-8.209-2.973-11.589zM8.678 14.813c-.988 0-1.798-.922-1.798-2.045s.793-2.047 1.798-2.047 1.815.922 1.798 2.047c.001 1.123-.793 2.045-1.798 2.045zm6.644 0c-.988 0-1.798-.922-1.798-2.045s.793-2.047 1.798-2.047 1.815.922 1.798 2.047c0 1.123-.793 2.045-1.798 2.045z" fill="black" fill-opacity="0.54"/></svg>
|
Before Width: | Height: | Size: 2.0 KiB After Width: | Height: | Size: 912 B |
@ -0,0 +1,3 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" width="16" height="15" viewBox="0 0 16 15" fill="none">
|
||||
<path d="M10.0588 2.74568L12.5294 5.15732M8.41176 14H15M1.82353 10.7845L1 14L4.29412 13.1961L13.8355 3.88237C14.1443 3.58087 14.3178 3.172 14.3178 2.74568C14.3178 2.31936 14.1443 1.9105 13.8355 1.609L13.6939 1.47073C13.385 1.16932 12.9662 1 12.5294 1C12.0927 1 11.6738 1.16932 11.3649 1.47073L1.82353 10.7845Z" stroke="#747474" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>
|
||||
</svg>
|
After Width: | Height: | Size: 500 B |
@ -1 +1,5 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" height="1em" viewBox="0 0 448 512"><!--! Font Awesome Free 6.4.2 by @fontawesome - https://fontawesome.com License - https://fontawesome.com/license (Commercial License) Copyright 2023 Fonticons, Inc. --><style>svg{fill:#4c4d52}</style><path d="M400 32H48C21.5 32 0 53.5 0 80v352c0 26.5 21.5 48 48 48h352c26.5 0 48-21.5 48-48V80c0-26.5-21.5-48-48-48zM277.3 415.7c-8.4 1.5-11.5-3.7-11.5-8 0-5.4.2-33 .2-55.3 0-15.6-5.2-25.5-11.3-30.7 37-4.1 76-9.2 76-73.1 0-18.2-6.5-27.3-17.1-39 1.7-4.3 7.4-22-1.7-45-13.9-4.3-45.7 17.9-45.7 17.9-13.2-3.7-27.5-5.6-41.6-5.6-14.1 0-28.4 1.9-41.6 5.6 0 0-31.8-22.2-45.7-17.9-9.1 22.9-3.5 40.6-1.7 45-10.6 11.7-15.6 20.8-15.6 39 0 63.6 37.3 69 74.3 73.1-4.8 4.3-9.1 11.7-10.6 22.3-9.5 4.3-33.8 11.7-48.3-13.9-9.1-15.8-25.5-17.1-25.5-17.1-16.2-.2-1.1 10.2-1.1 10.2 10.8 5 18.4 24.2 18.4 24.2 9.7 29.7 56.1 19.7 56.1 19.7 0 13.9.2 36.5.2 40.6 0 4.3-3 9.5-11.5 8-66-22.1-112.2-84.9-112.2-158.3 0-91.8 70.2-161.5 162-161.5S388 165.6 388 257.4c.1 73.4-44.7 136.3-110.7 158.3zm-98.1-61.1c-1.9.4-3.7-.4-3.9-1.7-.2-1.5 1.1-2.8 3-3.2 1.9-.2 3.7.6 3.9 1.9.3 1.3-1 2.6-3 3zm-9.5-.9c0 1.3-1.5 2.4-3.5 2.4-2.2.2-3.7-.9-3.7-2.4 0-1.3 1.5-2.4 3.5-2.4 1.9-.2 3.7.9 3.7 2.4zm-13.7-1.1c-.4 1.3-2.4 1.9-4.1 1.3-1.9-.4-3.2-1.9-2.8-3.2.4-1.3 2.4-1.9 4.1-1.5 2 .6 3.3 2.1 2.8 3.4zm-12.3-5.4c-.9 1.1-2.8.9-4.3-.6-1.5-1.3-1.9-3.2-.9-4.1.9-1.1 2.8-.9 4.3.6 1.3 1.3 1.8 3.3.9 4.1zm-9.1-9.1c-.9.6-2.6 0-3.7-1.5s-1.1-3.2 0-3.9c1.1-.9 2.8-.2 3.7 1.3 1.1 1.5 1.1 3.3 0 4.1zm-6.5-9.7c-.9.9-2.4.4-3.5-.6-1.1-1.3-1.3-2.8-.4-3.5.9-.9 2.4-.4 3.5.6 1.1 1.3 1.3 2.8.4 3.5zm-6.7-7.4c-.4.9-1.7 1.1-2.8.4-1.3-.6-1.9-1.7-1.5-2.6.4-.6 1.5-.9 2.8-.4 1.3.7 1.9 1.8 1.5 2.6z"/></svg>
|
||||
<svg width="800px" height="800px" viewBox="0 0 24 24" xmlns="http://www.w3.org/2000/svg">
|
||||
<title>github</title>
|
||||
<rect width="24" height="24" fill="none"/>
|
||||
<path d="M12,2A10,10,0,0,0,8.84,21.5c.5.08.66-.23.66-.5V19.31C6.73,19.91,6.14,18,6.14,18A2.69,2.69,0,0,0,5,16.5c-.91-.62.07-.6.07-.6a2.1,2.1,0,0,1,1.53,1,2.15,2.15,0,0,0,2.91.83,2.16,2.16,0,0,1,.63-1.34C8,16.17,5.62,15.31,5.62,11.5a3.87,3.87,0,0,1,1-2.71,3.58,3.58,0,0,1,.1-2.64s.84-.27,2.75,1a9.63,9.63,0,0,1,5,0c1.91-1.29,2.75-1,2.75-1a3.58,3.58,0,0,1,.1,2.64,3.87,3.87,0,0,1,1,2.71c0,3.82-2.34,4.66-4.57,4.91a2.39,2.39,0,0,1,.69,1.85V21c0,.27.16.59.67.5A10,10,0,0,0,12,2Z" fill="black" fill-opacity="0.54"/>
|
||||
</svg>
|
||||
|
Before Width: | Height: | Size: 1.7 KiB After Width: | Height: | Size: 679 B |
@ -0,0 +1,3 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" width="12" height="15" viewBox="0 0 12 15" fill="none">
|
||||
<path d="M0.857143 13.3333C0.857143 13.7754 1.03775 14.1993 1.35925 14.5118C1.68074 14.8244 2.11677 15 2.57143 15H9.42857C9.88323 15 10.3193 14.8244 10.6408 14.5118C10.9622 14.1993 11.1429 13.7754 11.1429 13.3333V3.33333H0.857143V13.3333ZM2.57143 5H9.42857V13.3333H2.57143V5ZM9 0.833333L8.14286 0H3.85714L3 0.833333H0V2.5H12V0.833333H9Z" fill="#747474"/>
|
||||
</svg>
|
After Width: | Height: | Size: 459 B |
@ -0,0 +1,128 @@
|
||||
import { useEffect, useRef, useState } from 'react';
|
||||
import { useSelector } from 'react-redux';
|
||||
import Edit from '../assets/edit.svg';
|
||||
import Exit from '../assets/exit.svg';
|
||||
import Message from '../assets/message.svg';
|
||||
import CheckMark from '../assets/checkmark.svg';
|
||||
import Trash from '../assets/trash.svg';
|
||||
|
||||
import { selectConversationId } from '../preferences/preferenceSlice';
|
||||
import { useOutsideAlerter } from '../hooks';
|
||||
|
||||
interface ConversationProps {
|
||||
name: string;
|
||||
id: string;
|
||||
}
|
||||
interface ConversationTileProps {
|
||||
conversation: ConversationProps;
|
||||
selectConversation: (arg1: string) => void;
|
||||
onDeleteConversation: (arg1: string) => void;
|
||||
onSave: ({ name, id }: ConversationProps) => void;
|
||||
}
|
||||
|
||||
export default function ConversationTile({
|
||||
conversation,
|
||||
selectConversation,
|
||||
onDeleteConversation,
|
||||
onSave,
|
||||
}: ConversationTileProps) {
|
||||
const conversationId = useSelector(selectConversationId);
|
||||
const tileRef = useRef<HTMLInputElement>(null);
|
||||
|
||||
const [isEdit, setIsEdit] = useState(false);
|
||||
const [conversationName, setConversationsName] = useState('');
|
||||
useOutsideAlerter(
|
||||
tileRef,
|
||||
() =>
|
||||
handleSaveConversation({
|
||||
id: conversationId || conversation.id,
|
||||
name: conversationName,
|
||||
}),
|
||||
[conversationName],
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
setConversationsName(conversation.name);
|
||||
}, [conversation.name]);
|
||||
|
||||
function handleEditConversation() {
|
||||
setIsEdit(true);
|
||||
}
|
||||
|
||||
function handleSaveConversation(changedConversation: ConversationProps) {
|
||||
if (changedConversation.name.trim().length) {
|
||||
onSave(changedConversation);
|
||||
setIsEdit(false);
|
||||
} else {
|
||||
onClear();
|
||||
}
|
||||
}
|
||||
|
||||
function onClear() {
|
||||
setConversationsName(conversation.name);
|
||||
setIsEdit(false);
|
||||
}
|
||||
return (
|
||||
<div
|
||||
ref={tileRef}
|
||||
onClick={() => {
|
||||
selectConversation(conversation.id);
|
||||
}}
|
||||
className={`my-auto mx-4 mt-4 flex h-12 cursor-pointer items-center justify-between gap-4 rounded-3xl hover:bg-gray-100 ${
|
||||
conversationId === conversation.id ? 'bg-gray-100' : ''
|
||||
}`}
|
||||
>
|
||||
<div
|
||||
className={`flex ${
|
||||
conversationId === conversation.id ? 'w-[75%]' : 'w-[95%]'
|
||||
} gap-4`}
|
||||
>
|
||||
<img src={Message} className="ml-2 w-5"></img>
|
||||
{isEdit ? (
|
||||
<input
|
||||
autoFocus
|
||||
type="text"
|
||||
className="h-6 w-full px-1 text-sm font-normal leading-6 outline-[#0075FF] focus:outline-1"
|
||||
value={conversationName}
|
||||
onChange={(e) => setConversationsName(e.target.value)}
|
||||
/>
|
||||
) : (
|
||||
<p className="my-auto overflow-hidden overflow-ellipsis whitespace-nowrap text-sm font-normal leading-6 text-eerie-black">
|
||||
{conversationName}
|
||||
</p>
|
||||
)}
|
||||
</div>
|
||||
{conversationId === conversation.id ? (
|
||||
<div className="flex">
|
||||
<img
|
||||
src={isEdit ? CheckMark : Edit}
|
||||
alt="Edit"
|
||||
className="mr-2 h-4 w-4 cursor-pointer hover:opacity-50"
|
||||
id={`img-${conversation.id}`}
|
||||
onClick={(event) => {
|
||||
event.stopPropagation();
|
||||
isEdit
|
||||
? handleSaveConversation({
|
||||
id: conversationId,
|
||||
name: conversationName,
|
||||
})
|
||||
: handleEditConversation();
|
||||
}}
|
||||
/>
|
||||
<img
|
||||
src={isEdit ? Exit : Trash}
|
||||
alt="Exit"
|
||||
className={`mr-4 ${
|
||||
isEdit ? 'h-3 w-3' : 'h-4 w-4'
|
||||
}mt-px cursor-pointer hover:opacity-50`}
|
||||
id={`img-${conversation.id}`}
|
||||
onClick={(event) => {
|
||||
event.stopPropagation();
|
||||
isEdit ? onClear() : onDeleteConversation(conversation.id);
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
) : null}
|
||||
</div>
|
||||
);
|
||||
}
|
@ -0,0 +1,96 @@
|
||||
# FILEPATH: /path/to/test_sagemaker.py
|
||||
|
||||
import json
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from application.llm.sagemaker import SagemakerAPILLM, LineIterator
|
||||
|
||||
class TestSagemakerAPILLM(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.sagemaker = SagemakerAPILLM()
|
||||
self.context = "This is the context"
|
||||
self.user_question = "What is the answer?"
|
||||
self.messages = [
|
||||
{"content": self.context},
|
||||
{"content": "Some other message"},
|
||||
{"content": self.user_question}
|
||||
]
|
||||
self.prompt = f"### Instruction \n {self.user_question} \n ### Context \n {self.context} \n ### Answer \n"
|
||||
self.payload = {
|
||||
"inputs": self.prompt,
|
||||
"stream": False,
|
||||
"parameters": {
|
||||
"do_sample": True,
|
||||
"temperature": 0.1,
|
||||
"max_new_tokens": 30,
|
||||
"repetition_penalty": 1.03,
|
||||
"stop": ["</s>", "###"]
|
||||
}
|
||||
}
|
||||
self.payload_stream = {
|
||||
"inputs": self.prompt,
|
||||
"stream": True,
|
||||
"parameters": {
|
||||
"do_sample": True,
|
||||
"temperature": 0.1,
|
||||
"max_new_tokens": 512,
|
||||
"repetition_penalty": 1.03,
|
||||
"stop": ["</s>", "###"]
|
||||
}
|
||||
}
|
||||
self.body_bytes = json.dumps(self.payload).encode('utf-8')
|
||||
self.body_bytes_stream = json.dumps(self.payload_stream).encode('utf-8')
|
||||
self.response = {
|
||||
"Body": MagicMock()
|
||||
}
|
||||
self.result = [
|
||||
{
|
||||
"generated_text": "This is the generated text"
|
||||
}
|
||||
]
|
||||
self.response['Body'].read.return_value.decode.return_value = json.dumps(self.result)
|
||||
|
||||
def test_gen(self):
|
||||
with patch.object(self.sagemaker.runtime, 'invoke_endpoint',
|
||||
return_value=self.response) as mock_invoke_endpoint:
|
||||
output = self.sagemaker.gen(None, None, self.messages)
|
||||
mock_invoke_endpoint.assert_called_once_with(
|
||||
EndpointName=self.sagemaker.endpoint,
|
||||
ContentType='application/json',
|
||||
Body=self.body_bytes
|
||||
)
|
||||
self.assertEqual(output,
|
||||
self.result[0]['generated_text'][len(self.prompt):])
|
||||
|
||||
def test_gen_stream(self):
|
||||
with patch.object(self.sagemaker.runtime, 'invoke_endpoint_with_response_stream',
|
||||
return_value=self.response) as mock_invoke_endpoint:
|
||||
output = list(self.sagemaker.gen_stream(None, None, self.messages))
|
||||
mock_invoke_endpoint.assert_called_once_with(
|
||||
EndpointName=self.sagemaker.endpoint,
|
||||
ContentType='application/json',
|
||||
Body=self.body_bytes_stream
|
||||
)
|
||||
self.assertEqual(output, [])
|
||||
|
||||
class TestLineIterator(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.stream = [
|
||||
{'PayloadPart': {'Bytes': b'{"outputs": [" a"]}\n'}},
|
||||
{'PayloadPart': {'Bytes': b'{"outputs": [" challenging"]}\n'}},
|
||||
{'PayloadPart': {'Bytes': b'{"outputs": [" problem"]}\n'}}
|
||||
]
|
||||
self.line_iterator = LineIterator(self.stream)
|
||||
|
||||
def test_iter(self):
|
||||
self.assertEqual(iter(self.line_iterator), self.line_iterator)
|
||||
|
||||
def test_next(self):
|
||||
self.assertEqual(next(self.line_iterator), b'{"outputs": [" a"]}')
|
||||
self.assertEqual(next(self.line_iterator), b'{"outputs": [" challenging"]}')
|
||||
self.assertEqual(next(self.line_iterator), b'{"outputs": [" problem"]}')
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|