Source code for chemsource.classifier

"""
Chemical classification module for chemsource.

This module provides AI-powered classification functionality for chemical
entities and compounds.
"""

from typing import Optional, List, Union, Any
from openai import OpenAI
from spellchecker import SpellChecker


[docs] def classify(name: str, input_text: Optional[str] = None, api_key: Optional[str] = None, baseprompt: Optional[str] = None, model: str = 'gpt-4o', temperature: float = 0, top_p: float = 0, max_length: int = 250000, clean_output: bool = False, explanation: bool = False, explanation_separator: str = "EXPLANATION_COMPLETE", output_explanation: bool = False, allowed_categories: Optional[List[str]] = None, custom_client: Optional[Any] = None, spell_checker: Optional[SpellChecker] = None) -> Union[str, List[str]]: """ Classify a chemical compound using an AI language model. This function takes a chemical compound name and additional information, then uses an AI model to classify it into predefined categories. Args: name (str): The name of the chemical compound to classify. input_text (str, optional): Additional information about the compound. api_key (str, optional): API key for the language model service. baseprompt (str, optional): Base prompt template for classification. model (str, optional): Name of the language model to use. Defaults to 'gpt-4o'. temperature (float, optional): Temperature parameter for model creativity. Defaults to 0. top_p (float, optional): Top-p parameter for nucleus sampling. Defaults to 0. max_length (int, optional): Maximum length of the prompt in characters. Defaults to 250000. clean_output (bool, optional): Whether to clean and validate the output. Defaults to False. explanation (bool, optional): Whether to expect and extract explanations from the model response. Only used when clean_output=True. The model's response should contain an explanation followed by the separator, then the classification. Defaults to False. explanation_separator (str, optional): The delimiter string that separates the explanation from the classification in the model's response. Only used when both clean_output=True and explanation=True. Defaults to "EXPLANATION_COMPLETE". output_explanation (bool, optional): Whether to return the explanation text alongside classification. When True, returns a tuple (classification_list, explanation_text). Only used when both explanation=True and clean_output=True. Defaults to False. allowed_categories (List[str], optional): List of allowed categories for filtering output. custom_client (Any, optional): Custom OpenAI client instance. spell_checker (SpellChecker, optional): Spell checker instance for output correction. Returns: Union[str, List[str], Tuple[List[str], str]]: - If clean_output=False: Raw model output string - If clean_output=True and output_explanation=False: List of categories - If clean_output=True, explanation=True, and output_explanation=True: Tuple of (category_list, explanation_text) Raises: ValueError: If clean_output is True but allowed_categories is None, or if output_explanation=True but explanation=False. IndexError: If explanation=True but the explanation_separator is not found in the response. Example: >>> classify("aspirin", "pain relief medication", api_key="your_key") "MEDICAL" >>> classify("aspirin", "pain relief medication", api_key="your_key", ... clean_output=True, allowed_categories=["MEDICAL", "FOOD"]) ["MEDICAL"] >>> # Using explanation feature >>> custom_prompt = "Explain why, then say EXPLANATION_COMPLETE, then classify: ..." >>> classify("aspirin", "pain relief", api_key="your_key", baseprompt=custom_prompt, ... clean_output=True, explanation=True, ... allowed_categories=["MEDICAL", "FOOD"]) ["MEDICAL"] >>> # Getting both classification and explanation >>> categories, explanation = classify("aspirin", "pain relief", api_key="your_key", ... baseprompt=custom_prompt, clean_output=True, ... explanation=True, output_explanation=True, ... allowed_categories=["MEDICAL", "FOOD"]) >>> print(categories) # ["MEDICAL"] >>> print(explanation) # "Aspirin is widely used as a pain reliever..." """ if custom_client is not None: client = custom_client elif model == "deepseek-chat": client = OpenAI( api_key=api_key, base_url="https://api.deepseek.com" ) else: client = OpenAI( api_key=api_key ) if clean_output and allowed_categories is None: raise ValueError("If clean_output is True, a list in allowed_categories must be provided to filter the output.") if output_explanation and not explanation: raise ValueError("If output_explanation is True, explanation must also be True.") split_base = baseprompt.split("COMPOUND_NAME") prompt = split_base[0] + str(name) + split_base[1] + str(input_text) prompt = prompt.replace(name, name.lower()) prompt = prompt[:max_length] # Use user role for custom clients (like Gemini) that may not support system messages message_role = "user" if custom_client is not None else "system" response = client.chat.completions.create( model=model, messages=[{"role": message_role, "content": prompt}], temperature=temperature, top_p=top_p, stream=False ) if not clean_output: return response.choices[0].message.content else: cleaned_response_string = response.choices[0].message.content.replace("\n", " ").replace(" ", " ").strip() if explanation: # Split by separator and extract classification part parts = cleaned_response_string.split(explanation_separator) if len(parts) < 2: raise ValueError( f"Explanation separator '{explanation_separator}' not found in model response. " f"When explanation=True, the model must include the separator in its response. " f"Response received: {cleaned_response_string[:200]}..." ) # Take everything after the first occurrence of the separator cleaned_response_string = parts[1].strip() cleaned_explanation = parts[0].strip() classification_list = cleaned_response_string.split(",") classification_list = [item.strip().replace(" ", " ") for item in classification_list] if allowed_categories is not None: updated_classification_list = [] for item in classification_list: if spell_checker is not None: updated_item = spell_checker.correction(item) if updated_item in allowed_categories: updated_classification_list.append(updated_item) else: # Fallback to original item if no spell checker provided if item.upper() in [cat.upper() for cat in allowed_categories]: updated_classification_list.append(item) if explanation and output_explanation: return updated_classification_list, cleaned_explanation else: return updated_classification_list else: if explanation and output_explanation: return classification_list, cleaned_explanation else: return classification_list