#!/usr/bin/env python
#
# Copyright (C) 2012-2015 Fanout, Inc.
#
# This file is part of Pushpin.
#
# Pushpin is free software: you can redistribute it and/or modify it under
# the terms of the GNU Affero General Public License as published by the Free
# Software Foundation, either version 3 of the License, or (at your option)
# any later version.
#
# Pushpin is distributed in the hope that it will be useful, but WITHOUT ANY
# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE. See the GNU Affero General Public License for
# more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.

import sys
import os
import ConfigParser

version = '1.5.0'

config_file = "/etc/pushpin/pushpin.conf"
log_file = None
verbose = False
for arg in sys.argv:
	if arg.startswith("--config="):
		config_file = arg[9:]
	elif arg.startswith("--logfile="):
		log_file = arg[10:]
	elif arg == "--verbose":
		verbose = True
	elif arg == "--version":
		print 'pushpin-handler %s' % version
		sys.exit(1)

class ConfigWithInclude(object):
	def __init__(self, fname):
		self.config_dir = os.path.dirname(fname)
		self.main = ConfigParser.ConfigParser()
		self.main.read([fname])

		self.include = None
		if self.main.has_option('global', 'include'):
			include_fname = self.main.get('global', 'include')
			if not os.path.isabs(include_fname):
				include_fname = os.path.join(self.config_dir, include_fname)
			self.include = ConfigParser.ConfigParser()
			self.include.read([include_fname])

	def has_option(self, section, key):
		if self.main.has_option(section, key):
			return True
		elif self.include is not None and self.include.has_option(section, key):
			return True
		return False

	def get(self, section, key):
		if self.main.has_option(section, key):
			return self.main.get(section, key)
		elif self.include is not None and self.include.has_option(section, key):
			return self.include.get(section, key)
		raise ConfigParser.NoOptionError(key, section)

config = ConfigWithInclude(config_file)

libdir = None
if config.has_option("global", "libdir"):
	libdir = config.get("global", "libdir")

if libdir:
	sys.path.insert(0, os.path.join(libdir, "handler"))

import time
import threading
import json
import copy
import logging
from logging.handlers import WatchedFileHandler
from base64 import b64decode
import urlparse
from setproctitle import setproctitle
import zmq
import tnetstring
import httpinterface
from validation import validate_publish, validate_http_publish, ValidationError
from conversion import ensure_utf8, convert_json_transport
from statusreasons import get_reason
import rpc

try:
	from sortedcontainers import SortedDict
except ImportError:
	from blist import sorteddict as SortedDict

setproctitle("pushpin-handler")

# reopen stdout file descriptor with write mode
# and 0 as the buffer size (unbuffered)
sys.stdout = os.fdopen(sys.stdout.fileno(), 'w', 0)

logger = logging.getLogger('handler')
if log_file:
	logger_handler = WatchedFileHandler(log_file)
else:
	logger_handler = logging.StreamHandler(stream=sys.stdout)
if verbose:
	logger.setLevel(logging.DEBUG)
	logger_handler.setLevel(logging.DEBUG)
else:
	logger.setLevel(logging.INFO)
	logger_handler.setLevel(logging.INFO)
formatter = logging.Formatter(fmt='%(levelname)s %(asctime)s.%(msecs)03d %(message)s', datefmt='%Y-%m-%d %H:%M:%S')
logger_handler.setFormatter(formatter)
logger.addHandler(logger_handler)

CONNECTION_TTL = 600
CONNECTION_REFRESH = 540
CONNECTION_LINGER = 60

SUBSCRIPTION_TTL = 60
SUBSCRIPTION_LINGER = 60

# zmq socket linger
DEFAULT_LINGER = 1000

SUB_SNDHWM = 0 # infinite

# delay to avoid overflowing the nic. wait 1ms for every 20 deliveries
SEND_BATCH_SIZE = 20
SEND_BATCH_DELAY = 0.001

instance_id = "pushpin-handler_%d" % os.getpid()

def get_option_raw(section, key, default=None):
	if config.has_option(section, key):
		return config.get(section, key)
	else:
		return default

rundir = get_option_raw("global", "rundir")
if rundir is None:
	# fallback to runner section (deprecated)
	rundir = get_option_raw("runner", "rundir")
assert(rundir)

def get_option(section, key, default=None):
	v = get_option_raw(section, key, default=default)
	if v is not None:
		v = v.replace('{rundir}', rundir)
	return v

ipc_file_mode = get_option("handler", "ipc_file_mode")
if ipc_file_mode is not None:
	ipc_file_mode = int(ipc_file_mode, 8)

m2a_in_stream_specs = get_option("handler", "m2a_in_stream_specs").split(",")
m2a_out_specs = get_option("handler", "m2a_out_specs").split(",")

if get_option("handler", "share_all") == "true":
	share_all = True
else:
	share_all = False

stats_spec = get_option("handler", "stats_spec")
command_spec = get_option("handler", "command_spec")
proxy_inspect_spec = get_option("handler", "proxy_inspect_spec")
proxy_accept_spec = get_option("handler", "proxy_accept_spec")
proxy_retry_out_spec = get_option("handler", "proxy_retry_out_spec")
ws_control_in_spec = get_option("handler", "proxy_ws_control_in_spec")
ws_control_out_spec = get_option("handler", "proxy_ws_control_out_spec")
proxy_stats_spec = get_option("handler", "proxy_stats_spec")
proxy_command_spec = get_option("handler", "proxy_command_spec")
state_spec = get_option("handler", "state_spec")
push_in_spec = get_option("handler", "push_in_spec")
push_in_sub_spec = get_option("handler", "push_in_sub_spec")
push_in_http_addr = get_option("handler", "push_in_http_addr")

push_in_http_port = get_option("handler", "push_in_http_port")
if push_in_http_port is not None:
	push_in_http_port = int(push_in_http_port)

assert(proxy_inspect_spec)
assert(proxy_accept_spec)
assert(proxy_retry_out_spec)
assert(push_in_spec)
assert(push_in_http_addr)
assert(push_in_http_port is not None)

ctx = zmq.Context()

class Hold(object):
	def __init__(self, rid, request, mode, response, auto_cross_origin, jsonp_callback):
		self.lock = threading.Lock()
		self.rid = rid
		self.out_seq = None # need to lock hold
		self.out_credits = 0 # need to lock hold
		self.request = request
		self.mode = mode
		self.response = response
		self.auto_cross_origin = auto_cross_origin
		self.jsonp_callback = jsonp_callback
		self.jsonp_extended_response = False
		self.expire_time = None
		self.last_keepalive = None
		self.grip_keep_alive = None
		self.grip_keep_alive_timeout = None
		self.last_send = None

class WsSession(object):
	def __init__(self, cid):
		self.cid = cid
		self.expire_time = None
		self.channel_prefix = None
		self.sid = None

class Subscription(object):
	def __init__(self, mode, channel):
		self.mode = mode
		self.channel = channel
		self.expire_time = None
		self.last_keepalive = None

class ConnectionInfo(object):
	def __init__(self, id, type):
		self.id = id
		self.type = type
		self.route = None
		self.peer_address = None
		self.ssl = False
		self.linger = False
		self.last_keepalive = None
		self.sid = None

class SessionDetectRule(object):
	def __init__(self):
		self.domain = None
		self.path_prefix = None
		self.json_param = None
		self.sid_ptr = None

class Session(object):
	def __init__(self):
		self.id = None

class StatsSubscription(object):
	def __init__(self):
		self.mode = None
		self.channel = None
		self.expire_time = None

lock = threading.Lock()
response_channels = dict()
response_lastids = dict()
stream_channels = dict()
channels_by_req = dict()
ws_sessions = dict()
ws_channels = dict()

# key=(type, channel)
subs = dict()

stats_lock = threading.Lock()
stats_activity = dict() # route, count
conns = dict() # conn id, ConnectionInfo

# assumes state is locked
# returns list of (mode, channel) that should be unsubscribed
def remove_from_req_channels(rid, response=True, stream=True):
	unsub = []
	req_channels = channels_by_req.get(rid)
	if req_channels is not None:
		to_remove = set()
		for mode, channel in req_channels:
			if response and mode == 'response':
				hchannels = response_channels.get(channel)
				if hchannels is not None and rid in hchannels:
					del hchannels[rid]
					if len(hchannels) == 0:
						del response_channels[channel]
						unsub.append((mode, channel))
				to_remove.add((mode, channel))
			elif stream and mode == 'stream':
				hchannels = stream_channels.get(channel)
				if hchannels is not None and rid in hchannels:
					del hchannels[rid]
					if len(hchannels) == 0:
						del stream_channels[channel]
						unsub.append((mode, channel))
				to_remove.add((mode, channel))
		for mode, channel in to_remove:
			req_channels.remove((mode, channel))
		if len(req_channels) == 0:
			del channels_by_req[rid]
	return unsub

# assumes state is locked
def remove_from_response_channels(rid):
	return remove_from_req_channels(rid, response=True, stream=False)

# assumes state is locked
def remove_from_stream_channels(rid):
	return remove_from_req_channels(rid, response=False, stream=True)

def headernames_contains(headers, name):
	lname = name.lower()
	for i in headers:
		if i.lower() == lname:
			return True
	return False

