summaryrefslogtreecommitdiff
path: root/metaserver/pipelayer.py
blob: 9baf78add3a772cadc80305d8e899847b7ee9804 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
#import os
import struct
from collections import deque
from zlib import crc32


class InvalidPacket(Exception):
    pass


FLAG_NAK1 = 0xE0
FLAG_NAK  = 0xE1
FLAG_REG  = 0xE2
FLAG_CFRM = 0xE3

FLAG_RANGE_START  = 0xE0
FLAG_RANGE_STOP   = 0xE4

max_old_packets = 200      # must be <= 256


class PipeLayer(object):
    timeout = 1
    headersize = 4

    def __init__(self, initialcrcs=(0, 0)):
        #self.localid = os.urandom(4)
        #self.remoteid = None
        self.cur_time = 0
        self.out_queue = deque()
        self.out_nextseqid = 0
        self.out_nextrepeattime = None
        self.in_nextseqid = 0
        self.in_outoforder = {}
        self.out_oldpackets = deque()
        self.out_flags = FLAG_REG
        self.out_resend = 0
        self.out_resend_skip = False
        self.in_crc, self.out_crc = initialcrcs

    def queue(self, data):
        if data:
            self.out_queue.appendleft(data)

    def queue_size(self):
        total = 0
        for data in self.out_queue:
            total += len(data)
        return total

    def in_sync(self):
        return not self.out_queue and self.out_nextrepeattime is None

    def settime(self, curtime):
        self.cur_time = curtime
        if self.out_queue:
            if len(self.out_oldpackets) < max_old_packets:
                return 0   # more data to send now
        if self.out_nextrepeattime is not None:
            return max(0, self.out_nextrepeattime - curtime)
        else:
            return None

    def is_congested(self):
        return len(self.out_oldpackets) >= max_old_packets

    def encode(self, maxlength):
        #print ' '*self._dump_indent, '--- OUTQ', self.out_resend, self.out_queue
        if len(self.out_oldpackets) >= max_old_packets:
            # congestion, stalling
            payload = 0
        else:
            payload = maxlength - 8
            if payload <= 0:
                raise ValueError("encode(): buffer too small")
        if (self.out_nextrepeattime is not None and
            self.out_nextrepeattime <= self.cur_time):
            # no ACK received so far, send a packet (possibly empty)
            if not self.out_queue:
                payload = 0
        else:
            if not self.out_queue:   # no more data to send
                return None
            if payload == 0:         # congestion
                return None
        # prepare a packet
        seqid = self.out_nextseqid
        flags = self.out_flags
        self.out_flags = FLAG_REG     # clear out the flags for the next time
        #if flags in (FLAG_NAK, FLAG_NAK1):
        #    print 'out_flags NAK', hex(flags)
        if payload > 0:
            self.out_nextseqid = (seqid + 1) & 0xFFFF
            data = self.out_queue.pop()
            packetlength = len(data)
            if self.out_resend > 0:
                if packetlength > payload + 4:
                    raise ValueError("XXX need constant buffer size for now")
                self.out_resend -= 1
                if self.out_resend_skip:
                    if self.out_resend > 0:
                        self.out_queue.pop()
                        self.out_resend -= 1
                        self.out_nextseqid = (seqid + 2) & 0xFFFF
                    self.out_resend_skip = False
                packetpayload = data
            else:
                packet = []
                while packetlength <= payload:
                    packet.append(data)
                    if not self.out_queue:
                        break
                    data = self.out_queue.pop()
                    packetlength += len(data)
                else:
                    rest = len(data) + payload - packetlength
                    packet.append(data[:rest])
                    self.out_queue.append(data[rest:])
                packetpayload = ''.join(packet)
                self.out_crc = crc32(packetpayload, self.out_crc)
                packetpayload += struct.pack("!I", self.out_crc & 0xffffffff)
                self.out_oldpackets.appendleft(packetpayload)
                #print ' '*self._dump_indent, '--- OLDPK', self.out_oldpackets
        else:
            # a pure ACK packet, no payload
            if self.out_oldpackets and flags == FLAG_REG:
                flags = FLAG_CFRM
            packetpayload = ''
        packet = struct.pack("!BBH", flags,
                             self.in_nextseqid & 0xFF,
                             seqid) + packetpayload
        if self.out_oldpackets:
            self.out_nextrepeattime = self.cur_time + self.timeout
        else:
            self.out_nextrepeattime = None
        #self.dump('OUT', packet)
        return packet

    def decode(self, rawdata):
        if len(rawdata) < 4:
            raise InvalidPacket
        #print ' '*self._dump_indent, '------ out %d (+%d) in %d' % (self.out_nextseqid, self.out_resend, self.in_nextseqid)
        #self.dump('IN ', rawdata)
        in_flags, ack_seqid, in_seqid = struct.unpack("!BBH", rawdata[:4])
        if not (FLAG_RANGE_START <= in_flags < FLAG_RANGE_STOP):
            raise InvalidPacket
        in_diff  = (in_seqid  - self.in_nextseqid ) & 0xFFFF
        ack_diff = (self.out_nextseqid + self.out_resend - ack_seqid) & 0xFF
        if in_diff >= max_old_packets:
            return ''    # invalid, but can occur as a late repetition
        if ack_diff != len(self.out_oldpackets):
            # forget all acknowledged packets
            if ack_diff > len(self.out_oldpackets):
                return ''   # invalid, but can occur with packet reordering
            while len(self.out_oldpackets) > ack_diff:
                #print ' '*self._dump_indent, '--- POP', repr(self.out_oldpackets[-1])
                self.out_oldpackets.pop()
            if self.out_oldpackets:
                self.out_nextrepeattime = self.cur_time + self.timeout
            else:
                self.out_nextrepeattime = None   # all packets ACKed
        if in_flags == FLAG_NAK or in_flags == FLAG_NAK1:
            #print 'recv NAK', hex(in_flags)
            # this is a NAK: resend the old packets as far as they've not
            # also been ACK'ed in the meantime (can occur with reordering)
            while self.out_resend < len(self.out_oldpackets):
                self.out_queue.append(self.out_oldpackets[self.out_resend])
                self.out_resend += 1
                self.out_nextseqid = (self.out_nextseqid - 1) & 0xFFFF
                #print ' '*self._dump_indent, '--- REP', self.out_nextseqid, repr(self.out_queue[-1])
            self.out_resend_skip = in_flags == FLAG_NAK1
        elif in_flags == FLAG_CFRM:
            # this is a CFRM: request for confirmation
            self.out_nextrepeattime = self.cur_time
        # receive this packet's payload if it is the next in the sequence
        if in_diff == 0:
            if len(rawdata) > 8:
                #print ' '*self._dump_indent, 'RECV ', self.in_nextseqid, repr(rawdata[4:])
                payload = rawdata[4:-4]
                crc, = struct.unpack("!I", rawdata[-4:])
                if crc != (crc32(payload, self.in_crc) & 0xffffffff):
                    self.bad_crc()
                    return ''   # bad crc! drop packet
                self.in_nextseqid = (self.in_nextseqid + 1) & 0xFFFF
                self.in_crc = crc
                result = [payload]
                while self.in_nextseqid in self.in_outoforder:
                    rawdata = self.in_outoforder.pop(self.in_nextseqid)
                    payload = rawdata[4:-4]
                    crc, = struct.unpack("!I", rawdata[-4:])
                    if crc != (crc32(payload, self.in_crc) & 0xffffffff):
                        # bad crc! clear all out-of-order packets
                        self.bad_crc()
                        break
                    self.in_nextseqid = (self.in_nextseqid + 1) & 0xFFFF
                    self.in_crc = crc
                    result.append(payload)
                return ''.join(result)
        else:
            # we missed at least one intermediate packet: send a NAK
            if len(rawdata) > 4:
                self.in_outoforder[in_seqid] = rawdata
            if ((self.in_nextseqid + 1) & 0xFFFF) in self.in_outoforder:
                self.out_flags = FLAG_NAK1
            else:
                self.out_flags = FLAG_NAK
            self.out_nextrepeattime = self.cur_time
        return ''

    def bad_crc(self):
        import sys
        print >> sys.stderr, "warning: bad crc on udp connexion"
        self.in_outoforder.clear()
        self.out_flags = FLAG_NAK
        self.out_nextrepeattime = self.cur_time

    _dump_indent = 0
    def dump(self, dir, rawdata):
        in_flags, ack_seqid, in_seqid = struct.unpack("!BBH", rawdata[:4])
        print ' ' * self._dump_indent, dir,
        if in_flags == FLAG_NAK:
            print 'NAK',
        elif in_flags == FLAG_NAK1:
            print 'NAK1',
        elif in_flags == FLAG_CFRM:
            print 'CFRM',
        #print ack_seqid, in_seqid, '(%d bytes)' % (len(rawdata)-4,)
        print ack_seqid, in_seqid, repr(rawdata[4:])


