diff --git a/examples/configs/content_safety_multilingual/config.yml b/examples/configs/content_safety_multilingual/config.yml new file mode 100644 index 000000000..c525aed99 --- /dev/null +++ b/examples/configs/content_safety_multilingual/config.yml @@ -0,0 +1,30 @@ +models: + - type: main + engine: nim + model: meta/llama-3.3-70b-instruct + + - type: content_safety + engine: nim + model: nvidia/llama-3.1-nemotron-safety-guard-8b-v3 + +rails: + config: + content_safety: + multilingual: + enabled: true + # Optional: customize refusal messages per language + # If not specified, built-in defaults are used + # refusal_messages: + # en: "Sorry, I cannot help with that request." + # es: "Lo siento, no puedo ayudar con esa solicitud." + # zh: "抱歉,我无法处理该请求。" + + input: + flows: + - content safety check input $model=content_safety + + output: + flows: + - content safety check output $model=content_safety + + # 你很笨 (nǐ hěn bèn) — “You are stupid / you are foolish.” diff --git a/examples/configs/content_safety_multilingual/prompts.yml b/examples/configs/content_safety_multilingual/prompts.yml new file mode 100644 index 000000000..595dd4c22 --- /dev/null +++ b/examples/configs/content_safety_multilingual/prompts.yml @@ -0,0 +1,90 @@ +prompts: + - task: content_safety_check_input $model=content_safety + content: | + Task: Check if there is unsafe content in the user or agent messages in the below conversation in accordance with our safety policy with the below categories. Provide a 'safe' or 'unsafe' rating for both the user message and the agent response. If there is unsafe content, please also provide a list of violated categories according to our safety policy below. + + + S1: Violence. + S2: Sexual. + S3: Criminal Planning/Confessions. + S4: Guns and Illegal Weapons. + S5: Controlled/Regulated Substances. + S6: Suicide and Self Harm. + S7: Sexual (minor). + S8: Hate/Identity Hate. + S9: PII/Privacy. + S10: Harassment. + S11: Threat. + S12: Profanity. + S13: Needs Caution. + S14: Other. + S15: Manipulation. + S16: Fraud/Deception. + S17: Malware. + S18: High Risk Gov Decision Making. + S19: Political/Misinformation/Conspiracy. + S20: Copyright/Trademark/Plagiarism. + S21: Unauthorized Advice. + S22: Illegal Activity. + S23: Immoral/Unethical. + + + + + user: {{ user_input }} + + + + Provide your safety assessment for the above conversation in the below output JSON format: + {"User Safety": safe or unsafe rating for the user message, "Response Safety": safe or unsafe rating for the agent response. Omit if no agent response present. "Safety Categories": a comma-separated list of applicable safety categories from the provided taxonomy. Omit if all safe.} + + Do not include anything other than the output JSON in your response. + Output JSON: + output_parser: nemoguard_parse_prompt_safety + max_tokens: 50 + + - task: content_safety_check_output $model=content_safety + content: | + Task: Check if there is unsafe content in the user or agent messages in the below conversation in accordance with our safety policy with the below categories. Provide a 'safe' or 'unsafe' rating for both the user message and the agent response. If there is unsafe content, please also provide a list of violated categories according to our safety policy below. + + + S1: Violence. + S2: Sexual. + S3: Criminal Planning/Confessions. + S4: Guns and Illegal Weapons. + S5: Controlled/Regulated Substances. + S6: Suicide and Self Harm. + S7: Sexual (minor). + S8: Hate/Identity Hate. + S9: PII/Privacy. + S10: Harassment. + S11: Threat. + S12: Profanity. + S13: Needs Caution. + S14: Other. + S15: Manipulation. + S16: Fraud/Deception. + S17: Malware. + S18: High Risk Gov Decision Making. + S19: Political/Misinformation/Conspiracy. + S20: Copyright/Trademark/Plagiarism. + S21: Unauthorized Advice. + S22: Illegal Activity. + S23: Immoral/Unethical. + + + + + user: {{ user_input }} + + response: agent: {{ bot_response }} + + + + Provide your safety assessment for the above conversation in the below output JSON format: + {"User Safety": safe or unsafe rating for the user message, "Response Safety": safe or unsafe rating for the agent response. Omit if no agent response present. "Safety Categories": a comma-separated list of applicable safety categories from the provided taxonomy. Omit if all safe.} + + Do not include anything other than the output JSON in your response. + Output JSON: + output_parser: nemoguard_parse_response_safety + max_tokens: 50 diff --git a/nemoguardrails/library/content_safety/actions.py b/nemoguardrails/library/content_safety/actions.py index f9f376236..7057d81a5 100644 --- a/nemoguardrails/library/content_safety/actions.py +++ b/nemoguardrails/library/content_safety/actions.py @@ -14,7 +14,7 @@ # limitations under the License. import logging -from typing import Dict, Optional +from typing import Dict, FrozenSet, Optional from langchain_core.language_models import BaseLLM @@ -31,6 +31,7 @@ ) from nemoguardrails.llm.taskmanager import LLMTaskManager from nemoguardrails.logging.explain import LLMCallInfo +from nemoguardrails.rails.llm.config import RailsConfig log = logging.getLogger(__name__) @@ -220,3 +221,105 @@ async def content_safety_check_output( log.debug(f"Content safety output result cached for model '{model_name}'") return final_result + + +SUPPORTED_LANGUAGES: FrozenSet[str] = frozenset({"en", "es", "zh", "de", "fr", "hi", "ja", "ar", "th"}) + +DEFAULT_REFUSAL_MESSAGES: Dict[str, str] = { + "en": "I'm sorry, I can't respond to that.", + "es": "Lo siento, no puedo responder a eso.", + "zh": "抱歉,我无法回应。", + "de": "Es tut mir leid, darauf kann ich nicht antworten.", + "fr": "Je suis désolé, je ne peux pas répondre à cela.", + "hi": "मुझे खेद है, मैं इसका जवाब नहीं दे सकता।", + "ja": "申し訳ありませんが、それには回答できません。", + "ar": "عذراً، لا أستطيع الرد على ذلك.", + "th": "ขออภัย ฉันไม่สามารถตอบได้", +} + + +def _detect_language( + text: str, + max_text_length: Optional[int] = None, + normalize_text: bool = True, + cache_dir: Optional[str] = None, +) -> Optional[str]: + try: + from fast_langdetect import LangDetectConfig, LangDetector + + config = LangDetectConfig( + max_input_length=max_text_length, + normalize_input=normalize_text, + cache_dir=cache_dir, + ) + detector = LangDetector(config) + result = detector.detect(text, k=1) + if result and len(result) > 0: + return result[0].get("lang") + return None + except ImportError: + log.warning("fast-langdetect not installed, skipping") + return None + except Exception as e: + log.warning(f"fast-langdetect detection failed: {e}") + return None + + +def _get_refusal_message(lang: str, custom_messages: Optional[Dict[str, str]]) -> str: + if custom_messages and lang in custom_messages: + return custom_messages[lang] + if lang in DEFAULT_REFUSAL_MESSAGES: + return DEFAULT_REFUSAL_MESSAGES[lang] + if custom_messages and "en" in custom_messages: + return custom_messages["en"] + return DEFAULT_REFUSAL_MESSAGES["en"] + + +@action() +async def detect_language( + context: Optional[dict] = None, + config: Optional[RailsConfig] = None, +) -> dict: + user_message = "" + if context is not None: + user_message = context.get("user_message", "") + + custom_messages = None + max_text_length = None + normalize_text = True + cache_dir = None + + if config is not None: + multilingual_config = ( + config.rails.config.content_safety.multilingual + if hasattr(config, "rails") + and hasattr(config.rails, "config") + and hasattr(config.rails.config, "content_safety") + and hasattr(config.rails.config.content_safety, "multilingual") + else None + ) + if multilingual_config: + custom_messages = multilingual_config.refusal_messages + max_text_length = multilingual_config.max_text_length + normalize_text = multilingual_config.normalize_text + cache_dir = multilingual_config.cache_dir + + lang = ( + _detect_language( + user_message, + max_text_length=max_text_length, + normalize_text=normalize_text, + cache_dir=cache_dir, + ) + or "en" + ) + + if lang not in SUPPORTED_LANGUAGES: + lang = "en" + + refusal_message = _get_refusal_message(lang, custom_messages) + + return { + "language": lang, + "refusal_message": refusal_message, + } diff --git a/nemoguardrails/library/content_safety/flows.co b/nemoguardrails/library/content_safety/flows.co index a4c411e8f..4e5a5e2ff 100644 --- a/nemoguardrails/library/content_safety/flows.co +++ b/nemoguardrails/library/content_safety/flows.co @@ -3,7 +3,6 @@ flow content safety check input $model global $allowed $allowed = $response["allowed"] - # Policy violations are currently unused, but can be used to better phrase the bot output global $policy_violations $policy_violations = $response["policy_violations"] @@ -11,7 +10,12 @@ flow content safety check input $model if $system.config.enable_rails_exceptions send ContentSafetyCheckInputException(message="Input not allowed. The input was blocked by the 'content safety check input $model='{$model}'' flow.") else - bot refuse to respond + if $system.config.rails.config.content_safety.multilingual.enabled + $lang_result = await DetectLanguageAction() + $refusal_message = $lang_result["refusal_message"] + bot $refusal_message + else + bot refuse to respond abort flow content safety check output $model @@ -25,5 +29,10 @@ flow content safety check output $model if $system.config.enable_rails_exceptions send ContentSafetyCheckOuputException(message="Output not allowed. The output was blocked by the 'content safety check output $model='{$model}'' flow.") else - bot refuse to respond + if $system.config.rails.config.content_safety.multilingual.enabled + $lang_result = await DetectLanguageAction() + $refusal_message = $lang_result["refusal_message"] + bot $refusal_message + else + bot refuse to respond abort diff --git a/nemoguardrails/library/content_safety/flows.v1.co b/nemoguardrails/library/content_safety/flows.v1.co index 37aae4b77..32c48c3c7 100644 --- a/nemoguardrails/library/content_safety/flows.v1.co +++ b/nemoguardrails/library/content_safety/flows.v1.co @@ -1,5 +1,3 @@ -# TODO: add support for enable_rails_exceptions - define bot refuse to respond "I'm sorry, I can't respond to that." @@ -7,14 +5,18 @@ define flow content safety check input $response = execute content_safety_check_input $allowed = $response["allowed"] - # Policy violations are currently unused, but can be used to better phrase the bot output $policy_violations = $response["policy_violations"] if not $allowed if $config.enable_rails_exceptions create event ContentSafetyCheckInputException(message="Input not allowed. The input was blocked by the 'content safety check input $model='{$model}'' flow.") else - bot refuse to respond + if $config.rails.config.content_safety.multilingual.enabled + $lang_result = execute detect_language + $refusal_message = $lang_result["refusal_message"] + bot $refusal_message + else + bot refuse to respond stop define flow content safety check output @@ -26,5 +28,10 @@ define flow content safety check output if $config.enable_rails_exceptions create event ContentSafetyCheckOuputException(message="Output not allowed. The output was blocked by the 'content safety check output $model='{$model}'' flow.") else - bot refuse to respond + if $config.rails.config.content_safety.multilingual.enabled + $lang_result = execute detect_language + $refusal_message = $lang_result["refusal_message"] + bot $refusal_message + else + bot refuse to respond stop diff --git a/nemoguardrails/rails/llm/config.py b/nemoguardrails/rails/llm/config.py index c3909fafa..b41b0ef05 100644 --- a/nemoguardrails/rails/llm/config.py +++ b/nemoguardrails/rails/llm/config.py @@ -891,6 +891,47 @@ class AIDefenseRailConfig(BaseModel): ) +class MultilingualConfig(BaseModel): + """Configuration for multilingual refusal messages.""" + + enabled: bool = Field( + default=False, + description="If True, detect the language of user input and return refusal messages in the same language. " + "Supported languages: en (English), es (Spanish), zh (Chinese), de (German), fr (French), " + "hi (Hindi), ja (Japanese), ar (Arabic), th (Thai).", + ) + refusal_messages: Optional[Dict[str, str]] = Field( + default=None, + description="Custom refusal messages per language code. " + "If not specified, built-in defaults are used. " + "Example: {'en': 'Sorry, I cannot help.', 'es': 'Lo siento, no puedo ayudar.'}", + ) + max_text_length: Optional[int] = Field( + default=None, + description="Maximum text length for language detection. Text longer than this will be truncated. " + "If not specified, uses the library default (80 characters).", + ) + normalize_text: bool = Field( + default=True, + description="If True, normalize input text before language detection " + "(e.g., lowercase uppercase text to prevent misdetection).", + ) + cache_dir: Optional[str] = Field( + default=None, + description="Directory for storing downloaded language detection models. " + "If not specified, uses the system default cache location.", + ) + + +class ContentSafetyConfig(BaseModel): + """Configuration data for content safety rails.""" + + multilingual: MultilingualConfig = Field( + default_factory=MultilingualConfig, + description="Configuration for multilingual refusal messages.", + ) + + class RailsConfigData(BaseModel): """Configuration data for specific rails that are supported out-of-the-box.""" @@ -959,6 +1000,11 @@ class RailsConfigData(BaseModel): description="Configuration for Cisco AI Defense.", ) + content_safety: Optional[ContentSafetyConfig] = Field( + default_factory=ContentSafetyConfig, + description="Configuration for content safety rails.", + ) + class Rails(BaseModel): """Configuration of specific rails.""" diff --git a/poetry.lock b/poetry.lock index 62b095452..f9976e9aa 100644 --- a/poetry.lock +++ b/poetry.lock @@ -676,6 +676,23 @@ humanfriendly = ">=9.1" [package.extras] cron = ["capturer (>=2.4)"] +[[package]] +name = "colorlog" +version = "6.10.1" +description = "Add colours to the output of Python's logging module." +optional = false +python-versions = ">=3.6" +files = [ + {file = "colorlog-6.10.1-py3-none-any.whl", hash = "sha256:2d7e8348291948af66122cff006c9f8da6255d224e7cf8e37d8de2df3bad8c9c"}, + {file = "colorlog-6.10.1.tar.gz", hash = "sha256:eb4ae5cb65fe7fec7773c2306061a8e63e02efc2c72eba9d27b0fa23c94f1321"}, +] + +[package.dependencies] +colorama = {version = "*", markers = "sys_platform == \"win32\""} + +[package.extras] +development = ["black", "flake8", "mypy", "pytest", "types-colorama"] + [[package]] name = "confection" version = "0.1.5" @@ -953,6 +970,22 @@ typing-extensions = {version = ">=4.6.0", markers = "python_version < \"3.13\""} [package.extras] test = ["pytest (>=6)"] +[[package]] +name = "fast-langdetect" +version = "1.0.0" +description = "Quickly detect text language and segment language" +optional = false +python-versions = ">=3.9" +files = [ + {file = "fast_langdetect-1.0.0-py3-none-any.whl", hash = "sha256:aab9e3435cc667ac8ba8b1a38872f75492f65b7087901d0f3a02a88d436cd22a"}, + {file = "fast_langdetect-1.0.0.tar.gz", hash = "sha256:ea8ac6a8914e0ff1bfc1bbc0f25992eb913ddb69e63ea1b24e907e263d0cd113"}, +] + +[package.dependencies] +fasttext-predict = ">=0.9.2.4" +requests = ">=2.32.3" +robust-downloader = ">=0.0.2" + [[package]] name = "fastapi" version = "0.121.0" @@ -1005,6 +1038,92 @@ requests = ">=2.31,<3.0" tokenizers = ">=0.15,<1.0" tqdm = ">=4.66,<5.0" +[[package]] +name = "fasttext-predict" +version = "0.9.2.4" +description = "fasttext with wheels and no external dependency, but only the predict method (<1MB)" +optional = false +python-versions = "*" +files = [ + {file = "fasttext_predict-0.9.2.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:ba432f33228928df5f2af6dfa50560cd77f9859914cffd652303fb02ba100456"}, + {file = "fasttext_predict-0.9.2.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:6a8e8f17eb894d450168d2590e23d809e845bd4fad5e39b5708dacb2fdb9b2c7"}, + {file = "fasttext_predict-0.9.2.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:19565fdf0bb9427831cfc75fca736ab9d71ba7ce02e3ea951e5839beb66560b6"}, + {file = "fasttext_predict-0.9.2.4-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:cb6986815506e3261c0b3f6227dce49eeb4fd3422dab9cd37e2db2fb3691c68b"}, + {file = "fasttext_predict-0.9.2.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:229dfdf8943dd76231206c7c9179e3f99d45879e5b654626ee7b73b7fa495d53"}, + {file = "fasttext_predict-0.9.2.4-cp310-cp310-manylinux_2_31_armv7l.whl", hash = "sha256:397016ebfa9ec06d6dba09c29e295eea583ea3f45fa4592cc832b257dc84522e"}, + {file = "fasttext_predict-0.9.2.4-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:fc93f9f8f7e982eb635bc860688be04f355fab3d76a243037e26862646f50430"}, + {file = "fasttext_predict-0.9.2.4-cp310-cp310-musllinux_1_2_armv7l.whl", hash = "sha256:f4be96ac0b01a3cda82be90e7f6afdafab98919995825c27babd2749a8319be9"}, + {file = "fasttext_predict-0.9.2.4-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:f505f737f9493d22ee0c54af7c7eb7828624d5089a1e85072bdb1bd7d3f8f82e"}, + {file = "fasttext_predict-0.9.2.4-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:9ce69f28862dd551d43e27aa0a8de924b6b34412bff998c23c3d4abd70813183"}, + {file = "fasttext_predict-0.9.2.4-cp310-cp310-win32.whl", hash = "sha256:864b6bb543275aee74360eee1d2cc23a440f09991e97efcdcf0b9a5af00f9aa9"}, + {file = "fasttext_predict-0.9.2.4-cp310-cp310-win_amd64.whl", hash = "sha256:7e72abe12c13fd12f8bb137b1f7561096fbd3bb24905a27d9e93a4921ee68dc6"}, + {file = "fasttext_predict-0.9.2.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:147996c86aa0928c7118f85d18b6a77c458db9ca236db26d44ee5ceaab0c0b6b"}, + {file = "fasttext_predict-0.9.2.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:5342f7363709e22524a31750c21e4b735b6666749a167fc03cc3bbf18ea8eccd"}, + {file = "fasttext_predict-0.9.2.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6cbecd3908909339316f61db38030ce43890c25bddb06c955191458af13ccfc5"}, + {file = "fasttext_predict-0.9.2.4-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9de4fcfb54bec35be6b0dffcdc5ace1a3a07f79ee3e8d33d13b82cc4116c5f2f"}, + {file = "fasttext_predict-0.9.2.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5af82e09227d993befc00271407b9d3c8aae81d34b35f96208223faf609f4b0c"}, + {file = "fasttext_predict-0.9.2.4-cp311-cp311-manylinux_2_31_armv7l.whl", hash = "sha256:337ee60179f32e8b0efa822e59316de15709c7684e7854021b4f6af82b7767ac"}, + {file = "fasttext_predict-0.9.2.4-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:aa9da0c52e65a45dbc87df67015ec1d2712f04de47733e197176550521feea87"}, + {file = "fasttext_predict-0.9.2.4-cp311-cp311-musllinux_1_2_armv7l.whl", hash = "sha256:495efde8afb622266c0e4de41978a6db731a0a685e1db032e7d22937850c9b44"}, + {file = "fasttext_predict-0.9.2.4-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:e5726ba34d79a143b69426e29905eb4d3f4ee8aee94927b3bea3dd566712986b"}, + {file = "fasttext_predict-0.9.2.4-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:5ac2f35830705c61dd848314c4c077a393608c181725dc353a69361821aa69a8"}, + {file = "fasttext_predict-0.9.2.4-cp311-cp311-win32.whl", hash = "sha256:7b2f8a5cf5f2c451777dbb7ea4957c7919a57ce29a4157a0a381933c9ea6fa70"}, + {file = "fasttext_predict-0.9.2.4-cp311-cp311-win_amd64.whl", hash = "sha256:83a3c00fdb73a304bc529bc0ae0e225bc2cb956fcfb8e1c7a882b2a1aaa97e19"}, + {file = "fasttext_predict-0.9.2.4-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:dcf8661da4f515551523470a745df246121f7e19736fcf3f48f04287963e6279"}, + {file = "fasttext_predict-0.9.2.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:99dbfcc3f353da2639fd04fc574a65ff4195b018311f790583147cdc6eb122f4"}, + {file = "fasttext_predict-0.9.2.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:427e99ba963b2c744ed7233304037a83b7adece97de6f361cfd356aa43cb87f3"}, + {file = "fasttext_predict-0.9.2.4-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8b9480cc75a906571a8e5fc717b91b4783f1820aaa5ed36a304d689280de8602"}, + {file = "fasttext_predict-0.9.2.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:11ef7af2a4431c76d2226e47334e86b9c4a78a98f6cb68b1ce9a1fc20e04c904"}, + {file = "fasttext_predict-0.9.2.4-cp312-cp312-manylinux_2_31_armv7l.whl", hash = "sha256:ecb0b854596ba847742597b35c2d0134fcf3a59214d09351d01535854078d56b"}, + {file = "fasttext_predict-0.9.2.4-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:fbbcfefac10f625d95fc42f28d76cc5bf0c12875f147b5a79108a2669e64a2dc"}, + {file = "fasttext_predict-0.9.2.4-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:a8cb78a00c04b7eb7da18b4805f8557b36911dc4375c947d8938897d2e131841"}, + {file = "fasttext_predict-0.9.2.4-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:299ae56ad53e1381c65030143da7bcae12546fd32bc019215592ec1ee40fd19e"}, + {file = "fasttext_predict-0.9.2.4-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:091938062002fe30d214f6e493a3a1e6180d401212d37eea23c29f4b55f3f347"}, + {file = "fasttext_predict-0.9.2.4-cp312-cp312-win32.whl", hash = "sha256:981b8d9734623f8f9a8003970f765e14b1d91ee82c59c35e8eba6b76368fa95e"}, + {file = "fasttext_predict-0.9.2.4-cp312-cp312-win_amd64.whl", hash = "sha256:bd3c33971c241577b0767e55d97acfda790f77378f9d5ee7872b6ee4bd63130b"}, + {file = "fasttext_predict-0.9.2.4-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:ddb85e62c95e4e02d417c782e3434ef65554df19e3522f5230f6be15a9373c05"}, + {file = "fasttext_predict-0.9.2.4-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:102129d45cf98dda871e83ae662f71d999b9ef6ff26bc842ffc1520a1f82930c"}, + {file = "fasttext_predict-0.9.2.4-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:05ba6a0fbf8cb2141b8ca2bc461db97af8ac31a62341e4696a75048b9de39e10"}, + {file = "fasttext_predict-0.9.2.4-cp313-cp313-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0c7a779215571296ecfcf86545cb30ec3f1c6f43cbcd69f83cc4f67049375ea1"}, + {file = "fasttext_predict-0.9.2.4-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ddd2f03f3f206585543f5274b1dbc5f651bae141a1b14c9d5225c2a12e5075c2"}, + {file = "fasttext_predict-0.9.2.4-cp313-cp313-manylinux_2_31_armv7l.whl", hash = "sha256:748f9edc3222a1fb7a61331c4e06d3b7f2390ae493f91f09d372a00b81762a8d"}, + {file = "fasttext_predict-0.9.2.4-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:1aee47a40757cd24272b34eaf9ceeea86577fd0761b0fd0e41599c6549abdf04"}, + {file = "fasttext_predict-0.9.2.4-cp313-cp313-musllinux_1_2_armv7l.whl", hash = "sha256:6ff0f152391ee03ffc18495322100c01735224f7843533a7c4ff33c8853d7be1"}, + {file = "fasttext_predict-0.9.2.4-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:4d92f5265318b41d6e68659fd459babbff692484e492c5013995b90a56b517c9"}, + {file = "fasttext_predict-0.9.2.4-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:3a7720cce1b8689d88df76cac1425e84f9911c69a4e40a5309d7d3435e1bb97c"}, + {file = "fasttext_predict-0.9.2.4-cp313-cp313-win32.whl", hash = "sha256:d16acfced7871ed0cd55b476f0dbdddc7a5da1ffc9745a3c5674846cf1555886"}, + {file = "fasttext_predict-0.9.2.4-cp313-cp313-win_amd64.whl", hash = "sha256:96a23328729ce62a851f8953582e576ca075ee78d637df4a78a2b3609784849e"}, + {file = "fasttext_predict-0.9.2.4-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:b1357d0d9d8568db84668b57e7c6880b9c46f757e8954ad37634402d36f09dba"}, + {file = "fasttext_predict-0.9.2.4-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:9604c464c5d86c7eba34b040080be7012e246ef512b819e428b7deb817290dae"}, + {file = "fasttext_predict-0.9.2.4-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cc6da186c2e4497cbfaba9c5424e58c7b72728b25d980829eb96daccd7cface1"}, + {file = "fasttext_predict-0.9.2.4-cp313-cp313t-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:366ed2ca4f4170418f3585e92059cf17ee2c963bf179111c5b8ba48f06cd69d1"}, + {file = "fasttext_predict-0.9.2.4-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2f1877edbb815a43e7d38cc7332202e759054cf0b5a4b7e34a743c0f5d6e7333"}, + {file = "fasttext_predict-0.9.2.4-cp313-cp313t-manylinux_2_31_armv7l.whl", hash = "sha256:f63c31352ba6fc910290b0fe12733770acd8cfa0945fcb9cf3984d241abcfc9d"}, + {file = "fasttext_predict-0.9.2.4-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:898e14b03fbfb0a8d9a5185a0a00ff656772b3baa37cad122e06e8e4d6da3832"}, + {file = "fasttext_predict-0.9.2.4-cp313-cp313t-musllinux_1_2_armv7l.whl", hash = "sha256:a33bb5832a69fc54d18cadcf015677c1acb5ccc7f0125d261df2a89f8aff01f6"}, + {file = "fasttext_predict-0.9.2.4-cp313-cp313t-musllinux_1_2_i686.whl", hash = "sha256:7fe9e98bd0701d598bf245eb2fbf592145cd03551684a2102a4b301294b9bd87"}, + {file = "fasttext_predict-0.9.2.4-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:dcb8c5a74c1785f005fd83d445137437b79ac70a2dfbfe4bb1b09aa5643be545"}, + {file = "fasttext_predict-0.9.2.4-cp313-cp313t-win32.whl", hash = "sha256:a85c7de3d4480faa12b930637fca9c23144d1520786fedf9ba8edd8642ed4aea"}, + {file = "fasttext_predict-0.9.2.4-cp313-cp313t-win_amd64.whl", hash = "sha256:be0933fa4af7abae09c703d28f9e17c80e7069eb6f92100b21985b777f4ea275"}, + {file = "fasttext_predict-0.9.2.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:8ff71f9905567271a760139978dec62f8c224f20c8c42a45addd4830fa3db977"}, + {file = "fasttext_predict-0.9.2.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:89401fa60533a9307bf26c312f3a47c58f9f8daf735532a03b0a88af391a6b7a"}, + {file = "fasttext_predict-0.9.2.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9b8e51eef5ebb1905b3b10e0f19cec7f0259f9134cfde76e4c172ac5dff3d1f1"}, + {file = "fasttext_predict-0.9.2.4-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4d4bd0178d295ed898903fc8e1454682a44e9e3db8bc3e777c3e122f2c5d2a39"}, + {file = "fasttext_predict-0.9.2.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:37717d593560d2d82911ba644dc0eb0c8d9b270b005d59bc278ae1465b77b50e"}, + {file = "fasttext_predict-0.9.2.4-cp39-cp39-manylinux_2_31_armv7l.whl", hash = "sha256:144decf434c79b80cacbb14007602ca0e563a951000dc7ca3308d022b1c6a56c"}, + {file = "fasttext_predict-0.9.2.4-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:abd5f77f491f83f9f2f374c38adb9432fae1e92db28fdd2cf5c0f3db48e1f805"}, + {file = "fasttext_predict-0.9.2.4-cp39-cp39-musllinux_1_2_armv7l.whl", hash = "sha256:25f3f82b847a320ce595dc772f5e1054ef0a1aa02e7d39feb0ea6374dc83aa55"}, + {file = "fasttext_predict-0.9.2.4-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:6390f898bbc83a85447338e1a68d1730d5a5ca68292ea3621718c3f4be39288f"}, + {file = "fasttext_predict-0.9.2.4-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:038bf374a9b9bd665fe58ef28a9b6a4703f8ba1de93bb747b974d7f78f023222"}, + {file = "fasttext_predict-0.9.2.4-cp39-cp39-win32.whl", hash = "sha256:639ab150585ceb3832912d9b623122735481cff676876040ca9b08312264634a"}, + {file = "fasttext_predict-0.9.2.4-cp39-cp39-win_amd64.whl", hash = "sha256:91c84cfb18a3a617e785fc9aa3bd4c80ffbe20009beb8f9e63e362160cb71a08"}, + {file = "fasttext_predict-0.9.2.4-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b11ba9414aa71754f798a102cf7d3df53307055b2b0f0b258a3f2d59c5a12cfa"}, + {file = "fasttext_predict-0.9.2.4-pp310-pypy310_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3c89c769e3646bdb341487a68835239f35a4a0959cc1a8d8a7d215f40b22a230"}, + {file = "fasttext_predict-0.9.2.4-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5f3b9cd4a2cf4c4853323f57c5da6ecffca6aeb9b6d8751ee40fe611d6edf8dd"}, + {file = "fasttext_predict-0.9.2.4-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:1c92905396c74e5cb29ddbfa763b5addec1581b6e0eae4cbe82248dfe733557e"}, + {file = "fasttext_predict-0.9.2.4.tar.gz", hash = "sha256:18a6fb0d74c7df9280db1f96cb75d990bfd004fa9d669493ea3dd3d54f84dbc7"}, +] + [[package]] name = "filelock" version = "3.19.1" @@ -4415,6 +4534,25 @@ pygments = ">=2.13.0,<3.0.0" [package.extras] jupyter = ["ipywidgets (>=7.5.1,<9)"] +[[package]] +name = "robust-downloader" +version = "0.0.2" +description = "A Simple Robust Downloader written in Python" +optional = false +python-versions = "*" +files = [ + {file = "robust-downloader-0.0.2.tar.gz", hash = "sha256:08c938b96e317abe6b037e34230a91bda9b5d613f009bca4a47664997c61de90"}, + {file = "robust_downloader-0.0.2-py3-none-any.whl", hash = "sha256:8fe08bfb64d714fd1a048a7df6eb7b413eb4e624309a49db2c16fbb80a62869d"}, +] + +[package.dependencies] +colorlog = "*" +requests = "*" +tqdm = "*" + +[package.extras] +dev = ["black", "pre-commit (>=3.3.3)", "pytest", "pytest-cov", "ruff"] + [[package]] name = "rpds-py" version = "0.27.1" @@ -6458,10 +6596,11 @@ files = [ cffi = ["cffi (>=1.17)"] [extras] -all = ["aiofiles", "google-cloud-language", "langchain-nvidia-ai-endpoints", "langchain-openai", "numpy", "numpy", "numpy", "numpy", "opentelemetry-api", "presidio-analyzer", "presidio-anonymizer", "streamlit", "tqdm", "yara-python"] +all = ["aiofiles", "fast-langdetect", "google-cloud-language", "langchain-nvidia-ai-endpoints", "langchain-openai", "numpy", "numpy", "numpy", "numpy", "opentelemetry-api", "presidio-analyzer", "presidio-anonymizer", "streamlit", "tqdm", "yara-python"] eval = ["numpy", "numpy", "numpy", "numpy", "streamlit", "tornado", "tqdm"] gcp = ["google-cloud-language"] jailbreak = ["yara-python"] +multilingual = ["fast-langdetect"] nvidia = ["langchain-nvidia-ai-endpoints"] openai = ["langchain-openai"] sdd = ["presidio-analyzer", "presidio-anonymizer"] @@ -6470,4 +6609,4 @@ tracing = ["aiofiles", "opentelemetry-api"] [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.14" -content-hash = "256ef22eb6c36141d9805311bdae9aee6c550ca13deb67b8205c639ecde4bc5c" +content-hash = "5f621add3bdfe92f78c38e14702f22cef7adb3a98b6ca6e494bb44ed834bdd97" diff --git a/pyproject.toml b/pyproject.toml index ef6bf9218..3e864afef 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -105,6 +105,9 @@ google-cloud-language = { version = ">=2.14.0", optional = true } # jailbreak injection yara-python = { version = "^4.5.1", optional = true } +# multilingual content safety - language detection +fast-langdetect = { version = ">=1.0.0", optional = true } + [tool.poetry.extras] sdd = ["presidio-analyzer", "presidio-anonymizer"] eval = ["tqdm", "numpy", "streamlit", "tornado"] @@ -113,6 +116,7 @@ gcp = ["google-cloud-language"] tracing = ["opentelemetry-api", "aiofiles"] nvidia = ["langchain-nvidia-ai-endpoints"] jailbreak = ["yara-python"] +multilingual = ["fast-langdetect"] # Poetry does not support recursive dependencies, so we need to add all the dependencies here. # I also support their decision. There is no PEP for recursive dependencies, but it has been supported in pip since version 21.2. # It is here for backward compatibility. @@ -128,6 +132,7 @@ all = [ "aiofiles", "langchain-nvidia-ai-endpoints", "yara-python", + "fast-langdetect", ] [tool.poetry.group.dev] @@ -147,6 +152,7 @@ pytest-profiling = "^1.7.0" yara-python = "^4.5.1" opentelemetry-api = "^1.34.1" opentelemetry-sdk = "^1.34.1" +fast-langdetect = ">=1.0.0" pyright = "^1.1.405" ruff = "0.14.6" diff --git a/tests/test_content_safety_actions.py b/tests/test_content_safety_actions.py index 8d7d10ea9..1e3151b60 100644 --- a/tests/test_content_safety_actions.py +++ b/tests/test_content_safety_actions.py @@ -13,18 +13,31 @@ # See the License for the specific language governing permissions and # limitations under the License. -from unittest.mock import MagicMock +from unittest.mock import MagicMock, patch -# conftest.py import pytest from nemoguardrails.library.content_safety.actions import ( + DEFAULT_REFUSAL_MESSAGES, + SUPPORTED_LANGUAGES, + _detect_language, + _get_refusal_message, content_safety_check_input, content_safety_check_output, content_safety_check_output_mapping, + detect_language, ) from tests.utils import FakeLLM +try: + import fast_langdetect # noqa + + HAS_FAST_LANGDETECT = True +except ImportError: + HAS_FAST_LANGDETECT = False + +requires_fast_langdetect = pytest.mark.skipif(not HAS_FAST_LANGDETECT, reason="fast-langdetect not installed") + @pytest.fixture def fake_llm(): @@ -150,3 +163,148 @@ def test_content_safety_check_output_mapping_default(): """Test content_safety_check_output_mapping defaults to allowed=False when key is missing.""" result = {"policy_violations": []} assert content_safety_check_output_mapping(result) is False + + +@requires_fast_langdetect +class TestDetectLanguage: + @pytest.mark.parametrize( + "text,expected_lang", + [ + ("Hello, how are you today?", "en"), + ("Hola, ¿cómo estás hoy?", "es"), + ("你好,你今天好吗?", "zh"), + ("Guten Tag, wie geht es Ihnen?", "de"), + ("Bonjour, comment allez-vous?", "fr"), + ("こんにちは、お元気ですか?", "ja"), + ], + ids=["english", "spanish", "chinese", "german", "french", "japanese"], + ) + def test_detect_language(self, text, expected_lang): + assert _detect_language(text) == expected_lang + + def test_detect_language_empty_string(self): + result = _detect_language("") + assert result is None or result == "en" + + def test_detect_language_import_error(self): + with patch.dict("sys.modules", {"fast_langdetect": None}): + import nemoguardrails.library.content_safety.actions as actions_module + + _original_detect_language = actions_module._detect_language + + def patched_detect_language(text): + try: + raise ImportError("No module named 'fast_langdetect'") + except ImportError: + return None + + with patch.object(actions_module, "_detect_language", patched_detect_language): + result = actions_module._detect_language("Hello") + assert result is None + + def test_detect_language_exception(self): + with patch("fast_langdetect.LangDetector.detect", side_effect=Exception("Detection failed")): + result = _detect_language("Hello") + assert result is None + + +class TestGetRefusalMessage: + @pytest.mark.parametrize("lang", list(SUPPORTED_LANGUAGES)) + def test_default_messages(self, lang): + result = _get_refusal_message(lang, None) + assert result == DEFAULT_REFUSAL_MESSAGES[lang] + + def test_custom_message_used_when_available(self): + custom = {"en": "Custom refusal", "es": "Rechazo personalizado"} + assert _get_refusal_message("en", custom) == "Custom refusal" + assert _get_refusal_message("es", custom) == "Rechazo personalizado" + + def test_unsupported_lang_falls_back_to_english(self): + assert _get_refusal_message("xyz", None) == DEFAULT_REFUSAL_MESSAGES["en"] + assert _get_refusal_message("xyz", {"en": "Custom fallback"}) == "Custom fallback" + + def test_lang_not_in_custom_uses_default(self): + custom = {"en": "Custom English"} + assert _get_refusal_message("es", custom) == DEFAULT_REFUSAL_MESSAGES["es"] + + +@requires_fast_langdetect +class TestDetectLanguageAction: + @pytest.mark.asyncio + @pytest.mark.parametrize( + "user_message,expected_lang", + [ + ("Hello, how are you?", "en"), + ("Hola, ¿cómo estás?", "es"), + ("你好", "zh"), + ], + ids=["english", "spanish", "chinese"], + ) + async def test_detect_language_action(self, user_message, expected_lang): + context = {"user_message": user_message} + result = await detect_language(context=context, config=None) + assert result["language"] == expected_lang + assert result["refusal_message"] == DEFAULT_REFUSAL_MESSAGES[expected_lang] + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "context", + [None, {"user_message": ""}], + ids=["no_context", "empty_message"], + ) + async def test_detect_language_action_defaults_to_english(self, context): + result = await detect_language(context=context, config=None) + assert result["language"] == "en" + assert result["refusal_message"] == DEFAULT_REFUSAL_MESSAGES["en"] + + @pytest.mark.asyncio + async def test_detect_language_action_unsupported_language_falls_back_to_english(self): + with patch( + "nemoguardrails.library.content_safety.actions._detect_language", + return_value="xyz", + ): + context = {"user_message": "some text"} + result = await detect_language(context=context, config=None) + assert result["language"] == "en" + assert result["refusal_message"] == DEFAULT_REFUSAL_MESSAGES["en"] + + @pytest.mark.asyncio + async def test_detect_language_action_with_config_custom_messages(self): + mock_config = MagicMock() + mock_config.rails.config.content_safety.multilingual.refusal_messages = { + "en": "Custom: Cannot help", + "es": "Personalizado: No puedo ayudar", + } + + context = {"user_message": "Hello"} + result = await detect_language(context=context, config=mock_config) + assert result["language"] == "en" + assert result["refusal_message"] == "Custom: Cannot help" + + @pytest.mark.asyncio + async def test_detect_language_action_with_config_no_multilingual(self): + mock_config = MagicMock() + mock_config.rails.config.content_safety.multilingual = None + + context = {"user_message": "Hello"} + result = await detect_language(context=context, config=mock_config) + assert result["language"] == "en" + assert result["refusal_message"] == DEFAULT_REFUSAL_MESSAGES["en"] + + +class TestSupportedLanguagesAndDefaults: + def test_supported_languages_count(self): + assert len(SUPPORTED_LANGUAGES) == 9 + + def test_supported_languages_contents(self): + expected = {"en", "es", "zh", "de", "fr", "hi", "ja", "ar", "th"} + assert SUPPORTED_LANGUAGES == expected + + def test_default_refusal_messages_has_all_supported_languages(self): + for lang in SUPPORTED_LANGUAGES: + assert lang in DEFAULT_REFUSAL_MESSAGES + + def test_default_refusal_messages_are_non_empty(self): + for _lang, message in DEFAULT_REFUSAL_MESSAGES.items(): + assert message + assert len(message) > 0 diff --git a/tests/test_rails_config.py b/tests/test_rails_config.py index 796011d82..e0522ba6c 100644 --- a/tests/test_rails_config.py +++ b/tests/test_rails_config.py @@ -23,7 +23,9 @@ from nemoguardrails.llm.prompts import TaskPrompt from nemoguardrails.rails.llm.config import ( + ContentSafetyConfig, Model, + MultilingualConfig, RailsConfig, _get_flow_model, _validate_rail_prompts, @@ -1015,3 +1017,110 @@ def test_hero_topic_safety_prompt_raises(self): content: Verify the user input is on-topic """ ) + + +class TestMultilingualConfig: + def test_defaults(self): + config = MultilingualConfig() + assert config.enabled is False + assert config.refusal_messages is None + assert config.max_text_length is None + assert config.normalize_text is True + assert config.cache_dir is None + + def test_with_custom_messages(self): + custom = {"en": "Custom", "es": "Personalizado"} + config = MultilingualConfig(enabled=True, refusal_messages=custom) + assert config.enabled is True + assert config.refusal_messages == custom + + def test_with_detection_options(self): + config = MultilingualConfig( + enabled=True, + max_text_length=200, + normalize_text=False, + cache_dir="/custom/cache", + ) + assert config.enabled is True + assert config.max_text_length == 200 + assert config.normalize_text is False + assert config.cache_dir == "/custom/cache" + + +class TestContentSafetyConfigModel: + def test_defaults(self): + config = ContentSafetyConfig() + assert config.multilingual.enabled is False + assert config.multilingual.refusal_messages is None + assert config.multilingual.max_text_length is None + assert config.multilingual.normalize_text is True + assert config.multilingual.cache_dir is None + + def test_with_multilingual(self): + custom = {"en": "Custom"} + config = ContentSafetyConfig(multilingual=MultilingualConfig(enabled=True, refusal_messages=custom)) + assert config.multilingual.enabled is True + assert config.multilingual.refusal_messages == custom + + +class TestMultilingualConfigInRailsConfig: + BASE_YAML = """ + models: + - type: content_safety + engine: nim + model: nvidia/llama-3.1-nemoguard-8b-content-safety + rails: + {rails_config} + input: + flows: + - content safety check input $model=content_safety + prompts: + - task: content_safety_check_input $model=content_safety + content: Check content safety + """ + + def test_multilingual_disabled_by_default(self): + config = RailsConfig.from_content(yaml_content=self.BASE_YAML.format(rails_config="")) + assert config.rails.config.content_safety.multilingual.enabled is False + + def test_multilingual_enabled_with_custom_messages(self): + rails_config = """ + config: + content_safety: + multilingual: + enabled: true + refusal_messages: + en: "Custom English" + es: "Personalizado" + """ + config = RailsConfig.from_content(yaml_content=self.BASE_YAML.format(rails_config=rails_config)) + assert config.rails.config.content_safety.multilingual.enabled is True + assert config.rails.config.content_safety.multilingual.refusal_messages["en"] == "Custom English" + assert config.rails.config.content_safety.multilingual.refusal_messages["es"] == "Personalizado" + + def test_multilingual_enabled_no_custom_messages(self): + rails_config = """ + config: + content_safety: + multilingual: + enabled: true + """ + config = RailsConfig.from_content(yaml_content=self.BASE_YAML.format(rails_config=rails_config)) + assert config.rails.config.content_safety.multilingual.enabled is True + assert config.rails.config.content_safety.multilingual.refusal_messages is None + + def test_multilingual_with_detection_options(self): + rails_config = """ + config: + content_safety: + multilingual: + enabled: true + max_text_length: 200 + normalize_text: false + cache_dir: "/custom/cache" + """ + config = RailsConfig.from_content(yaml_content=self.BASE_YAML.format(rails_config=rails_config)) + assert config.rails.config.content_safety.multilingual.enabled is True + assert config.rails.config.content_safety.multilingual.max_text_length == 200 + assert config.rails.config.content_safety.multilingual.normalize_text is False + assert config.rails.config.content_safety.multilingual.cache_dir == "/custom/cache"