Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 38 additions & 10 deletions singer/messages.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import sys
import simplejson as json

import singer.utils as u
import dateutil
import pytz

class Message(object):
'''Base class for messages.'''
Expand Down Expand Up @@ -37,10 +39,14 @@ class RecordMessage(Message):

'''

def __init__(self, stream, record, version=None):
def __init__(self, stream, record, version=None, time_extracted=None):
self.stream = stream
self.record = record
self.version = version
self.time_extracted = time_extracted
if time_extracted and not time_extracted.tzinfo:
raise ValueError("'time_extracted' must be either None " +
"or an aware datetime (with a time zone)")

def asdict(self):
result = {
Expand All @@ -50,6 +56,9 @@ def asdict(self):
}
if self.version is not None:
result['version'] = self.version
if self.time_extracted:
as_utc = self.time_extracted.astimezone(pytz.utc)
result['time_extracted'] = as_utc.strftime(u.DATETIME_FMT)
return result

def __str__(self):
Expand All @@ -76,18 +85,22 @@ class SchemaMessage(Message):
>>> key_properties=['id'])

'''
def __init__(self, stream, schema, key_properties):
def __init__(self, stream, schema, key_properties, bookmark_properties=None):
self.stream = stream
self.schema = schema
self.key_properties = key_properties
self.bookmark_properties = bookmark_properties

def asdict(self):
return {
result = {
'type': 'SCHEMA',
'stream': self.stream,
'schema': self.schema,
'key_properties': self.key_properties
}
if self.bookmark_properties:
result['bookmark_properties'] = self.bookmark_properties
return result


class StateMessage(Message):
Expand Down Expand Up @@ -157,14 +170,20 @@ def parse_message(msg):
msg_type = _required_key(obj, 'type')

if msg_type == 'RECORD':
time_extracted = obj.get('time_extracted')
if time_extracted:
time_extracted = dateutil.parser.parse(time_extracted)
return RecordMessage(stream=_required_key(obj, 'stream'),
record=_required_key(obj, 'record'),
version=obj.get('version'))
version=obj.get('version'),
time_extracted=time_extracted)


elif msg_type == 'SCHEMA':
return SchemaMessage(stream=_required_key(obj, 'stream'),
schema=_required_key(obj, 'schema'),
key_properties=_required_key(obj, 'key_properties'))
key_properties=_required_key(obj, 'key_properties'),
bookmark_properties=obj.get('bookmark_properties'))

elif msg_type == 'STATE':
return StateMessage(value=_required_key(obj, 'value'))
Expand All @@ -183,12 +202,14 @@ def write_message(message):
sys.stdout.flush()


def write_record(stream_name, record, stream_alias=None):
def write_record(stream_name, record, stream_alias=None, time_extracted=None):
"""Write a single record for the given stream.

>>> write_record("users", {"id": 2, "email": "mike@stitchdata.com"})
"""
write_message(RecordMessage(stream=(stream_alias or stream_name), record=record))
write_message(RecordMessage(stream=(stream_alias or stream_name),
record=record,
time_extracted=time_extracted))


def write_records(stream_name, records):
Expand All @@ -202,7 +223,7 @@ def write_records(stream_name, records):
write_record(stream_name, record)


def write_schema(stream_name, schema, key_properties, stream_alias=None):
def write_schema(stream_name, schema, key_properties, bookmark_properties=None, stream_alias=None):
"""Write a schema message.

>>> stream = 'test'
Expand All @@ -214,11 +235,18 @@ def write_schema(stream_name, schema, key_properties, stream_alias=None):
key_properties = [key_properties]
if not isinstance(key_properties, list):
raise Exception("key_properties must be a string or list of strings")

if isinstance(bookmark_properties, (str, bytes)):
bookmark_properties = [bookmark_properties]
if bookmark_properties and not isinstance(bookmark_properties, list):
raise Exception("bookmark_properties must be a string or list of strings")

write_message(
SchemaMessage(
stream=(stream_alias or stream_name),
schema=schema,
key_properties=key_properties))
key_properties=key_properties,
bookmark_properties=bookmark_properties))


def write_state(value):
Expand Down
20 changes: 19 additions & 1 deletion tests/test_singer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import singer
import unittest

import datetime
import dateutil

class TestSinger(unittest.TestCase):
def test_parse_message_record_good(self):
Expand All @@ -17,6 +18,23 @@ def test_parse_message_record_with_version_good(self):
message,
singer.RecordMessage(record={'name': 'foo'}, stream='users', version=2))

def test_parse_message_record_naive_extraction_time(self):
with self.assertRaisesRegex(ValueError, "must be either None or an aware datetime"):
message = singer.parse_message(
'{"type": "RECORD", "record": {"name": "foo"}, "stream": "users", "version": 2, "time_extracted": "1970-01-02T00:00:00"}')

def test_parse_message_record_aware_extraction_time(self):
message = singer.parse_message(
'{"type": "RECORD", "record": {"name": "foo"}, "stream": "users", "version": 2, "time_extracted": "1970-01-02T00:00:00.000Z"}')
expected = singer.RecordMessage(
record={'name': 'foo'},
stream='users',
version=2,
time_extracted=dateutil.parser.parse("1970-01-02T00:00:00.000Z"))
print(message)
print(expected)
self.assertEqual(message, expected)

def test_parse_message_record_missing_record(self):
with self.assertRaises(Exception):
singer.parse_message('{"type": "RECORD", "stream": "users"}')
Expand Down