def pipe_over_udp(udpsock, send_fd=-1, recv_fd=-1,
                  timeout=1.0, inactivity_timeout=None):
    """Example: send all data showing up in send_fd over the given UDP
    socket, and write incoming data into recv_fd.  The send_fd and
    recv_fd are plain file descriptors.  When an EOF is read from
    send_fd, this function returns (after making sure that all data was
    received by the remote side).
    """
    import os
    from select import select
    from time import time
    p = PipeLayer()
    p.timeout = timeout
    iwtdlist = [udpsock]
    if send_fd >= 0:
        iwtdlist.append(send_fd)
    running = True
    while running or not p.in_sync():
        delay = delay1 = p.settime(time())
        if delay is None:
            delay = inactivity_timeout
        iwtd, owtd, ewtd = select(iwtdlist, [], [], delay)
        if iwtd:
            if send_fd in iwtd:
                data = os.read(send_fd, 1500 - p.headersize)
                if not data:
                    # EOF
                    iwtdlist.remove(send_fd)
                    running = False
                else:
                    #print 'queue', len(data)
                    p.queue(data)
            if udpsock in iwtd:
                packet = udpsock.recv(65535)
                #print 'decode', len(packet)
                p.settime(time())
                data = p.decode(packet)
                i = 0
                while i < len(data):
                    i += os.write(recv_fd, data[i:])
        elif delay1 is None:
            break    # long inactivity
        p.settime(time())
        packet = p.encode(1500)
        if packet:
            #print 'send', len(packet)
            #if os.urandom(1) >= '\x08':    # emulate packet losses
            udpsock.send(packet)