def header_get(headers, name):
	lname = name.lower()
	if isinstance(headers, list):
		for i in headers:
			if i[0].lower() == lname:
				return i[1]
	else:
		for k, v in headers.iteritems():
			if k.lower() == lname:
				return v
	return None

# return list of strings
def header_get_all(headers, name):
	lname = name.lower()
	hvals = list()
	if isinstance(headers, list):
		for i in headers:
			if i[0].lower() == lname:
				hvals.append(i[1])
	else:
		for k, v in headers.iteritems():
			if k.lower() == lname:
				hvals.append(v)
	out = list()
	for hval in hvals:
		parts = hval.split(',')
		for p in parts:
			p = p.strip()
			if p:
				out.append(p)
	return out

def header_remove(headers, name):
	lname = name.lower()
	if isinstance(headers, list):
		for n, i in enumerate(headers):
			if i[0].lower() == lname:
				del headers[n]
				break
	else:
		for k in headers.keys():
			if k.lower() == lname:
				del headers[k]
				break

def header_set(headers, name, value):
	header_remove(headers, name)
	headers[name] = value

def header_names_contains(header_names, name):
	lname = name.lower()
	for i in header_names:
		if i.lower() == lname:
			return True
	return False

# return (initial value, params dict)
def _parse_header_params(value):
	parts = value.split(';')
	v = parts[0].strip()
	params = dict()
	for n in range(1, len(parts)):
		part = parts[n].strip()
		at = part.find('=')
		if at != -1:
			pname = part[:at]
			pval = part[at + 1:]
		else:
			pname = part
			pval = None
		if pname:
			params[pname] = pval
	return (v, params)

# return (initial value, params dict) or None
def header_get_parsed(headers, name):
	h = header_get(headers, name)
	if h is None:
		return None
	return _parse_header_params(h)

# return list of (initial value, params dict)
def header_get_all_parsed(headers, name):
	out = list()
	hlist = header_get_all(headers, name)
	for h in hlist:
		out.append(_parse_header_params(h))
	return out

HTTP_FORMAT = "HTTP/1.1 %(code)s %(status)s\r\n%(headers)s\r\n\r\n%(body)s"
HTTP_FORMAT_NOHEADERS = "HTTP/1.1 %(code)s %(status)s\r\n\r\n%(body)s"

def http_response(body, code, status, headers):
	payload = {"code": code, "status": status, "body": body}
	header_set(headers, "Content-Length", str(len(body)))
	payload["headers"] = "\r\n".join("%s: %s" % (k, v) for k, v in
		headers.items())

	return HTTP_FORMAT % payload

def http_response_nolen(body, code, status, headers):
	payload = {"code": code, "status": status, "body": body}
	header_remove(headers, "Content-Length")

	if len(headers) > 0:
		payload["headers"] = "\r\n".join("%s: %s" % (k, v) for k, v in
			headers.items())

		return HTTP_FORMAT % payload
	else:
		return HTTP_FORMAT_NOHEADERS % payload

def reply_http_old(sock, rid, code, status, headers, body, nolen=False):
	header = "%s %d:%s," % (rid[0], len(rid[1]), rid[1])

	if isinstance(status, unicode):
		status = status.encode("utf-8")

	# ensure headers are utf-8
	tmp = dict()
	for k, v in headers.iteritems():
		if isinstance(k, unicode):
				k = k.encode("utf-8")
		if isinstance(v, unicode):
				v = v.encode("utf-8")
		tmp[k] = v
	headers = tmp

	if isinstance(body, unicode):
		body = body.encode("utf-8")

	if nolen:
		msg = http_response_nolen(body, code, status, headers)
	else:
		msg = http_response(body, code, status, headers)
	m_raw = header + " " + msg
	logger.debug("OUT publish: %s" % m_raw)
	sock.send(m_raw)

def reply_http_chunk_old(sock, rid, content):
	header = "%s %d:%s," % (rid[0], len(rid[1]), rid[1])
	m_raw = header + " " + content
	sock.send(m_raw)

# return body size sent
def reply_http(sock, rid, code, reason, headers, body, nolen=False, seq=None):
	if isinstance(reason, unicode):
		reason = reason.encode("utf-8")

	# ensure headers are utf-8
	tmp = dict()
	for k, v in headers.iteritems():
		if isinstance(k, unicode):
				k = k.encode("utf-8")
		if isinstance(v, unicode):
				v = v.encode("utf-8")
		tmp[k] = v
	headers = tmp

	if isinstance(body, unicode):
		body = body.encode("utf-8")

	if nolen:
		header_remove(headers, "Content-Length")
	else:
		header_set(headers, "Content-Length", str(len(body)))

	out = dict()
	out["from"] = instance_id
	out["id"] = rid[1]
	if seq is not None:
		out["seq"] = seq
	out["code"] = code
	out["reason"] = reason
	headers_list = list()
	for k, v in headers.iteritems():
		headers_list.append([k, v])
	out["headers"] = headers_list
	if body:
		out["body"] = body
	if nolen:
		out["more"] = True

	m_raw = rid[0] + " T" + tnetstring.dumps(out)
	logger.debug("OUT publish: %s" % m_raw)
	sock.send(m_raw)
	return len(body)

def reply_http_chunk(sock, rid, content, seq=None):
	out = dict()
	out["from"] = instance_id
	out["id"] = rid[1]
	if seq is not None:
		out["seq"] = seq
	out["body"] = content
	out["more"] = True

	m_raw = rid[0] + " T" + tnetstring.dumps(out)
	logger.debug("OUT publish: %s" % m_raw)
	sock.send(m_raw)

def reply_http_close(sock, rid, seq=None):
	out = dict()
	out["from"] = instance_id
	out["id"] = rid[1]
	if seq is not None:
		out["seq"] = seq

	m_raw = rid[0] + " T" + tnetstring.dumps(out)
	logger.debug("OUT publish: %s" % m_raw)
	sock.send(m_raw)

simple_headers = set()
simple_headers.add("Cache-control")
simple_headers.add("Content-Language")
simple_headers.add("Content-Length")
simple_headers.add("Content-Type")
simple_headers.add("Expires")
simple_headers.add("Last-Modified")
simple_headers.add("Pragma")

# modifies response_headers as needed
def apply_cors_headers(request_headers, response_headers):
	if not header_get(response_headers, "Access-Control-Allow-Methods"):
		acr_method = header_get(request_headers, "Access-Control-Request-Method")
		if acr_method:
			header_set(response_headers, "Access-Control-Allow-Methods", acr_method)
		else:
			header_set(response_headers, "Access-Control-Allow-Methods", "OPTIONS, HEAD, GET, POST, PUT, DELETE")

	if not header_get(response_headers, "Access-Control-Allow-Headers"):
		acr_headers = header_get(request_headers, "Access-Control-Request-Headers")
		allow_headers = list()
		if acr_headers:
			for name in acr_headers.split(","):
				name = name.strip()
				if name:
					allow_headers.append(name)
		if len(allow_headers) > 0:
			header_set(response_headers, "Access-Control-Allow-Headers", ", ".join(allow_headers))

	if not header_get(response_headers, "Access-Control-Expose-Headers"):
		expose_headers = list()
		for name in response_headers.keys():
			lname = name.lower()
			if not header_names_contains(simple_headers, name) and not lname.startswith("access-control-") and not header_names_contains(expose_headers, name):
				expose_headers.append(name)
		if len(expose_headers) > 0:
			header_set(response_headers, "Access-Control-Expose-Headers", ", ".join(expose_headers))

	if not header_get(response_headers, "Access-Control-Allow-Credentials"):
		header_set(response_headers, "Access-Control-Allow-Credentials", "true")

	if not header_get(response_headers, "Access-Control-Allow-Origin"):
		origin = header_get(request_headers, "Origin")
		if not origin:
			origin = "*"
		header_set(response_headers, "Access-Control-Allow-Origin", origin)

def bind_spec(sock, spec):
	sock.bind(spec)
	if spec.startswith('ipc://') and ipc_file_mode is not None:
		os.chmod(spec[6:], ipc_file_mode)

class JsonPointer(object):
	def __init__(self):
		self.obj = None
		self.child_name = None
		self.child_index = None

def resolve_json_pointer(data, p):
	if not p.startswith('/'):
		raise ValueError('pointer must start with /')

	ptr = JsonPointer()
	ptr.obj = data

	# root
	if len(p) == 1:
		return ptr

	parts = p.split('/')[1:]
	for part in parts:
		if len(part) == 0:
			raise ValueError('reference cannot be empty')

		part = part.replace('~1', '/')
		part = part.replace('~0', '~')

		if ptr.child_name is not None or ptr.child_index is not None:
			raise ValueError('cannot step into undefined reference')

		if isinstance(ptr.obj, dict):
			if part in ptr.obj:
				ptr.obj = ptr.obj[part]
			else:
				ptr.child_name = part
		elif isinstance(ptr.obj, list):
			if part == '-':
				ptr.child_index = -1
			else:
				try:
					index = int(part)
				except:
					raise ValueError('index must be an integer')
				if index >= len(ptr.obj):
					raise ValueError('index out of range')
				ptr.obj = ptr.obj[index]
		else:
			raise ValueError('non-container value cannot have child reference')

	return ptr

