JAXServer
JAXServer 🧬
JAXServer
is one of offered utilities by EasyDeL, and it's help hosting using and doing process with LLMs
and its also hackable, so you can override your own method in it and use it support both mid-level and high-level apis
and also give you a Gradio Chat and Instruct Pre-build and ready to use page
- Supported Models are:
- EveryModel that have
transformers.FlaxPretrainedModel
as their Parent :)
- EveryModel that have
Input Configs
The config input is a dictionary that contains the following keys:
port
: The port number that the server will listen on.- Default Value has been set to
2059
- Default Value has been set to
batch_size
: The batch size for training.- Default Value has been set to
1
- Default Value has been set to
max_sequence_length
: The maximum length of a sequence.- Default Value has been set to
2048
- Default Value has been set to
max_new_tokens
: The maximum number of new tokens generated by the model in a single step.- Default Value has been set to
2048
- Default Value has been set to
max_compile_tokens
: The maximum number of tokens that can be streamed to the model in a single batch.- Default Value has been set to
32
- Default Value has been set to
temperature
: The temperature parameter for sampling from the model's output distribution.- Default Value has been set to
0.1
- Default Value has been set to
top_p
: The top-p parameter for sampling from the model's output distribution.- Default Value has been set to
0.95
- Default Value has been set to
top_k
: The top-k parameter for sampling from the model's output distribution.- Default Value has been set to
50
- Default Value has been set to
mesh_axes_shape
: The shape of the mesh axes for distributed training.- Default Value has been set to
(1, -1, 1, 1)
- Default Value has been set to
host
: The host address for the server.- Default Value has been set to
'0.0.0.0'
- Default Value has been set to
dtype
: The data type for the model's parameters.- Default Value has been set to
'fp16'
- Default Value has been set to
mesh_axes_names
: The names of the mesh axes for distributed training.- Default Value has been set to
("dp", "fsdp", "tp", "sp")
- Default Value has been set to
logging
: Whether the model should log its training progress.:- Default Value has been set to
True
- Default Value has been set to
stream_tokens_for_gradio
: Whether the model should stream tokens to Gradio.- Default Value has been set to
True
- Default Value has been set to
use_prefix_tokenizer
: Whether the model should use a prefix tokenizer.- Default Value has been set to
True
- Default Value has been set to
pre_compile
: Whether the model should be pre-compiled.- Default Value has been set to
True
- Default Value has been set to
JAXServer Functions
JAXServer
has format_chat
and format_instruct
funcs that you have to implement them to prompt your model
def format_instruct(self, system: str, instruction: str) -> str:
"""
Here you will get the system and instruction from user, and you can apply your prompting style
"""
raise NotImplementedError()
def format_chat(self, history: typing.List[str], prompt: str, system: typing.Union[str, None]) -> str:
"""
Here you will get the system, prompt and history from user, and you can apply your prompting style
"""
raise NotImplementedError()
JAXServer
Contains a method named .sample
and with using sample
method you can generate text from text
what does this do and how this works ? here's the inputs that sample
function takes in
def sample(self,
string,
*,
greedy: bool = False,
max_new_tokens: int = None,
**kwargs
) -> [str, int]:
...
- Arguments:
- string : String to be tokenized
(String)
- Greedy : Use Greedy Search Method or NO
(Bool)
- Max New Tokens : Number Of new Tokens to be Generated
(Int)
- string : String to be tokenized
- Yields:
- String : Next Tokens Predicted to String
(String)
- Number of Used Tokens : Number of Used Tokens to generate answer
(Int)
- String : Next Tokens Predicted to String
you can use this function outside the class like this
for string, num_used_tokens in server.sample(
'im a string',
greedy=False,
max_new_tokens=256 # or None to use Maximum numbers passed in Config
):
print(f'\r{num_used_tokens}: {string}', end="")
Gradio Functions 🤖
if you want to change gradio response functions you can override them like this
Chat Gradio Function
this is the default gradio functions and this is how it looks :
def sample_gradio_chat(self, prompt, history, max_new_tokens, system, greedy):
string = self.chat_format(history=history, prompt=prompt, system=system)
if not self.config.stream_tokens_for_gradio:
response = ""
for response, _ in self.sample(
string=string,
greedy=greedy,
max_new_tokens=max_new_tokens,
):
...
history.append([prompt, response])
else:
history.append([prompt, ""])
for response, _ in self.sample(
string=string,
greedy=greedy,
max_new_tokens=max_new_tokens,
):
history[-1][-1] = response
yield "", history
return "", history
and here's a example of changing that in order to use Llama Models
def sample_gradio_chat(self, prompt, history, max_new_tokens, system, greedy):
def prompt_llama2_model(message: str, chat_history,
system_prompt: str) -> str:
do_strip = False
texts = [f'<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n']
for user_input, response in chat_history:
user_input = user_input.strip() if do_strip else user_input
do_strip = True
texts.append(f'{user_input} [/INST] {response.strip()} </s><s>[INST] ')
message = message.strip() if do_strip else message
texts.append(f'{message} [/INST]')
return "".join(texts)
string = prompt_llama2_model(
message=prompt,
chat_history=history or [],
system_prompt=system
)
if not self.config.stream_tokens_for_gradio:
response = ""
for response, _ in self.sample(
string=string,
greedy=greedy,
max_new_tokens=max_new_tokens,
):
...
history.append([prompt, response])
else:
history.append([prompt, ""])
for response, _ in self.sample(
string=string,
greedy=greedy,
max_new_tokens=max_new_tokens
):
history[-1][-1] = response
yield "", history
return "", history
as you see you can easily override the functions just like how you want and use them with some simple changes,
and you can Also Use Their Gradio Client
or use JAXServer
FastAPI
builtin methods
FastAPI 🌪
Instruct API
to Override this api you have to code forward_instruct
just like what you want the default implementation of this
function is
def forward_instruct(self, data: InstructRequest):
if not self._funcs_generated:
return {
'status': "down"
}
string = self.config.instruct_format.format(instruct=data.prompt, system=data.system)
response, used_tokens = [None] * 2
for response, used_tokens in self.sample(
string=string,
greedy=data.greedy,
max_new_tokens=None
):
...
self.number_of_served_request_until_last_up_time += 1
return {
'input': f'{string}',
'response': response,
'tokens_used': used_tokens,
}
- BaseModel Class For PYData in FastAPI :
class InstructRequest(BaseModel):
prompt: str
system: Optional[str] = None
temperature: Optional[float] = None
greedy: Optional[bool] = False
- And here's an example of using this api via python and creating a simple client with using
requests
library in python :
import requests
content = {
'prompt': 'can you code a simple neural network in c++ for me',
'system': 'You are an AI assistant generate short and useful response',
'temperature': 0.1,
'greedy': False
}
response = requests.post(
url='http://ip:port/instruct',
json=content
).json()
print(response['response'])
# Response of model
print(response['input'])
# The input passed to the model
Chat API
to Override this api you have to code forward_chat
just like what you want the default implementation of this function
is
def forward_chat(self, data: ChatRequest):
if not self._funcs_generated:
return {
'status': "down"
}
history = self.process_chat_history(data.history or [])
history += self.config.prompt_prefix_chat + data.prompt + self.config.prompt_postfix_chat
response, used_tokens = [None] * 2
for response, used_tokens in self.process(
string=history,
greedy=data.greedy,
max_new_tokens=None
):
...
self.number_of_served_request_until_last_up_time += 1
return {
'input': f'{history}',
'response': response,
'tokens_used': used_tokens,
}
- BaseModel Class For PYData in FastAPI :
class ChatRequest(BaseModel):
prompt: str
history: Union[List[List], None] = None
temperature: Optional[float] = None
greedy: Optional[bool] = False
- And here's an example of using this api via python and creating a simple client with using
requests
library in python :
import requests
content = {
'prompt': 'can you code a simple neural network in c++ for me',
'history': [
['hello how are you', 'Hello\nthanks, im here to assist you you have any question that i could help you with']
],
'temperature': 0.1,
'greedy': False
}
response = requests.post(
url='http://ip:port/chat',
json=content
).json()
print(response['response'])
# Response of model
print(response['input'])
# The input passed to the model
Status 📣
Simply by sending a get API to https://ip:port/status
you will receive base information about the server and
how it being run, num cores in use, number of generated prompt , number of request and ...