#!/usr/bin/env python
#
# Copyright 2012 John-Mark Gurney.
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# 1. Redistributions of source code must retain the above copyright
#    notice, this list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright
#    notice, this list of conditions and the following disclaimer in the
#    documentation and/or other materials provided with the distribution.
#
# THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
# ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
# OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
# HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
# OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
# SUCH DAMAGE.
#
#	$Id: //depot/python/pyfp/pyfp-0.5/filepass.py#1 $
#

# The client chooses which file it wants to accept.  After a VKE, the
# server will be the one to accept the PIN so that he doesn't send the
# file to someone he doesn't want to.

# python version of aes:
# https://code.google.com/p/slowaes/source/browse/trunk/python/aes.py
# raw: https://slowaes.googlecode.com/svn/trunk/python/aes.py

'''This programs is used to transfer files securely between machines.  The
programs provide a pin that is used to verify that the communication is
secure.  The sender must accept that the receiver has presented the same
pin number.

On the receiver side to automaticly accept the first announcement:
python filepass.py -ac

On the sender side:
python filepass.py foobar.txt
'''

__version__ = '''$Revision: #1 $'''
# $Id: //depot/python/pyfp/pyfp-0.5/filepass.py#1 $

import StringIO
import array
import encstream
import getopt
import os.path
import random
import socket
import struct
import sys
import time
import vke

try:
	import Crypto.Cipher.AES
	class AES:
		blockSize = 16

		def __init__(self, key):
			a = self._aes = Crypto.Cipher.AES.new(key)

		def encrypt(self, data):
			return array.array('B', self._aes.encrypt(array.array('B', data).tostring()))

		def decrypt(self, data):
			return array.array('B', self._aes.decrypt(array.array('B', data).tostring()))
except ImportError:
	from aes import AES

DK_PORT = 39954
DK_MAGIC = 0xd70951c5

class DKPkt(object):
	'''Class for sending/parsing data packets.

	The init_ functions are used when Python creates the object.  This processes the passed in arguments and sets necessary attributes.

	The cmd_ functions are passed the packet to decode from the wire.  It must return a tuple of args that will be passed to the same init_ function.

	The enc_ function encodes the data for passing on the wire.
	'''

	DKPKT_STR = '>IH'

	_cmds = {
		# id, num init args
		'hello': [ 1, 1 ],
		'accept': [ 2, 2 ],
	}
	_cmdnums = dict((y[0], x) for x, y in _cmds.iteritems())
	assert len(_cmdnums) == len(_cmds)

	name = property(lambda x: x._name)
	port = property(lambda x: x._port)
	typ = property(lambda x: x._type)

	def __init__(self, typ, *args):
		if typ not in self._cmds:
			raise ValueError('invalid type: %s' % `typ`)

		self._args = args
		self._name = None
		self._port = None

		self._type = typ
		if len(args) != self._cmds[typ][1]:
			raise TypeError('%s takes exactly %d additional argument' % (typ, self._cmds[typ][1]))

		getattr(self, 'init_%s' % typ)(*args)

	def init_hello(self, *args):
		self._name = args[0]

	@staticmethod
	def cmd_hello(rpkt):
		return (rpkt.decode('utf8'), )

	def enc_hello(self):
		return self._name.encode('utf8')

	def init_accept(self, *args):
		self._name = args[0]
		self._port = args[1]

	@staticmethod
	def cmd_accept(rpkt):
		port = struct.unpack('>H', rpkt[:2])[0]
		return (rpkt[2:].decode('utf8'), port)

	def enc_accept(self):
		return struct.pack('>H', self._port) + self._name.encode('utf8')

	@classmethod
	def parsepkt(cls, pkt):
		'''Parse a packet received from the wire and create the
		object for it.'''

		part = struct.calcsize(cls.DKPKT_STR)
		magic, cmd = struct.unpack(cls.DKPKT_STR, pkt[:part])
		if magic != DK_MAGIC:
			print `magic`, `DK_MAGIC`
			raise ValueError('Invalid Magic')

		if cmd not in cls._cmdnums:
			raise ValueError('Invalid command num: %d' % cmd)

		typ = cls._cmdnums[cmd]
		args = getattr(cls, 'cmd_%s' % typ)(pkt[part:])

		return cls(typ, *args)

	def __str__(self):
		return struct.pack(self.DKPKT_STR, DK_MAGIC, self._cmds[self._type][0]) + getattr(self, 'enc_%s' % self._type)()

	def __repr__(self):
		return 'DKPkt(%s, *%s)' % (`self._type`, `self._args`)