def json_patch(data, ops):
	out = copy.deepcopy(data)
	for op in ops:
		otype = op['op']
		if otype == 'add':
			ptr = resolve_json_pointer(out, op['path'])
			value = op['value']
			if ptr.child_name is not None:
				ptr.obj[ptr.child_name] = value
			elif ptr.child_index is not None:
				if ptr.child_index != -1:
					ptr.obj[ptr.child_index] = value
				else:
					ptr.obj.append(value)
			else:
				if ptr.obj is out:
					out = value
				else:
					ptr.obj = value
		else:
			raise ValueError('unsupported op: %s' % otype)
	return out

def session_detect_rules_set(state_rpc, rules):
	rule_data_list = list()
	for r in rules:
		rule_data = {"domain": r.domain, "path-prefix": r.path_prefix, "sid-ptr": r.sid_ptr}
		if r.json_param:
			rule_data["json-param"] = r.json_param
		rule_data_list.append(rule_data)
	state_rpc.call("session-detect-rules-set", {"rules": rule_data_list})

def session_detect_rules_get(state_rpc, domain, path):
	out = list()
	ret = state_rpc.call("session-detect-rules-get", {"domain": domain, "path": path})
	for rule_data in ret:
		r = SessionDetectRule()
		r.domain = rule_data["domain"]
		r.path_prefix = rule_data["path-prefix"]
		r.sid_ptr = rule_data["sid-ptr"]
		r.json_param = rule_data.get("json-param")
		out.append(r)
	return out

def session_create_or_update(state_rpc, sid, last_ids):
	state_rpc.call("session-create-or-update", {"sid": sid, "last-ids": last_ids})

# sid_last_ids is dict of {sid: last_ids}
def session_update_many(state_rpc, sid_last_ids):
	state_rpc.call("session-update-many", {"sid-last-ids": sid_last_ids})

def session_get_last_ids(state_rpc, sid):
	# dict of channel, last-id
	return state_rpc.call("session-get-last-ids", {"sid": sid})

def inspect_worker():
	sock = ctx.socket(zmq.REP)
	sock.linger = DEFAULT_LINGER
	sock.connect(proxy_inspect_spec)

	if state_spec:
		state_rpc = rpc.RpcClient(['inproc://state'], context=ctx)
	else:
		state_rpc = None

	while True:
		m_raw = sock.recv()
		m = tnetstring.loads(m_raw)
		logger.debug("IN inspect: %s" % m)

		id = m["id"]
		args = m["args"]

		method = args["method"]
		uri = args["uri"]
		get_session = args.get("get-session", False)

		m = dict()

		# reply saying to always proxy
		m["no-proxy"] = False

		if share_all:
			m["sharing-key"] = method + '|' + uri

		# determine session info
		if get_session and state_rpc:
			try:
				uri = urlparse.urlparse(uri)
				rules = session_detect_rules_get(state_rpc, uri.netloc, uri.path)
				logger.debug("retrieved %d rules", len(rules))
				if rules:
					sid = None
					for r in rules:
						try:
							if r.json_param:
								params = urlparse.parse_qs(args["body"])
								body = json.loads(params[r.json_param][0])
							else:
								body = json.loads(args["body"])
							ptr = resolve_json_pointer(body, r.sid_ptr)
						except:
							# try next rule
							continue
						if ptr and not ptr.child_name and not ptr.child_index:
							sid = ensure_utf8(ptr.obj)
							break
					if sid:
						try:
							last_ids = session_get_last_ids(state_rpc, sid)
						except rpc.CallError as e:
							if e.condition == "item-not-found":
								last_ids = None
							else:
								raise

						m["sid"] = sid
						if last_ids:
							m["last-ids"] = last_ids
			except:
				logger.exception("failed to detect session")

		resp = {"id": id, "success": True, "value": m}
		logger.debug("OUT inspect: %s" % resp)
		m_raw = tnetstring.dumps(resp)
		sock.send(m_raw)

	sock.close()

