diff --git a/tpu/create_tpu_topology.py b/tpu/create_tpu_topology.py index b29c7fa4302..7d0d7468c67 100644 --- a/tpu/create_tpu_topology.py +++ b/tpu/create_tpu_topology.py @@ -44,7 +44,7 @@ def create_cloud_tpu_with_topology( node = tpu_v2.Node() # Here we are creating a TPU v3-8 with 2x2 topology. node.accelerator_config = tpu_v2.AcceleratorConfig( - type_=tpu_v2.AcceleratorConfig.Type.V3, + type_=tpu_v2.AcceleratorConfig.Type.V2, topology="2x2", ) node.runtime_version = runtime_version diff --git a/tpu/delete_tpu.py b/tpu/delete_tpu.py index b185aed3ac2..f927d83c121 100644 --- a/tpu/delete_tpu.py +++ b/tpu/delete_tpu.py @@ -45,4 +45,4 @@ def delete_cloud_tpu(project_id: str, zone: str, tpu_name: str = "tpu-name") -> if __name__ == "__main__": PROJECT_ID = os.getenv("GOOGLE_CLOUD_PROJECT") ZONE = "us-central1-b" - delete_cloud_tpu(PROJECT_ID, ZONE, "tpu-name12") + delete_cloud_tpu(PROJECT_ID, ZONE, "tpu-name") diff --git a/tpu/queued_resources_create.py b/tpu/queued_resources_create.py new file mode 100644 index 00000000000..91dad552bcf --- /dev/null +++ b/tpu/queued_resources_create.py @@ -0,0 +1,80 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + +from google.cloud.tpu_v2alpha1 import CreateQueuedResourceRequest, Node + + +def create_queued_resource( + project_id: str, + zone: str, + tpu_name: str, + tpu_type: str = "v2-8", + runtime_version: str = "tpu-vm-tf-2.17.0-pjrt", + queued_resource_name: str = "resource-name", +) -> Node: + # [START tpu_queued_resources_create] + from google.cloud import tpu_v2alpha1 + + # TODO(developer): Update and un-comment below lines + # project_id = "your-project-id" + # zone = "us-central1-b" + # tpu_name = "tpu-name" + # tpu_type = "v2-8" + # runtime_version = "tpu-vm-tf-2.17.0-pjrt" + # queued_resource_name = "resource-name" + + node = tpu_v2alpha1.Node() + node.accelerator_type = tpu_type + # To see available runtime version use command: + # gcloud compute tpus versions list --zone={ZONE} + node.runtime_version = runtime_version + + node_spec = tpu_v2alpha1.QueuedResource.Tpu.NodeSpec() + node_spec.parent = f"projects/{project_id}/locations/{zone}" + node_spec.node_id = tpu_name + node_spec.node = node + + resource = tpu_v2alpha1.QueuedResource() + resource.tpu = tpu_v2alpha1.QueuedResource.Tpu(node_spec=[node_spec]) + + request = CreateQueuedResourceRequest( + parent=f"projects/{project_id}/locations/{zone}", + queued_resource_id=queued_resource_name, + queued_resource=resource, + ) + + client = tpu_v2alpha1.TpuClient() + operation = client.create_queued_resource(request=request) + + response = operation.result() + print(response.name) + print(response.state.state) + # Example response: + # projects/[project_id]/locations/[zone]/queuedResources/resource-name + # State.WAITING_FOR_RESOURCES + + # [END tpu_queued_resources_create] + return response + + +if __name__ == "__main__": + PROJECT_ID = os.getenv("GOOGLE_CLOUD_PROJECT") + ZONE = "us-central1-b" + create_queued_resource( + project_id=PROJECT_ID, + zone=ZONE, + tpu_name="tpu-name", + queued_resource_name="resource-name", + ) diff --git a/tpu/queued_resources_create_network.py b/tpu/queued_resources_create_network.py new file mode 100644 index 00000000000..5061fbed2bb --- /dev/null +++ b/tpu/queued_resources_create_network.py @@ -0,0 +1,90 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + +from google.cloud.tpu_v2alpha1 import CreateQueuedResourceRequest, Node + + +def create_queued_resource_network( + project_id: str, + zone: str, + tpu_name: str, + tpu_type: str = "v2-8", + runtime_version: str = "tpu-vm-tf-2.17.0-pjrt", + queued_resource_name: str = "resource-name", + network: str = "default", +) -> Node: + # [START tpu_queued_resources_network] + from google.cloud import tpu_v2alpha1 + + # TODO(developer): Update and un-comment below lines + # project_id = "your-project-id" + # zone = "us-central1-b" + # tpu_name = "tpu-name" + # tpu_type = "v2-8" + # runtime_version = "tpu-vm-tf-2.17.0-pjrt" + # queued_resource_name = "resource-name" + # network = "default" + + node = tpu_v2alpha1.Node() + node.accelerator_type = tpu_type + node.runtime_version = runtime_version + # Setting network configuration + node.network_config = tpu_v2alpha1.NetworkConfig( + network=network, # Update if you want to use a specific network + subnetwork="default", # Update if you want to use a specific subnetwork + enable_external_ips=True, + can_ip_forward=True, + ) + + node_spec = tpu_v2alpha1.QueuedResource.Tpu.NodeSpec() + node_spec.parent = f"projects/{project_id}/locations/{zone}" + node_spec.node_id = tpu_name + node_spec.node = node + + resource = tpu_v2alpha1.QueuedResource() + resource.tpu = tpu_v2alpha1.QueuedResource.Tpu(node_spec=[node_spec]) + + request = CreateQueuedResourceRequest( + parent=f"projects/{project_id}/locations/{zone}", + queued_resource_id=queued_resource_name, + queued_resource=resource, + ) + + client = tpu_v2alpha1.TpuClient() + operation = client.create_queued_resource(request=request) + + response = operation.result() + print(response.name) + print(response.tpu.node_spec[0].node.network_config) + print(resource.tpu.node_spec[0].node.network_config.network == "default") + # Example response: + # network: "default" + # subnetwork: "default" + # enable_external_ips: true + # can_ip_forward: true + + # [END tpu_queued_resources_network] + return response + + +if __name__ == "__main__": + PROJECT_ID = os.getenv("GOOGLE_CLOUD_PROJECT") + ZONE = "us-central1-b" + create_queued_resource_network( + project_id=PROJECT_ID, + zone=ZONE, + tpu_name="tpu-name", + queued_resource_name="resource-name", + ) diff --git a/tpu/queued_resources_create_spot.py b/tpu/queued_resources_create_spot.py new file mode 100644 index 00000000000..59bacc3b031 --- /dev/null +++ b/tpu/queued_resources_create_spot.py @@ -0,0 +1,82 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + +from google.cloud.tpu_v2alpha1 import CreateQueuedResourceRequest, Node + + +def create_queued_resource_spot( + project_id: str, + zone: str, + tpu_name: str, + tpu_type: str = "v2-8", + runtime_version: str = "tpu-vm-tf-2.17.0-pjrt", + queued_resource_name: str = "resource-name", +) -> Node: + # [START tpu_queued_resources_create_spot] + from google.cloud import tpu_v2alpha1 + + # TODO(developer): Update and un-comment below lines + # project_id = "your-project-id" + # zone = "us-central1-b" + # tpu_name = "tpu-name" + # tpu_type = "v2-8" + # runtime_version = "tpu-vm-tf-2.17.0-pjrt" + # queued_resource_name = "resource-name" + + node = tpu_v2alpha1.Node() + node.accelerator_type = tpu_type + # To see available runtime version use command: + # gcloud compute tpus versions list --zone={ZONE} + node.runtime_version = runtime_version + + node_spec = tpu_v2alpha1.QueuedResource.Tpu.NodeSpec() + node_spec.parent = f"projects/{project_id}/locations/{zone}" + node_spec.node_id = tpu_name + node_spec.node = node + + resource = tpu_v2alpha1.QueuedResource() + resource.tpu = tpu_v2alpha1.QueuedResource.Tpu(node_spec=[node_spec]) + # Create a spot resource + resource.spot = tpu_v2alpha1.QueuedResource.Spot() + + request = CreateQueuedResourceRequest( + parent=f"projects/{project_id}/locations/{zone}", + queued_resource_id=queued_resource_name, + queued_resource=resource, + ) + + client = tpu_v2alpha1.TpuClient() + operation = client.create_queued_resource(request=request) + response = operation.result() + + print(response.name) + print(response.state.state) + # Example response: + # projects/[project_id]/locations/[zone]/queuedResources/resource-name + # State.WAITING_FOR_RESOURCES + + # [END tpu_queued_resources_create_spot] + return response + + +if __name__ == "__main__": + PROJECT_ID = os.getenv("GOOGLE_CLOUD_PROJECT") + ZONE = "us-central1-b" + create_queued_resource_spot( + project_id=PROJECT_ID, + zone=ZONE, + tpu_name="tpu-name", + queued_resource_name="resource-name", + ) diff --git a/tpu/queued_resources_create_startup_script.py b/tpu/queued_resources_create_startup_script.py new file mode 100644 index 00000000000..44d0ad3fe44 --- /dev/null +++ b/tpu/queued_resources_create_startup_script.py @@ -0,0 +1,93 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + +from google.cloud.tpu_v2alpha1 import CreateQueuedResourceRequest, Node + + +def create_queued_resource_startup_script( + project_id: str, + zone: str, + tpu_name: str, + tpu_type: str = "v2-8", + runtime_version: str = "tpu-vm-tf-2.17.0-pjrt", + queued_resource_name: str = "resource-name", +) -> Node: + # [START tpu_queued_resources_startup_script] + from google.cloud import tpu_v2alpha1 + + # TODO(developer): Update and un-comment below lines + # project_id = "your-project-id" + # zone = "us-central1-b" + # tpu_name = "tpu-name" + # tpu_type = "v2-8" + # runtime_version = "tpu-vm-tf-2.17.0-pjrt" + # queued_resource_name = "resource-name" + + node = tpu_v2alpha1.Node() + node.accelerator_type = tpu_type + # To see available runtime version use command: + # gcloud compute tpus versions list --zone={ZONE} + node.runtime_version = runtime_version + # This startup script updates numpy to the latest version and logs the output to a file. + script = { + "startup-script": """#!/bin/bash + echo "Hello World" > /var/log/hello.log + sudo pip3 install --upgrade numpy >> /var/log/hello.log 2>&1 + """ + } + node.metadata = script + # Enabling external IPs for internet access from the TPU node for updating numpy + node.network_config = tpu_v2alpha1.NetworkConfig( + enable_external_ips=True, + ) + + node_spec = tpu_v2alpha1.QueuedResource.Tpu.NodeSpec() + node_spec.parent = f"projects/{project_id}/locations/{zone}" + node_spec.node_id = tpu_name + node_spec.node = node + + resource = tpu_v2alpha1.QueuedResource() + resource.tpu = tpu_v2alpha1.QueuedResource.Tpu(node_spec=[node_spec]) + + request = CreateQueuedResourceRequest( + parent=f"projects/{project_id}/locations/{zone}", + queued_resource_id=queued_resource_name, + queued_resource=resource, + ) + + client = tpu_v2alpha1.TpuClient() + operation = client.create_queued_resource(request=request) + + response = operation.result() + print(response.name) + print(response.tpu.node_spec[0].node.metadata) + # Example response: + # projects/[project_id]/locations/[zone]/queuedResources/resource-name + # {'startup-script': '#!/bin/bash\n echo "Hello World" > /var/log/hello.log\n + # sudo pip3 install --upgrade numpy >> /var/log/hello.log 2>&1\n '} + + # [END tpu_queued_resources_startup_script] + return response + + +if __name__ == "__main__": + PROJECT_ID = os.getenv("GOOGLE_CLOUD_PROJECT") + ZONE = "us-central1-b" + create_queued_resource_startup_script( + project_id=PROJECT_ID, + zone=ZONE, + tpu_name="tpu-name", + queued_resource_name="resource-name", + ) diff --git a/tpu/queued_resources_create_time_bound.py b/tpu/queued_resources_create_time_bound.py new file mode 100644 index 00000000000..76b352a4de3 --- /dev/null +++ b/tpu/queued_resources_create_time_bound.py @@ -0,0 +1,90 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + +from google.cloud.tpu_v2alpha1 import CreateQueuedResourceRequest, Node + + +def create_queued_resource_time_bound( + project_id: str, + zone: str, + tpu_name: str, + tpu_type: str = "v2-8", + runtime_version: str = "tpu-vm-tf-2.17.0-pjrt", + queued_resource_name: str = "resource-name", +) -> Node: + # [START tpu_queued_resources_time_bound] + from google.cloud import tpu_v2alpha1 + + # TODO(developer): Update and un-comment below lines + # project_id = "your-project-id" + # zone = "us-central1-b" + # tpu_name = "tpu-name" + # tpu_type = "v2-8" + # runtime_version = "tpu-vm-tf-2.17.0-pjrt" + # queued_resource_name = "resource-name" + + node = tpu_v2alpha1.Node() + node.accelerator_type = tpu_type + # To see available runtime version use command: + # gcloud compute tpus versions list --zone={ZONE} + node.runtime_version = runtime_version + + node_spec = tpu_v2alpha1.QueuedResource.Tpu.NodeSpec() + node_spec.parent = f"projects/{project_id}/locations/{zone}" + node_spec.node_id = tpu_name + node_spec.node = node + + resource = tpu_v2alpha1.QueuedResource() + resource.tpu = tpu_v2alpha1.QueuedResource.Tpu(node_spec=[node_spec]) + + # Use one of the following queueing policies + resource.queueing_policy = tpu_v2alpha1.QueuedResource.QueueingPolicy( + # valid_after_duration = "6000s", # Duration after which a resource should be allocated + valid_until_duration="90s", # Specify how long a queued resource request remains valid + # valid_after_time="2024-10-31T09:00:00Z", # Specify a time after which a resource should be allocated + # valid_until_time="2024-10-29T16:00:00Z", # Specify a time before which the resource should be allocated + ) + + request = CreateQueuedResourceRequest( + parent=f"projects/{project_id}/locations/{zone}", + queued_resource_id=queued_resource_name, + queued_resource=resource, + ) + + client = tpu_v2alpha1.TpuClient() + operation = client.create_queued_resource(request=request) + + response = operation.result() + print(resource.queueing_policy) + print(response.queueing_policy.valid_until_time) + # Example response: + # valid_until_duration { + # seconds: 90 + # } + # 2024-10-29 14:22:53.562090+00:00 + + # [END tpu_queued_resources_time_bound] + return response + + +if __name__ == "__main__": + PROJECT_ID = os.getenv("GOOGLE_CLOUD_PROJECT") + ZONE = "us-central1-b" + create_queued_resource_time_bound( + project_id=PROJECT_ID, + zone=ZONE, + tpu_name="tpu-name", + queued_resource_name="resource-name", + ) diff --git a/tpu/queued_resources_delete.py b/tpu/queued_resources_delete.py new file mode 100644 index 00000000000..db503d08719 --- /dev/null +++ b/tpu/queued_resources_delete.py @@ -0,0 +1,47 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + + +def delete_queued_resource( + project_id: str, zone: str, queued_resource_name: str +) -> None: + # [START tpu_queued_resource_delete] + from google.cloud import tpu_v2alpha1 + + # TODO(developer): Update and un-comment below lines + # project_id = "your-project-id" + # zone = "us-central1-b" + # queued_resource_name = "resource-name" + + client = tpu_v2alpha1.TpuClient() + name = ( + f"projects/{project_id}/locations/{zone}/queuedResources/{queued_resource_name}" + ) + + try: + op = client.delete_queued_resource(name=name) + op.result() + print(f"Queued resource '{queued_resource_name}' successfully deleted.") + except TypeError as e: + print(f"Error deleting resource: {e}") + print(f"Queued resource '{queued_resource_name}' successfully deleted.") + + # [END tpu_queued_resource_delete] + + +if __name__ == "__main__": + PROJECT_ID = os.environ["GOOGLE_CLOUD_PROJECT"] + ZONE = "us-central1-b" + delete_queued_resource(PROJECT_ID, ZONE, "resource-name") diff --git a/tpu/queued_resources_delete_force.py b/tpu/queued_resources_delete_force.py new file mode 100644 index 00000000000..357a8378a2a --- /dev/null +++ b/tpu/queued_resources_delete_force.py @@ -0,0 +1,48 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + + +def delete_force_queued_resource( + project_id: str, zone: str, queued_resource_name: str +) -> None: + # [START tpu_queued_resources_delete_force] + from google.cloud import tpu_v2alpha1 + + # TODO(developer): Update and un-comment below lines + # project_id = "your-project-id" + # zone = "us-central1-b" + # queued_resource_name = "resource-name" + + client = tpu_v2alpha1.TpuClient() + request = tpu_v2alpha1.DeleteQueuedResourceRequest( + name=f"projects/{project_id}/locations/{zone}/queuedResources/{queued_resource_name}", + force=True, # Set force=True to delete the resource with tpu nodes. + ) + + try: + op = client.delete_queued_resource(request=request) + op.result() + print(f"Queued resource '{queued_resource_name}' successfully deleted.") + except TypeError as e: + print(f"Error deleting resource: {e}") + print(f"Queued resource '{queued_resource_name}' successfully deleted.") + + # [END tpu_queued_resources_delete_force] + + +if __name__ == "__main__": + PROJECT_ID = os.environ["GOOGLE_CLOUD_PROJECT"] + ZONE = "us-central1-b" + delete_force_queued_resource(PROJECT_ID, ZONE, "resource-name") diff --git a/tpu/queued_resources_get.py b/tpu/queued_resources_get.py new file mode 100644 index 00000000000..dc19f64291b --- /dev/null +++ b/tpu/queued_resources_get.py @@ -0,0 +1,48 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + +from google.cloud.tpu_v2alpha1 import QueuedResource + + +def get_queued_resource( + project_id: str, zone: str, queued_resource_name: str +) -> QueuedResource: + # [START tpu_queued_resources_get] + from google.cloud import tpu_v2alpha1 + + # TODO(developer): Update and un-comment below lines + # project_id = "your-project-id" + # zone = "us-central1-b" + # queued_resource_name = "resource-name" + + client = tpu_v2alpha1.TpuClient() + name = ( + f"projects/{project_id}/locations/{zone}/queuedResources/{queued_resource_name}" + ) + resource = client.get_queued_resource(name=name) + print("Resource name:", resource.name) + print(resource.state.state) + # Example response: + # Resource name: projects/{project_id}/locations/{zone}/queuedResources/resource-name + # State.ACTIVE + + # [END tpu_queued_resources_get] + return resource + + +if __name__ == "__main__": + PROJECT_ID = os.environ["GOOGLE_CLOUD_PROJECT"] + ZONE = "us-central1-b" + get_queued_resource(PROJECT_ID, ZONE, "resource-name") diff --git a/tpu/queued_resources_list.py b/tpu/queued_resources_list.py new file mode 100644 index 00000000000..4ed56a9596f --- /dev/null +++ b/tpu/queued_resources_list.py @@ -0,0 +1,44 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + +from google.cloud.tpu_v2alpha1.services.tpu.pagers import ListQueuedResourcesPager + + +def list_queued_resources(project_id: str, zone: str) -> ListQueuedResourcesPager: + # [START tpu_queued_resources_list] + from google.cloud import tpu_v2alpha1 + + # TODO(developer): Update and un-comment below lines + # project_id = "your-project-id" + # zone = "us-central1-b" + + client = tpu_v2alpha1.TpuClient() + parent = f"projects/{project_id}/locations/{zone}" + resources = client.list_queued_resources(parent=parent) + for resource in resources: + print("Resource name:", resource.name) + print("TPU id:", resource.tpu.node_spec[0].node_id) + # Example response: + # Resource name: projects/{project_id}/locations/{zone}/queuedResources/resource-name + # TPU id: tpu-name + + # [END tpu_queued_resources_list] + return resources + + +if __name__ == "__main__": + PROJECT_ID = os.getenv("GOOGLE_CLOUD_PROJECT") + ZONE = "us-central1-b" + list_queued_resources(PROJECT_ID, ZONE) diff --git a/tpu/test_queued_resources.py b/tpu/test_queued_resources.py new file mode 100644 index 00000000000..c80fc9420ab --- /dev/null +++ b/tpu/test_queued_resources.py @@ -0,0 +1,119 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from typing import Callable + +import uuid + +from google.cloud.tpu_v2alpha1 import QueuedResource + +import pytest + +import queued_resources_create +import queued_resources_create_network +import queued_resources_create_startup_script +import queued_resources_create_time_bound +import queued_resources_delete_force +import queued_resources_list + + +PROJECT_ID = os.getenv("GOOGLE_CLOUD_PROJECT") +ZONE = "us-central1-b" +TPU_TYPE = "v2-8" +TPU_VERSION = "tpu-vm-tf-2.17.0-pjrt" + + +@pytest.fixture(scope="function") +def test_resource_name() -> None: + yield "test-resource-" + uuid.uuid4().hex[:6] + + +@pytest.fixture(scope="function") +def test_tpu_name() -> None: + yield "test-tpu-" + uuid.uuid4().hex[:6] + + +@pytest.fixture(scope="function") +def create_resource() -> Callable: + resources = [] + + def _create_resource( + create_func: Callable, resource_name: str, tpu_name: str + ) -> QueuedResource: + resource = create_func( + PROJECT_ID, + ZONE, + tpu_name, + TPU_TYPE, + TPU_VERSION, + resource_name, + ) + resources.append((resource_name, tpu_name)) + assert resource_name in resource.name + return resource + + yield _create_resource + for resource_name, tpu_name in resources: + queued_resources_delete_force.delete_force_queued_resource( + PROJECT_ID, ZONE, resource_name + ) + + +def test_list_queued_resources( + create_resource: Callable, test_resource_name: str, test_tpu_name: str +) -> None: + create_resource( + queued_resources_create.create_queued_resource, + test_resource_name, + test_tpu_name, + ) + resources = queued_resources_list.list_queued_resources(PROJECT_ID, ZONE) + assert any( + test_resource_name in resource.name for resource in resources + ), f"Resources does not contain '{test_resource_name}'" + + +def test_create_resource_with_network( + create_resource: Callable, test_resource_name: str, test_tpu_name: str +) -> None: + resource = create_resource( + queued_resources_create_network.create_queued_resource_network, + test_resource_name, + test_tpu_name, + ) + assert resource.tpu.node_spec[0].node.network_config.network == "default" + + +def test_create_resource_with_startup_script( + create_resource: Callable, test_resource_name: str, test_tpu_name: str +) -> None: + resource = create_resource( + queued_resources_create_startup_script.create_queued_resource_startup_script, + test_resource_name, + test_tpu_name, + ) + assert ( + "--upgrade numpy" in resource.tpu.node_spec[0].node.metadata["startup-script"] + ) + + +def test_create_queued_resource_time_bound( + create_resource: Callable, test_resource_name: str, test_tpu_name: str +) -> None: + resource = create_resource( + queued_resources_create_time_bound.create_queued_resource_time_bound, + test_resource_name, + test_tpu_name, + ) + assert resource.queueing_policy.valid_until_time