diff --git a/src/story_protocol_python_sdk/abi/DisputeModule/DisputeModule_client.py b/src/story_protocol_python_sdk/abi/DisputeModule/DisputeModule_client.py index 58b9476f..985b87e4 100644 --- a/src/story_protocol_python_sdk/abi/DisputeModule/DisputeModule_client.py +++ b/src/story_protocol_python_sdk/abi/DisputeModule/DisputeModule_client.py @@ -71,5 +71,8 @@ def build_tagIfRelatedIpInfringed_transaction( ipIdToTag, infringerDisputeId ).build_transaction(tx_params) + def isIpTagged(self, ipId): + return self.contract.functions.isIpTagged(ipId).call() + def isWhitelistedDisputeTag(self, tag): return self.contract.functions.isWhitelistedDisputeTag(tag).call() diff --git a/src/story_protocol_python_sdk/abi/GroupingModule/GroupingModule_client.py b/src/story_protocol_python_sdk/abi/GroupingModule/GroupingModule_client.py index a3b17c86..4149bbbd 100644 --- a/src/story_protocol_python_sdk/abi/GroupingModule/GroupingModule_client.py +++ b/src/story_protocol_python_sdk/abi/GroupingModule/GroupingModule_client.py @@ -43,6 +43,14 @@ def build_addIp_transaction( groupIpId, ipIds, maxAllowedRewardShare ).build_transaction(tx_params) + def removeIp(self, groupIpId, ipIds): + return self.contract.functions.removeIp(groupIpId, ipIds).transact() + + def build_removeIp_transaction(self, groupIpId, ipIds, tx_params): + return self.contract.functions.removeIp( + groupIpId, ipIds + ).build_transaction(tx_params) + def claimReward(self, groupId, token, ipIds): return self.contract.functions.claimReward(groupId, token, ipIds).transact() diff --git a/src/story_protocol_python_sdk/abi/IPAssetRegistry/IPAssetRegistry_client.py b/src/story_protocol_python_sdk/abi/IPAssetRegistry/IPAssetRegistry_client.py index 9d753c6f..243a5136 100644 --- a/src/story_protocol_python_sdk/abi/IPAssetRegistry/IPAssetRegistry_client.py +++ b/src/story_protocol_python_sdk/abi/IPAssetRegistry/IPAssetRegistry_client.py @@ -51,3 +51,6 @@ def ipId(self, chainId, tokenContract, tokenId): def isRegistered(self, id): return self.contract.functions.isRegistered(id).call() + + def isRegisteredGroup(self, groupId): + return self.contract.functions.isRegisteredGroup(groupId).call() diff --git a/src/story_protocol_python_sdk/abi/LicenseRegistry/LicenseRegistry_client.py b/src/story_protocol_python_sdk/abi/LicenseRegistry/LicenseRegistry_client.py index d8243fec..5075b297 100644 --- a/src/story_protocol_python_sdk/abi/LicenseRegistry/LicenseRegistry_client.py +++ b/src/story_protocol_python_sdk/abi/LicenseRegistry/LicenseRegistry_client.py @@ -49,6 +49,9 @@ def getRoyaltyPercent(self, ipId, licenseTemplate, licenseTermsId): ipId, licenseTemplate, licenseTermsId ).call() + def hasDerivativeIps(self, parentIpId): + return self.contract.functions.hasDerivativeIps(parentIpId).call() + def hasIpAttachedLicenseTerms(self, ipId, licenseTemplate, licenseTermsId): return self.contract.functions.hasIpAttachedLicenseTerms( ipId, licenseTemplate, licenseTermsId diff --git a/src/story_protocol_python_sdk/abi/LicenseToken/LicenseToken_client.py b/src/story_protocol_python_sdk/abi/LicenseToken/LicenseToken_client.py index aa542e69..129bb81b 100644 --- a/src/story_protocol_python_sdk/abi/LicenseToken/LicenseToken_client.py +++ b/src/story_protocol_python_sdk/abi/LicenseToken/LicenseToken_client.py @@ -31,5 +31,10 @@ def __init__(self, web3: Web3): abi = json.load(abi_file) self.contract = self.web3.eth.contract(address=contract_address, abi=abi) + def getTotalTokensByLicensor(self, licensorIpId): + return self.contract.functions.getTotalTokensByLicensor( + licensorIpId + ).call() + def ownerOf(self, tokenId): return self.contract.functions.ownerOf(tokenId).call() diff --git a/src/story_protocol_python_sdk/resources/Group.py b/src/story_protocol_python_sdk/resources/Group.py index 1789f8b3..6bd5aced 100644 --- a/src/story_protocol_python_sdk/resources/Group.py +++ b/src/story_protocol_python_sdk/resources/Group.py @@ -4,6 +4,9 @@ from story_protocol_python_sdk.abi.CoreMetadataModule.CoreMetadataModule_client import ( CoreMetadataModuleClient, ) +from story_protocol_python_sdk.abi.DisputeModule.DisputeModule_client import ( + DisputeModuleClient, +) from story_protocol_python_sdk.abi.GroupingModule.GroupingModule_client import ( GroupingModuleClient, ) @@ -19,6 +22,9 @@ from story_protocol_python_sdk.abi.LicenseRegistry.LicenseRegistry_client import ( LicenseRegistryClient, ) +from story_protocol_python_sdk.abi.LicenseToken.LicenseToken_client import ( + LicenseTokenClient, +) from story_protocol_python_sdk.abi.LicensingModule.LicensingModule_client import ( LicensingModuleClient, ) @@ -58,9 +64,11 @@ def __init__(self, web3: Web3, account, chain_id: int): self.grouping_module_client = GroupingModuleClient(web3) self.grouping_workflows_client = GroupingWorkflowsClient(web3) self.ip_asset_registry_client = IPAssetRegistryClient(web3) + self.dispute_module_client = DisputeModuleClient(web3) self.core_metadata_module_client = CoreMetadataModuleClient(web3) self.licensing_module_client = LicensingModuleClient(web3) self.license_registry_client = LicenseRegistryClient(web3) + self.license_token_client = LicenseTokenClient(web3) self.pi_license_template_client = PILicenseTemplateClient(web3) self.module_registry_client = ModuleRegistryClient(web3) self.sign_util = Sign(web3, self.chain_id, self.account) @@ -453,6 +461,121 @@ def register_group_and_attach_license_and_add_ips( f"Failed to register group and attach license and add IPs: {str(e)}" ) + def add_ips_to_group( + self, + group_ip_id: str, + ip_ids: list, + max_allowed_reward_share_percentage: int = 100, + tx_options: dict | None = None, + ) -> dict: + """ + Add IPs to an existing group IP. + + :param group_ip_id str: The ID of the group IP. + :param ip_ids list: List of IP IDs to add to the group. + :param max_allowed_reward_share_percentage int: [Optional] Maximum allowed reward share percentage (0-100). Default is 100. + :param tx_options dict: [Optional] The transaction options. + :return dict: A dictionary with the transaction hash. + """ + try: + if not self.web3.is_address(group_ip_id): + raise ValueError(f'Group IP ID "{group_ip_id}" is invalid.') + + for ip_id in ip_ids: + if not self.web3.is_address(ip_id): + raise ValueError(f'IP ID "{ip_id}" is invalid.') + + # Contract-level validation: groupId must not be disputed + if self.dispute_module_client.isIpTagged(group_ip_id): + raise ValueError( + f'Disputed group cannot add IP: group "{group_ip_id}" is tagged by dispute module.' + ) + + # Contract-level validation: ipIds must not contain disputed IPs or groups + for ip_id in ip_ids: + if self.dispute_module_client.isIpTagged(ip_id): + raise ValueError( + f'Cannot add disputed IP to group: IP "{ip_id}" is tagged by dispute module.' + ) + if self.ip_asset_registry_client.isRegisteredGroup(ip_id): + raise ValueError( + f'Cannot add group to group: IP "{ip_id}" is a registered group.' + ) + + max_allowed_reward_share = get_revenue_share( + max_allowed_reward_share_percentage, + type=RevShareType.MAX_ALLOWED_REWARD_SHARE, + ) + + response = build_and_send_transaction( + self.web3, + self.account, + self.grouping_module_client.build_addIp_transaction, + group_ip_id, + ip_ids, + max_allowed_reward_share, + tx_options=tx_options, + ) + + result = {"tx_hash": response["tx_hash"]} + if "tx_receipt" in response: + result["tx_receipt"] = response["tx_receipt"] + return result + + except Exception as e: + raise ValueError(f"Failed to add IP to group: {str(e)}") + + def remove_ips_from_group( + self, + group_ip_id: str, + ip_ids: list, + tx_options: dict | None = None, + ) -> dict: + """ + Remove IPs from a group IP. + + :param group_ip_id str: The ID of the group IP. + :param ip_ids list: List of IP IDs to remove from the group. + :param tx_options dict: [Optional] The transaction options. + :return dict: A dictionary with the transaction hash. + """ + try: + if not self.web3.is_address(group_ip_id): + raise ValueError(f'Group IP ID "{group_ip_id}" is invalid.') + + for ip_id in ip_ids: + if not self.web3.is_address(ip_id): + raise ValueError(f'IP ID "{ip_id}" is invalid.') + + # Contract-level validation: group must not have derivative IPs + if self.license_registry_client.hasDerivativeIps(group_ip_id): + raise ValueError( + f'Group frozen: group "{group_ip_id}" has derivative IPs and cannot remove members.' + ) + + # Contract-level validation: group must not have minted license tokens + if self.license_token_client.getTotalTokensByLicensor(group_ip_id) > 0: + raise ValueError( + f'Group frozen: group "{group_ip_id}" has already minted license tokens and cannot remove members.' + ) + + response = build_and_send_transaction( + self.web3, + self.account, + self.grouping_module_client.build_removeIp_transaction, + group_ip_id, + ip_ids, + tx_options=tx_options, + ) + + result = {"tx_hash": response["tx_hash"]} + if "tx_receipt" in response: + result["tx_receipt"] = response["tx_receipt"] + return result + + except Exception as e: + raise ValueError(f"Failed to remove IPs from group: {str(e)}") + def collect_and_distribute_group_royalties( self, group_ip_id: str, @@ -847,3 +970,57 @@ def _parse_tx_royalty_paid_event(self, tx_receipt: dict) -> list: ) return royalties_distributed + + def get_added_ip_to_group_events(self, tx_receipt: dict) -> list: + """ + Parse AddedIpToGroup events from a transaction receipt (for chain-state verification). + + :param tx_receipt dict: The transaction receipt. + :return list: List of dicts with groupId and ipIds (checksum addresses). + """ + events = [] + for log in tx_receipt["logs"]: + try: + event_result = self.grouping_module_client.contract.events.AddedIpToGroup.process_log( + log + ) + args = event_result["args"] + events.append( + { + "groupId": self.web3.to_checksum_address(args["groupId"]), + "ipIds": [ + self.web3.to_checksum_address(addr) + for addr in args["ipIds"] + ], + } + ) + except Exception: + continue + return events + + def get_removed_ip_from_group_events(self, tx_receipt: dict) -> list: + """ + Parse RemovedIpFromGroup events from a transaction receipt (for chain-state verification). + + :param tx_receipt dict: The transaction receipt. + :return list: List of dicts with groupId and ipIds (checksum addresses). + """ + events = [] + for log in tx_receipt["logs"]: + try: + event_result = self.grouping_module_client.contract.events.RemovedIpFromGroup.process_log( + log + ) + args = event_result["args"] + events.append( + { + "groupId": self.web3.to_checksum_address(args["groupId"]), + "ipIds": [ + self.web3.to_checksum_address(addr) + for addr in args["ipIds"] + ], + } + ) + except Exception: + continue + return events diff --git a/src/story_protocol_python_sdk/utils/transaction_utils.py b/src/story_protocol_python_sdk/utils/transaction_utils.py index e203cffd..dfbc74ac 100644 --- a/src/story_protocol_python_sdk/utils/transaction_utils.py +++ b/src/story_protocol_python_sdk/utils/transaction_utils.py @@ -1,6 +1,91 @@ +import time + from web3 import Web3 TRANSACTION_TIMEOUT = 300 +REPLACEMENT_UNDERPRICED_RETRY_DELAY = 5 +REPLACEMENT_GAS_BUMP_RATIO = 1.2 + + +def _validate_nonce(nonce) -> int: + """Validate and return nonce. Raises ValueError if invalid.""" + if not isinstance(nonce, int) or nonce < 0: + raise ValueError( + f"Invalid nonce value: {nonce}. Nonce must be a non-negative integer." + ) + return nonce + + +def _get_transaction_options( + web3: Web3, + account, + tx_options: dict, + *, + nonce_override: int | None = None, + bump_gas: bool = False, +) -> dict: + """ + Build the transaction options dict (from, nonce, value, gas). + Used for both encodedTxDataOnly and send path. + """ + opts = {"from": account.address} + + # Nonce: use override (retry), explicit from tx_options, or fetch from chain + if nonce_override is not None: + opts["nonce"] = nonce_override + elif "nonce" in tx_options: + opts["nonce"] = _validate_nonce(tx_options["nonce"]) + else: + opts["nonce"] = web3.eth.get_transaction_count(account.address) + + if "value" in tx_options: + opts["value"] = tx_options["value"] + + # Gas: bump for replacement, or use tx_options + if bump_gas: + try: + opts["gasPrice"] = int( + web3.eth.gas_price * REPLACEMENT_GAS_BUMP_RATIO + ) + except Exception: + opts["gasPrice"] = web3.to_wei(2, "gwei") + else: + if "gasPrice" in tx_options: + opts["gasPrice"] = web3.to_wei(tx_options["gasPrice"], "gwei") + if "maxFeePerGas" in tx_options: + opts["maxFeePerGas"] = tx_options["maxFeePerGas"] + + return opts + + +def _is_retryable_send_error(exc: Exception) -> bool: + """True if we should retry send (same nonce, higher gas).""" + msg = str(exc).lower() + return ( + "replacement transaction underpriced" in msg + or "nonce too low" in msg + ) + + +def _send_one( + web3: Web3, + account, + client_function, + client_args: tuple, + tx_options: dict, + transaction_options: dict, +) -> dict: + """Build, sign, send one transaction. No retry.""" + transaction = client_function(*client_args, transaction_options) + signed_txn = account.sign_transaction(transaction) + tx_hash = web3.eth.send_raw_transaction(signed_txn.raw_transaction) + + if not tx_options.get("wait_for_receipt", True): + return {"tx_hash": tx_hash.hex()} + + timeout = tx_options.get("timeout", TRANSACTION_TIMEOUT) + tx_receipt = web3.eth.wait_for_transaction_receipt(tx_hash, timeout=timeout) + return {"tx_hash": tx_hash.hex(), "tx_receipt": tx_receipt} def build_and_send_transaction( @@ -13,6 +98,9 @@ def build_and_send_transaction( """ Builds and sends a transaction using the provided client function and arguments. + On "replacement transaction underpriced" or "nonce too low", retries once + after a short delay with the same nonce and higher gas. + :param web3 Web3: An instance of Web3. :param account: The account to use for signing the transaction. :param client_function: The client function to build the transaction. @@ -29,51 +117,44 @@ def build_and_send_transaction( or encoded data if encodedTxDataOnly is True. :raises Exception: If there is an error during the transaction process. """ - try: - tx_options = tx_options or {} - - transaction_options = { - "from": account.address, - } - - if "nonce" in tx_options: - nonce = tx_options["nonce"] - if not isinstance(nonce, int) or nonce < 0: - raise ValueError( - f"Invalid nonce value: {nonce}. Nonce must be a non-negative integer." - ) - transaction_options["nonce"] = nonce - else: - transaction_options["nonce"] = web3.eth.get_transaction_count( - account.address - ) + tx_options = tx_options or {} + client_args = tuple(client_args) - if "value" in tx_options: - transaction_options["value"] = tx_options["value"] + # Encode-only path: build options and return encoded data, no send + if tx_options.get("encodedTxDataOnly"): + opts = _get_transaction_options(web3, account, tx_options) + encoded = client_function(*client_args, opts) + return {"encodedTxData": encoded} - if "gasPrice" in tx_options: - transaction_options["gasPrice"] = web3.to_wei( - tx_options["gasPrice"], "gwei" - ) - if "maxFeePerGas" in tx_options: - transaction_options["maxFeePerGas"] = tx_options["maxFeePerGas"] + # Send path: optionally retry once with same nonce + higher gas + used_nonce = None + last_error = None - transaction = client_function(*client_args, transaction_options) + for attempt in range(2): + opts = _get_transaction_options( + web3, + account, + tx_options, + nonce_override=used_nonce, + bump_gas=(attempt == 1), + ) + if used_nonce is None: + used_nonce = opts["nonce"] - if tx_options.get("encodedTxDataOnly"): - return {"encodedTxData": transaction} - - signed_txn = account.sign_transaction(transaction) - tx_hash = web3.eth.send_raw_transaction(signed_txn.raw_transaction) - - wait_for_receipt = tx_options.get("wait_for_receipt", True) - - if wait_for_receipt: - timeout = tx_options.get("timeout", TRANSACTION_TIMEOUT) - tx_receipt = web3.eth.wait_for_transaction_receipt(tx_hash, timeout=timeout) - return {"tx_hash": tx_hash.hex(), "tx_receipt": tx_receipt} - else: - return {"tx_hash": tx_hash.hex()} + try: + return _send_one( + web3, + account, + client_function, + client_args, + tx_options, + opts, + ) + except Exception as e: + last_error = e + if not _is_retryable_send_error(e): + raise + if attempt == 0: + time.sleep(REPLACEMENT_UNDERPRICED_RETRY_DELAY) - except Exception as e: - raise e + raise last_error diff --git a/tests/integration/test_integration_group.py b/tests/integration/test_integration_group.py index e8729e05..a0b1ab44 100644 --- a/tests/integration/test_integration_group.py +++ b/tests/integration/test_integration_group.py @@ -366,3 +366,205 @@ def test_get_claimable_reward( assert len(claimable_rewards) == 2 assert claimable_rewards[0] == 10 assert claimable_rewards[1] == 10 + + +def _normalize_address(web3, addr: str) -> str: + """Normalize address for comparison (checksum).""" + return web3.to_checksum_address(addr) + + +class TestAddIpsToGroupAndRemoveIpsFromGroup: + """Integration tests for add_ips_to_group and remove_ips_from_group with strict on-chain verification.""" + + def test_add_ips_to_group( + self, story_client: StoryClient, nft_collection: Address + ): + """Test adding IPs to an existing group; verify chain state via AddedIpToGroup event and get_claimable_reward.""" + result1 = GroupTestHelper.mint_and_register_ip_asset_with_pil_terms( + story_client, nft_collection + ) + result2 = GroupTestHelper.mint_and_register_ip_asset_with_pil_terms( + story_client, nft_collection + ) + ip_id1 = result1["ip_id"] + ip_id2 = result2["ip_id"] + license_terms_id = result1["license_terms_id"] + + group_ip_id = GroupTestHelper.register_group_and_attach_license( + story_client, license_terms_id, [ip_id1] + ) + + result = story_client.Group.add_ips_to_group( + group_ip_id=group_ip_id, + ip_ids=[ip_id2], + ) + + assert "tx_hash" in result + assert isinstance(result["tx_hash"], str) + assert len(result["tx_hash"]) > 0 + # Strict: verify on-chain AddedIpToGroup event + assert "tx_receipt" in result, "add_ips_to_group must return tx_receipt for verification" + added_events = story_client.Group.get_added_ip_to_group_events( + result["tx_receipt"] + ) + assert len(added_events) == 1 + assert _normalize_address(story_client.web3, added_events[0]["groupId"]) == _normalize_address( + story_client.web3, group_ip_id + ) + assert set(_normalize_address(story_client.web3, a) for a in added_events[0]["ipIds"]) == { + _normalize_address(story_client.web3, ip_id2) + } + # Verify new member is in group: get_claimable_reward for [ip_id1, ip_id2] should succeed + claimable = story_client.Group.get_claimable_reward( + group_ip_id=group_ip_id, + currency_token=MockERC20, + member_ip_ids=[ip_id1, ip_id2], + ) + assert isinstance(claimable, list) + assert len(claimable) == 2 + + def test_add_ips_to_group_with_max_reward_share( + self, story_client: StoryClient, nft_collection: Address + ): + """Test adding IPs to group with custom max_allowed_reward_share_percentage; verify chain via event.""" + result1 = GroupTestHelper.mint_and_register_ip_asset_with_pil_terms( + story_client, nft_collection + ) + result2 = GroupTestHelper.mint_and_register_ip_asset_with_pil_terms( + story_client, nft_collection + ) + ip_id1 = result1["ip_id"] + ip_id2 = result2["ip_id"] + license_terms_id = result1["license_terms_id"] + + group_ip_id = GroupTestHelper.register_group_and_attach_license( + story_client, license_terms_id, [ip_id1] + ) + + result = story_client.Group.add_ips_to_group( + group_ip_id=group_ip_id, + ip_ids=[ip_id2], + max_allowed_reward_share_percentage=50, + ) + + assert "tx_hash" in result + assert isinstance(result["tx_hash"], str) + assert "tx_receipt" in result + added_events = story_client.Group.get_added_ip_to_group_events( + result["tx_receipt"] + ) + assert len(added_events) == 1 + assert _normalize_address(story_client.web3, added_events[0]["groupId"]) == _normalize_address( + story_client.web3, group_ip_id + ) + assert set(_normalize_address(story_client.web3, a) for a in added_events[0]["ipIds"]) == { + _normalize_address(story_client.web3, ip_id2) + } + + def test_remove_ips_from_group( + self, story_client: StoryClient, nft_collection: Address + ): + """Test removing IPs from a group; verify chain state via RemovedIpFromGroup event and get_claimable_reward.""" + result1 = GroupTestHelper.mint_and_register_ip_asset_with_pil_terms( + story_client, nft_collection + ) + result2 = GroupTestHelper.mint_and_register_ip_asset_with_pil_terms( + story_client, nft_collection + ) + ip_id1 = result1["ip_id"] + ip_id2 = result2["ip_id"] + license_terms_id = result1["license_terms_id"] + + group_ip_id = GroupTestHelper.register_group_and_attach_license( + story_client, license_terms_id, [ip_id1, ip_id2] + ) + + result = story_client.Group.remove_ips_from_group( + group_ip_id=group_ip_id, + ip_ids=[ip_id2], + ) + + assert "tx_hash" in result + assert isinstance(result["tx_hash"], str) + assert len(result["tx_hash"]) > 0 + assert "tx_receipt" in result + removed_events = story_client.Group.get_removed_ip_from_group_events( + result["tx_receipt"] + ) + assert len(removed_events) == 1 + assert _normalize_address(story_client.web3, removed_events[0]["groupId"]) == _normalize_address( + story_client.web3, group_ip_id + ) + assert set(_normalize_address(story_client.web3, a) for a in removed_events[0]["ipIds"]) == { + _normalize_address(story_client.web3, ip_id2) + } + # After remove, only ip_id1 remains; get_claimable_reward for [ip_id1] must succeed + claimable = story_client.Group.get_claimable_reward( + group_ip_id=group_ip_id, + currency_token=MockERC20, + member_ip_ids=[ip_id1], + ) + assert isinstance(claimable, list) + assert len(claimable) == 1 + + def test_add_then_remove_ips_from_group( + self, story_client: StoryClient, nft_collection: Address + ): + """Test add then remove in sequence; verify each step via on-chain events and final member list.""" + result1 = GroupTestHelper.mint_and_register_ip_asset_with_pil_terms( + story_client, nft_collection + ) + result2 = GroupTestHelper.mint_and_register_ip_asset_with_pil_terms( + story_client, nft_collection + ) + result3 = GroupTestHelper.mint_and_register_ip_asset_with_pil_terms( + story_client, nft_collection + ) + ip_id1 = result1["ip_id"] + ip_id2 = result2["ip_id"] + ip_id3 = result3["ip_id"] + license_terms_id = result1["license_terms_id"] + + group_ip_id = GroupTestHelper.register_group_and_attach_license( + story_client, license_terms_id, [ip_id1] + ) + + # Add ip_id2 and ip_id3 + add_result = story_client.Group.add_ips_to_group( + group_ip_id=group_ip_id, + ip_ids=[ip_id2, ip_id3], + ) + assert "tx_hash" in add_result + assert "tx_receipt" in add_result + added_events = story_client.Group.get_added_ip_to_group_events( + add_result["tx_receipt"] + ) + assert len(added_events) == 1 + assert set(_normalize_address(story_client.web3, a) for a in added_events[0]["ipIds"]) == { + _normalize_address(story_client.web3, ip_id2), + _normalize_address(story_client.web3, ip_id3), + } + + # Remove ip_id2 + remove_result = story_client.Group.remove_ips_from_group( + group_ip_id=group_ip_id, + ip_ids=[ip_id2], + ) + assert "tx_hash" in remove_result + assert "tx_receipt" in remove_result + removed_events = story_client.Group.get_removed_ip_from_group_events( + remove_result["tx_receipt"] + ) + assert len(removed_events) == 1 + assert set(_normalize_address(story_client.web3, a) for a in removed_events[0]["ipIds"]) == { + _normalize_address(story_client.web3, ip_id2) + } + + # Final state: only ip_id1 and ip_id3 are members + claimable = story_client.Group.get_claimable_reward( + group_ip_id=group_ip_id, + currency_token=MockERC20, + member_ip_ids=[ip_id1, ip_id3], + ) + assert isinstance(claimable, list) + assert len(claimable) == 2 diff --git a/tests/unit/resources/test_group.py b/tests/unit/resources/test_group.py index 94984de5..8a38dbbc 100644 --- a/tests/unit/resources/test_group.py +++ b/tests/unit/resources/test_group.py @@ -424,6 +424,341 @@ def test_claim_rewards_transaction_build_failure( ) +class TestGroupAddIpsToGroup: + """Test class for Group.add_ips_to_group method""" + + def test_add_ips_to_group_invalid_group_ip_id( + self, group: Group, mock_web3_is_address + ): + """Test add_ips_to_group with invalid group IP ID.""" + invalid_group_ip_id = "invalid_group_ip_id" + with mock_web3_is_address(False): + with pytest.raises( + ValueError, + match="Failed to add IP to group:", + ): + group.add_ips_to_group( + group_ip_id=invalid_group_ip_id, + ip_ids=[IP_ID], + ) + + def test_add_ips_to_group_invalid_ip_id(self, group: Group, mock_web3): + """Test add_ips_to_group with invalid IP ID.""" + invalid_ip_id = "invalid_ip_id" + with patch.object(mock_web3, "is_address") as mock_is_address: + mock_is_address.side_effect = [True, False] + with pytest.raises( + ValueError, + match="Failed to add IP to group:", + ): + group.add_ips_to_group( + group_ip_id=IP_ID, + ip_ids=[invalid_ip_id], + ) + + def test_add_ips_to_group_max_reward_share_exceeds_100( + self, group: Group, mock_web3_is_address + ): + """Test add_ips_to_group rejects max_allowed_reward_share_percentage > 100 (via get_revenue_share).""" + with mock_web3_is_address(): + with patch.object( + group.dispute_module_client, "isIpTagged", return_value=False + ), patch.object( + group.ip_asset_registry_client, + "isRegisteredGroup", + return_value=False, + ): + with pytest.raises( + ValueError, + match="must be between 0 and 100", + ): + group.add_ips_to_group( + group_ip_id=IP_ID, + ip_ids=[IP_ID], + max_allowed_reward_share_percentage=101, + ) + + def test_add_ips_to_group_disputed_group(self, group: Group, mock_web3_is_address): + """Test add_ips_to_group rejects disputed group.""" + with mock_web3_is_address(): + with patch.object( + group.dispute_module_client, "isIpTagged", return_value=True + ): + with pytest.raises( + ValueError, + match="Disputed group cannot add IP", + ): + group.add_ips_to_group( + group_ip_id=IP_ID, + ip_ids=[IP_ID], + ) + + def test_add_ips_to_group_disputed_ip(self, group: Group, mock_web3_is_address): + """Test add_ips_to_group rejects disputed IP in ip_ids.""" + with mock_web3_is_address(): + with patch.object( + group.dispute_module_client, + "isIpTagged", + side_effect=[False, True], # group ok, first ip disputed + ): + with pytest.raises( + ValueError, + match="Cannot add disputed IP to group", + ): + group.add_ips_to_group( + group_ip_id=IP_ID, + ip_ids=[IP_ID, ADDRESS], + ) + + def test_add_ips_to_group_ip_is_registered_group( + self, group: Group, mock_web3_is_address + ): + """Test add_ips_to_group rejects IP that is a registered group.""" + with mock_web3_is_address(): + with patch.object( + group.dispute_module_client, "isIpTagged", return_value=False + ), patch.object( + group.ip_asset_registry_client, + "isRegisteredGroup", + side_effect=[False, True], # first ip ok, second ip is group + ): + with pytest.raises( + ValueError, + match="Cannot add group to group", + ): + group.add_ips_to_group( + group_ip_id=IP_ID, + ip_ids=[IP_ID, ADDRESS], + ) + + def test_add_ips_to_group_success( + self, + group: Group, + mock_web3_is_address, + ): + """Test successful add_ips_to_group operation.""" + with mock_web3_is_address(): + with patch.object( + group.dispute_module_client, "isIpTagged", return_value=False + ), patch.object( + group.ip_asset_registry_client, + "isRegisteredGroup", + return_value=False, + ), patch( + "story_protocol_python_sdk.resources.Group.build_and_send_transaction", + return_value={"tx_hash": TX_HASH, "tx_receipt": {}}, + ): + result = group.add_ips_to_group( + group_ip_id=IP_ID, + ip_ids=[IP_ID, ADDRESS], + ) + + assert "tx_hash" in result + assert result["tx_hash"] == TX_HASH + + def test_add_ips_to_group_default_max_allowed_reward_share_percentage( + self, + group: Group, + mock_web3_is_address, + ): + """Test add_ips_to_group uses default max_allowed_reward_share_percentage 100.""" + with mock_web3_is_address(): + with patch.object( + group.dispute_module_client, "isIpTagged", return_value=False + ), patch.object( + group.ip_asset_registry_client, + "isRegisteredGroup", + return_value=False, + ), patch( + "story_protocol_python_sdk.resources.Group.build_and_send_transaction", + return_value={"tx_hash": TX_HASH, "tx_receipt": {}}, + ) as mock_build: + group.add_ips_to_group( + group_ip_id=IP_ID, + ip_ids=[IP_ID], + ) + # 100 -> 100 * 10**6 + call_args = mock_build.call_args[0] + assert call_args[5] == 100 * 10**6 + + def test_add_ips_to_group_max_allowed_reward_share_percentage_zero( + self, + group: Group, + mock_web3_is_address, + ): + """Test add_ips_to_group with max_allowed_reward_share_percentage 0.""" + with mock_web3_is_address(): + with patch.object( + group.dispute_module_client, "isIpTagged", return_value=False + ), patch.object( + group.ip_asset_registry_client, + "isRegisteredGroup", + return_value=False, + ), patch( + "story_protocol_python_sdk.resources.Group.build_and_send_transaction", + return_value={"tx_hash": TX_HASH, "tx_receipt": {}}, + ) as mock_build: + group.add_ips_to_group( + group_ip_id=IP_ID, + ip_ids=[IP_ID], + max_allowed_reward_share_percentage=0, + ) + call_args = mock_build.call_args[0] + assert call_args[5] == 0 + + def test_add_ips_to_group_transaction_fails( + self, group: Group, mock_web3_is_address + ): + """Test add_ips_to_group when transaction build/send fails.""" + with mock_web3_is_address(): + with patch.object( + group.dispute_module_client, "isIpTagged", return_value=False + ), patch.object( + group.ip_asset_registry_client, + "isRegisteredGroup", + return_value=False, + ), patch( + "story_protocol_python_sdk.resources.Group.build_and_send_transaction", + side_effect=Exception("Transaction build failed"), + ): + with pytest.raises( + ValueError, + match="Failed to add IP to group: Transaction build failed", + ): + group.add_ips_to_group( + group_ip_id=IP_ID, + ip_ids=[IP_ID], + ) + + +class TestGroupRemoveIpsFromGroup: + """Test class for Group.remove_ips_from_group method""" + + def test_remove_ips_from_group_invalid_group_ip_id( + self, group: Group, mock_web3_is_address + ): + """Test remove_ips_from_group with invalid group IP ID.""" + invalid_group_ip_id = "invalid_group_ip_id" + with mock_web3_is_address(False): + with pytest.raises( + ValueError, + match="Failed to remove IPs from group:", + ): + group.remove_ips_from_group( + group_ip_id=invalid_group_ip_id, + ip_ids=[IP_ID], + ) + + def test_remove_ips_from_group_group_has_derivative_ips( + self, group: Group, mock_web3_is_address + ): + """Test remove_ips_from_group rejects when group has derivative IPs.""" + with mock_web3_is_address(): + with patch.object( + group.license_registry_client, + "hasDerivativeIps", + return_value=True, + ): + with pytest.raises( + ValueError, + match="Group frozen:.*has derivative IPs", + ): + group.remove_ips_from_group( + group_ip_id=IP_ID, + ip_ids=[IP_ID], + ) + + def test_remove_ips_from_group_group_has_minted_license_tokens( + self, group: Group, mock_web3_is_address + ): + """Test remove_ips_from_group rejects when group has minted license tokens.""" + with mock_web3_is_address(): + with patch.object( + group.license_registry_client, + "hasDerivativeIps", + return_value=False, + ), patch.object( + group.license_token_client, + "getTotalTokensByLicensor", + return_value=10, + ): + with pytest.raises( + ValueError, + match="Group frozen:.*has already minted license tokens", + ): + group.remove_ips_from_group( + group_ip_id=IP_ID, + ip_ids=[IP_ID], + ) + + def test_remove_ips_from_group_invalid_ip_id(self, group: Group, mock_web3): + """Test remove_ips_from_group with invalid IP ID.""" + invalid_ip_id = "invalid_ip_id" + with patch.object(mock_web3, "is_address") as mock_is_address: + mock_is_address.side_effect = [True, False] + with pytest.raises( + ValueError, + match="Failed to remove IPs from group:", + ): + group.remove_ips_from_group( + group_ip_id=IP_ID, + ip_ids=[invalid_ip_id], + ) + + def test_remove_ips_from_group_success( + self, + group: Group, + mock_web3_is_address, + ): + """Test successful remove_ips_from_group operation.""" + with mock_web3_is_address(): + with patch.object( + group.license_registry_client, + "hasDerivativeIps", + return_value=False, + ), patch.object( + group.license_token_client, + "getTotalTokensByLicensor", + return_value=0, + ), patch( + "story_protocol_python_sdk.resources.Group.build_and_send_transaction", + return_value={"tx_hash": TX_HASH, "tx_receipt": {}}, + ): + result = group.remove_ips_from_group( + group_ip_id=IP_ID, + ip_ids=[IP_ID, ADDRESS], + ) + + assert "tx_hash" in result + assert result["tx_hash"] == TX_HASH + + def test_remove_ips_from_group_transaction_fails( + self, group: Group, mock_web3_is_address + ): + """Test remove_ips_from_group when transaction build/send fails.""" + with mock_web3_is_address(): + with patch.object( + group.license_registry_client, + "hasDerivativeIps", + return_value=False, + ), patch.object( + group.license_token_client, + "getTotalTokensByLicensor", + return_value=0, + ), patch( + "story_protocol_python_sdk.resources.Group.build_and_send_transaction", + side_effect=Exception("Transaction build failed"), + ): + with pytest.raises( + ValueError, + match="Failed to remove IPs from group: Transaction build failed", + ): + group.remove_ips_from_group( + group_ip_id=IP_ID, + ip_ids=[IP_ID], + ) + + class TestGroupGetClaimableReward: """Test class for Group.get_claimable_reward method"""