def accept_worker():
	sock = ctx.socket(zmq.REP)
	sock.linger = DEFAULT_LINGER
	sock.connect(proxy_accept_spec)

	out_sock = ctx.socket(zmq.PUB)
	out_sock.linger = DEFAULT_LINGER
	for spec in m2a_out_specs:
		out_sock.connect(spec)

	retry_sock = ctx.socket(zmq.PUSH)
	retry_sock.linger = 0
	retry_sock.connect(proxy_retry_out_spec)

	stats_sock = ctx.socket(zmq.PUSH)
	stats_sock.linger = 0
	stats_sock.connect('inproc://stats_in')

	if state_spec:
		state_rpc = rpc.RpcClient(['inproc://state'], context=ctx)
	else:
		state_rpc = None

	while True:
		m_raw = sock.recv()
		m = tnetstring.loads(m_raw)
		logger.debug("IN accept: %s" % m)

		req_id = m["id"]

		if m['method'] != 'accept':
			mresp = {"id": req_id, "success": False, "condition": "method-not-found"}
			logger.debug("OUT accept: %s" % mresp)
			sock.send(tnetstring.dumps(mresp))
			continue

		m = m["args"]

		if "response" not in m:
			mresp = {"id": req_id, "success": False, "condition": "bad-request"}
			logger.debug("OUT accept: %s" % mresp)
			sock.send(tnetstring.dumps(mresp))
			continue

		resp_headers = m["response"].get("headers") or {}
		content_type, x = header_get_parsed(resp_headers, "Content-Type")
		sid = header_get(resp_headers, "Grip-Session-Id")
		last_ids = dict()
		rule_headers = header_get_all_parsed(resp_headers, "Grip-Session-Detect")
		rules = list()
		try:
			for rh in rule_headers:
				params = rh[1]
				at = rh[0].find("=")
				if at != -1:
					params[rh[0][:at]] = rh[0][at + 1:]
				else:
					params[rh[0]] = ""
				if params.get("path-prefix") and params.get("sid-ptr"):
					r = SessionDetectRule()
					uri = m["request-data"]["uri"]
					r.domain = urlparse.urlparse(uri).netloc
					r.path_prefix = params["path-prefix"]
					r.sid_ptr = params["sid-ptr"]
					if "json-param" in params:
						r.json_param = params["json-param"]
					rules.append(r)
			last_headers = header_get_all_parsed(resp_headers, "Grip-Last")
			for lh in last_headers:
				channel = lh[0]
				last_id = lh[1]["last-id"]
				last_ids[channel] = last_id
		except:
			logger.debug("error parsing detect/last headers")

		if m.get("use-session") and state_rpc:
			if rules:
				try:
					session_detect_rules_set(state_rpc, rules)
				except:
					logger.debug("couldn't store detection rules")

			if sid:
				try:
					session_create_or_update(state_rpc, sid, last_ids)
				except:
					logger.debug("couldn't create/update session")

		# TODO: support non-hold grip-instruct

		if content_type != "application/grip-instruct" and not header_get(resp_headers, "Grip-Hold"):
			new_headers = list()
			for i in resp_headers:
				if i[0].lower().startswith('grip-'):
					continue
				new_headers.append(i)
			m["response"]["headers"] = new_headers
			mresp = {"id": req_id, "success": True, "value": {"response": m["response"]}}
			logger.debug("OUT accept: %s" % mresp)
			sock.send(tnetstring.dumps(mresp))
			continue

		mresp = {"id": req_id, "success": True, "value": {"accepted": True}}
		logger.debug("OUT accept: %s" % mresp)
		sock.send(tnetstring.dumps(mresp))

		reqs = m["requests"]

		try:
			if "route" in m:
				route = m["route"]
			else:
				route = ""
			if "channel-prefix" in m:
				channel_prefix = m["channel-prefix"]
			else:
				channel_prefix = ""

			mode = None
			channels = list()
			timeout = 55
			resp_headers = m['response'].get('headers') or {}

			# headers-based grip
			grip_hold = header_get_parsed(resp_headers, 'Grip-Hold')
			if grip_hold:
				mode = grip_hold[0]
			grip_channels = header_get_all_parsed(resp_headers, 'Grip-Channel')
			if grip_channels:
				for c in grip_channels:
					name = channel_prefix + c[0]
					prev_id = c[1].get('prev-id')
					channels.append((name, prev_id))
			grip_timeout = header_get_parsed(resp_headers, 'Grip-Timeout')
			if grip_timeout:
				timeout = int(grip_timeout[0])
			grip_expose_headers = header_get_all(resp_headers, 'Grip-Expose-Headers')
			grip_keep_alive = header_get_parsed(resp_headers, 'Grip-Keep-Alive')
			keep_alive = None
			keep_alive_timeout = None
			if grip_keep_alive:
				val = grip_keep_alive[0]
				params = grip_keep_alive[1]
				t = params.get('timeout')
				f = params.get('format')
				if f is None:
					f = 'raw'
				if f == 'cstring':
					keep_alive = val.decode('string_escape')
				elif f == 'base64':
					keep_alive = b64decode(val)
				elif f == 'raw':
					keep_alive = val
				else:
					raise ValueError('invalid keep alive format')
				if t is not None:
					keep_alive_timeout = int(t)
					if keep_alive_timeout < 1:
						keep_alive = None
						keep_alive_timeout = None
				else:
					keep_alive_timeout = 55

			instruct = None
			ctype = header_get_parsed(resp_headers, 'Content-Type')
			if ctype and ctype[0] == 'application/grip-instruct':
				if m['response']['code'] != 200:
					raise ValueError('response code for grip-instruct must be 200')
				instruct = json.loads(m['response']['body'])
				hold = instruct['hold']
				mode = hold.get('mode')
				if mode is None:
					mode = 'response'
				for hc in hold['channels']:
					name = channel_prefix + hc['name']
					prev_id = hc.get('prev-id')
					channels.append((name, prev_id))
				if 'timeout' in hold:
					timeout = int(hold['timeout'])
				if 'keep-alive' in hold:
					ka = hold['keep-alive']
					if 'content-bin' in ka:
						keep_alive = b64decode(ka['content-bin'])
					else:
						keep_alive = ka['content'].encode('utf-8')
					keep_alive_timeout = int(ka['timeout'])
				response = instruct.get('response')
				if response is None:
					response = dict()
					response['body'] = ''
				if "headers" in response and isinstance(response["headers"], list):
					d = dict()
					for i in response["headers"]:
						d[i[0]] = i[1]
					response["headers"] = d
				if "body-bin" in response:
					response["body"] = b64decode(response["body-bin"])
					del response["body-bin"]
				elif "body" in response:
					response["body"] = response["body"].encode("utf-8")
				else:
					response["body"] = ""
			else:
				response = m['response']
				if "headers" in response and isinstance(response["headers"], list):
					d = dict()
					for i in response["headers"]:
						if i[0].lower().startswith('grip-'):
							continue
						if grip_expose_headers and not headernames_contains(grip_expose_headers, i[0]):
							continue
						d[i[0]] = i[1]
					response["headers"] = d
				if 'body' not in response:
					response['body'] = ''

			if mode != 'response' and mode != 'stream':
				raise ValueError('bad mode')
		except:
			logger.debug("failed to parse accept instructions")
			for req in reqs:
				rid = (req["rid"]["sender"], req["rid"]["id"])
				rheaders = dict()
				rheaders['Content-Type'] = 'text/plain'
				reply_http(out_sock, rid, 502, 'Bad Gateway', rheaders, 'Error while proxying to origin.\n')
			continue

		logger.debug("accepting %d requests" % len(reqs))

		for req in reqs:
			rid = (req["rid"]["sender"], req["rid"]["id"])

			stats_lock.acquire()
			ci = ConnectionInfo("%s:%s" % (rid[0], rid[1]), "http")
			ci.route = route
			if "peer-address" in req:
				ci.peer_address = req["peer-address"]
			if "https" in req:
				ci.ssl = req["https"]
			ci.last_keepalive = int(time.time())
			ci.sid = sid
			# note: if we had a lingering connection, this will replace it
			conns[ci.id] = ci
			stats_lock.release()

			h = Hold(rid, m["request-data"], mode, response, req.get("auto-cross-origin"), req.get("jsonp-callback"))
			now = int(time.time())
			h.last_keepalive = now
			h.out_seq = req["out-seq"]
			h.out_credits = req["out-credits"]
			h.jsonp_extended_response = req.get("jsonp-extended-response", False)
			h.grip_keep_alive = keep_alive
			h.grip_keep_alive_timeout = keep_alive_timeout
			h.last_send = now

			notify_subs = set()
			if mode == "response":
				# bind channels
				quit = False
				lock.acquire()
				for channel, prev_id in channels:
					logger.debug("adding response hold on %s" % channel)
					h.expire_time = int(time.time()) + timeout
					if prev_id is not None:
						last_id = response_lastids.get(channel)
						if last_id is not None and last_id != prev_id:
							del response_lastids[channel]
							lock.release()
							stats_lock.acquire()
							ci.linger = True
							ci.last_keepalive = int(time.time())
							stats_lock.release()
							# note: we don't need to do a handoff here because we didn't ack to take over yet
							logger.debug("lastid inconsistency (got=%s, expected=%s), retrying" % (prev_id, last_id))
							r = dict()
							r["requests"] = [req] # only retry the request that failed the check
							r["request-data"] = m["request-data"]
							if "inspect" in m:
								r["inspect"] = m["inspect"]
							logger.debug("OUT retry: %s" % r)
							r_raw = tnetstring.dumps(r)
							retry_sock.send(r_raw)
							quit = True
							break
					hchannel = response_channels.get(channel)
					if not hchannel:
						hchannel = dict()
						response_channels[channel] = hchannel
					hchannel[rid] = h
					req_channels = channels_by_req.get(rid)
					if not req_channels:
						req_channels = set()
						channels_by_req[rid] = req_channels
					req_channels.add((mode, channel))
					sub_key = (mode, channel)
					sub = subs.get(sub_key)
					if not sub:
						sub = Subscription(mode, channel)
						sub.last_keepalive = now
						subs[sub_key] = sub
						notify_subs.add(sub_key)
					sub.expire_time = None
				if quit:
					# we already unlocked if this is set
					continue
				lock.release()

				# ack
				out = dict()
				out['from'] = instance_id
				out['id'] = rid[1]
				out['type'] = 'keep-alive'
				m_raw = rid[0] + ' T' + tnetstring.dumps(out)
				logger.debug('OUT publish (ack response accept): %s' % m_raw)
				out_sock.send(m_raw)
			else: # stream
				# initial reply
				if "code" in response:
					rcode = response["code"]
				else:
					rcode = 200

				if "reason" in response:
					rreason = response["reason"]
				else:
					rreason = get_reason(rcode)

				if "headers" in response:
					rheaders = response["headers"]
				else:
					rheaders = dict()

				if h.auto_cross_origin:
					apply_cors_headers(h.request["headers"], rheaders)

				h.lock.acquire()
				body_size = reply_http(out_sock, rid, rcode, rreason, rheaders, response.get("body"), True, h.out_seq)
				h.out_credits -= body_size
				h.out_seq += 1
				h.lock.release()

				# bind channels
				lock.acquire()
				for channel, prev_id in channels:
					logger.debug("adding stream hold on %s" % channel)
					hchannel = stream_channels.get(channel)
					if not hchannel:
						hchannel = dict()
						stream_channels[channel] = hchannel
					hchannel[rid] = h
					req_channels = channels_by_req.get(rid)
					if not req_channels:
						req_channels = set()
						channels_by_req[rid] = req_channels
					req_channels.add((mode, channel))
					sub_key = (mode, channel)
					sub = subs.get(sub_key)
					if not sub:
						sub = Subscription(mode, channel)
						sub.last_keepalive = now
						subs[sub_key] = sub
						notify_subs.add(sub_key)
					sub.expire_time = None
				lock.release()

			for sub_key in notify_subs:
				out = dict()
				out['from'] = instance_id
				out['mode'] = ensure_utf8(sub_key[0])
				out['channel'] = ensure_utf8(sub_key[1])
				out['ttl'] = SUBSCRIPTION_TTL
				stats_sock.send('sub ' + tnetstring.dumps(out))

	sock.close()

def push_in_zmq_worker():
	in_sock = ctx.socket(zmq.PULL)
	bind_spec(in_sock, push_in_spec)

	out_sock = ctx.socket(zmq.PUSH)
	out_sock.linger = DEFAULT_LINGER
	out_sock.connect("inproc://push_in")

	while True:
		m_raw = in_sock.recv()
		try:
			try:
				m = tnetstring.loads(m_raw)
			except:
				raise ValidationError("bad format (not a tnetstring)")

			m = validate_publish(m)

		except ValidationError as e:
			logger.debug("warning: %s, dropping" % e.message)
			continue

		out_sock.send(tnetstring.dumps(m))

	out_sock.close()

def push_in_sub_zmq_worker():
	in_sock = ctx.socket(zmq.SUB)
	in_sock.sndhwm = SUB_SNDHWM
	in_sock.linger = DEFAULT_LINGER
	bind_spec(in_sock, push_in_sub_spec)

	sub_cmd_in_sock = ctx.socket(zmq.PULL)
	sub_cmd_in_sock.connect("inproc://sub_cmd_in")

	out_sock = ctx.socket(zmq.PUSH)
	out_sock.linger = DEFAULT_LINGER
	out_sock.connect("inproc://push_in")

	poller = zmq.Poller()
	poller.register(in_sock, zmq.POLLIN)
	poller.register(sub_cmd_in_sock, zmq.POLLIN)

	while True:
		socks = dict(poller.poll())
		if socks.get(in_sock) == zmq.POLLIN:
			m_raw = in_sock.recv_multipart()

			try:
				try:
					m = tnetstring.loads(m_raw[1])
				except:
					raise ValidationError("bad format (not a tnetstring)")

				m['channel'] = m_raw[0]
				m = validate_publish(m)

			except ValidationError as e:
				logger.debug("warning: %s, dropping" % e.message)
				continue

			out_sock.send(tnetstring.dumps(m))
		elif socks.get(sub_cmd_in_sock) == zmq.POLLIN:
			m = tnetstring.loads(sub_cmd_in_sock.recv())
			mtype = m['type']
			channel = m['channel']
			if mtype == 'subscribe':
				logger.debug('SUB socket subscribe: %s' % channel)
				in_sock.setsockopt(zmq.SUBSCRIBE, channel)
			elif mtype == 'unsubscribe':
				logger.debug('SUB socket unsubscribe: %s' % channel)
				in_sock.setsockopt(zmq.UNSUBSCRIBE, channel)

	in_sock.close()
	out_sock.close()

