|
| 1 | +from boto3 import session |
| 2 | +from botocore.exceptions import ClientError |
| 3 | +from typing import Dict, List |
| 4 | +import os |
| 5 | +import logging |
| 6 | +import uuid |
| 7 | + |
| 8 | +logger = logging.getLogger(__name__) |
| 9 | + |
| 10 | +ENDPOINT_ENV = "aws_endpoint_url" |
| 11 | + |
| 12 | +DEFAULT_BUCKET_TO_DOMAIN = { |
| 13 | + "prod-maven-ga": "maven.repository.redhat.com", |
| 14 | + "prod-maven-ea": "maven.repository.redhat.com", |
| 15 | + "stage-maven-ga": "maven.strage.repository.redhat.com", |
| 16 | + "stage-maven-ea": "maven.strage.repository.redhat.com", |
| 17 | + "prod-npm": "npm.repository.redhat.com", |
| 18 | + "stage-npm": "npm.stage.repository.redhat.com" |
| 19 | +} |
| 20 | + |
| 21 | + |
| 22 | +class CFClient(object): |
| 23 | + """The CFClient is a wrapper of the original boto3 clouldfrong client, |
| 24 | + which will provide CloudFront functions to be used in the charon. |
| 25 | + """ |
| 26 | + |
| 27 | + def __init__( |
| 28 | + self, |
| 29 | + aws_profile=None, |
| 30 | + extra_conf=None |
| 31 | + ) -> None: |
| 32 | + self.__client = self.__init_aws_client(aws_profile, extra_conf) |
| 33 | + |
| 34 | + def __init_aws_client( |
| 35 | + self, aws_profile=None, extra_conf=None |
| 36 | + ): |
| 37 | + if aws_profile: |
| 38 | + logger.debug("Using aws profile: %s", aws_profile) |
| 39 | + cf_session = session.Session(profile_name=aws_profile) |
| 40 | + else: |
| 41 | + cf_session = session.Session() |
| 42 | + endpoint_url = self.__get_endpoint(extra_conf) |
| 43 | + return cf_session.client( |
| 44 | + 'cloudfront', |
| 45 | + endpoint_url=endpoint_url |
| 46 | + ) |
| 47 | + |
| 48 | + def __get_endpoint(self, extra_conf) -> str: |
| 49 | + endpoint_url = os.getenv(ENDPOINT_ENV) |
| 50 | + if not endpoint_url or not endpoint_url.strip(): |
| 51 | + if isinstance(extra_conf, Dict): |
| 52 | + endpoint_url = extra_conf.get(ENDPOINT_ENV, None) |
| 53 | + if endpoint_url: |
| 54 | + logger.info("Using endpoint url for aws client: %s", endpoint_url) |
| 55 | + else: |
| 56 | + logger.debug("No user-specified endpoint url is used.") |
| 57 | + return endpoint_url |
| 58 | + |
| 59 | + def invalidate_paths(self, distr_id: str, paths: List[str]) -> Dict[str, str]: |
| 60 | + """Send a invalidating requests for the paths in distribution to CloudFront. |
| 61 | + This will invalidate the paths in the distribution to enforce the refreshment |
| 62 | + from backend S3 bucket for these paths. For details see: |
| 63 | + https://docs.aws.amazon.com/AmazonCloudFront/latest/DeveloperGuide/Invalidation.html |
| 64 | + * The distr_id is the id for the distribution. This id can be get through |
| 65 | + get_dist_id_by_domain(domain) function |
| 66 | + * Can specify the invalidating paths through paths param. |
| 67 | + """ |
| 68 | + caller_ref = str(uuid.uuid4()) |
| 69 | + logger.debug("[CloudFront] Creating invalidation for paths: %s", paths) |
| 70 | + try: |
| 71 | + response = self.__client.create_invalidation( |
| 72 | + DistributionId=distr_id, |
| 73 | + InvalidationBatch={ |
| 74 | + 'CallerReference': caller_ref, |
| 75 | + 'Paths': { |
| 76 | + 'Quantity': len(paths), |
| 77 | + 'Items': paths |
| 78 | + } |
| 79 | + } |
| 80 | + ) |
| 81 | + if response: |
| 82 | + invalidation = response.get('Invalidation', {}) |
| 83 | + return { |
| 84 | + 'Id': invalidation.get('Id', None), |
| 85 | + 'Status': invalidation.get('Status', None) |
| 86 | + } |
| 87 | + except Exception as err: |
| 88 | + logger.error( |
| 89 | + "[CloudFront] Error occurred while creating invalidation, error: %s", err |
| 90 | + ) |
| 91 | + |
| 92 | + def check_invalidation(self, distr_id: str, invalidation_id: str) -> dict: |
| 93 | + try: |
| 94 | + response = self.__client.get_invalidation( |
| 95 | + DistributionId=distr_id, |
| 96 | + Id=invalidation_id |
| 97 | + ) |
| 98 | + if response: |
| 99 | + invalidation = response.get('Invalidation', {}) |
| 100 | + return { |
| 101 | + 'Id': invalidation.get('Id', None), |
| 102 | + 'CreateTime': invalidation.get('CreateTime', None), |
| 103 | + 'Status': invalidation.get('Status', None) |
| 104 | + } |
| 105 | + except Exception as err: |
| 106 | + logger.error( |
| 107 | + "[CloudFront] Error occurred while check invalidation of id %s, " |
| 108 | + "error: %s", invalidation_id, err |
| 109 | + ) |
| 110 | + |
| 111 | + def get_dist_id_by_domain(self, domain: str) -> str: |
| 112 | + """Get distribution id by a domain name. The id can be used to send invalidating |
| 113 | + request through #invalidate_paths function |
| 114 | + * Domain are Ronda domains, like "maven.repository.redhat.com" |
| 115 | + or "npm.repository.redhat.com" |
| 116 | + """ |
| 117 | + try: |
| 118 | + response = self.__client.list_distributions() |
| 119 | + if response: |
| 120 | + dist_list_items = response.get("DistributionList", {}).get("Items", []) |
| 121 | + for distr in dist_list_items: |
| 122 | + aliases_items = distr.get('Aliases', {}).get('Items', []) |
| 123 | + if aliases_items and domain in aliases_items: |
| 124 | + return distr['Id'] |
| 125 | + logger.error("[CloudFront]: Distribution not found for domain %s", domain) |
| 126 | + except ClientError as err: |
| 127 | + logger.error( |
| 128 | + "[CloudFront]: Error occurred while get distribution for domain %s: %s", |
| 129 | + domain, err |
| 130 | + ) |
| 131 | + return None |
| 132 | + |
| 133 | + def get_domain_by_bucket(self, bucket: str) -> str: |
| 134 | + return DEFAULT_BUCKET_TO_DOMAIN.get(bucket, None) |
0 commit comments