stablemed / README.md
Tonic's picture
Update README.md
d0a0f49
metadata
library_name: peft
base_model: stabilityai/stablelm-3b-4e1t
license: mit
language:
  - en
metrics:
  - bleu
  - bertscore
  - accuracy
tags:
  - medical

Model Card for Model ID

Welcome to StableMed , it's a stable 3b llm - alpha fine tuned model for Medical Question and Answering.

Model Details

Model Description

This is a stable 3b finetune for medical QnA using MedQuad. It's intended for education in public health and sanitation, specifically to improve our understanding of outreach and communication.

Model Sources [optional]

Uses

Use this model for educational purposes only , do not use for decision support in the wild.

Use this model for Medical Q n A.

Use this model as a educational tool for "miniature" models.

Direct Use

Medical Question and Answering

Downstream Use [optional]

Finetune this model to work in a network or swarm of medical finetunes.

Out-of-Scope Use

do not use this model in the wild.

do not use this model directly.

do not use this model for real world decision support.

Bias, Risks, and Limitations

[We use Giskard for evaluation - Coming Soon!]

Recommendations

Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model.

DO NOT USE THIS MODEL WITHOUT EVALUATION

DO NOT USE THIS MODEL WITHOUT BENCHMARKING

DO NOT USE THIS MODEL WITHOUT FURTHER FINETUNING

How to Get Started with the Model

Use the code below to get started with the model.

from transformers import AutoTokenizer, MistralForCausalLM
import torch
import gradio as gr
import random
from textwrap import wrap
from transformers import AutoConfig, AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM, MistralForCausalLM
from peft import PeftModel, PeftConfig
import torch
import gradio as gr
import os

hf_token = os.environ.get('HUGGINGFACE_TOKEN')

# Functions to Wrap the Prompt Correctly
def wrap_text(text, width=90):
    lines = text.split('\n')
    wrapped_lines = [textwrap.fill(line, width=width) for line in lines]
    wrapped_text = '\n'.join(wrapped_lines)
    return wrapped_text
def multimodal_prompt(user_input, system_prompt="You are an expert medical analyst:"):

    # Combine user input and system prompt
    formatted_input = f"[INSTRUCTION]{system_prompt}[QUESTION]{user_input}"

    # Encode the input text
    encodeds = tokenizer(formatted_input, return_tensors="pt", add_special_tokens=False)
    model_inputs = encodeds.to(device)

    # Generate a response using the model
    output = model.generate(
        **model_inputs,
        max_length=max_length,
        use_cache=True,
        early_stopping=True,
        bos_token_id=model.config.bos_token_id,
        eos_token_id=model.config.eos_token_id,
        pad_token_id=model.config.eos_token_id,
        temperature=0.1,
        do_sample=True
    )

    # Decode the response
    response_text = tokenizer.decode(output[0], skip_special_tokens=True)

    return response_text

# Define the device
device = "cuda" if torch.cuda.is_available() else "cpu"

# Use the base model's ID
base_model_id = "stabilityai/stablelm-3b-4e1t"
model_directory = "Tonic/stablemed"

# Instantiate the Tokenizer
tokenizer = AutoTokenizer.from_pretrained("stabilityai/stablelm-3b-4e1t", trust_remote_code=True, padding_side="left")
# tokenizer = AutoTokenizer.from_pretrained("Tonic/stablemed", trust_remote_code=True, padding_side="left")
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = 'left'

# Load the PEFT model
peft_config = PeftConfig.from_pretrained("Tonic/stablemed", token=hf_token)
peft_model = MistralForCausalLM.from_pretrained("stabilityai/stablelm-3b-4e1t", trust_remote_code=True)
peft_model = PeftModel.from_pretrained(peft_model, "Tonic/stablemed", token=hf_token)

class ChatBot:
    def __init__(self):
        self.history = []

    def predict(self, user_input, system_prompt="You are an expert medical analyst:"):
        # Combine user input and system prompt
        formatted_input = f"[INSTRUCTION:]{system_prompt}[QUESTION:] {user_input}"

        # Encode user input
        user_input_ids = tokenizer.encode(formatted_input, return_tensors="pt")

        # Concatenate the user input with chat history
        if len(self.history) > 0:
            chat_history_ids = torch.cat([self.history, user_input_ids], dim=-1)
        else:
            chat_history_ids = user_input_ids

        # Generate a response using the PEFT model
        response = peft_model.generate(input_ids=chat_history_ids, max_length=400, pad_token_id=tokenizer.eos_token_id)

        # Update chat history
        self.history = chat_history_ids

        # Decode and return the response
        response_text = tokenizer.decode(response[0], skip_special_tokens=True)
        return response_text

bot = ChatBot()