# return None for success or string on error
def push_in_http_handler(context, m):
	out_sock = context["out_sock"]

	try:
		m = validate_http_publish(m)
	except ValidationError as e:
		return e.message

	for n, i in enumerate(m["items"]):
		try:
			out = dict()

			channel = i.get("channel")
			if channel is not None:
				out["channel"] = ensure_utf8(channel)

			id = i.get("id")
			if id is not None:
				out["id"] = ensure_utf8(id)

			prev_id = i.get("prev-id")
			if prev_id is not None:
				out["prev-id"] = ensure_utf8(prev_id)

			formats = i.get("formats")
			if formats is None:
				formats = dict()
				for f in ("http-response", "http-stream", "ws-message"):
					if f in i:
						formats[f] = i[f]

			convformats = dict()
			for k, v in formats.iteritems():
				if k in ("http-response", "http-stream", "ws-message"):
					convformats[ensure_utf8(k)] = convert_json_transport(k, v)

			out["formats"] = convformats

			out_sock.send(tnetstring.dumps(out))
		except Exception as e:
			logger.debug("failed to process item: %s", e.message)
			continue

def push_in_http_worker():
	out_sock = ctx.socket(zmq.PUSH)
	out_sock.linger = DEFAULT_LINGER
	out_sock.connect("inproc://push_in")

	context = dict()
	context["out_sock"] = out_sock
	httpinterface.run(push_in_http_addr, push_in_http_port, push_in_http_handler, context)

	out_sock.close()

def push_in_worker(c):
	in_sock = ctx.socket(zmq.PULL)
	in_sock.bind("inproc://push_in")
	c.acquire()
	c.notify()
	c.release()

	out_sock = ctx.socket(zmq.PUB)
	out_sock.linger = DEFAULT_LINGER
	for spec in m2a_out_specs:
		out_sock.connect(spec)

	ws_control_out_sock = ctx.socket(zmq.PUSH)
	ws_control_out_sock.linger = DEFAULT_LINGER
	ws_control_out_sock.connect(ws_control_out_spec)

	stats_sock = ctx.socket(zmq.PUSH)
	stats_sock.linger = 0
	stats_sock.connect('inproc://stats_in')

	if state_spec:
		state_rpc = rpc.RpcClient(['inproc://state'], context=ctx)
	else:
		state_rpc = None

	while True:
		m_raw = in_sock.recv()
		m = tnetstring.loads(m_raw)
		logger.debug("IN publish: %s" % m)
		channel = m["channel"]
		id = m.get("id", None)
		formats = m["formats"]

		response_holds = list()
		stream_holds = list()
		ws_cids = list()
		sids = set()
		notify_unsubs = set()

		if "http-response" in formats:
			lock.acquire()
			hchannel = response_channels.get(channel)
			if hchannel:
				response_holds = hchannel.values()
				unsubs = set()
				for h in response_holds:
					unsub_list = remove_from_response_channels(h.rid)
					for sub_key in unsub_list:
						unsubs.add(sub_key)
				assert(channel not in response_channels)
				for sub_key in unsubs:
					sub = subs.get(sub_key)
					if sub and sub.expire_time is None:
						# flag for deletion soon
						sub.expire_time = int(time.time()) + SUBSCRIPTION_LINGER
			item_id = m.get("id")
			if item_id is not None:
				response_lastids[channel] = item_id
			lock.release()
			stats_lock.acquire()
			for h in response_holds:
				ci = conns.get("%s:%s" % (h.rid[0], h.rid[1]))
				if ci is not None and ci.sid:
					sids.add(ci.sid)
			stats_lock.release()

		if "http-stream" in formats:
			do_close = (formats["http-stream"].get("action") == "close")

			lock.acquire()
			hchannel = stream_channels.get(channel)
			if hchannel:
				stream_holds = hchannel.values()
				if do_close:
					unsubs = set()
					for h in stream_holds:
						unsub_list = remove_from_stream_channels(h.rid)
						for sub_key in unsub_list:
							unsubs.add(sub_key)
					assert(channel not in stream_channels)
					for sub_key in unsubs:
						sub = subs.get(sub_key)
						if sub:
							del subs[sub_key]
							notify_unsubs.add(sub_key)
			lock.release()
			stats_lock.acquire()
			for h in stream_holds:
				ci = conns.get("%s:%s" % (h.rid[0], h.rid[1]))
				if ci is not None and ci.sid:
					sids.add(ci.sid)
			stats_lock.release()

		if "ws-message" in formats:
			lock.acquire()
			hchannel = ws_channels.get(channel)
			if hchannel:
				for cid, sess in hchannel.iteritems():
					ws_cids.append(cid)
					if sess.sid:
						sids.add(sess.sid)
			lock.release()

		# update sessions' last-id
		if id is not None and sids:
			sid_last_ids = dict()
			for sid in sids:
				sid_last_ids[sid] = {channel: id}

			if sid_last_ids and state_rpc:
				try:
					session_update_many(state_rpc, sid_last_ids)
				except:
					logger.debug("couldn't update sessions")

		if response_holds:
			logger.debug("relaying to %d http-response subscribers" % len(response_holds))
			http_response = formats["http-response"]

			if "code" in http_response:
				pcode = http_response["code"]
			else:
				pcode = 200

			if "reason" in http_response:
				preason = http_response["reason"]
			else:
				preason = get_reason(pcode)

			if "headers" in http_response:
				pheaders = http_response["headers"]
				if isinstance(pheaders, list):
					d = dict()
					for i in pheaders:
						d[i[0]] = i[1]
					pheaders = d
			else:
				pheaders = dict()

			if "body" in http_response:
				pbody = http_response["body"]
			else:
				pbody = ""

			if "body-patch" in http_response:
				pbody_patch = http_response["body-patch"]
			else:
				pbody_patch = None

			grip_expose_headers = header_get_all(pheaders, 'Grip-Expose-Headers')
			header_remove(pheaders, 'Grip-Expose-Headers')

			for n, h in enumerate(response_holds):
				# inherit any headers from the timeout response
				if 'headers' in h.response:
					rheaders = copy.deepcopy(h.response['headers'])
				else:
					rheaders = dict()

				# apply the headers from the pushed message
				for k, v in pheaders.iteritems():
					header_set(rheaders, k, v)

				# if Grip-Expose-Headers was provided in the pushed message, filter the results
				if grip_expose_headers:
					rkeys = rheaders.keys()
					for k in rkeys:
						if not headernames_contains(grip_expose_headers, k):
							del rheaders[k]

				# if body patch specified, inherit body from timeout response
				if pbody_patch is not None:
					try:
						rbody = json.loads(h.response['body'])
						rbody = json_patch(rbody, pbody_patch)
						rbody = json.dumps(rbody)
					except Exception as e:
						logger.debug("failed to parse json patch: %s", e.message)
						rbody = ''
				else:
					rbody = pbody

				headers = dict()
				if h.jsonp_callback:
					if h.jsonp_extended_response:
						result = dict()
						result["code"] = pcode
						result["reason"] = preason
						result["headers"] = dict()
						if rheaders:
							for k, v in rheaders.iteritems():
								result["headers"][k] = v
						header_set(result["headers"], "Content-Length", str(len(pbody)))
						result["body"] = rbody

						body = h.jsonp_callback + "(" + json.dumps(result) + ");\n"
					else:
						body = h.jsonp_callback + "(" + rbody + ");\n"

					header_set(headers, "Content-Type", "application/javascript")
					header_set(headers, "Content-Length", str(len(body)))
					reply_http(out_sock, h.rid, 200, "OK", headers, body)
				else:
					if rheaders:
						for k, v in rheaders.iteritems():
							headers[k] = v

					if h.auto_cross_origin:
						apply_cors_headers(h.request["headers"], headers)

					reply_http(out_sock, h.rid, pcode, preason, headers, rbody)

				# report request done

				stats_lock.acquire()
				ci = conns.get("%s:%s" % (h.rid[0], h.rid[1]))
				if ci is not None:
					ci = copy.deepcopy(ci)
					del conns[ci.id]
				stats_lock.release()

				if ci is not None:
					out = dict()
					out['from'] = instance_id
					if ci.route:
						out['route'] = ci.route
					out['id'] = ci.id
					out['unavailable'] = True
					stats_sock.send('conn ' + tnetstring.dumps(out))

				if n % SEND_BATCH_SIZE == 0:
					time.sleep(SEND_BATCH_DELAY)

			rcount = len(response_holds)
			if rcount > 0:
				out = dict()
				out['from'] = instance_id
				out['channel'] = ensure_utf8(channel)
				if 'id' in m:
					out['item-id'] = ensure_utf8(m['id'])
				out['count'] = rcount
				out['transport'] = 'http-response'
				stats_sock.send('message ' + tnetstring.dumps(out))

		if stream_holds:
			content = formats["http-stream"].get("content")
			if content or do_close:
				logger.debug("relaying to %d http-stream subscribers" % len(stream_holds))
				for n, h in enumerate(stream_holds):
					if content:
						if h.out_credits < len(content):
							logger.debug('not enough send credits, dropping')
							continue
						h.lock.acquire()
						reply_http_chunk(out_sock, h.rid, content, h.out_seq)
						h.out_credits -= len(content)
						h.out_seq += 1
						h.lock.release()
						lock.acquire()
						h.last_send = int(time.time())
						lock.release()
					if do_close:
						reply_http_close(out_sock, h.rid)

						# report request done

						stats_lock.acquire()
						ci = conns.get("%s:%s" % (h.rid[0], h.rid[1]))
						if ci is not None:
							ci = copy.deepcopy(ci)
							del conns[ci.id]
						stats_lock.release()

						if ci is not None:
							out = dict()
							out['from'] = instance_id
							if ci.route:
								out['route'] = ci.route
							out['id'] = ci.id
							out['unavailable'] = True
							stats_sock.send('conn ' + tnetstring.dumps(out))

					if n % SEND_BATCH_SIZE == 0:
						time.sleep(SEND_BATCH_DELAY)

			rcount = len(stream_holds)
			if rcount > 0:
				out = dict()
				out['from'] = instance_id
				out['channel'] = ensure_utf8(channel)
				if 'id' in m:
					out['item-id'] = ensure_utf8(m['id'])
				out['count'] = rcount
				out['transport'] = 'http-stream'
				stats_sock.send('message ' + tnetstring.dumps(out))

		if ws_cids:
			logger.debug("relaying to %d ws-message subscribers" % len(ws_cids))
			for n, cid in enumerate(ws_cids):
				t = formats['ws-message']
				if 'content-bin' in t:
					content_type = 'binary'
					content = t['content-bin']
				elif "content" in t:
					content_type = 'text'
					content = t['content']
				else:
					content = None

				if content is not None:
					item = dict()
					item['cid'] = cid
					item['type'] = 'send'
					item['content-type'] = content_type
					item['message'] = content
					out = dict()
					out['items'] = [item]
					logger.debug('OUT wscontrol: %s' % out)
					ws_control_out_sock.send(tnetstring.dumps(out))

				if n % SEND_BATCH_SIZE == 0:
					time.sleep(SEND_BATCH_DELAY)

			rcount = len(ws_cids)
			if rcount > 0:
				out = dict()
				out['from'] = instance_id
				out['channel'] = ensure_utf8(channel)
				if 'id' in m:
					out['item-id'] = ensure_utf8(m['id'])
				out['count'] = rcount
				out['transport'] = 'ws-message'
				stats_sock.send('message ' + tnetstring.dumps(out))

		for sub_key in notify_unsubs:
			out = dict()
			out['from'] = instance_id
			out['mode'] = ensure_utf8(sub_key[0])
			out['channel'] = ensure_utf8(sub_key[1])
			out['unavailable'] = True
			stats_sock.send('sub ' + tnetstring.dumps(out))

	in_sock.close()

