feat: added run_chat for chat models (#88)

pull/89/head
Laurel Orr 1 year ago committed by GitHub
parent afe0fc5a1d
commit 8548329be9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -80,8 +80,8 @@ You can also just set `export COHERE_API_KEY=<COHERE_API_KEY>` and not use `clie
You can see the model details and possible model inputs to `run()` via
```python
print(manifest.client_pool.get_client().get_model_params())
print(manifest.client_pool.get_client().get_model_inputs())
print(manifest.client_pool.get_current_client().get_model_params())
print(manifest.client_pool.get_current_client().get_model_inputs())
```
## Global Cache

@ -47,7 +47,7 @@
" cache_name=\"sqlite\",\n",
" cache_connection=\"my_sqlite_manifest.sqlite\"\n",
")\n",
"print(manifest.client_pool.get_client().get_model_params())"
"print(manifest.client_pool.get_current_client().get_model_params())"
]
},
{
@ -86,7 +86,7 @@
" cache_name=\"sqlite\",\n",
" cache_connection=\"my_sqlite_manifest.sqlite\"\n",
")\n",
"print(manifest_diff.client_pool.get_client().get_model_params())"
"print(manifest_diff.client_pool.get_current_client().get_model_params())"
]
},
{

@ -37,7 +37,7 @@
"from manifest import Manifest\n",
"\n",
"manifest = Manifest(client_name=\"openaiembedding\")\n",
"print(manifest.client_pool.get_client().get_model_params())"
"print(manifest.client_pool.get_next_client().get_model_params())"
]
},
{
@ -100,7 +100,7 @@
" cache_name=\"sqlite\",\n",
" cache_connection=\"my_sqlite_manifest.sqlite\"\n",
")\n",
"print(manifest.client_pool.get_client().get_model_params())"
"print(manifest.client_pool.get_next_client().get_model_params())"
]
},
{

@ -158,12 +158,20 @@ class ClientConnectionPool:
for client in self.client_pool:
client.close()
def get_client(self) -> Client:
def num_clients(self) -> int:
"""Get number of clients."""
return len(self.client_pool)
def get_next_client(self) -> Client:
"""Get client."""
client_int = self.scheduler.get_client()
self.current_client_id = client_int
return self.client_pool[client_int]
def get_current_client(self) -> Client:
"""Get current client."""
return self.client_pool[self.current_client_id]
def start_timer(self) -> None:
"""Start timer."""
self.client_pool_metrics[self.current_client_id].start = time.time()

@ -309,7 +309,7 @@ class Manifest:
"""
is_batch = isinstance(prompt, list)
# Get the client to run
client = self.client_pool.get_client()
client = self.client_pool.get_next_client()
stop_token = stop_token if stop_token is not None else self.stop_token
# Must pass kwargs as dict for client "pop" methods removed used arguments
request_params = client.get_request(prompt, kwargs)
@ -386,10 +386,10 @@ class Manifest:
if chunk_size > 0:
for i in range(0, len(prompts), chunk_size):
prompt_chunks.append(
(self.client_pool.get_client(), prompts[i : i + chunk_size])
(self.client_pool.get_next_client(), prompts[i : i + chunk_size])
)
else:
prompt_chunks = [(self.client_pool.get_client(), prompts)]
prompt_chunks = [(self.client_pool.get_next_client(), prompts)]
# Run the chunks
tasks = []
@ -487,7 +487,7 @@ class Manifest:
"""
is_batch = False
# Get the client to run
client = self.client_pool.get_client()
client = self.client_pool.get_next_client()
# Get a request for an empty prompt to handle all kwargs
request_params = client.get_request("", kwargs)
# Add prompt and cast as chat request
@ -543,7 +543,7 @@ class Manifest:
Returns:
response from prompt.
"""
client = self.client_pool.get_client()
client = self.client_pool.get_next_client()
# Must pass kwargs as dict for client "pop" methods removed used arguments
request_params = client.get_request(prompt, kwargs)
request_params_as_score = LMScoreRequest(**request_params.to_dict())

@ -46,13 +46,13 @@ def test_timing() -> None:
client_connection2 = ClientConnection(client_name="dummy")
connection_pool = ClientConnectionPool([client_connection1, client_connection2])
connection_pool.get_client()
connection_pool.get_next_client()
assert connection_pool.current_client_id == 0
connection_pool.start_timer()
time.sleep(2)
connection_pool.end_timer()
connection_pool.get_client()
connection_pool.get_next_client()
assert connection_pool.current_client_id == 1
connection_pool.start_timer()
time.sleep(1)

@ -43,7 +43,7 @@ def test_init(sqlite_cache: str) -> None:
cache_connection=sqlite_cache,
)
assert len(manifest.client_pool.client_pool) == 1
client = manifest.client_pool.get_client()
client = manifest.client_pool.get_next_client()
assert isinstance(client, DummyClient)
assert isinstance(manifest.cache, SQLiteCache)
assert client.n == 1 # type: ignore
@ -56,7 +56,7 @@ def test_init(sqlite_cache: str) -> None:
stop_token="\n",
)
assert len(manifest.client_pool.client_pool) == 1
client = manifest.client_pool.get_client()
client = manifest.client_pool.get_next_client()
assert isinstance(client, DummyClient)
assert isinstance(manifest.cache, NoopCache)
assert client.n == 3 # type: ignore
@ -81,7 +81,7 @@ def test_run(sqlite_cache: str, n: int, return_response: bool) -> None:
assert str(exc_info.value) == "[('bad_input', 5)] arguments are not recognized."
# Allow params in the request object but not in the client to go through
assert "top_k" not in manifest.client_pool.get_client().PARAMS
assert "top_k" not in manifest.client_pool.get_next_client().PARAMS
result = manifest.run(prompt, return_response=return_response, top_k=5)
assert result is not None

Loading…
Cancel
Save