Skip to content

JAXServer

JAXServer 🧬

JAXServer is one of offered utilities by EasyDeL, and it's help hosting using and doing process with LLMs and its also hackable, so you can override your own method in it and use it support both mid-level and high-level apis and also give you a Gradio Chat and Instruct Pre-build and ready to use page

  • Supported Models are:
    • EveryModel that have transformers.FlaxPretrainedModel as their Parent :)

Input Configs

The config input is a dictionary that contains the following keys:

  • port: The port number that the server will listen on.
    • Default Value has been set to 2059
  • batch_size: The batch size for training.
    • Default Value has been set to 1
  • max_sequence_length: The maximum length of a sequence.
    • Default Value has been set to 2048
  • max_new_tokens: The maximum number of new tokens generated by the model in a single step.
    • Default Value has been set to 2048
  • max_compile_tokens: The maximum number of tokens that can be streamed to the model in a single batch.
    • Default Value has been set to 32
  • temperature: The temperature parameter for sampling from the model's output distribution.
    • Default Value has been set to 0.1
  • top_p: The top-p parameter for sampling from the model's output distribution.
    • Default Value has been set to 0.95
  • top_k: The top-k parameter for sampling from the model's output distribution.
    • Default Value has been set to 50
  • mesh_axes_shape: The shape of the mesh axes for distributed training.
    • Default Value has been set to (1, -1, 1, 1)
  • host: The host address for the server.
    • Default Value has been set to '0.0.0.0'
  • dtype: The data type for the model's parameters.
    • Default Value has been set to 'fp16'
  • mesh_axes_names: The names of the mesh axes for distributed training.
    • Default Value has been set to ("dp", "fsdp", "tp", "sp")
  • logging: Whether the model should log its training progress.:
    • Default Value has been set to True
  • stream_tokens_for_gradio: Whether the model should stream tokens to Gradio.
    • Default Value has been set to True
  • use_prefix_tokenizer: Whether the model should use a prefix tokenizer.
    • Default Value has been set to True
  • pre_compile: Whether the model should be pre-compiled.
    • Default Value has been set to True

JAXServer Functions

JAXServer has format_chat and format_instruct funcs that you have to implement them to prompt your model


def format_instruct(self, system: str, instruction: str) -> str:
    """
    Here you will get the system and instruction from user, and you can apply your prompting style
    """
    raise NotImplementedError()


def format_chat(self, history: typing.List[str], prompt: str, system: typing.Union[str, None]) -> str:
    """
    Here you will get the system, prompt and history from user, and you can apply your prompting style
    """
    raise NotImplementedError()

JAXServer Contains a method named .sample and with using sample method you can generate text from text

what does this do and how this works ? here's the inputs that sample function takes in

def sample(self,
           string,
           *,
           greedy: bool = False,
           max_new_tokens: int = None,
           **kwargs
           ) -> [str, int]:
    ...
  • Arguments:
    • string : String to be tokenized (String)
    • Greedy : Use Greedy Search Method or NO (Bool)
    • Max New Tokens : Number Of new Tokens to be Generated (Int)
  • Yields:
    • String : Next Tokens Predicted to String (String)
    • Number of Used Tokens : Number of Used Tokens to generate answer (Int)

you can use this function outside the class like this

for string, num_used_tokens in server.sample(
        'im a string',
        greedy=False,
        max_new_tokens=256  # or None to use Maximum numbers passed in Config
):
    print(f'\r{num_used_tokens}: {string}', end="")

Gradio Functions 🤖

if you want to change gradio response functions you can override them like this

Chat Gradio Function

this is the default gradio functions and this is how it looks :

def sample_gradio_chat(self, prompt, history, max_new_tokens, system, greedy):
    string = self.chat_format(history=history, prompt=prompt, system=system)

    if not self.config.stream_tokens_for_gradio:
        response = ""
        for response, _ in self.sample(
                string=string,
                greedy=greedy,
                max_new_tokens=max_new_tokens,
        ):
            ...
        history.append([prompt, response])
    else:
        history.append([prompt, ""])
        for response, _ in self.sample(
                string=string,
                greedy=greedy,
                max_new_tokens=max_new_tokens,
        ):
            history[-1][-1] = response
            yield "", history
    return "", history