def session_worker():
	in_sock = ctx.socket(zmq.DEALER)
	in_sock.identity = instance_id
	for spec in m2a_in_stream_specs:
		in_sock.connect(spec)

	out_sock = ctx.socket(zmq.PUB)
	out_sock.linger = DEFAULT_LINGER
	for spec in m2a_out_specs:
		out_sock.connect(spec)

	stats_sock = ctx.socket(zmq.PUSH)
	stats_sock.linger = 0
	stats_sock.connect('inproc://stats_in')

	while True:
		m_list = in_sock.recv_multipart()
		m = tnetstring.loads(m_list[1][1:])
		logger.debug('IN session: %s' % m)
		mtype = m.get('type')
		notify_unsubs = set()
		if mtype is not None and (mtype == 'error' or mtype == 'cancel'):
			rid = (m['from'], m['id'])
			logger.debug('cleaning up subscriber %s' % repr(rid))
			now = int(time.time())
			lock.acquire()
			unsub_list = remove_from_req_channels(rid)
			for sub_key in unsub_list:
				sub = subs.get(sub_key)
				if sub:
					if sub.mode == 'response' and sub.expire_time is None:
						# flag for deletion soon
						sub.expire_time = now + SUBSCRIPTION_LINGER
					elif sub.mode == 'stream':
						del subs[sub_key]
						notify_unsubs.add(sub_key)
			lock.release()

			stats_lock.acquire()
			ci = conns.get("%s:%s" % (rid[0], rid[1]))
			if ci is not None:
				ci = copy.deepcopy(ci)
				del conns[ci.id]
			stats_lock.release()

			if ci is not None:
				out = dict()
				out['from'] = instance_id
				if ci.route:
					out['route'] = ci.route
				out['id'] = ci.id
				out['unavailable'] = True
				stats_sock.send('conn ' + tnetstring.dumps(out))

		elif mtype is not None:
			# is this a known session?
			rid = (m['from'], m['id'])

			h = None
			lock.acquire()
			for hchannels in response_channels.itervalues():
				if rid in hchannels:
					h = hchannels[rid]
					break
			if h is None:
				for hchannels in stream_channels.itervalues():
					if rid in hchannels:
						h = hchannels[rid]
						break
			lock.release()

			if h is not None:
				if mtype == 'credit':
					credits = m['credits']
					lock.acquire()
					h.out_credits += credits
					lock.release()
					logger.debug('received %d credits, now %d' % (credits, h.out_credits))
			else:
				# no such session, send cancel
				out = dict()
				out['from'] = instance_id
				out['id'] = m['id']
				out['type'] = 'cancel'
				m_raw = m['from'] + ' T' + tnetstring.dumps(out)
				logger.debug('OUT publish: %s' % m_raw)
				out_sock.send(m_raw)

		for sub_key in notify_unsubs:
			out = dict()
			out['from'] = instance_id
			out['mode'] = ensure_utf8(sub_key[0])
			out['channel'] = ensure_utf8(sub_key[1])
			out['unavailable'] = True
			stats_sock.send('sub ' + tnetstring.dumps(out))

def ws_control_worker():
	in_sock = ctx.socket(zmq.PULL)
	in_sock.connect(ws_control_in_spec)

	out_sock = ctx.socket(zmq.PUSH)
	out_sock.linger = DEFAULT_LINGER
	out_sock.connect(ws_control_out_spec)

	stats_sock = ctx.socket(zmq.PUSH)
	stats_sock.linger = 0
	stats_sock.connect('inproc://stats_in')

	if state_spec:
		state_rpc = rpc.RpcClient(['inproc://state'], context=ctx)
	else:
		state_rpc = None

	while True:
		m_raw = in_sock.recv()
		m = tnetstring.loads(m_raw)
		logger.debug('IN wscontrol: %s' % m)

		now = int(time.time())

		for item in m['items']:
			mtype = item.get('type')
			if mtype == 'here':
				cid = item['cid']
				channel_prefix = item.get('channel-prefix')
				lock.acquire()
				s = ws_sessions.get(cid)
				if not s:
					s = WsSession(cid)
					s.channel_prefix = channel_prefix
					ws_sessions[cid] = s
					logger.debug('added ws session: %s' % cid)
				s.expire_time = now + 60
				lock.release()
			elif mtype == 'gone' or mtype == 'cancel':
				cid = item['cid']

				notify_unsubs = set()

				lock.acquire()
				s = ws_sessions.get(cid)
				if s:
					del ws_sessions[cid]
					channels = set()
					for channel, hchannels in ws_channels.iteritems():
						channels.add(channel)
						if cid in hchannels:
							del hchannels[cid]
					for channel in channels:
						if channel in ws_channels and len(ws_channels[channel]) == 0:
							del ws_channels[channel]
							sub_key = ('ws', channel)
							sub = subs.get(sub_key)
							if sub:
								del subs[sub_key]
								notify_unsubs.add(sub_key)
					logger.debug('removed ws session: %s' % cid)
				lock.release()

				for sub_key in notify_unsubs:
					out = dict()
					out['from'] = instance_id
					out['mode'] = ensure_utf8(sub_key[0])
					out['channel'] = ensure_utf8(sub_key[1])
					out['unavailable'] = True
					stats_sock.send('sub ' + tnetstring.dumps(out))

			elif mtype == 'grip':
				cid = item['cid']

				try:
					gm = json.loads(item['message'])
					gtype = gm['type']
					channel = None
					sid = None
					if gtype == 'subscribe' or gtype == 'unsubscribe':
						channel = gm['channel']
						if not isinstance(channel, basestring) or len(channel) < 1:
							raise ValueError('invalid channel')
					elif gtype == 'session':
						sid = gm['id']
						if not isinstance(sid, basestring) or len(sid) < 1:
							raise ValueError('invalid id')
						sid = ensure_utf8(sid)
				except:
					gm = None

				notify_sub = False
				notify_unsub = False
				notify_detach = False
				save_sid = False

				if gm:
					lock.acquire()
					s = ws_sessions.get(cid)
					if s:
						if channel is not None and s.channel_prefix:
							channel = s.channel_prefix + channel
						if gtype == 'subscribe':
							hchannel = ws_channels.get(channel)
							if not hchannel:
								hchannel = dict()
								ws_channels[channel] = hchannel
							hchannel[cid] = s
							sub_key = ('ws', channel)
							sub = subs.get(sub_key)
							if not sub:
								sub = Subscription('ws', channel)
								sub.last_keepalive = now
								subs[sub_key] = sub
								notify_sub = True
							sub.expire_time = None
						elif gtype == 'unsubscribe':
							hchannel = ws_channels.get(channel)
							if hchannel:
								if cid in hchannel:
									del hchannel[cid]
									if len(hchannel) == 0:
										del ws_channels[channel]
									sub_key = ('ws', channel)
									sub = subs.get(sub_key)
									if sub:
										del subs[sub_key]
										notify_unsub = True
						elif gtype == 'detach':
							notify_detach = True
						elif gtype == 'session':
							s.sid = sid
							save_sid = True
					lock.release()

				if state_rpc and save_sid:
					try:
						session_create_or_update(state_rpc, sid, {})
					except:
						logger.debug("couldn't create/update session")

				if notify_sub:
					logger.debug('ws session %s subscribed to %s' % (cid, channel))
					out = dict()
					out['from'] = instance_id
					out['mode'] = 'ws'
					out['channel'] = ensure_utf8(channel)
					out['ttl'] = SUBSCRIPTION_TTL
					stats_sock.send('sub ' + tnetstring.dumps(out))

				if notify_unsub:
					out = dict()
					out['from'] = instance_id
					out['mode'] = 'ws'
					out['channel'] = ensure_utf8(channel)
					out['unavailable'] = True
					stats_sock.send('sub ' + tnetstring.dumps(out))

				if notify_detach:
					item = dict()
					item['cid'] = cid
					item['type'] = 'detach'
					out = dict()
					out['items'] = [item]
					logger.debug('OUT wscontrol: %s' % out)
					out_sock.send(tnetstring.dumps(out))