def betterwritelines(fp, iter_, flushsize=64*1024):
	'''This is an improvement over socket._fileobject.writelines terrible
	implementation.'''

	cnt = 0
	tmp = []
	for i in iter_:
		tmp.append(i)
		cnt += len(i)
		if cnt > flushsize:
			fp.writelines(tmp)
			fp.flush()
			tmp = []
			cnt = 0

	fp.writelines(tmp)
	fp.flush()

def beserver(*files, **kwargs):
	'''Announce that we are sending files.

	Optional keyword argument name will be used as the announced name
	instead of the file names joined together.  This is provided in
	case the file names are sensitive.'''

	timeo = 1
	# make sure we can open all the files
	fps = [ open(x, 'rb') for x in files ]

	name = kwargs.pop('name', None)
	if name is None:
		name = ', '.join(files)

	dest = kwargs.pop('dest', '0')

	if kwargs:
		raise ValueError('unknown kwargs provided: %s' %
		    ', '.join(kwargs))

	s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)

	s.setblocking(1)

	lastpkt = time.time() - 2 * timeo
	while True:
		if time.time() >= lastpkt + timeo:
			announcepkt = str(DKPkt('hello', name))
			s.sendto(announcepkt, (dest, DK_PORT))
			print 'sending announcement...'
			lastpkt = time.time()

		# XXX - could be negative if processing takes a long while
		s.settimeout((lastpkt + timeo) - time.time())
		try:
			pkt, addr = s.recvfrom(2048)
		except socket.timeout:
			continue

		# Try to do key exchange
		# The server will be the one to authenticate the pin.  The
		# client was the one to choose which file to accept.
		pkt = DKPkt.parsepkt(pkt)
		if pkt.typ != 'accept' or pkt.name != name:
			continue

		break

	print 'negotiating key...'
	# We have a client that wants our file.
	s2 = socket.socket()
	s2.connect((addr[0], pkt.port))

	kobj = vke.VerifiedKeyExchangeAlice()
	fp = s2.makefile('r+')

	idopts = fp.readline()
	if not idopts.startswith('FILEPASS'):
		raise RuntimeError('client is not filepass!')

	resp = kobj.getfirstmessage()
	fp.writelines([ 'FILEPASS\n', resp, '\n'])
	fp.flush()

	msg = fp.readline()[:-1]
	resp = kobj.processreply(msg)
	fp.writelines([ resp, '\n'])
	fp.flush()

	token = kobj.gentoken('AUTH', 20)
	yn = raw_input('Accept token: %d [y/N]?' % token).upper()
	if not yn or (yn and yn[0] != 'Y'):
		print 'exiting...'
		sys.exit(0)

	fnamehmackeys = kobj.keyiter('NAMEHMAC', blen=32)
	fnamecipherkeys = kobj.keyiter('NAMECIPHER', blen=16)
	fdatahmackeys = kobj.keyiter('FILEHMAC', blen=32)
	fdatacipherkeys = kobj.keyiter('FILECIPHER', blen=16)
	for out in fps:
		fname = os.path.basename(out.name)
		betterwritelines(fp, encstream.encfile(StringIO.StringIO(fname),
		    AES, fnamecipherkeys, fnamehmackeys))
		fp.flush()
		betterwritelines(fp, encstream.encfile(out, AES,
		    fdatacipherkeys, fdatahmackeys))
		fp.flush()

	s2.shutdown(socket.SHUT_WR)
	while True:
		a = fp.read(128)
		if not a:
			break
	fp.close()
	

def waitfor(s, typ, fun=None):
	'''Wait for a DKPkt of typ on s.  The argument fun is a function
	that is passed a DKPkt instance and the addr of when sender
	and when returns true, waitfor will return the DKPkt instance and
	the addr the packet came from.'''

	while True:
		pkt, addr = s.recvfrom(2048)

		ann = DKPkt.parsepkt(pkt)

		if ann.typ == typ and (fun is None or fun(ann, addr)):
			return ann, addr

def getfiles(ann, addr, rejectedfiles):
	'''Ask the user if he/she wants to accept the announcement.  Returns
	True if the user wants the file.

	We store rejected files in the dictionary rejectedfiles.  If a file
	is in rejectedfiles, we return False.'''

	if ann.name in rejectedfiles:
		return False

	yn = raw_input('Accept file(s): %s [Y/n]?' % ann.name).upper()

	if len(yn) and yn[0] == 'N':
		rejectedfiles[ann.name] = True
		return False

	return True

