diff --git a/.gitignore b/.gitignore index b80f310..800ced4 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,6 @@ my_examples .idea +.env .venv .DS_Store select_ai.egg-info diff --git a/tests/conftest.py b/tests/conftest.py index f9c54f0..ff4c9aa 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -33,6 +33,35 @@ PYSAI_TEST_USER = "PYSAI_TEST_USER" PYSAI_OCI_CREDENTIAL_NAME = f"PYSAI_OCI_CREDENTIAL_{uuid.uuid4().hex.upper()}" +_BASIC_SCHEMA_PRIVILEGES = ( + "CREATE SESSION", + "CREATE TABLE", + "UNLIMITED TABLESPACE", +) + + +def _ensure_test_user_exists(username: str, password: str): + username_upper = username.upper() + with select_ai.cursor() as cr: + cr.execute( + "SELECT 1 FROM dba_users WHERE username = :username", + username=username_upper, + ) + if cr.fetchone(): + return + escaped_password = password.replace('"', '""') + cr.execute( + f'CREATE USER {username_upper} IDENTIFIED BY "{escaped_password}"' + ) + select_ai.db.get_connection().commit() + + +def _grant_basic_schema_privileges(username: str): + username_upper = username.upper() + with select_ai.cursor() as cr: + for privilege in _BASIC_SCHEMA_PRIVILEGES: + cr.execute(f"GRANT {privilege} TO {username_upper}") + select_ai.db.get_connection().commit() def get_env_value(name, default_value=None, required=False): @@ -93,6 +122,11 @@ def test_env(pytestconfig): @pytest.fixture(autouse=True, scope="session") def setup_test_user(test_env): select_ai.connect(**test_env.connect_params(admin=True)) + _ensure_test_user_exists( + username=test_env.test_user, + password=test_env.test_user_password, + ) + _grant_basic_schema_privileges(username=test_env.test_user) select_ai.grant_privileges(users=[test_env.test_user]) select_ai.grant_http_access( users=[test_env.test_user], @@ -146,6 +180,7 @@ def oci_credential(connect, test_env): "private_key": get_env_value("OCI_PRIVATE_KEY", required=True), "fingerprint": get_env_value("OCI_FINGERPRINT", required=True), } + print(credential) select_ai.create_credential(credential, replace=True) yield credential select_ai.delete_credential(PYSAI_OCI_CREDENTIAL_NAME) diff --git a/tests/gsd/test_2000_synthetic_data.py b/tests/gsd/test_2000_synthetic_data.py new file mode 100644 index 0000000..137237c --- /dev/null +++ b/tests/gsd/test_2000_synthetic_data.py @@ -0,0 +1,137 @@ +# ----------------------------------------------------------------------------- +# Copyright (c) 2025, Oracle and/or its affiliates. +# +# Licensed under the Universal Permissive License v 1.0 as shown at +# http://oss.oracle.com/licenses/upl. +# ----------------------------------------------------------------------------- + +""" +2000 - Synthetic data generation tests +""" + +import uuid + +import pytest +import select_ai +from select_ai import ( + Profile, + ProfileAttributes, + SyntheticDataAttributes, + SyntheticDataParams, +) + +PROFILE_PREFIX = f"PYSAI_2000_{uuid.uuid4().hex.upper()}" + + +def _build_attributes(record_count=1, **kwargs): + return SyntheticDataAttributes( + object_name="people", + record_count=record_count, + **kwargs, + ) + + +@pytest.fixture(scope="module") +def synthetic_provider(oci_compartment_id): + return select_ai.OCIGenAIProvider( + oci_compartment_id=oci_compartment_id, + oci_apiformat="GENERIC", + ) + + +@pytest.fixture(scope="module") +def synthetic_profile_attributes(oci_credential, synthetic_provider): + return ProfileAttributes( + credential_name=oci_credential["credential_name"], + object_list=[ + {"owner": "ADMIN", "name": "people"}, + {"owner": "ADMIN", "name": "gymnast"}, + ], + provider=synthetic_provider, + ) + + +@pytest.fixture(scope="module") +def synthetic_profile(synthetic_profile_attributes): + profile = Profile( + profile_name=f"{PROFILE_PREFIX}_SYNC", + attributes=synthetic_profile_attributes, + description="Synthetic data test profile", + replace=True, + ) + yield profile + try: + profile.delete(force=True) + except Exception: + pass + + +def test_2000_generate_with_full_params(synthetic_profile): + """Generate synthetic data with full parameter set""" + params = SyntheticDataParams(sample_rows=10, priority="HIGH") + attributes = _build_attributes( + record_count=5, + params=params, + user_prompt="age must be greater than 20", + ) + result = synthetic_profile.generate_synthetic_data(attributes) + assert result is None + + +def test_2001_generate_minimum_fields(synthetic_profile): + """Generate synthetic data with minimum fields""" + attributes = _build_attributes() + result = synthetic_profile.generate_synthetic_data(attributes) + assert result is None + + +def test_2002_generate_zero_sample_rows(synthetic_profile): + """Generate synthetic data with zero sample rows""" + params = SyntheticDataParams(sample_rows=0, priority="HIGH") + attributes = _build_attributes(params=params) + result = synthetic_profile.generate_synthetic_data(attributes) + assert result is None + + +def test_2003_generate_single_sample_row(synthetic_profile): + """Generate synthetic data with single sample row""" + params = SyntheticDataParams(sample_rows=1, priority="HIGH") + attributes = _build_attributes(params=params) + result = synthetic_profile.generate_synthetic_data(attributes) + assert result is None + + +def test_2004_generate_low_priority(synthetic_profile): + """Generate synthetic data with low priority""" + params = SyntheticDataParams(sample_rows=1, priority="LOW") + attributes = _build_attributes(params=params) + result = synthetic_profile.generate_synthetic_data(attributes) + assert result is None + + +def test_2005_generate_missing_object_name(synthetic_profile): + """Missing object_name raises error""" + attributes = SyntheticDataAttributes(record_count=1) + with pytest.raises(Exception): + synthetic_profile.generate_synthetic_data(attributes) + + +def test_2006_generate_invalid_priority(synthetic_profile): + """Invalid priority raises error""" + params = SyntheticDataParams(sample_rows=1, priority="CRITICAL") + attributes = _build_attributes(params=params) + with pytest.raises(Exception): + synthetic_profile.generate_synthetic_data(attributes) + + +def test_2007_generate_negative_record_count(synthetic_profile): + """Negative record count raises error""" + attributes = _build_attributes(record_count=-5) + with pytest.raises(Exception): + synthetic_profile.generate_synthetic_data(attributes) + + +def test_2008_generate_with_none_attributes(synthetic_profile): + """Passing None as attributes raises error""" + with pytest.raises(Exception): + synthetic_profile.generate_synthetic_data(None) diff --git a/tests/gsd/test_2100_synthetic_data_async.py b/tests/gsd/test_2100_synthetic_data_async.py new file mode 100644 index 0000000..7648256 --- /dev/null +++ b/tests/gsd/test_2100_synthetic_data_async.py @@ -0,0 +1,148 @@ +# ----------------------------------------------------------------------------- +# Copyright (c) 2025, Oracle and/or its affiliates. +# +# Licensed under the Universal Permissive License v 1.0 as shown at +# http://oss.oracle.com/licenses/upl. +# ----------------------------------------------------------------------------- + +""" +2100 - Synthetic data generation tests (async) +""" + +import uuid + +import pytest +import select_ai +from select_ai import ( + AsyncProfile, + ProfileAttributes, + SyntheticDataAttributes, + SyntheticDataParams, +) + +PROFILE_PREFIX = f"PYSAI_2100_{uuid.uuid4().hex.upper()}" + + +def _build_attributes(record_count=1, **kwargs): + return SyntheticDataAttributes( + object_name="people", + record_count=record_count, + **kwargs, + ) + + +@pytest.fixture(scope="module") +def async_synthetic_provider(oci_compartment_id): + return select_ai.OCIGenAIProvider( + oci_compartment_id=oci_compartment_id, + oci_apiformat="GENERIC", + ) + + +@pytest.fixture(scope="module") +def async_synthetic_profile_attributes( + oci_credential, async_synthetic_provider +): + return ProfileAttributes( + credential_name=oci_credential["credential_name"], + object_list=[ + {"owner": "ADMIN", "name": "people"}, + {"owner": "ADMIN", "name": "gymnast"}, + ], + provider=async_synthetic_provider, + ) + + +@pytest.fixture(scope="module") +async def async_synthetic_profile(async_synthetic_profile_attributes): + profile = await AsyncProfile( + profile_name=f"{PROFILE_PREFIX}_ASYNC", + attributes=async_synthetic_profile_attributes, + description="Synthetic data async test profile", + replace=True, + ) + yield profile + try: + await profile.delete(force=True) + except Exception: + pass + + +@pytest.mark.anyio +async def test_2100_generate_with_full_params(async_synthetic_profile): + """Generate synthetic data with full parameter set""" + params = SyntheticDataParams(sample_rows=10, priority="HIGH") + attributes = _build_attributes( + record_count=5, + params=params, + user_prompt="age must be greater than 20", + ) + result = await async_synthetic_profile.generate_synthetic_data(attributes) + assert result is None + + +@pytest.mark.anyio +async def test_2101_generate_minimum_fields(async_synthetic_profile): + """Generate synthetic data with minimum fields""" + attributes = _build_attributes() + result = await async_synthetic_profile.generate_synthetic_data(attributes) + assert result is None + + +@pytest.mark.anyio +async def test_2102_generate_zero_sample_rows(async_synthetic_profile): + """Generate synthetic data with zero sample rows""" + params = SyntheticDataParams(sample_rows=0, priority="HIGH") + attributes = _build_attributes(params=params) + result = await async_synthetic_profile.generate_synthetic_data(attributes) + assert result is None + + +@pytest.mark.anyio +async def test_2103_generate_single_sample_row(async_synthetic_profile): + """Generate synthetic data with single sample row""" + params = SyntheticDataParams(sample_rows=1, priority="HIGH") + attributes = _build_attributes(params=params) + result = await async_synthetic_profile.generate_synthetic_data(attributes) + assert result is None + + +@pytest.mark.anyio +async def test_2104_generate_low_priority(async_synthetic_profile): + """Generate synthetic data with low priority""" + params = SyntheticDataParams(sample_rows=1, priority="LOW") + attributes = _build_attributes(params=params) + result = await async_synthetic_profile.generate_synthetic_data(attributes) + assert result is None + + +@pytest.mark.anyio +async def test_2105_generate_missing_object_name(async_synthetic_profile): + """Missing object_name raises error""" + attributes = SyntheticDataAttributes(record_count=1) + with pytest.raises(Exception): + await async_synthetic_profile.generate_synthetic_data(attributes) + + +@pytest.mark.anyio +async def test_2106_generate_invalid_priority(async_synthetic_profile): + """Invalid priority raises error""" + params = SyntheticDataParams(sample_rows=1, priority="CRITICAL") + attributes = _build_attributes(params=params) + with pytest.raises(Exception): + await async_synthetic_profile.generate_synthetic_data(attributes) + + +@pytest.mark.anyio +async def test_2107_generate_negative_record_count(async_synthetic_profile): + """Negative record count raises error""" + attributes = _build_attributes(record_count=-5) + with pytest.raises(Exception): + await async_synthetic_profile.generate_synthetic_data(attributes) + + +@pytest.mark.anyio +async def test_2108_generate_with_none_attributes(async_synthetic_profile): + """Passing None as attributes raises error""" + with pytest.raises(Exception): + await async_synthetic_profile.generate_synthetic_data(None) diff --git a/tests/profiles/test_1200_profile.py b/tests/profiles/test_1200_profile.py index 5226af7..e14c4dd 100644 --- a/tests/profiles/test_1200_profile.py +++ b/tests/profiles/test_1200_profile.py @@ -206,7 +206,7 @@ def test_1207(): assert profile.attributes.provider.model == "meta.llama-3.1-70b-instruct" -def test_1208(oci_credential): +def test_1208(oci_credential, oci_compartment_id): """Set multiple attributes for a Profile""" profile = Profile(PYSAI_1200_PROFILE) profile_attrs = ProfileAttributes( @@ -214,6 +214,7 @@ def test_1208(oci_credential): provider=select_ai.OCIGenAIProvider( model="meta.llama-4-maverick-17b-128e-instruct-fp8", region="us-chicago-1", + oci_compartment_id=oci_compartment_id, oci_apiformat="GENERIC", ), object_list=[{"owner": "ADMIN", "name": "gymnasts"}], diff --git a/tests/profiles/test_1400_conversation.py b/tests/profiles/test_1400_conversation.py new file mode 100644 index 0000000..f2e3d6b --- /dev/null +++ b/tests/profiles/test_1400_conversation.py @@ -0,0 +1,203 @@ +# ----------------------------------------------------------------------------- +# Copyright (c) 2025, Oracle and/or its affiliates. +# +# Licensed under the Universal Permissive License v 1.0 as shown at +# http://oss.oracle.com/licenses/upl. +# ----------------------------------------------------------------------------- + +""" +1400 - Conversation API tests +""" + +import uuid + +import pytest +import select_ai +from oracledb import DatabaseError +from select_ai import Conversation, ConversationAttributes + +CONVERSATION_PREFIX = f"PYSAI_1400_{uuid.uuid4().hex.upper()}" + + +@pytest.fixture +def conversation_factory(): + created = [] + + def _create(**kwargs): + attributes = ConversationAttributes(**kwargs) + conv = Conversation(attributes=attributes) + conv.create() + created.append(conv) + return conv + + yield _create + + for conv in created: + conv.delete(force=True) + + +@pytest.fixture +def conversation(conversation_factory): + return conversation_factory(title=f"{CONVERSATION_PREFIX}_ACTIVE") + + +def test_1400_create_with_title(conversation): + """Create a conversation with title""" + assert conversation.conversation_id + + +def test_1401_create_with_description(conversation_factory): + """Create a conversation with title and description""" + conv = conversation_factory( + title=f"{CONVERSATION_PREFIX}_HISTORY", + description="LLM's understanding of history of science", + ) + attrs = conv.get_attributes() + assert attrs.title == f"{CONVERSATION_PREFIX}_HISTORY" + assert attrs.description == "LLM's understanding of history of science" + + +def test_1402_create_without_title(conversation_factory): + """Create a conversation without providing a title""" + conv = conversation_factory() + attrs = conv.get_attributes() + assert attrs.title == "New Conversation" + + +def test_1403_create_with_missing_attributes(): + """Missing attributes raise AttributeError""" + conv = Conversation(attributes=None) + with pytest.raises(AttributeError): + conv.create() + + +def test_1404_get_attributes(conversation): + """Fetch conversation attributes""" + attrs = conversation.get_attributes() + assert attrs.title == f"{CONVERSATION_PREFIX}_ACTIVE" + assert attrs.description is None + + +def test_1405_set_attributes(conversation): + """Update conversation attributes""" + updated = ConversationAttributes( + title=f"{CONVERSATION_PREFIX}_UPDATED", + description="Updated Description", + ) + conversation.set_attributes(updated) + attrs = conversation.get_attributes() + assert attrs.title == f"{CONVERSATION_PREFIX}_UPDATED" + assert attrs.description == "Updated Description" + + +def test_1406_set_attributes_with_none(conversation): + """Setting empty attributes raises AttributeError""" + with pytest.raises(AttributeError): + conversation.set_attributes(None) + + +def test_1407_delete_conversation(conversation_factory): + """Delete conversation and validate removal""" + conv = conversation_factory(title=f"{CONVERSATION_PREFIX}_DELETE") + conv.delete(force=True) + with pytest.raises(select_ai.errors.ConversationNotFoundError): + conv.get_attributes() + + +def test_1408_delete_twice(conversation_factory): + """Deleting an already deleted conversation raises DatabaseError""" + conv = conversation_factory(title=f"{CONVERSATION_PREFIX}_DELETE_TWICE") + conv.delete(force=True) + with pytest.raises(DatabaseError): + conv.delete() + + +def test_1409_list_contains_created_conversation(conversation): + """Conversation list contains the created conversation""" + conversation_ids = {item.conversation_id for item in Conversation.list()} + assert conversation.conversation_id in conversation_ids + + +def test_1410_multiple_conversations_have_unique_ids(conversation_factory): + """Multiple conversations produce unique identifiers""" + titles = [ + f"{CONVERSATION_PREFIX}_AI", + f"{CONVERSATION_PREFIX}_DB", + f"{CONVERSATION_PREFIX}_MATH", + ] + conversations = [conversation_factory(title=title) for title in titles] + ids = {conv.conversation_id for conv in conversations} + assert len(ids) == len(titles) + + +def test_1411_create_with_long_values(): + """Creating conversation with overly long values fails""" + conv = Conversation( + attributes=ConversationAttributes( + title="A" * 255, + description="B" * 1000, + ) + ) + with pytest.raises(Exception): + conv.create() + + +def test_1412_set_attributes_with_invalid_id(): + """Updating conversation with invalid id raises DatabaseError""" + conv = Conversation(conversation_id="fake_id") + with pytest.raises(DatabaseError): + conv.set_attributes(ConversationAttributes(title="Invalid")) + + +def test_1413_delete_with_invalid_id(): + """Deleting conversation with invalid id raises DatabaseError""" + conv = Conversation(conversation_id="fake_id") + with pytest.raises(DatabaseError): + conv.delete() + + +def test_1414_get_attributes_with_invalid_id(): + """Fetching attributes for invalid conversation raises ConversationNotFound""" + conv = Conversation(conversation_id="invalid") + with pytest.raises(select_ai.errors.ConversationNotFoundError): + conv.get_attributes() + + +def test_1415_get_attributes_for_deleted_conversation(conversation_factory): + """Fetching attributes after deletion raises ConversationNotFound""" + conv = conversation_factory(title=f"{CONVERSATION_PREFIX}_TO_DELETE") + conv.delete(force=True) + with pytest.raises(select_ai.errors.ConversationNotFoundError): + conv.get_attributes() + + +def test_1416_list_contains_new_conversation(conversation_factory): + """List reflects newly created conversation""" + conv = conversation_factory(title=f"{CONVERSATION_PREFIX}_LIST") + listed = list(Conversation.list()) + assert any(item.conversation_id == conv.conversation_id for item in listed) + + +def test_1417_list_returns_conversation_instances(): + """List returns Conversation objects""" + listed = list(Conversation.list()) + assert all(isinstance(item, Conversation) for item in listed) + + +def test_1418_get_attributes_without_description(conversation_factory): + """Conversation created without description has None description""" + conv = conversation_factory(title=f"{CONVERSATION_PREFIX}_NO_DESC") + attrs = conv.get_attributes() + assert attrs.title == f"{CONVERSATION_PREFIX}_NO_DESC" + assert attrs.description is None + + +def test_1419_create_with_description_none(conversation_factory): + """Explicitly setting description to None is allowed""" + conv = conversation_factory( + title=f"{CONVERSATION_PREFIX}_NONE_DESC", + description=None, + ) + attrs = conv.get_attributes() + assert attrs.title == f"{CONVERSATION_PREFIX}_NONE_DESC" + assert attrs.description is None diff --git a/tests/profiles/test_1500_conversation_async.py b/tests/profiles/test_1500_conversation_async.py new file mode 100644 index 0000000..c87e4a5 --- /dev/null +++ b/tests/profiles/test_1500_conversation_async.py @@ -0,0 +1,249 @@ +# ----------------------------------------------------------------------------- +# Copyright (c) 2025, Oracle and/or its affiliates. +# +# Licensed under the Universal Permissive License v 1.0 as shown at +# http://oss.oracle.com/licenses/upl. +# ----------------------------------------------------------------------------- + +""" +1500 - AsyncConversation API tests +""" + +import uuid + +import pytest +import select_ai +from oracledb import DatabaseError +from select_ai import AsyncConversation, ConversationAttributes + +CONVERSATION_PREFIX = f"PYSAI_1500_{uuid.uuid4().hex.upper()}" + + +@pytest.fixture +async def async_conversation_factory(): + created = [] + + async def _create(**kwargs): + attributes = ConversationAttributes(**kwargs) + conversation = AsyncConversation(attributes=attributes) + await conversation.create() + created.append(conversation) + return conversation + + yield _create + + for conversation in created: + await conversation.delete(force=True) + + +@pytest.fixture +async def async_conversation(async_conversation_factory): + return await async_conversation_factory( + title=f"{CONVERSATION_PREFIX}_ACTIVE" + ) + + +@pytest.mark.anyio +async def test_1500_create_with_title(async_conversation): + """Create an async conversation with title""" + assert async_conversation.conversation_id + + +@pytest.mark.anyio +async def test_1501_create_with_description(async_conversation_factory): + """Create an async conversation with title and description""" + conversation = await async_conversation_factory( + title=f"{CONVERSATION_PREFIX}_HISTORY", + description="LLM's understanding of history of science", + ) + attributes = await conversation.get_attributes() + assert attributes.title == f"{CONVERSATION_PREFIX}_HISTORY" + assert ( + attributes.description == "LLM's understanding of history of science" + ) + + +@pytest.mark.anyio +async def test_1502_create_without_title(async_conversation_factory): + """Create an async conversation without providing a title""" + conversation = await async_conversation_factory() + attributes = await conversation.get_attributes() + assert attributes.title == "New Conversation" + + +@pytest.mark.anyio +async def test_1503_create_with_missing_attributes(): + """Missing attributes raise AttributeError""" + conversation = AsyncConversation(attributes=None) + with pytest.raises(AttributeError): + await conversation.create() + + +@pytest.mark.anyio +async def test_1504_get_attributes(async_conversation): + """Fetch async conversation attributes""" + attributes = await async_conversation.get_attributes() + assert attributes.title == f"{CONVERSATION_PREFIX}_ACTIVE" + assert attributes.description is None + + +@pytest.mark.anyio +async def test_1505_set_attributes(async_conversation): + """Update async conversation attributes""" + updated = ConversationAttributes( + title=f"{CONVERSATION_PREFIX}_UPDATED", + description="Updated Description", + ) + await async_conversation.set_attributes(updated) + attributes = await async_conversation.get_attributes() + assert attributes.title == f"{CONVERSATION_PREFIX}_UPDATED" + assert attributes.description == "Updated Description" + + +@pytest.mark.anyio +async def test_1506_set_attributes_with_none(async_conversation): + """Setting empty attributes raises AttributeError""" + with pytest.raises(AttributeError): + await async_conversation.set_attributes(None) + + +@pytest.mark.anyio +async def test_1507_delete_conversation(async_conversation_factory): + """Delete async conversation and validate removal""" + conversation = await async_conversation_factory( + title=f"{CONVERSATION_PREFIX}_DELETE" + ) + await conversation.delete(force=True) + with pytest.raises(select_ai.errors.ConversationNotFoundError): + await conversation.get_attributes() + + +@pytest.mark.anyio +async def test_1508_delete_twice(async_conversation_factory): + """Deleting an already deleted async conversation raises DatabaseError""" + conversation = await async_conversation_factory( + title=f"{CONVERSATION_PREFIX}_DELETE_TWICE" + ) + await conversation.delete(force=True) + with pytest.raises(DatabaseError): + await conversation.delete() + + +@pytest.mark.anyio +async def test_1509_list_contains_created_conversation(async_conversation): + """Async conversation list contains the created conversation""" + ids = {item.conversation_id async for item in AsyncConversation.list()} + assert async_conversation.conversation_id in ids + + +@pytest.mark.anyio +async def test_1510_multiple_conversations_have_unique_ids( + async_conversation_factory, +): + """Multiple async conversations produce unique identifiers""" + titles = [ + f"{CONVERSATION_PREFIX}_AI", + f"{CONVERSATION_PREFIX}_DB", + f"{CONVERSATION_PREFIX}_MATH", + ] + conversations = [ + await async_conversation_factory(title=title) for title in titles + ] + ids = {conversation.conversation_id for conversation in conversations} + assert len(ids) == len(titles) + + +@pytest.mark.anyio +async def test_1511_create_with_long_values(): + """Creating async conversation with overly long values fails""" + conversation = AsyncConversation( + attributes=ConversationAttributes( + title="A" * 255, + description="B" * 1000, + ) + ) + with pytest.raises(Exception): + await conversation.create() + + +@pytest.mark.anyio +async def test_1512_set_attributes_with_invalid_id(): + """Updating async conversation with invalid id raises DatabaseError""" + conversation = AsyncConversation(conversation_id="fake_id") + with pytest.raises(DatabaseError): + await conversation.set_attributes( + ConversationAttributes(title="Invalid") + ) + + +@pytest.mark.anyio +async def test_1513_delete_with_invalid_id(): + """Deleting async conversation with invalid id raises DatabaseError""" + conversation = AsyncConversation(conversation_id="fake_id") + with pytest.raises(DatabaseError): + await conversation.delete() + + +@pytest.mark.anyio +async def test_1514_get_attributes_with_invalid_id(): + """Fetching attributes for invalid async conversation raises ConversationNotFound""" + conversation = AsyncConversation(conversation_id="invalid") + with pytest.raises(select_ai.errors.ConversationNotFoundError): + await conversation.get_attributes() + + +@pytest.mark.anyio +async def test_1515_get_attributes_for_deleted_conversation( + async_conversation_factory, +): + """Fetching attributes after deletion raises ConversationNotFound""" + conversation = await async_conversation_factory( + title=f"{CONVERSATION_PREFIX}_TO_DELETE" + ) + await conversation.delete(force=True) + with pytest.raises(select_ai.errors.ConversationNotFoundError): + await conversation.get_attributes() + + +@pytest.mark.anyio +async def test_1516_list_contains_new_conversation(async_conversation_factory): + """List reflects newly created async conversation""" + conversation = await async_conversation_factory( + title=f"{CONVERSATION_PREFIX}_LIST" + ) + listed = [item async for item in AsyncConversation.list()] + assert any( + item.conversation_id == conversation.conversation_id for item in listed + ) + + +@pytest.mark.anyio +async def test_1517_list_returns_async_conversation_instances(): + """List returns AsyncConversation objects""" + listed = [item async for item in AsyncConversation.list()] + assert all(isinstance(item, AsyncConversation) for item in listed) + + +@pytest.mark.anyio +async def test_1518_get_attributes_without_description( + async_conversation_factory, +): + """Async conversation created without description has None description""" + conversation = await async_conversation_factory( + title=f"{CONVERSATION_PREFIX}_NO_DESC" + ) + attributes = await conversation.get_attributes() + assert attributes.title == f"{CONVERSATION_PREFIX}_NO_DESC" + assert attributes.description is None + + +@pytest.mark.anyio +async def test_1519_create_with_description_none(async_conversation_factory): + """Explicitly setting description to None is allowed""" + conversation = await async_conversation_factory( + title=f"{CONVERSATION_PREFIX}_NONE_DESC", + description=None, + ) + attributes = await conversation.get_attributes() + assert attributes.title == f"{CONVERSATION_PREFIX}_NONE_DESC" + assert attributes.description is None diff --git a/tests/profiles/test_1600_generate.py b/tests/profiles/test_1600_generate.py new file mode 100644 index 0000000..831626a --- /dev/null +++ b/tests/profiles/test_1600_generate.py @@ -0,0 +1,282 @@ +# ----------------------------------------------------------------------------- +# Copyright (c) 2025, Oracle and/or its affiliates. +# +# Licensed under the Universal Permissive License v 1.0 as shown at +# http://oss.oracle.com/licenses/upl. +# ----------------------------------------------------------------------------- + +""" +1600 - Profile generate API tests +""" + +import json +import uuid + +import oracledb +import pandas as pd +import pytest +import select_ai +from select_ai import ( + Conversation, + ConversationAttributes, + Profile, + ProfileAttributes, +) +from select_ai.profile import Action + +PROFILE_PREFIX = f"PYSAI_1600_{uuid.uuid4().hex.upper()}" + +PROMPTS = [ + "What is a database?", + "How many gymnasts in database?", + "How many people are there in the database?", +] + + +@pytest.fixture(scope="module") +def generate_provider(oci_compartment_id): + return select_ai.OCIGenAIProvider( + oci_compartment_id=oci_compartment_id, + oci_apiformat="GENERIC", + ) + + +@pytest.fixture(scope="module") +def generate_profile_attributes(test_env, oci_credential, generate_provider): + return ProfileAttributes( + credential_name=oci_credential["credential_name"], + object_list=[ + {"owner": test_env.test_user, "name": "people"}, + {"owner": test_env.test_user, "name": "gymnast"}, + ], + provider=generate_provider, + ) + + +@pytest.fixture(scope="module") +def generate_profile(generate_profile_attributes): + profile = Profile( + profile_name=f"{PROFILE_PREFIX}_POSITIVE", + attributes=generate_profile_attributes, + description="Generate Calls Test Profile", + replace=True, + ) + profile.set_attribute( + attribute_name="model", + attribute_value="meta.llama-3.1-405b-instruct", + ) + yield profile + profile.delete(force=True) + + +@pytest.fixture +def negative_profile(test_env, oci_credential, generate_provider): + profile_name = f"{PROFILE_PREFIX}_NEG_{uuid.uuid4().hex.upper()}" + attributes = ProfileAttributes( + credential_name=oci_credential["credential_name"], + provider=generate_provider, + ) + profile = Profile( + profile_name=profile_name, + attributes=attributes, + description="Generate Calls Negative Test Profile", + replace=True, + ) + profile.set_attribute( + attribute_name="object_list", + attribute_value=json.dumps( + [ + {"owner": test_env.test_user, "name": "people"}, + {"owner": test_env.test_user, "name": "gymnast"}, + ] + ), + ) + profile.set_attribute( + attribute_name="model", + attribute_value="meta.llama-3.1-405b-instruct", + ) + yield profile + profile.delete(force=True) + + +def test_1600_action_enum_members(): + """Validate Action enum exposes expected members""" + for member in [ + "RUNSQL", + "SHOWSQL", + "EXPLAINSQL", + "NARRATE", + "CHAT", + "SHOWPROMPT", + ]: + assert hasattr(Action, member) + + +def test_1601_action_enum_values(): + """Validate Action enum values""" + assert Action.RUNSQL.value == "runsql" + assert Action.SHOWSQL.value == "showsql" + assert Action.EXPLAINSQL.value == "explainsql" + assert Action.NARRATE.value == "narrate" + assert Action.CHAT.value == "chat" + + +def test_1602_action_from_string(): + """Validate Action enum construction from string""" + assert Action("runsql") is Action.RUNSQL + assert Action("chat") is Action.CHAT + assert Action("explainsql") is Action.EXPLAINSQL + assert Action("narrate") is Action.NARRATE + assert Action("showsql") is Action.SHOWSQL + + +def test_1603_action_invalid_string(): + """Invalid enum string raises ValueError""" + with pytest.raises(ValueError): + Action("invalid_action") + + +def test_1604_show_sql(generate_profile): + """show_sql returns SQL text""" + for prompt in PROMPTS: + show_sql = generate_profile.show_sql(prompt=prompt) + assert isinstance(show_sql, str) + assert "SELECT" in show_sql.upper() + + +def test_1605_show_prompt(generate_profile): + """show_prompt returns prompt text""" + for prompt in PROMPTS: + show_prompt = generate_profile.show_prompt(prompt=prompt) + assert isinstance(show_prompt, str) + assert len(show_prompt) > 0 + + +def test_1606_run_sql(generate_profile): + """run_sql returns DataFrame""" + df = generate_profile.run_sql(prompt=PROMPTS[1]) + assert isinstance(df, pd.DataFrame) + assert len(df.columns) > 0 + + +def test_1607_chat(generate_profile): + """chat returns text response""" + response = generate_profile.chat(prompt="What is OCI ?") + assert isinstance(response, str) + assert len(response) > 0 + + +def test_1608_narrate(generate_profile): + """narrate returns narrative text""" + for prompt in PROMPTS: + narration = generate_profile.narrate(prompt=prompt) + assert isinstance(narration, str) + assert len(narration) > 0 + + +def test_1609_chat_session(generate_profile): + """chat_session provides a session context""" + conversation = Conversation(attributes=ConversationAttributes()) + with generate_profile.chat_session( + conversation=conversation, delete=True + ) as session: + assert session is not None + + +def test_1610_explain_sql(generate_profile): + """explain_sql returns explanation text""" + for prompt in PROMPTS: + explain_sql = generate_profile.explain_sql(prompt=prompt) + assert isinstance(explain_sql, str) + assert len(explain_sql) > 0 + + +def test_1611_generate_runsql(generate_profile): + """generate with RUNSQL returns DataFrame""" + df = generate_profile.generate(prompt=PROMPTS[1], action=Action.RUNSQL) + assert isinstance(df, pd.DataFrame) + + +def test_1612_generate_showsql(generate_profile): + """generate with SHOWSQL returns SQL""" + sql = generate_profile.generate(prompt=PROMPTS[1], action=Action.SHOWSQL) + assert isinstance(sql, str) + assert "SELECT" in sql.upper() + + +def test_1613_generate_chat(generate_profile): + """generate with CHAT returns response""" + chat_resp = generate_profile.generate( + prompt="Tell me about OCI", action=Action.CHAT + ) + assert isinstance(chat_resp, str) + assert len(chat_resp) > 0 + + +def test_1614_generate_narrate(generate_profile): + """generate with NARRATE returns response""" + narrate_resp = generate_profile.generate( + prompt=PROMPTS[1], action=Action.NARRATE + ) + assert isinstance(narrate_resp, str) + assert len(narrate_resp) > 0 + + +def test_1615_generate_explainsql(generate_profile): + """generate with EXPLAINSQL returns explanation""" + for prompt in PROMPTS: + explain_sql = generate_profile.generate( + prompt=prompt, action=Action.EXPLAINSQL + ) + assert isinstance(explain_sql, str) + assert len(explain_sql) > 0 + + +def test_1616_empty_prompt_raises_value_error(negative_profile): + """Empty prompts raise ValueError for profile methods""" + with pytest.raises(ValueError): + negative_profile.chat(prompt="") + with pytest.raises(ValueError): + negative_profile.narrate(prompt="") + with pytest.raises(ValueError): + negative_profile.show_sql(prompt="") + with pytest.raises(ValueError): + negative_profile.show_prompt(prompt="") + with pytest.raises(ValueError): + negative_profile.run_sql(prompt="") + with pytest.raises(ValueError): + negative_profile.explain_sql(prompt="") + + +def test_1617_none_prompt_raises_value_error(negative_profile): + """None prompts raise ValueError for profile methods""" + with pytest.raises(ValueError): + negative_profile.chat(prompt=None) + with pytest.raises(ValueError): + negative_profile.narrate(prompt=None) + with pytest.raises(ValueError): + negative_profile.show_sql(prompt=None) + with pytest.raises(ValueError): + negative_profile.show_prompt(prompt=None) + with pytest.raises(ValueError): + negative_profile.run_sql(prompt=None) + with pytest.raises(ValueError): + negative_profile.explain_sql(prompt=None) + + +# def test_1618_run_sql_with_ambiguous_prompt(negative_profile): +# """Ambiguous prompt raises DatabaseError for run_sql""" +# with pytest.raises(oracledb.DatabaseError): +# negative_profile.run_sql(prompt="delete data from user") + + +# def test_1619_run_sql_with_invalid_object_list(negative_profile): +# """run_sql with non existent table raises DatabaseError""" +# negative_profile.set_attribute( +# attribute_name="object_list", +# attribute_value=json.dumps( +# [{"owner": test_env.test_user, "name": "non_existent_table"}] +# ), +# ) +# with pytest.raises(oracledb.DatabaseError): +# negative_profile.run_sql(prompt="How many entries in the table") diff --git a/tests/profiles/test_1700_generate_async.py b/tests/profiles/test_1700_generate_async.py new file mode 100644 index 0000000..7efa337 --- /dev/null +++ b/tests/profiles/test_1700_generate_async.py @@ -0,0 +1,310 @@ +# ----------------------------------------------------------------------------- +# Copyright (c) 2025, Oracle and/or its affiliates. +# +# Licensed under the Universal Permissive License v 1.0 as shown at +# http://oss.oracle.com/licenses/upl. +# ----------------------------------------------------------------------------- + +""" +1700 - AsyncProfile generate API tests +""" + +import json +import uuid + +import oracledb +import pandas as pd +import pytest +import select_ai +from select_ai import ( + AsyncConversation, + AsyncProfile, + ConversationAttributes, + ProfileAttributes, +) +from select_ai.profile import Action + +PROFILE_PREFIX = f"PYSAI_1700_{uuid.uuid4().hex.upper()}" + +PROMPTS = [ + "What is a database?", + "How many gymnasts in database?", + "How many people are in the database?", +] + + +@pytest.fixture(scope="module") +def async_generate_provider(oci_compartment_id): + return select_ai.OCIGenAIProvider( + oci_compartment_id=oci_compartment_id, + oci_apiformat="GENERIC", + ) + + +@pytest.fixture(scope="module") +def async_generate_profile_attributes( + oci_credential, async_generate_provider, test_env +): + return ProfileAttributes( + credential_name=oci_credential["credential_name"], + object_list=[ + {"owner": test_env.test_user, "name": "people"}, + {"owner": test_env.test_user, "name": "gymnast"}, + ], + provider=async_generate_provider, + ) + + +@pytest.fixture(scope="module") +async def async_generate_profile(async_generate_profile_attributes): + profile = await AsyncProfile( + profile_name=f"{PROFILE_PREFIX}_POSITIVE", + attributes=async_generate_profile_attributes, + description="Async generate calls test profile", + replace=True, + ) + await profile.set_attribute( + attribute_name="model", + attribute_value="meta.llama-3.1-405b-instruct", + ) + yield profile + await profile.delete(force=True) + + +@pytest.fixture +async def async_negative_profile( + oci_credential, async_generate_provider, test_env +): + profile_name = f"{PROFILE_PREFIX}_NEG_{uuid.uuid4().hex.upper()}" + attributes = ProfileAttributes( + credential_name=oci_credential["credential_name"], + provider=async_generate_provider, + ) + profile = await AsyncProfile( + profile_name=profile_name, + attributes=attributes, + description="Async generate calls negative test profile", + replace=True, + ) + await profile.set_attribute( + attribute_name="object_list", + attribute_value=json.dumps( + [ + {"owner": test_env.test_user, "name": "people"}, + {"owner": test_env.test_user, "name": "gymnast"}, + ] + ), + ) + await profile.set_attribute( + attribute_name="model", + attribute_value="meta.llama-3.1-405b-instruct", + ) + yield profile + await profile.delete(force=True) + + +@pytest.mark.anyio +async def test_1700_action_enum_members(): + """Validate Action enum exposes expected members""" + for member in [ + "RUNSQL", + "SHOWSQL", + "EXPLAINSQL", + "NARRATE", + "CHAT", + "SHOWPROMPT", + ]: + assert hasattr(Action, member) + + +@pytest.mark.anyio +async def test_1701_action_enum_values(): + """Validate Action enum values""" + assert Action.RUNSQL.value == "runsql" + assert Action.SHOWSQL.value == "showsql" + assert Action.EXPLAINSQL.value == "explainsql" + assert Action.NARRATE.value == "narrate" + assert Action.CHAT.value == "chat" + + +@pytest.mark.anyio +async def test_1702_action_from_string(): + """Validate Action enum construction from string""" + assert Action("runsql") is Action.RUNSQL + assert Action("chat") is Action.CHAT + assert Action("explainsql") is Action.EXPLAINSQL + assert Action("narrate") is Action.NARRATE + assert Action("showsql") is Action.SHOWSQL + + +@pytest.mark.anyio +async def test_1703_action_invalid_string(): + """Invalid enum string raises ValueError""" + with pytest.raises(ValueError): + Action("invalid_action") + + +@pytest.mark.anyio +async def test_1704_show_sql(async_generate_profile): + """show_sql returns SQL text""" + for prompt in PROMPTS: + show_sql = await async_generate_profile.show_sql(prompt=prompt) + assert isinstance(show_sql, str) + assert "SELECT" in show_sql.upper() + + +@pytest.mark.anyio +async def test_1705_show_prompt(async_generate_profile): + """show_prompt returns prompt text""" + for prompt in PROMPTS: + show_prompt = await async_generate_profile.show_prompt(prompt=prompt) + assert isinstance(show_prompt, str) + assert len(show_prompt) > 0 + + +@pytest.mark.anyio +async def test_1706_run_sql(async_generate_profile): + """run_sql returns DataFrame""" + dataframe = await async_generate_profile.run_sql(prompt=PROMPTS[1]) + assert isinstance(dataframe, pd.DataFrame) + assert len(dataframe.columns) > 0 + + +@pytest.mark.anyio +async def test_1707_chat(async_generate_profile): + """chat returns text response""" + response = await async_generate_profile.chat(prompt="What is OCI ?") + assert isinstance(response, str) + assert len(response) > 0 + + +@pytest.mark.anyio +async def test_1708_narrate(async_generate_profile): + """narrate returns narrative text""" + for prompt in PROMPTS: + narration = await async_generate_profile.narrate(prompt=prompt) + assert isinstance(narration, str) + assert len(narration) > 0 + + +@pytest.mark.anyio +async def test_1709_chat_session(async_generate_profile): + """chat_session provides a session context""" + conversation = AsyncConversation(attributes=ConversationAttributes()) + async with async_generate_profile.chat_session( + conversation=conversation, delete=True + ) as session: + assert session is not None + + +@pytest.mark.anyio +async def test_1710_explain_sql(async_generate_profile): + """explain_sql returns explanation text""" + for prompt in PROMPTS: + explain_sql = await async_generate_profile.explain_sql(prompt=prompt) + assert isinstance(explain_sql, str) + assert len(explain_sql) > 0 + + +@pytest.mark.anyio +async def test_1711_generate_runsql(async_generate_profile): + """generate with RUNSQL returns DataFrame""" + dataframe = await async_generate_profile.generate( + prompt=PROMPTS[1], action=Action.RUNSQL + ) + assert isinstance(dataframe, pd.DataFrame) + + +@pytest.mark.anyio +async def test_1712_generate_showsql(async_generate_profile): + """generate with SHOWSQL returns SQL""" + sql = await async_generate_profile.generate( + prompt=PROMPTS[1], action=Action.SHOWSQL + ) + assert isinstance(sql, str) + assert "SELECT" in sql.upper() + + +@pytest.mark.anyio +async def test_1713_generate_chat(async_generate_profile): + """generate with CHAT returns response""" + chat_response = await async_generate_profile.generate( + prompt="Tell me about OCI", action=Action.CHAT + ) + assert isinstance(chat_response, str) + assert len(chat_response) > 0 + + +@pytest.mark.anyio +async def test_1714_generate_narrate(async_generate_profile): + """generate with NARRATE returns response""" + narrate_response = await async_generate_profile.generate( + prompt=PROMPTS[1], action=Action.NARRATE + ) + assert isinstance(narrate_response, str) + assert len(narrate_response) > 0 + + +@pytest.mark.anyio +async def test_1715_generate_explainsql(async_generate_profile): + """generate with EXPLAINSQL returns explanation""" + for prompt in PROMPTS: + explain_sql = await async_generate_profile.generate( + prompt=prompt, action=Action.EXPLAINSQL + ) + assert isinstance(explain_sql, str) + assert len(explain_sql) > 0 + + +@pytest.mark.anyio +async def test_1716_empty_prompt_raises_value_error(async_negative_profile): + """Empty prompts raise ValueError for async profile methods""" + with pytest.raises(ValueError): + await async_negative_profile.chat(prompt="") + with pytest.raises(ValueError): + await async_negative_profile.narrate(prompt="") + with pytest.raises(ValueError): + await async_negative_profile.show_sql(prompt="") + with pytest.raises(ValueError): + await async_negative_profile.show_prompt(prompt="") + with pytest.raises(ValueError): + await async_negative_profile.run_sql(prompt="") + with pytest.raises(ValueError): + await async_negative_profile.explain_sql(prompt="") + + +@pytest.mark.anyio +async def test_1717_none_prompt_raises_value_error(async_negative_profile): + """None prompts raise ValueError for async profile methods""" + with pytest.raises(ValueError): + await async_negative_profile.chat(prompt=None) + with pytest.raises(ValueError): + await async_negative_profile.narrate(prompt=None) + with pytest.raises(ValueError): + await async_negative_profile.show_sql(prompt=None) + with pytest.raises(ValueError): + await async_negative_profile.show_prompt(prompt=None) + with pytest.raises(ValueError): + await async_negative_profile.run_sql(prompt=None) + with pytest.raises(ValueError): + await async_negative_profile.explain_sql(prompt=None) + + +# @pytest.mark.anyio +# async def test_1718_run_sql_with_ambiguous_prompt(async_negative_profile): +# """Ambiguous prompt raises DatabaseError for run_sql""" +# with pytest.raises(oracledb.DatabaseError): +# await async_negative_profile.run_sql(prompt="select from user") + + +# @pytest.mark.anyio +# async def test_1719_run_sql_with_invalid_object_list(async_negative_profile, test_env): +# """run_sql with non existent table raises DatabaseError""" +# await async_negative_profile.set_attribute( +# attribute_name="object_list", +# attribute_value=json.dumps( +# [{"owner": test_env.test_user, "name": "non_existent_table"}] +# ), +# ) +# with pytest.raises(oracledb.DatabaseError): +# await async_negative_profile.run_sql(prompt="How many entries in the table") diff --git a/tests/profiles/test_1800_chat_session.py b/tests/profiles/test_1800_chat_session.py new file mode 100644 index 0000000..52b90b6 --- /dev/null +++ b/tests/profiles/test_1800_chat_session.py @@ -0,0 +1,223 @@ +# ----------------------------------------------------------------------------- +# Copyright (c) 2025, Oracle and/or its affiliates. +# +# Licensed under the Universal Permissive License v 1.0 as shown at +# http://oss.oracle.com/licenses/upl. +# ----------------------------------------------------------------------------- + +""" +1800 - Chat session API tests +""" + +import uuid + +import pytest +import select_ai +from select_ai import ( + Conversation, + ConversationAttributes, + Profile, + ProfileAttributes, +) + +PROFILE_PREFIX = f"PYSAI_1800_{uuid.uuid4().hex.upper()}" + +CATEGORY_PROMPTS = { + "database": [ + ("What is a database?", "database"), + ("Explain the difference between SQL and NoSQL.", "sql"), + ("Give me an example of a SQL SELECT query.", "select"), + ("How do transactions ensure consistency?", "transaction"), + ("What are indexes and why are they used?", "index"), + ], + "cloud": [ + ("What is cloud computing?", "cloud"), + ("Explain IaaS, PaaS, and SaaS briefly.", "iaas"), + ("What is the benefit of auto-scaling?", "scaling"), + ("How do cloud regions and availability zones differ?", "region"), + ("What is serverless computing?", "serverless"), + ], + "ai": [ + ("What is artificial intelligence?", "intelligence"), + ("Explain supervised vs unsupervised learning.", "supervised"), + ("What are neural networks?", "neural"), + ("How does reinforcement learning work?", "reinforcement"), + ("Give me a real-world use case of AI.", "ai"), + ], + "physics": [ + ("What is Newton's first law?", "newton"), + ("Explain the concept of gravity.", "gravity"), + ("How does friction affect motion?", "friction"), + ("What is the difference between speed and velocity?", "velocity"), + ("Explain kinetic and potential energy with examples.", "energy"), + ], + "general": [ + ("What is the capital of Japan?", "tokyo"), + ("Tell me a fun fact about space.", "space"), + ("Who invented the telephone?", "telephone"), + ("What is the fastest land animal?", "cheetah"), + ("Explain why the sky looks blue.", "sky"), + ], +} + + +@pytest.fixture(scope="module") +def chat_session_provider(oci_compartment_id): + return select_ai.OCIGenAIProvider( + oci_compartment_id=oci_compartment_id, + oci_apiformat="GENERIC", + ) + + +@pytest.fixture(scope="module") +def chat_session_profile(oci_credential, chat_session_provider): + profile = Profile( + profile_name=f"{PROFILE_PREFIX}_PROFILE", + attributes=ProfileAttributes( + credential_name=oci_credential["credential_name"], + object_list=[ + {"owner": "ADMIN", "name": "people"}, + {"owner": "ADMIN", "name": "gymnast"}, + ], + provider=chat_session_provider, + ), + description="Chat session test profile", + replace=True, + ) + profile.set_attribute( + attribute_name="model", + attribute_value="meta.llama-3.1-405b-instruct", + ) + yield profile + profile.delete(force=True) + + +@pytest.fixture +def conversation_factory(): + conversations = [] + + def _create(**kwargs): + conversation = Conversation( + attributes=ConversationAttributes(**kwargs) + ) + conversation.create() + conversations.append(conversation) + return conversation + + yield _create + + for conversation in conversations: + conversation.delete(force=True) + + +def _assert_keywords(session, prompts): + for prompt, keyword in prompts: + response = session.chat(prompt=prompt) + assert keyword.lower() in response.lower() + + +def test_1800_database_chat_session( + chat_session_profile, conversation_factory +): + """Chat session processes database prompts""" + conversation = conversation_factory( + title="Database", + description="LLM's understanding of databases", + ) + with chat_session_profile.chat_session( + conversation=conversation, delete=False + ) as session: + assert session is not None + _assert_keywords(session, CATEGORY_PROMPTS["database"]) + + +def test_1801_physics_chat_session_delete_true( + chat_session_profile, conversation_factory +): + """Chat session deletes conversation when delete=True""" + conversation = conversation_factory(title="Physics") + with chat_session_profile.chat_session( + conversation=conversation, delete=True + ) as session: + _assert_keywords(session, CATEGORY_PROMPTS["physics"]) + with pytest.raises(Exception): + conversation.delete() + + +def test_1802_multiple_sessions_same_conversation( + chat_session_profile, conversation_factory +): + """Same conversation supports multiple chat sessions""" + conversation = conversation_factory( + title="Cloud Two Session", + description="LLM's understanding of cloud using multiple chat sessions.", + ) + with chat_session_profile.chat_session( + conversation=conversation + ) as session_one: + _assert_keywords(session_one, CATEGORY_PROMPTS["cloud"][:3]) + with chat_session_profile.chat_session( + conversation=conversation + ) as session_two: + _assert_keywords(session_two, CATEGORY_PROMPTS["cloud"][3:]) + + +def test_1803_many_sessions_same_conversation( + chat_session_profile, conversation_factory +): + """Conversation reused across several sessions""" + conversation = conversation_factory( + title="Multi Session", + description="LLM's understanding of cloud using multiple chat sessions.", + ) + with chat_session_profile.chat_session( + conversation=conversation, delete=False + ) as session_one: + _assert_keywords(session_one, CATEGORY_PROMPTS["cloud"][:3]) + with chat_session_profile.chat_session( + conversation=conversation, delete=False + ) as session_two: + _assert_keywords(session_two, CATEGORY_PROMPTS["cloud"][3:]) + with chat_session_profile.chat_session( + conversation=conversation, delete=False + ) as session_three: + _assert_keywords(session_three, CATEGORY_PROMPTS["ai"][:3]) + with chat_session_profile.chat_session( + conversation=conversation, delete=False + ) as session_four: + _assert_keywords(session_four, CATEGORY_PROMPTS["ai"][3:]) + with chat_session_profile.chat_session( + conversation=conversation, delete=False + ) as session_five: + _assert_keywords(session_five, CATEGORY_PROMPTS["general"]) + + +def test_1804_special_characters(chat_session_profile, conversation_factory): + """Chat session handles special characters""" + conversation = conversation_factory( + title="Special Character Test ✨😊你", + description="♥️✨你好", + ) + with chat_session_profile.chat_session( + conversation=conversation, delete=True + ) as session: + response = session.chat( + prompt="Tell me something with lot of emojis and special characters 🚀🔥" + ) + assert isinstance(response, str) + assert "error" not in response.lower() + + +def test_1805_invalid_conversation_object(chat_session_profile): + """Passing non conversation object raises error""" + with pytest.raises(Exception): + with chat_session_profile.chat_session(conversation="fake-object"): + pass + + +# def test_1806_missing_conversation_attributes(chat_session_profile): +# """Conversation without attributes raises error""" +# conversation = Conversation(attributes=None) +# with pytest.raises(Exception): +# with chat_session_profile.chat_session(conversation=conversation): +# _assert_keywords(chat_session_profile, [("Hello World", "hello")]) diff --git a/tests/profiles/test_1900_chat_session_async.py b/tests/profiles/test_1900_chat_session_async.py new file mode 100644 index 0000000..73dae19 --- /dev/null +++ b/tests/profiles/test_1900_chat_session_async.py @@ -0,0 +1,240 @@ +# ----------------------------------------------------------------------------- +# Copyright (c) 2025, Oracle and/or its affiliates. +# +# Licensed under the Universal Permissive License v 1.0 as shown at +# http://oss.oracle.com/licenses/upl. +# ----------------------------------------------------------------------------- + +""" +1900 - Async chat session API tests +""" + +import uuid + +import pytest +import select_ai +from select_ai import ( + AsyncConversation, + AsyncProfile, + ConversationAttributes, + ProfileAttributes, +) + +PROFILE_PREFIX = f"PYSAI_1900_{uuid.uuid4().hex.upper()}" + +CATEGORY_PROMPTS = { + "database": [ + ("What is a database?", "database"), + ("Explain the difference between SQL and NoSQL.", "sql"), + ("Give me an example of a SQL SELECT query.", "select"), + ("How do transactions ensure consistency?", "transaction"), + ("What are indexes and why are they used?", "index"), + ], + "cloud": [ + ("What is cloud computing?", "cloud"), + ("Explain IaaS, PaaS, and SaaS briefly.", "iaas"), + ("What is the benefit of auto-scaling?", "scaling"), + ("How do cloud regions and availability zones differ?", "region"), + ("What is serverless computing?", "serverless"), + ], + "ai": [ + ("What is artificial intelligence?", "intelligence"), + ("Explain supervised vs unsupervised learning.", "supervised"), + ("What are neural networks?", "neural"), + ("How does reinforcement learning work?", "reinforcement"), + ("Give me a real-world use case of AI.", "ai"), + ], + "physics": [ + ("What is Newton's first law?", "newton"), + ("Explain the concept of gravity.", "gravity"), + ("How does friction affect motion?", "friction"), + ("What is the difference between speed and velocity?", "velocity"), + ("Explain kinetic and potential energy with examples.", "energy"), + ], + "general": [ + ("What is the capital of Japan?", "tokyo"), + ("Tell me a fun fact about space.", "space"), + ("Who invented the telephone?", "telephone"), + ("What is the fastest land animal?", "cheetah"), + ("Explain why the sky looks blue.", "sky"), + ], +} + + +@pytest.fixture(scope="module") +def async_chat_session_provider(oci_compartment_id): + return select_ai.OCIGenAIProvider( + oci_compartment_id=oci_compartment_id, + oci_apiformat="GENERIC", + ) + + +@pytest.fixture(scope="module") +async def async_chat_session_profile( + oci_credential, async_chat_session_provider +): + profile = await AsyncProfile( + profile_name=f"{PROFILE_PREFIX}_PROFILE", + attributes=ProfileAttributes( + credential_name=oci_credential["credential_name"], + object_list=[ + {"owner": "ADMIN", "name": "people"}, + {"owner": "ADMIN", "name": "gymnast"}, + ], + provider=async_chat_session_provider, + ), + description="Async chat session test profile", + replace=True, + ) + await profile.set_attribute( + attribute_name="model", + attribute_value="meta.llama-3.1-405b-instruct", + ) + yield profile + await profile.delete(force=True) + + +@pytest.fixture +async def async_conversation_factory(): + conversations = [] + + async def _create(**kwargs): + conversation = AsyncConversation( + attributes=ConversationAttributes(**kwargs) + ) + await conversation.create() + conversations.append(conversation) + return conversation + + yield _create + + for conversation in conversations: + await conversation.delete(force=True) + + +async def _assert_keywords(session, prompts): + for prompt, keyword in prompts: + response = await session.chat(prompt=prompt) + assert keyword.lower() in response.lower() + + +@pytest.mark.anyio +async def test_1900_database_chat_session( + async_chat_session_profile, async_conversation_factory +): + """Async chat session processes database prompts""" + conversation = await async_conversation_factory( + title="Database", + description="LLM's understanding of databases", + ) + async with async_chat_session_profile.chat_session( + conversation=conversation, delete=False + ) as session: + assert session is not None + await _assert_keywords(session, CATEGORY_PROMPTS["database"]) + + +@pytest.mark.anyio +async def test_1901_physics_chat_session_delete_true( + async_chat_session_profile, async_conversation_factory +): + """Async chat session deletes conversation when delete=True""" + conversation = await async_conversation_factory(title="Physics") + async with async_chat_session_profile.chat_session( + conversation=conversation, delete=True + ) as session: + await _assert_keywords(session, CATEGORY_PROMPTS["physics"]) + with pytest.raises(Exception): + await conversation.delete() + + +@pytest.mark.anyio +async def test_1902_multiple_sessions_same_conversation( + async_chat_session_profile, async_conversation_factory +): + """Same async conversation supports multiple chat sessions""" + conversation = await async_conversation_factory( + title="Cloud Two Session", + description="LLM's understanding of cloud using multiple chat sessions.", + ) + async with async_chat_session_profile.chat_session( + conversation=conversation + ) as session_one: + await _assert_keywords(session_one, CATEGORY_PROMPTS["cloud"][:3]) + async with async_chat_session_profile.chat_session( + conversation=conversation + ) as session_two: + await _assert_keywords(session_two, CATEGORY_PROMPTS["cloud"][3:]) + + +@pytest.mark.anyio +async def test_1903_many_sessions_same_conversation( + async_chat_session_profile, async_conversation_factory +): + """Conversation reused across several async sessions""" + conversation = await async_conversation_factory( + title="Multi Session", + description="LLM's understanding of cloud using multiple chat sessions.", + ) + async with async_chat_session_profile.chat_session( + conversation=conversation, delete=False + ) as session_one: + await _assert_keywords(session_one, CATEGORY_PROMPTS["cloud"][:3]) + async with async_chat_session_profile.chat_session( + conversation=conversation, delete=False + ) as session_two: + await _assert_keywords(session_two, CATEGORY_PROMPTS["cloud"][3:]) + async with async_chat_session_profile.chat_session( + conversation=conversation, delete=False + ) as session_three: + await _assert_keywords(session_three, CATEGORY_PROMPTS["ai"][:3]) + async with async_chat_session_profile.chat_session( + conversation=conversation, delete=False + ) as session_four: + await _assert_keywords(session_four, CATEGORY_PROMPTS["ai"][3:]) + async with async_chat_session_profile.chat_session( + conversation=conversation, delete=False + ) as session_five: + await _assert_keywords(session_five, CATEGORY_PROMPTS["general"]) + + +@pytest.mark.anyio +async def test_1904_special_characters( + async_chat_session_profile, async_conversation_factory +): + """Async chat session handles special characters""" + conversation = await async_conversation_factory( + title="Special Character Test ✨😊你", + description="♥️✨你好", + ) + async with async_chat_session_profile.chat_session( + conversation=conversation, delete=True + ) as session: + response = await session.chat( + prompt="Tell me something with lot of emojis and special characters 🚀🔥" + ) + assert isinstance(response, str) + assert "error" not in response.lower() + + +@pytest.mark.anyio +async def test_1905_invalid_conversation_object(async_chat_session_profile): + """Passing non conversation object raises error""" + with pytest.raises(Exception): + async with async_chat_session_profile.chat_session( + conversation="fake-object" + ): + pass + + +@pytest.mark.anyio +async def test_1906_missing_conversation_attributes( + async_chat_session_profile, +): + """Conversation without attributes raises error""" + conversation = AsyncConversation(attributes=None) + with pytest.raises(Exception): + async with async_chat_session_profile.chat_session( + conversation=conversation + ): + await conversation.chat(prompt="Hello World")