def timeout_worker():
	out_sock = ctx.socket(zmq.PUB)
	out_sock.linger = DEFAULT_LINGER
	for spec in m2a_out_specs:
		out_sock.connect(spec)

	stats_sock = ctx.socket(zmq.PUSH)
	stats_sock.linger = 0
	stats_sock.connect('inproc://stats_in')

	if state_spec:
		state_rpc = rpc.RpcClient(['inproc://state'], context=ctx)
	else:
		state_rpc = None

	while True:
		now = int(time.time())

		lock.acquire()
		holds = list()
		for hchannels in response_channels.itervalues():
			for h in hchannels.values():
				if h.expire_time and now >= h.expire_time:
					holds.append(h)
		for h in holds:
			unsub_list = remove_from_response_channels(h.rid)
			for sub_key in unsub_list:
				sub = subs.get(sub_key)
				if sub and sub.expire_time is None:
					# flag for deletion soon
					sub.expire_time = now + SUBSCRIPTION_LINGER
		lock.release()

		if len(holds) > 0:
			logger.debug("timing out %d subscribers" % len(holds))

			for h in holds:
				if "code" in h.response:
					pcode = h.response["code"]
				else:
					pcode = 200

				if "reason" in h.response:
					preason = h.response["reason"]
				else:
					preason = get_reason(pcode)

				if "headers" in h.response:
					pheaders = h.response["headers"]
				else:
					pheaders = dict()

				if "body" in h.response:
					pbody = h.response["body"]
				else:
					pbody = ""

				headers = dict()
				if h.jsonp_callback:
					if h.jsonp_extended_response:
						result = dict()
						result["code"] = pcode
						result["reason"] = preason
						result["headers"] = dict()
						if pheaders:
							for k, v in pheaders.iteritems():
								result["headers"][k] = v
						header_set(result["headers"], "Content-Length", str(len(pbody)))
						result["body"] = pbody

						body = h.jsonp_callback + "(" + json.dumps(result) + ");\n"
					else:
						body = h.jsonp_callback + "(" + pbody + ");\n"

					header_set(headers, "Content-Type", "application/javascript")
					header_set(headers, "Content-Length", str(len(body)))
					reply_http(out_sock, h.rid, 200, "OK", headers, body)
				else:
					if pheaders:
						for k, v in pheaders.iteritems():
							headers[k] = v

					if h.auto_cross_origin:
						apply_cors_headers(h.request["headers"], headers)

					reply_http(out_sock, h.rid, pcode, preason, headers, pbody)

				stats_lock.acquire()
				ci = conns.get("%s:%s" % (h.rid[0], h.rid[1]))
				if ci is not None:
					ci = copy.deepcopy(ci)
					del conns[ci.id]
				stats_lock.release()

				if ci is not None:
					out = dict()
					out['from'] = instance_id
					if ci.route:
						out['route'] = ci.route
					out['id'] = ci.id
					out['unavailable'] = True
					stats_sock.send('conn ' + tnetstring.dumps(out))

		now = int(time.time())
		grip_ka_rids = dict() # (hold, content)
		ka_rids = set()
		lock.acquire()
		for channel, hchannels in response_channels.iteritems():
			for h in hchannels.values():
				if h.last_keepalive is None or h.last_keepalive + 30 < now:
					if h.rid not in ka_rids:
						h.last_keepalive = now
						ka_rids.add(h.rid)
		for channel, hchannels in stream_channels.iteritems():
			for h in hchannels.values():
				if h.grip_keep_alive and (h.last_send is None or h.last_send + h.grip_keep_alive_timeout < now):
					if h.rid not in grip_ka_rids:
						content = h.grip_keep_alive
						if h.out_credits < len(content):
							logger.debug('not enough send credits, skipping keep alive')
							continue
						h.last_send = now
						h.last_keepalive = now
						grip_ka_rids[h.rid] = (h, content)
				if h.last_keepalive is None or h.last_keepalive + 30 < now:
					if h.rid not in grip_ka_rids and h.rid not in ka_rids:
						h.last_keepalive = now
						ka_rids.add(h.rid)
		lock.release()

		if len(grip_ka_rids) > 0:
			logger.debug("keep-aliving (grip) %d subscribers" % len(grip_ka_rids))
			for rid, v in grip_ka_rids.iteritems():
				h, content = v
				h.lock.acquire()
				reply_http_chunk(out_sock, rid, content, h.out_seq)
				h.out_credits -= len(content)
				h.out_seq += 1
				h.lock.release()

		if len(ka_rids) > 0:
			logger.debug("keep-aliving %d subscribers" % len(ka_rids))
			for rid in ka_rids:
				out = dict()
				out['from'] = instance_id
				out['id'] = rid[1]
				out['type'] = 'keep-alive'
				m_raw = rid[0] + ' T' + tnetstring.dumps(out)
				logger.debug('OUT publish: %s' % m_raw)
				out_sock.send(m_raw)

		now = int(time.time())
		cids = set()
		lock.acquire()
		for cid, s in ws_sessions.iteritems():
			if s.expire_time and now >= s.expire_time:
				cids.add(cid)
		for cid in cids:
			del ws_sessions[cid]
			channels = set()
			for channel, hchannels in ws_channels.iteritems():
				channels.add(channel)
				if cid in hchannels:
					del hchannels[cid]
			for channel in channels:
				if channel in ws_channels and len(ws_channels[channel]) == 0:
					del ws_channels[channel]
					sub_key = ('ws', channel)
					sub = subs.get(sub_key)
					if sub:
						# use expire_time to flag for removal
						sub.expire_time = now
		lock.release()
		if len(cids) > 0:
			logger.debug("timing out %d ws sessions" % len(cids))

		notify_subs = set()
		notify_unsubs = set()
		lock.acquire()
		for sub_key, sub in subs.iteritems():
			if sub.expire_time is not None and now >= sub.expire_time:
				notify_unsubs.add(sub_key)
			elif sub.last_keepalive is None or sub.last_keepalive + 30 < now:
				sub.last_keepalive = now
				notify_subs.add(sub_key)
		for sub_key in notify_unsubs:
			del subs[sub_key]
		lock.release()

		for sub_key in notify_unsubs:
			out = dict()
			out['from'] = instance_id
			out['mode'] = ensure_utf8(sub_key[0])
			out['channel'] = ensure_utf8(sub_key[1])
			out['unavailable'] = True
			stats_sock.send('sub ' + tnetstring.dumps(out))

		if len(notify_subs) > 0:
			logger.debug('keep-aliving %d subscriptions' % len(notify_subs))
			for sub_key in notify_subs:
				out = dict()
				out['from'] = instance_id
				out['mode'] = ensure_utf8(sub_key[0])
				out['channel'] = ensure_utf8(sub_key[1])
				out['ttl'] = SUBSCRIPTION_TTL
				stats_sock.send('sub ' + tnetstring.dumps(out))

		refresh_conns = list()
		send_activity = list()
		stats_lock.acquire()
		remove_cids = set()
		for cid, ci in conns.iteritems():
			if ci.last_keepalive is None or (not ci.linger and ci.last_keepalive + CONNECTION_REFRESH < now) or (ci.linger and ci.last_keepalive + CONNECTION_LINGER < now):
				if ci.linger:
					remove_cids.add(cid)
				else:
					ci.last_keepalive = now
					refresh_conns.append(copy.deepcopy(ci))
		for cid in remove_cids:
			del conns[cid]
		for route, activity in stats_activity.iteritems():
			send_activity.append((route, activity))
		stats_activity.clear()
		stats_lock.release()

		refresh_sids = list()

		for ci in refresh_conns:
			out = dict()
			out['from'] = instance_id
			if ci.route:
				out['route'] = ci.route
			out['id'] = ci.id
			out['type'] = ci.type
			if ci.peer_address:
				out['peer-address'] = ci.peer_address
			if ci.ssl:
				out['ssl'] = True
			out['ttl'] = CONNECTION_TTL
			stats_sock.send('conn ' + tnetstring.dumps(out))

			if ci.sid:
				refresh_sids.append(ci.sid)

		for i in send_activity:
			out = dict()
			out['from'] = instance_id
			if i[0]:
				out['route'] = i[0]
			out['count'] = i[1]
			stats_sock.send('activity ' + tnetstring.dumps(out))

		if refresh_sids and state_rpc:
			try:
				sid_last_ids = dict()
				for sid in refresh_sids:
					sid_last_ids[sid] = dict()
				session_update_many(state_rpc, sid_last_ids)
			except:
				logger.debug("couldn't update sessions")

		time.sleep(1)

