diff --git a/mongostore.go b/mongostore.go index 76789dbe9..e50af2c16 100644 --- a/mongostore.go +++ b/mongostore.go @@ -6,7 +6,7 @@ import ( "github.com/globalsign/mgo" "github.com/globalsign/mgo/bson" - + "github.com/pkg/errors" "github.com/quickfixgo/quickfix/config" ) @@ -66,7 +66,11 @@ func newMongoStore(sessionID SessionID, mongoURL string, mongoDatabase string, m messagesCollection: messagesCollection, sessionsCollection: sessionsCollection, } - store.cache.Reset() + + if err = store.cache.Reset(); err != nil { + err = errors.Wrap(err, "cache reset") + return + } if store.db, err = mgo.Dial(mongoURL); err != nil { return @@ -139,27 +143,43 @@ func (store *mongoStore) Refresh() error { return store.populateCache() } -func (store *mongoStore) populateCache() (err error) { +func (store *mongoStore) populateCache() error { msgFilter := generateMessageFilter(&store.sessionID) query := store.db.DB(store.mongoDatabase).C(store.sessionsCollection).Find(msgFilter) - if cnt, err := query.Count(); err == nil && cnt > 0 { + cnt, err := query.Count() + if err != nil { + return errors.Wrap(err, "count") + } + + if cnt > 0 { // session record found, load it sessionData := &mongoQuickFixEntryData{} - err = query.One(&sessionData) - if err == nil { - store.cache.creationTime = sessionData.CreationTime - store.cache.SetNextTargetMsgSeqNum(sessionData.IncomingSeqNum) - store.cache.SetNextSenderMsgSeqNum(sessionData.OutgoingSeqNum) + if err = query.One(&sessionData); err != nil { + return errors.Wrap(err, "query one") + } + + store.cache.creationTime = sessionData.CreationTime + if err = store.cache.SetNextTargetMsgSeqNum(sessionData.IncomingSeqNum); err != nil { + return errors.Wrap(err, "cache set next target") } - } else if err == nil && cnt == 0 { - // session record not found, create it - msgFilter.CreationTime = store.cache.creationTime - msgFilter.IncomingSeqNum = store.cache.NextTargetMsgSeqNum() - msgFilter.OutgoingSeqNum = store.cache.NextSenderMsgSeqNum() - err = store.db.DB(store.mongoDatabase).C(store.sessionsCollection).Insert(msgFilter) + + if err = store.cache.SetNextSenderMsgSeqNum(sessionData.OutgoingSeqNum); err != nil { + return errors.Wrap(err, "cache set next sender") + } + + return nil } - return + + // session record not found, create it + msgFilter.CreationTime = store.cache.creationTime + msgFilter.IncomingSeqNum = store.cache.NextTargetMsgSeqNum() + msgFilter.OutgoingSeqNum = store.cache.NextSenderMsgSeqNum() + + if err = store.db.DB(store.mongoDatabase).C(store.sessionsCollection).Insert(msgFilter); err != nil { + return errors.Wrap(err, "insert") + } + return nil } // NextSenderMsgSeqNum returns the next MsgSeqNum that will be sent @@ -200,13 +220,17 @@ func (store *mongoStore) SetNextTargetMsgSeqNum(next int) error { // IncrNextSenderMsgSeqNum increments the next MsgSeqNum that will be sent func (store *mongoStore) IncrNextSenderMsgSeqNum() error { - store.cache.IncrNextSenderMsgSeqNum() + if err := store.cache.IncrNextSenderMsgSeqNum(); err != nil { + return errors.Wrap(err, "cache incr") + } return store.SetNextSenderMsgSeqNum(store.cache.NextSenderMsgSeqNum()) } // IncrNextTargetMsgSeqNum increments the next MsgSeqNum that should be received func (store *mongoStore) IncrNextTargetMsgSeqNum() error { - store.cache.IncrNextTargetMsgSeqNum() + if err := store.cache.IncrNextTargetMsgSeqNum(); err != nil { + return errors.Wrap(err, "cache incr") + } return store.SetNextTargetMsgSeqNum(store.cache.NextTargetMsgSeqNum()) }