diff --git a/apns.py b/apns.py index fdcec08..463cbe7 100644 --- a/apns.py +++ b/apns.py @@ -24,7 +24,9 @@ # SOFTWARE. from binascii import a2b_hex, b2a_hex -from datetime import datetime +from datetime import datetime, timedelta +from time import mktime +from random import getrandbits from socket import socket, AF_INET, SOCK_STREAM from struct import pack, unpack @@ -142,7 +144,7 @@ def write(self, string): class PayloadAlert(object): - def __init__(self, body, action_loc_key=None, loc_key=None, + def __init__(self, body=None, action_loc_key=None, loc_key=None, loc_args=None, launch_image=None): super(PayloadAlert, self).__init__() self.body = body @@ -152,7 +154,9 @@ def __init__(self, body, action_loc_key=None, loc_key=None, self.launch_image = launch_image def dict(self): - d = { 'body': self.body } + d = {} + if self.body: + d['body'] = self.body if self.action_loc_key: d['action-loc-key'] = self.action_loc_key if self.loc_key: @@ -263,6 +267,30 @@ def items(self): # some more data and append to buffer break +class UnknownResponse(Exception): + def __init__(self): + super(UnknownResponse, self).__init__() + +class UnknownError(Exception): + def __init__(self): + super(UnknownError, self).__init__() + +class ProcessingError(Exception): + def __init__(self): + super(ProcessingError, self).__init__() + +class InvalidTokenSizeError(Exception): + def __init__(self): + super(InvalidTokenSizeError, self).__init__() + +class InvalidPayloadSizeError(Exception): + def __init__(self): + super(InvalidPayloadSizeError, self).__init__() + +class InvalidTokenError(Exception): + def __init__(self): + super(InvalidTokenError, self).__init__() + class GatewayConnection(APNsConnection): """ A class that represents a connection to the APNs gateway server @@ -274,21 +302,44 @@ def __init__(self, use_sandbox=False, **kwargs): 'gateway.sandbox.push.apple.com')[use_sandbox] self.port = 2195 - def _get_notification(self, token_hex, payload): + def _get_notification(self, token_hex, payload, identifier, expiry): """ Takes a token as a hex string and a payload as a Python dict and sends the notification """ + identifier_bin = identifier[:4] + expiry_bin = APNs.packed_uint_big_endian(int(mktime(expiry.timetuple()))) token_bin = a2b_hex(token_hex) token_length_bin = APNs.packed_ushort_big_endian(len(token_bin)) payload_json = payload.json() payload_length_bin = APNs.packed_ushort_big_endian(len(payload_json)) - notification = ('\0' + token_length_bin + token_bin + notification = ('\x01' + identifier_bin + expiry_bin + token_length_bin + token_bin + payload_length_bin + payload_json) return notification - def send_notification(self, token_hex, payload): - self.write(self._get_notification(token_hex, payload)) + def send_notification(self, token_hex, payload, expiry=None): + if expiry is None: + expiry = datetime.now() + timedelta(30) + + identifier = pack('>I', getrandbits(32)) + self.write(self._get_notification(token_hex, payload, identifier, expiry)) + + error_response = self.read(6) + if error_response != '': + command = error_response[0] + status = ord(error_response[1]) + response_identifier = error_response[2:6] + + if command != '\x08' or response_identifier != identifier: + raise UnknownResponse() + + if status == 0: + return + + raise {1: ProcessingError, + 5: InvalidTokenSizeError, + 7: InvalidPayloadSizeError, + 8: InvalidTokenError}.get(status, UnknownError)() diff --git a/tests.py b/tests.py index ff5af97..e1962f2 100644 --- a/tests.py +++ b/tests.py @@ -3,6 +3,7 @@ from apns import * from binascii import a2b_hex from random import random +from datetime import datetime import hashlib import os @@ -77,10 +78,14 @@ def testGatewayServer(self): sound = "default", badge = 4 ) - notification = gateway_server._get_notification(token_hex, payload) + identifier = 'abcd' + expiry = datetime(2000, 01, 01, 00, 00, 00) + notification = gateway_server._get_notification(token_hex, payload, identifier, expiry) expected_length = ( - 1 # leading null byte + 1 # leading command byte + + 4 # Identifier as a 4 bytes buffer + + 4 # Expiry timestamp as a packed integer + 2 # length of token as a packed short + len(token_hex) / 2 # length of token as binary string + 2 # length of payload as a packed short @@ -88,7 +93,7 @@ def testGatewayServer(self): ) self.assertEqual(len(notification), expected_length) - self.assertEqual(notification[0], '\0') + self.assertEqual(notification[0], '\x01') # Enhanched format command byte def testFeedbackServer(self): pem_file = TEST_CERTIFICATE