"""
Chemical classification module for chemsource.
This module provides AI-powered classification functionality for chemical
entities and compounds.
"""
from typing import Optional, List, Union, Any
from openai import OpenAI
from spellchecker import SpellChecker
[docs]
def classify(name: str,
input_text: Optional[str] = None,
api_key: Optional[str] = None,
baseprompt: Optional[str] = None,
model: str = 'gpt-4o',
temperature: float = 0,
top_p: float = 0,
max_length: int = 250000,
clean_output: bool = False,
explanation: bool = False,
explanation_separator: str = "EXPLANATION_COMPLETE",
output_explanation: bool = False,
allowed_categories: Optional[List[str]] = None,
custom_client: Optional[Any] = None,
spell_checker: Optional[SpellChecker] = None) -> Union[str, List[str]]:
"""
Classify a chemical compound using an AI language model.
This function takes a chemical compound name and additional information,
then uses an AI model to classify it into predefined categories.
Args:
name (str): The name of the chemical compound to classify.
input_text (str, optional): Additional information about the compound.
api_key (str, optional): API key for the language model service.
baseprompt (str, optional): Base prompt template for classification.
model (str, optional): Name of the language model to use. Defaults to 'gpt-4o'.
temperature (float, optional): Temperature parameter for model creativity. Defaults to 0.
top_p (float, optional): Top-p parameter for nucleus sampling. Defaults to 0.
max_length (int, optional): Maximum length of the prompt in characters. Defaults to 250000.
clean_output (bool, optional): Whether to clean and validate the output. Defaults to False.
explanation (bool, optional): Whether to expect and extract explanations from the model response.
Only used when clean_output=True. The model's response should contain
an explanation followed by the separator, then the classification.
Defaults to False.
explanation_separator (str, optional): The delimiter string that separates the explanation from
the classification in the model's response. Only used when
both clean_output=True and explanation=True.
Defaults to "EXPLANATION_COMPLETE".
output_explanation (bool, optional): Whether to return the explanation text alongside classification.
When True, returns a tuple (classification_list, explanation_text).
Only used when both explanation=True and clean_output=True.
Defaults to False.
allowed_categories (List[str], optional): List of allowed categories for filtering output.
custom_client (Any, optional): Custom OpenAI client instance.
spell_checker (SpellChecker, optional): Spell checker instance for output correction.
Returns:
Union[str, List[str], Tuple[List[str], str]]:
- If clean_output=False: Raw model output string
- If clean_output=True and output_explanation=False: List of categories
- If clean_output=True, explanation=True, and output_explanation=True:
Tuple of (category_list, explanation_text)
Raises:
ValueError: If clean_output is True but allowed_categories is None, or if
output_explanation=True but explanation=False.
IndexError: If explanation=True but the explanation_separator is not found in the response.
Example:
>>> classify("aspirin", "pain relief medication", api_key="your_key")
"MEDICAL"
>>> classify("aspirin", "pain relief medication", api_key="your_key",
... clean_output=True, allowed_categories=["MEDICAL", "FOOD"])
["MEDICAL"]
>>> # Using explanation feature
>>> custom_prompt = "Explain why, then say EXPLANATION_COMPLETE, then classify: ..."
>>> classify("aspirin", "pain relief", api_key="your_key", baseprompt=custom_prompt,
... clean_output=True, explanation=True,
... allowed_categories=["MEDICAL", "FOOD"])
["MEDICAL"]
>>> # Getting both classification and explanation
>>> categories, explanation = classify("aspirin", "pain relief", api_key="your_key",
... baseprompt=custom_prompt, clean_output=True,
... explanation=True, output_explanation=True,
... allowed_categories=["MEDICAL", "FOOD"])
>>> print(categories) # ["MEDICAL"]
>>> print(explanation) # "Aspirin is widely used as a pain reliever..."
"""
if custom_client is not None:
client = custom_client
elif model == "deepseek-chat":
client = OpenAI(
api_key=api_key,
base_url="https://api.deepseek.com"
)
else:
client = OpenAI(
api_key=api_key
)
if clean_output and allowed_categories is None:
raise ValueError("If clean_output is True, a list in allowed_categories must be provided to filter the output.")
if output_explanation and not explanation:
raise ValueError("If output_explanation is True, explanation must also be True.")
split_base = baseprompt.split("COMPOUND_NAME")
prompt = split_base[0] + str(name) + split_base[1] + str(input_text)
prompt = prompt.replace(name, name.lower())
prompt = prompt[:max_length]
# Use user role for custom clients (like Gemini) that may not support system messages
message_role = "user" if custom_client is not None else "system"
response = client.chat.completions.create(
model=model,
messages=[{"role": message_role, "content": prompt}],
temperature=temperature,
top_p=top_p,
stream=False
)
if not clean_output:
return response.choices[0].message.content
else:
cleaned_response_string = response.choices[0].message.content.replace("\n", " ").replace(" ", " ").strip()
if explanation:
# Split by separator and extract classification part
parts = cleaned_response_string.split(explanation_separator)
if len(parts) < 2:
raise ValueError(
f"Explanation separator '{explanation_separator}' not found in model response. "
f"When explanation=True, the model must include the separator in its response. "
f"Response received: {cleaned_response_string[:200]}..."
)
# Take everything after the first occurrence of the separator
cleaned_response_string = parts[1].strip()
cleaned_explanation = parts[0].strip()
classification_list = cleaned_response_string.split(",")
classification_list = [item.strip().replace(" ", " ") for item in classification_list]
if allowed_categories is not None:
updated_classification_list = []
for item in classification_list:
if spell_checker is not None:
updated_item = spell_checker.correction(item)
if updated_item in allowed_categories:
updated_classification_list.append(updated_item)
else:
# Fallback to original item if no spell checker provided
if item.upper() in [cat.upper() for cat in allowed_categories]:
updated_classification_list.append(item)
if explanation and output_explanation:
return updated_classification_list, cleaned_explanation
else:
return updated_classification_list
else:
if explanation and output_explanation:
return classification_list, cleaned_explanation
else:
return classification_list