and here's a example of changing that in order to use Llama Models

def sample_gradio_chat(self, prompt, history, max_new_tokens, system, greedy):
    def prompt_llama2_model(message: str, chat_history,
                            system_prompt: str) -> str:

        do_strip = False
        texts = [f'<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n']
        for user_input, response in chat_history:
            user_input = user_input.strip() if do_strip else user_input
            do_strip = True
            texts.append(f'{user_input} [/INST] {response.strip()} </s><s>[INST] ')
        message = message.strip() if do_strip else message
        texts.append(f'{message} [/INST]')
        return "".join(texts)

    string = prompt_llama2_model(
        message=prompt,
        chat_history=history or [],
        system_prompt=system
    )
    if not self.config.stream_tokens_for_gradio:
        response = ""
        for response, _ in self.sample(
                string=string,
                greedy=greedy,
                max_new_tokens=max_new_tokens,
        ):
            ...
        history.append([prompt, response])
    else:
        history.append([prompt, ""])
        for response, _ in self.sample(
                string=string,
                greedy=greedy,
                max_new_tokens=max_new_tokens
        ):
            history[-1][-1] = response
            yield "", history

    return "", history

as you see you can easily override the functions just like how you want and use them with some simple changes, and you can Also Use Their Gradio Client or use JAXServer FastAPI builtin methods

FastAPI 🌪

Instruct API

to Override this api you have to code forward_instruct just like what you want the default implementation of this function is

def forward_instruct(self, data: InstructRequest):
    if not self._funcs_generated:
        return {
            'status': "down"
        }

    string = self.config.instruct_format.format(instruct=data.prompt, system=data.system)
    response, used_tokens = [None] * 2
    for response, used_tokens in self.sample(
            string=string,
            greedy=data.greedy,
            max_new_tokens=None
    ):
        ...
    self.number_of_served_request_until_last_up_time += 1
    return {
        'input': f'{string}',
        'response': response,
        'tokens_used': used_tokens,
    }
  • BaseModel Class For PYData in FastAPI :
class InstructRequest(BaseModel):
    prompt: str
    system: Optional[str] = None
    temperature: Optional[float] = None
    greedy: Optional[bool] = False
  • And here's an example of using this api via python and creating a simple client with using requests library in python :
import requests

content = {
    'prompt': 'can you code a simple neural network in c++ for me',
    'system': 'You are an AI assistant generate short and useful response',
    'temperature': 0.1,
    'greedy': False
}

response = requests.post(
    url='http://ip:port/instruct',
    json=content
).json()

print(response['response'])
# Response of model
print(response['input'])
# The input passed to the model

Chat API

to Override this api you have to code forward_chat just like what you want the default implementation of this function is

def forward_chat(self, data: ChatRequest):
    if not self._funcs_generated:
        return {
            'status': "down"
        }

    history = self.process_chat_history(data.history or [])
    history += self.config.prompt_prefix_chat + data.prompt + self.config.prompt_postfix_chat

    response, used_tokens = [None] * 2
    for response, used_tokens in self.process(
            string=history,
            greedy=data.greedy,
            max_new_tokens=None
    ):
        ...
    self.number_of_served_request_until_last_up_time += 1
    return {
        'input': f'{history}',
        'response': response,
        'tokens_used': used_tokens,
    }
  • BaseModel Class For PYData in FastAPI :
class ChatRequest(BaseModel):
    prompt: str
    history: Union[List[List], None] = None
    temperature: Optional[float] = None
    greedy: Optional[bool] = False
  • And here's an example of using this api via python and creating a simple client with using requests library in python :
import requests

content = {
    'prompt': 'can you code a simple neural network in c++ for me',
    'history': [
        ['hello how are you', 'Hello\nthanks, im here to assist you you have any question that i could help you with']
    ],
    'temperature': 0.1,
    'greedy': False
}

response = requests.post(
    url='http://ip:port/chat',
    json=content
).json()

print(response['response'])
# Response of model
print(response['input'])
# The input passed to the model

Status 📣

Simply by sending a get API to https://ip:port/status you will receive base information about the server and how it being run, num cores in use, number of generated prompt , number of request and ...