title = "👋🏻Welcome to Tonic's StableMed Chat🚀"
description = """
You can use this Space to test out the current model [StableMed](https://huggingface.co/Tonic/stablemed) or You can also use 😷StableMed⚕️ on your own data & in your own way by cloning this space. 🧬🔬🔍 Simply click here: <a style="display:inline-block" href="https://huggingface.co/spaces/Tonic/StableMed_Chat?duplicate=true"><img src="https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAABAAAAAQCAYAAAAf8/9hAAAAAXNSR0IArs4c6QAAAP5JREFUOE+lk7FqAkEURY+ltunEgFXS2sZGIbXfEPdLlnxJyDdYB62sbbUKpLbVNhyYFzbrrA74YJlh9r079973psed0cvUD4A+4HoCjsA85X0Dfn/RBLBgBDxnQPfAEJgBY+A9gALA4tcbamSzS4xq4FOQAJgCDwV2CPKV8tZAJcAjMMkUe1vX+U+SMhfAJEHasQIWmXNN3abzDwHUrgcRGmYcgKe0bxrblHEB4E/pndMazNpSZGcsZdBlYJcEL9Afo75molJyM2FxmPgmgPqlWNLGfwZGG6UiyEvLzHYDmoPkDDiNm9JR9uboiONcBXrpY1qmgs21x1QwyZcpvxt9NS09PlsPAAAAAElFTkSuQmCC&logoWidth=14" alt="Duplicate Space"></a></h3> 
# Join us : 🌟TeamTonic🌟 is always making cool demos! Join our active builder's🛠️community on 👻Discord: [Discord](https://discord.gg/GWpVpekp) On 🤗Huggingface: [TeamTonic](https://huggingface.co/TeamTonic) & [MultiTransformer](https://huggingface.co/MultiTransformer) On 🌐Github: [Polytonic](https://github.com/tonic-ai) & contribute to 🌟 [PolyGPT](https://github.com/tonic-ai/polygpt-alpha)
"""
examples = [["What is the proper treatment for buccal herpes?", "Please provide information on the most effective antiviral medications and home remedies for treating buccal herpes."]]

iface = gr.Interface(
    fn=bot.predict,
    title=title,
    description=description,
    examples=examples,
    inputs=["text", "text"],  # Take user input and system prompt separately
    outputs="text",
    theme="ParityError/Anime"
)

iface.launch()

Training Details

Training Data

Dataset

output
Dataset({
    features: ['qtype', 'Question', 'Answer'],
    num_rows: 16407
})

Training Procedure

trainable params: 12940288 || all params: 1539606528 || trainable%: 0.8404931886596937

Using Lora

Preprocessing [optional]

Original Model Configuration:

StableLMEpochForCausalLM(
  (model): StableLMEpochModel(
    (embed_tokens): Embedding(50304, 2560)
    (layers): ModuleList(
      (0-31): 32 x DecoderLayer(
        (self_attn): Attention(
          (q_proj): Linear4bit(in_features=2560, out_features=2560, bias=False)
          (k_proj): Linear4bit(in_features=2560, out_features=2560, bias=False)
          (v_proj): Linear4bit(in_features=2560, out_features=2560, bias=False)
          (o_proj): Linear4bit(in_features=2560, out_features=2560, bias=False)
          (rotary_emb): RotaryEmbedding()
        )
        (mlp): MLP(
          (gate_proj): Linear4bit(in_features=2560, out_features=6912, bias=False)
          (up_proj): Linear4bit(in_features=2560, out_features=6912, bias=False)
          (down_proj): Linear4bit(in_features=6912, out_features=2560, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LayerNorm((2560,), eps=1e-05, elementwise_affine=True)
        (post_attention_layernorm): LayerNorm((2560,), eps=1e-05, elementwise_affine=True)
      )
    )
    (norm): LayerNorm((2560,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=2560, out_features=50304, bias=False)
)

Data Formatting :

Given a target sentence construct the underlying meaning representation of the input sentence as a single function with attributes and attribute values.
This function should describe the target string accurately and the function must be one of the following ['inform', 'request', 'give_opinion', 'confirm', 'verify_attribute', 'suggest', 'request_explanation', 'recommend', 'request_attribute'].
The attributes must be one of the following: ['name', 'pathology', 'therapeutic', 'dosage', 'side_effects', 'contraindications', 'manufacturer', 'price', 'availability', 'administration', 'warnings', 'interactions', 'storage', 'expiration_date', 'formulation', 'strength', 'route_of_administration', 'class', 'prescription_required', 'generic_name', 'brand_name', 'patient_instructions']

Training Hyperparameters

  • Training regime:

Speeds, Sizes, Times [optional]

TrainOutput(global_step=2051, training_loss=0.6156479549198718, metrics={'train_runtime': 22971.4974, 'train_samples_per_second': 0.357, 'train_steps_per_second': 0.089, 'total_flos': 6.5950444363776e+16, 'train_loss': 0.6156479549198718, 'epoch': 0.5})

Results

Value Measurement
50 1.427000
100 0.763200
150 0.708200
200 0.662300
250 0.650900
300 0.617400
350 0.602900
400 0.608900
450 0.596100
500 0.602000
550 0.594700
600 0.584700
650 0.611000
700 0.558700
750 0.616300
800 0.568700
850 0.597300
900 0.607400
950 0.563200
1000 0.602900
1050 0.594900
1100 0.583000
1150 0.604500
1200 0.547400
1250 0.586600
1300 0.554300
1350 0.581000
1400 0.578900
1450 0.563200
1500 0.556800
1550 0.570300
1600 0.599800
1650 0.556000
1700 0.592500
1750 0.597200
1800 0.559100
1850 0.586100
1900 0.581100
1950 0.589400
2000 0.581100
2050 0.533100

Environmental Impact

Carbon emissions can be estimated using the Machine Learning Impact calculator presented in Lacoste et al. (2019).

  • Hardware Type: [More Information Needed]
  • Hours used: [More Information Needed]
  • Cloud Provider: [More Information Needed]
  • Compute Region: [More Information Needed]
  • Carbon Emitted: [More Information Needed]

Technical Specifications [optional]

Model Architecture and Objective

with LORA :

PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): StableLMEpochForCausalLM(
      (model): StableLMEpochModel(
        (embed_tokens): Embedding(50304, 2560)
        (layers): ModuleList(
          (0-31): 32 x DecoderLayer(
            (self_attn): Attention(
              (q_proj): Linear4bit(
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=2560, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=8, out_features=2560, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (base_layer): Linear4bit(in_features=2560, out_features=2560, bias=False)
              )
              (k_proj): Linear4bit(
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=2560, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=8, out_features=2560, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (base_layer): Linear4bit(in_features=2560, out_features=2560, bias=False)
              )
              (v_proj): Linear4bit(
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=2560, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=8, out_features=2560, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (base_layer): Linear4bit(in_features=2560, out_features=2560, bias=False)
              )
              (o_proj): Linear4bit(
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=2560, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=8, out_features=2560, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (base_layer): Linear4bit(in_features=2560, out_features=2560, bias=False)
              )
              (rotary_emb): RotaryEmbedding()
            )
            (mlp): MLP(
              (gate_proj): Linear4bit(
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=2560, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=8, out_features=6912, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (base_layer): Linear4bit(in_features=2560, out_features=6912, bias=False)
              )
              (up_proj): Linear4bit(
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=2560, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=8, out_features=6912, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (base_layer): Linear4bit(in_features=2560, out_features=6912, bias=False)
              )
              (down_proj): Linear4bit(
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=6912, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=8, out_features=2560, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (base_layer): Linear4bit(in_features=6912, out_features=2560, bias=False)
              )
              (act_fn): SiLU()
            )
            (input_layernorm): LayerNorm((2560,), eps=1e-05, elementwise_affine=True)
            (post_attention_layernorm): LayerNorm((2560,), eps=1e-05, elementwise_affine=True)
          )
        )
        (norm): LayerNorm((2560,), eps=1e-05, elementwise_affine=True)
      )
      (lm_head): Linear(
        in_features=2560, out_features=50304, bias=False
        (lora_dropout): ModuleDict(
          (default): Dropout(p=0.05, inplace=False)
        )
        (lora_A): ModuleDict(
          (default): Linear(in_features=2560, out_features=8, bias=False)
        )
        (lora_B): ModuleDict(
          (default): Linear(in_features=8, out_features=50304, bias=False)
        )
        (lora_embedding_A): ParameterDict()
        (lora_embedding_B): ParameterDict()
      )
    )
  )
)

Compute Infrastructure

GCS

Hardware

T4

Software

transformers peft torch datasets

Model Card Authors [optional]

Tonic

Model Card Contact

Tonic

Training procedure

The following bitsandbytes quantization config was used during training:

  • quant_method: bitsandbytes
  • load_in_8bit: False
  • load_in_4bit: True
  • llm_int8_threshold: 6.0
  • llm_int8_skip_modules: None
  • llm_int8_enable_fp32_cpu_offload: False
  • llm_int8_has_fp16_weight: False
  • bnb_4bit_quant_type: nf4
  • bnb_4bit_use_double_quant: True
  • bnb_4bit_compute_dtype: bfloat16

Framework versions

  • PEFT 0.6.2.dev0