class PipeOverUdp(object):

    def __init__(self, udpsock, timeout=1.0):
        import thread, os
        self.os = os
        self.sendpipe = os.pipe()
        self.recvpipe = os.pipe()
        thread.start_new_thread(pipe_over_udp, (udpsock,
                                                self.sendpipe[0],
                                                self.recvpipe[1],
                                                timeout))

    def __del__(self):
        os = self.os
        if self.sendpipe:
            os.close(self.sendpipe[0])
            os.close(self.sendpipe[1])
            self.sendpipe = None
        if self.recvpipe:
            os.close(self.recvpipe[0])
            os.close(self.recvpipe[1])
            self.recvpipe = None

    close = __del__

    def send(self, data):
        if not self.sendpipe:
            raise IOError("I/O operation on a closed PipeOverUdp")
        return self.os.write(self.sendpipe[1], data)

    def sendall(self, data):
        i = 0
        while i < len(data):
            i += self.send(data[i:])

    def recv(self, bufsize):
        if not self.recvpipe:
            raise IOError("I/O operation on a closed PipeOverUdp")
        return self.os.read(self.recvpipe[0], bufsize)

    def recvall(self, bufsize):
        buf = []
        while bufsize > 0:
            data = self.recv(bufsize)
            buf.append(data)
            bufsize -= len(data)
        return ''.join(buf)

    def fileno(self):
        if not self.recvpipe:
            raise IOError("I/O operation on a closed PipeOverUdp")
        return self.recvpipe[0]

    def ofileno(self):
        if not self.sendpipe:
            raise IOError("I/O operation on a closed PipeOverUdp")
        return self.sendpipe[1]