def beclient(alwaysyes, overwrite=None):
	'''Listen for announcements from the server to receive files.

	If bool(alwaysyes) is True, the function will accept the first
	announcement received.'''

	s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)

	s.bind(('0', DK_PORT))

	rejectedfiles = {}

	if alwaysyes:
		checkfun = lambda *args: True
	else:
		checkfun = lambda x, y, z=rejectedfiles: getfiles(x, y, z)
	ann, addr = waitfor(s, 'hello', checkfun)

	s2 = socket.socket()
	s2.bind(('0', 0))
	#s2.bind(('0', random.randint(40000, 41000)))
	s2.listen(1)
	s2port = s2.getsockname()[1]

	s2.settimeout(.5)
	start = time.time()
	while True:
		s.sendto(str(DKPkt('accept', ann.name, s2port)), 0, addr)

		try:
			s3, s3addr = s2.accept()
			break
		except socket.timeout:
			if time.time() - start > 15:
				raise RuntimeError('did not receive connection after 15 seconds')
			continue

	s3.setblocking(True)
	fp = s3.makefile('r+')
	kobj = vke.VerifiedKeyExchangeBob()

	# announce ourselves, we may have options in the future
	fp.write('FILEPASS\n')
	fp.flush()

	# get server options
	idopts = fp.readline()
	if not idopts.startswith('FILEPASS'):
		raise RuntimeError('sender is not filepass!')

	msg = fp.readline()[:-1]
	resp = kobj.processfirstmessage(msg)

	fp.writelines([ resp, '\n' ])
	fp.flush()

	msg = fp.readline()[:-1]
	kobj.processsecondmessage(msg)

	print 'Authentication token: %d' % kobj.gentoken('AUTH', 20)

	fnamehmackeys = kobj.keyiter('NAMEHMAC', blen=32)
	fnamecipherkeys = kobj.keyiter('NAMECIPHER', blen=16)
	fdatahmackeys = kobj.keyiter('FILEHMAC', blen=32)
	fdatacipherkeys = kobj.keyiter('FILECIPHER', blen=16)
	while True:
		#print 'starting file'
		fname = ''.join(encstream.decfile(fp, AES, fnamecipherkeys,
		    fnamehmackeys))
		if not fname:
			print 'exiting...'
			break
		yn = raw_input('Accept filename: %s [Y/n]?' % `fname`).upper()
		if yn and yn[0] == 'N':
			fname = raw_input('New filename [%s]:' % `fname`)

		# XXX - open does not have a way to error if the file already
		# exists, us os.open instead
		while True:
			try:
				fd = os.open(fname,
				    os.O_WRONLY|os.O_CREAT|os.O_EXCL)
				try:
					fpout = os.fdopen(fd, 'w')
				except:
					os.close(fd)
					raise
				break
			except OSError, e:
				if e.errno != 17:
					raise

				if overwrite is None:
					yn = raw_input('File %s already exists, overwrite [y/N]?' % fname).upper()
				elif overwrite:
					yn = 'Y'
				else:
					yn = 'N'
				if yn and yn[0] == 'Y':
					fpout = open(fname, 'w+')
					break

				fname = raw_input('New file name?')
				continue

		betterwritelines(fpout, encstream.decfile(fp, AES,
		    fdatacipherkeys, fdatahmackeys))
		fpout.close()

def usage(exitval=2):
	print 'Usage: %s [ -a ] [ -o ]' % sys.argv[0]
	print '       %s [ -n <name> ] <files> ...' % sys.argv[0]

	if exitval is not None:
		sys.exit(exitval)

def main():
	try:
		opts, args = getopt.getopt(sys.argv[1:], 'ad:n:o')
	except getopt.GetoptError, err:
		print str(err)
		usage()

	alwaysyes = False
	serverargs = { 'dest': '0' }
	baseserverargs = serverargs.copy()
	clientargs = { }
	baseclientargs = clientargs.copy()

	for o, a in opts:
		if o == '-a':
			alwaysyes = True
		elif o == '-d':
			serverargs['dest'] = a
		elif o == '-n':
			serverargs['name'] = a
		elif o == '-o':
			clientargs['overwrite'] = True

	if not args:
		if baseserverargs != serverargs:
			print 'Provided server arguments, but no files were specified.'
			usage()

		beclient(alwaysyes, **clientargs)
	else:
		if baseclientargs != clientargs:
			print 'Provided client arguments, but files were specified.'
			usage()

		beserver(*args, **serverargs)

if __name__ == '__main__':
	main()