def command_handler(method, args, data):
	proxy_client = data.get('proxy_client')

	logger.debug('IN command: %s args=%s' % (method, args))

	if method == 'conncheck':
		if 'ids' not in args or not isinstance(args['ids'], list):
			raise rpc.CallError('bad-format')
		cids = set(args['ids'])

		stats_lock.acquire()
		missing = set()
		for cid in cids:
			if cid not in conns:
				missing.add(cid)
		stats_lock.release()

		if len(missing) > 0:
			try:
				found = proxy_client.call('conncheck', {'ids': list(missing)})
			except:
				raise rpc.CallError('proxy-request-failed')
			for cid in found:
				missing.remove(cid)

		for cid in missing:
			cids.remove(cid)

		return list(cids)
	elif method == 'get-zmq-uris':
		out = dict()
		if command_spec:
			out['command'] = command_spec
		if push_in_spec:
			out['publish-pull'] = push_in_spec
		if push_in_sub_spec:
			out['publish-sub'] = push_in_sub_spec
		return out
	else:
		raise rpc.CallError('method-not-found')

command_server = None

def command_worker():
	data = dict()
	if proxy_command_spec:
		data['proxy_client'] = rpc.RpcClient([proxy_command_spec], context=ctx)
	command_server = rpc.RpcServer(command_spec, context=ctx, ipc_file_mode=ipc_file_mode)
	command_server.run(command_handler, data)

def proxy_stats_worker():
	in_sock = ctx.socket(zmq.SUB)
	in_sock.setsockopt(zmq.SUBSCRIBE, '')
	in_sock.connect(proxy_stats_spec)

	stats_sock = ctx.socket(zmq.PUSH)
	stats_sock.linger = 0
	stats_sock.connect('inproc://stats_in')

	if state_spec:
		state_rpc = rpc.RpcClient(['inproc://state'], context=ctx)
	else:
		state_rpc = None

	while True:
		m_raw = in_sock.recv()
		at = m_raw.find(' ')
		mtype = m_raw[:at]
		m = tnetstring.loads(m_raw[at + 1:])
		#logger.debug("IN proxy stats: %s %s" % (mtype, m))
		if mtype == 'activity':
			route = m.get('route')
			if not route:
				route = ''
			count = m['count']
			stats_lock.acquire()
			if route in stats_activity:
				stats_activity[route] += count
			else:
				stats_activity[route] = count
			stats_lock.release()
		elif mtype == 'conn':
			# get sid
			sid = None
			lock.acquire()
			if m.get('type') == 'ws':
				s = ws_sessions.get(m['id'])
				if s is not None:
					sid = s.sid
			lock.release()

			# relay
			m['from'] = instance_id
			stats_sock.send(mtype + ' ' + tnetstring.dumps(m))

			# update session
			if sid and not m.get('unavailable') and state_rpc:
				try:
					session_update_many(state_rpc, {sid: {}})
				except:
					logger.debug("couldn't update session")

def stats_worker(c):
	in_sock = ctx.socket(zmq.PULL)
	in_sock.bind('inproc://stats_in')

	if stats_spec:
		out_sock = ctx.socket(zmq.PUB)
		out_sock.linger = 0
		bind_spec(out_sock, stats_spec)
	else:
		out_sock = None

	if push_in_sub_spec:
		subs_modes_by_channel = dict() # key=channel, value={mode: sub}
		subs_by_exp = SortedDict() # key=expire_time, value=set(sub)
		sub_sock = ctx.socket(zmq.PUSH)
		sub_sock.linger = 0
		sub_sock.bind('inproc://sub_cmd_in')
	else:
		sub_sock = None

	c.acquire()
	c.notify()
	c.release()

	while True:
		try:
			m_raw = in_sock.recv()
			at = m_raw.find(' ')
			mtype = m_raw[:at]
			mdata = m_raw[at + 1:]

			now = int(time.time() * 1000)

			if sub_sock and mtype == 'sub':
				m = tnetstring.loads(mdata)
				mode = m['mode']
				channel = m['channel']
				here = not m.get('unavailable', False)

				if here:
					ttl = m['ttl']
					notify = False
					subs_modes = subs_modes_by_channel.get(channel)
					if subs_modes is None:
						subs_modes = {}
						subs_modes_by_channel[channel] = subs_modes
						notify = True
					sub = subs_modes.get(mode)
					if sub is not None:
						subs_exp = subs_by_exp[sub.expire_time]
						subs_exp.remove(sub)
						if len(subs_exp) == 0:
							del subs_by_exp[sub.expire_time]
					else:
						sub = StatsSubscription()
						sub.mode = mode
						sub.channel = channel
						subs_modes[mode] = sub
					sub.expire_time = now + (ttl * 1000)
					subs_exp = subs_by_exp.get(sub.expire_time)
					if subs_exp is None:
						subs_exp = set()
						subs_by_exp[sub.expire_time] = subs_exp
					subs_exp.add(sub)
					if notify:
						sm = {'channel': channel}
						sm['type'] = 'subscribe'
						sub_sock.send(tnetstring.dumps(sm))
				else:
					subs_modes = subs_modes_by_channel.get(channel)
					if subs_modes is not None:
						sub = subs_modes.get(mode)
						if sub is not None:
							subs_exp = subs_by_exp[sub.expire_time]
							subs_exp.remove(sub)
							if len(subs_exp) == 0:
								del subs_by_exp[sub.expire_time]
							del subs_modes[mode]
							if len(subs_modes) == 0:
								del subs_modes_by_channel[channel]
								sm = {'channel': channel}
								sm['type'] = 'unsubscribe'
								sub_sock.send(tnetstring.dumps(sm))

			if out_sock:
				m_raw = mtype + ' T' + mdata
				logger.debug('OUT stats: %s' % m_raw)
				out_sock.send(m_raw)

			if sub_sock:
				while len(subs_by_exp) > 0:
					next_exp = iter(subs_by_exp).next()
					if next_exp > now:
						break

					subs_exp = subs_by_exp[next_exp]
					del subs_by_exp[next_exp]
					for sub in subs_exp:
						logger.debug('stats_worker: expiring %s %s' % (sub.mode, sub.channel))
						subs_modes = subs_modes_by_channel.get(sub.channel)
						del subs_modes[sub.mode]
						if len(subs_modes) == 0:
							del subs_modes_by_channel[sub.channel]
							sm = {'channel': sub.channel}
							sm['type'] = 'unsubscribe'
							sub_sock.send(tnetstring.dumps(sm))
		except zmq.ContextTerminated:
			raise
		except:
			logger.exception('failed')

def state_internal_handler(method, args, data):
	state_client = data.get('state_client')
	return state_client.call(method, args, timeout=1000)

def state_internal_worker(c):
	data = dict()
	if state_spec:
		data['state_client'] = rpc.RpcClient([state_spec], bind=True, context=ctx, ipc_file_mode=ipc_file_mode)
	state_internal_server = rpc.RpcServer('inproc://state', context=ctx)
	c.acquire()
	c.notify()
	c.release()
	state_internal_server.run(state_internal_handler, data)

if state_spec:
	# we use a condition here to ensure the inproc bind succeeds before progressing
	c = threading.Condition()
	c.acquire()
	state_internal_thread = threading.Thread(target=state_internal_worker, args=(c,))
	state_internal_thread.daemon = True
	state_internal_thread.start()
	c.wait()
	c.release()

inspect_thread = threading.Thread(target=inspect_worker)
inspect_thread.start()

# we use a condition here to ensure the inproc bind succeeds before progressing
c = threading.Condition()
c.acquire()
stats_thread = threading.Thread(target=stats_worker, args=(c,))
stats_thread.daemon = True
stats_thread.start()
c.wait()
c.release()

accept_thread = threading.Thread(target=accept_worker)
accept_thread.start()

# we use a condition here to ensure the inproc bind succeeds before progressing
c = threading.Condition()
c.acquire()
push_in_thread = threading.Thread(target=push_in_worker, args=(c,))
push_in_thread.start()
c.wait()
c.release()

push_in_zmq_thread = threading.Thread(target=push_in_zmq_worker)
push_in_zmq_thread.start()

if push_in_sub_spec:
	push_in_sub_zmq_thread = threading.Thread(target=push_in_sub_zmq_worker)
	push_in_sub_zmq_thread.start()

push_in_http_thread = threading.Thread(target=push_in_http_worker)
push_in_http_thread.start()

session_thread = threading.Thread(target=session_worker)
session_thread.daemon = True
session_thread.start()

if ws_control_in_spec and ws_control_out_spec:
	ws_control_thread = threading.Thread(target=ws_control_worker)
	ws_control_thread.daemon = True
	ws_control_thread.start()

if proxy_stats_spec:
	proxy_stats_thread = threading.Thread(target=proxy_stats_worker)
	proxy_stats_thread.daemon = True
	proxy_stats_thread.start()

if command_spec:
	command_thread = threading.Thread(target=command_worker)
	command_thread.daemon = True
	command_thread.start()

timeout_thread = threading.Thread(target=timeout_worker)
timeout_thread.daemon = True
timeout_thread.start()

try:
	while True:
		time.sleep(60)
except KeyboardInterrupt:
	pass

if command_server:
	command_server.stop()

httpinterface.stop()
ctx.term()
