1 /*
2  * Protocol Buffers - Google's data interchange format
3  * Copyright 2014 Google Inc.  All rights reserved.
4  * https://developers.google.com/protocol-buffers/
5  *
6  * Redistribution and use in source and binary forms, with or without
7  * modification, are permitted provided that the following conditions are
8  * met:
9  *
10  *     * Redistributions of source code must retain the above copyright
11  * notice, this list of conditions and the following disclaimer.
12  *     * Redistributions in binary form must reproduce the above
13  * copyright notice, this list of conditions and the following disclaimer
14  * in the documentation and/or other materials provided with the
15  * distribution.
16  *     * Neither the name of Google Inc. nor the names of its
17  * contributors may be used to endorse or promote products derived from
18  * this software without specific prior written permission.
19  *
20  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
21  * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
22  * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
23  * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
24  * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
25  * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
26  * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
27  * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
28  * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
29  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
30  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31  */
32 
33 package com.google.protobuf.jruby;
34 
35 import com.google.protobuf.Descriptors;
36 import org.jruby.*;
37 import org.jruby.anno.JRubyClass;
38 import org.jruby.anno.JRubyMethod;
39 import org.jruby.runtime.Block;
40 import org.jruby.runtime.ObjectAllocator;
41 import org.jruby.runtime.ThreadContext;
42 import org.jruby.runtime.builtin.IRubyObject;
43 import java.util.Arrays;
44 
45 @JRubyClass(name = "RepeatedClass", include = "Enumerable")
46 public class RubyRepeatedField extends RubyObject {
createRubyRepeatedField(Ruby runtime)47     public static void createRubyRepeatedField(Ruby runtime) {
48         RubyModule mProtobuf = runtime.getClassFromPath("Google::Protobuf");
49         RubyClass cRepeatedField = mProtobuf.defineClassUnder("RepeatedField", runtime.getObject(),
50                 new ObjectAllocator() {
51                     @Override
52                     public IRubyObject allocate(Ruby runtime, RubyClass klazz) {
53                         return new RubyRepeatedField(runtime, klazz);
54                     }
55                 });
56         cRepeatedField.defineAnnotatedMethods(RubyRepeatedField.class);
57         cRepeatedField.includeModule(runtime.getEnumerable());
58     }
59 
RubyRepeatedField(Ruby runtime, RubyClass klazz)60     public RubyRepeatedField(Ruby runtime, RubyClass klazz) {
61         super(runtime, klazz);
62     }
63 
RubyRepeatedField(Ruby runtime, RubyClass klazz, Descriptors.FieldDescriptor.Type fieldType, IRubyObject typeClass)64     public RubyRepeatedField(Ruby runtime, RubyClass klazz, Descriptors.FieldDescriptor.Type fieldType, IRubyObject typeClass) {
65         this(runtime, klazz);
66         this.fieldType = fieldType;
67         this.storage = runtime.newArray();
68         this.typeClass = typeClass;
69     }
70 
71     @JRubyMethod(required = 1, optional = 2)
initialize(ThreadContext context, IRubyObject[] args)72     public IRubyObject initialize(ThreadContext context, IRubyObject[] args) {
73         Ruby runtime = context.runtime;
74         this.storage = runtime.newArray();
75         IRubyObject ary = null;
76         if (!(args[0] instanceof RubySymbol)) {
77             throw runtime.newArgumentError("Expected Symbol for type name");
78         }
79         this.fieldType = Utils.rubyToFieldType(args[0]);
80         if (fieldType == Descriptors.FieldDescriptor.Type.MESSAGE
81                 || fieldType == Descriptors.FieldDescriptor.Type.ENUM) {
82             if (args.length < 2)
83                 throw runtime.newArgumentError("Expected at least 2 arguments for message/enum");
84             typeClass = args[1];
85             if (args.length > 2)
86                 ary = args[2];
87             Utils.validateTypeClass(context, fieldType, typeClass);
88         } else {
89             if (args.length > 2)
90                 throw runtime.newArgumentError("Too many arguments: expected 1 or 2");
91             if (args.length > 1)
92                 ary = args[1];
93         }
94         if (ary != null) {
95             RubyArray arr = ary.convertToArray();
96             for (int i = 0; i < arr.size(); i++) {
97                 this.storage.add(arr.eltInternal(i));
98             }
99         }
100         return this;
101     }
102 
103     /*
104      * call-seq:
105      *     RepeatedField.[]=(index, value)
106      *
107      * Sets the element at the given index. On out-of-bounds assignments, extends
108      * the array and fills the hole (if any) with default values.
109      */
110     @JRubyMethod(name = "[]=")
indexSet(ThreadContext context, IRubyObject index, IRubyObject value)111     public IRubyObject indexSet(ThreadContext context, IRubyObject index, IRubyObject value) {
112         int arrIndex = normalizeArrayIndex(index);
113         Utils.checkType(context, fieldType, value, (RubyModule) typeClass);
114         IRubyObject defaultValue = defaultValue(context);
115         for (int i = this.storage.size(); i < arrIndex; i++) {
116             this.storage.set(i, defaultValue);
117         }
118         this.storage.set(arrIndex, value);
119         return context.runtime.getNil();
120     }
121 
122     /*
123      * call-seq:
124      *     RepeatedField.[](index) => value
125      *
126      * Accesses the element at the given index. Returns nil on out-of-bounds
127      */
128     @JRubyMethod(required=1, optional=1, name = {"at", "[]"})
index(ThreadContext context, IRubyObject[] args)129     public IRubyObject index(ThreadContext context, IRubyObject[] args) {
130         if (args.length == 1){
131             IRubyObject arg = args[0];
132             if (Utils.isRubyNum(arg)) {
133                 /* standard case */
134                 int arrIndex = normalizeArrayIndex(arg);
135                 if (arrIndex < 0 || arrIndex >= this.storage.size()) {
136                     return context.runtime.getNil();
137                 }
138                 return this.storage.eltInternal(arrIndex);
139             } else if (arg instanceof RubyRange) {
140                 RubyRange range = ((RubyRange) arg);
141                 int beg = RubyNumeric.num2int(range.first(context));
142                 int to = RubyNumeric.num2int(range.last(context));
143                 int len = to - beg + 1;
144                 return this.storage.subseq(beg, len);
145             }
146         }
147         /* assume 2 arguments */
148         int beg = RubyNumeric.num2int(args[0]);
149         int len = RubyNumeric.num2int(args[1]);
150         if (beg < 0) {
151             beg += this.storage.size();
152         }
153         if (beg >= this.storage.size()) {
154             return context.runtime.getNil();
155         }
156         return this.storage.subseq(beg, len);
157     }
158 
159     /*
160      * call-seq:
161      *     RepeatedField.push(value)
162      *
163      * Adds a new element to the repeated field.
164      */
165     @JRubyMethod(name = {"push", "<<"})
push(ThreadContext context, IRubyObject value)166     public IRubyObject push(ThreadContext context, IRubyObject value) {
167         if (!(fieldType == Descriptors.FieldDescriptor.Type.MESSAGE &&
168             value == context.runtime.getNil())) {
169             Utils.checkType(context, fieldType, value, (RubyModule) typeClass);
170         }
171         this.storage.add(value);
172         return this.storage;
173     }
174 
175     /*
176      * private Ruby method used by RepeatedField.pop
177      */
178     @JRubyMethod(visibility = org.jruby.runtime.Visibility.PRIVATE)
pop_one(ThreadContext context)179     public IRubyObject pop_one(ThreadContext context) {
180         IRubyObject ret = this.storage.last();
181         this.storage.remove(ret);
182         return ret;
183     }
184 
185     /*
186      * call-seq:
187      *     RepeatedField.replace(list)
188      *
189      * Replaces the contents of the repeated field with the given list of elements.
190      */
191     @JRubyMethod
replace(ThreadContext context, IRubyObject list)192     public IRubyObject replace(ThreadContext context, IRubyObject list) {
193         RubyArray arr = (RubyArray) list;
194         checkArrayElementType(context, arr);
195         this.storage = arr;
196         return this.storage;
197     }
198 
199     /*
200      * call-seq:
201      *     RepeatedField.clear
202      *
203      * Clears (removes all elements from) this repeated field.
204      */
205     @JRubyMethod
clear(ThreadContext context)206     public IRubyObject clear(ThreadContext context) {
207         this.storage.clear();
208         return this.storage;
209     }
210 
211     /*
212      * call-seq:
213      *     RepeatedField.length
214      *
215      * Returns the length of this repeated field.
216      */
217     @JRubyMethod(name = {"length", "size"})
length(ThreadContext context)218     public IRubyObject length(ThreadContext context) {
219         return context.runtime.newFixnum(this.storage.size());
220     }
221 
222     /*
223      * call-seq:
224      *     RepeatedField.+(other) => repeated field
225      *
226      * Returns a new repeated field that contains the concatenated list of this
227      * repeated field's elements and other's elements. The other (second) list may
228      * be either another repeated field or a Ruby array.
229      */
230     @JRubyMethod(name = {"+"})
plus(ThreadContext context, IRubyObject list)231     public IRubyObject plus(ThreadContext context, IRubyObject list) {
232         RubyRepeatedField dup = (RubyRepeatedField) dup(context);
233         if (list instanceof RubyArray) {
234             checkArrayElementType(context, (RubyArray) list);
235             dup.storage.addAll((RubyArray) list);
236         } else {
237             RubyRepeatedField repeatedField = (RubyRepeatedField) list;
238             if (! fieldType.equals(repeatedField.fieldType) || (typeClass != null && !
239                     typeClass.equals(repeatedField.typeClass)))
240                 throw context.runtime.newArgumentError("Attempt to append RepeatedField with different element type.");
241             dup.storage.addAll((RubyArray) repeatedField.toArray(context));
242         }
243         return dup;
244     }
245 
246     /*
247      * call-seq:
248      *     RepeatedField.concat(other) => self
249      *
250      * concats the passed in array to self.  Returns a Ruby array.
251      */
252     @JRubyMethod
concat(ThreadContext context, IRubyObject list)253     public IRubyObject concat(ThreadContext context, IRubyObject list) {
254         if (list instanceof RubyArray) {
255             checkArrayElementType(context, (RubyArray) list);
256             this.storage.addAll((RubyArray) list);
257         } else {
258             RubyRepeatedField repeatedField = (RubyRepeatedField) list;
259             if (! fieldType.equals(repeatedField.fieldType) || (typeClass != null && !
260                     typeClass.equals(repeatedField.typeClass)))
261                 throw context.runtime.newArgumentError("Attempt to append RepeatedField with different element type.");
262             this.storage.addAll((RubyArray) repeatedField.toArray(context));
263         }
264         return this.storage;
265     }
266 
267     /*
268      * call-seq:
269      *     RepeatedField.hash => hash_value
270      *
271      * Returns a hash value computed from this repeated field's elements.
272      */
273     @JRubyMethod
hash(ThreadContext context)274     public IRubyObject hash(ThreadContext context) {
275         int hashCode = this.storage.hashCode();
276         return context.runtime.newFixnum(hashCode);
277     }
278 
279     /*
280      * call-seq:
281      *     RepeatedField.==(other) => boolean
282      *
283      * Compares this repeated field to another. Repeated fields are equal if their
284      * element types are equal, their lengths are equal, and each element is equal.
285      * Elements are compared as per normal Ruby semantics, by calling their :==
286      * methods (or performing a more efficient comparison for primitive types).
287      */
288     @JRubyMethod(name = "==")
eq(ThreadContext context, IRubyObject other)289     public IRubyObject eq(ThreadContext context, IRubyObject other) {
290         return this.toArray(context).op_equal(context, other);
291     }
292 
293     /*
294      * call-seq:
295      *     RepeatedField.each(&block)
296      *
297      * Invokes the block once for each element of the repeated field. RepeatedField
298      * also includes Enumerable; combined with this method, the repeated field thus
299      * acts like an ordinary Ruby sequence.
300      */
301     @JRubyMethod
each(ThreadContext context, Block block)302     public IRubyObject each(ThreadContext context, Block block) {
303         this.storage.each(context, block);
304         return this.storage;
305     }
306 
307 
308     @JRubyMethod(name = {"to_ary", "to_a"})
toArray(ThreadContext context)309     public IRubyObject toArray(ThreadContext context) {
310         return this.storage;
311     }
312 
313     /*
314      * call-seq:
315      *     RepeatedField.dup => repeated_field
316      *
317      * Duplicates this repeated field with a shallow copy. References to all
318      * non-primitive element objects (e.g., submessages) are shared.
319      */
320     @JRubyMethod
dup(ThreadContext context)321     public IRubyObject dup(ThreadContext context) {
322         RubyRepeatedField dup = new RubyRepeatedField(context.runtime, metaClass, fieldType, typeClass);
323         for (int i = 0; i < this.storage.size(); i++) {
324             dup.push(context, this.storage.eltInternal(i));
325         }
326         return dup;
327     }
328 
329     // Java API
get(int index)330     protected IRubyObject get(int index) {
331         return this.storage.eltInternal(index);
332     }
333 
deepCopy(ThreadContext context)334     protected RubyRepeatedField deepCopy(ThreadContext context) {
335         RubyRepeatedField copy = new RubyRepeatedField(context.runtime, metaClass, fieldType, typeClass);
336         for (int i = 0; i < size(); i++) {
337             IRubyObject value = storage.eltInternal(i);
338             if (fieldType == Descriptors.FieldDescriptor.Type.MESSAGE) {
339                 copy.storage.add(((RubyMessage) value).deepCopy(context));
340             } else {
341                 copy.storage.add(value);
342             }
343         }
344         return copy;
345     }
346 
size()347     protected int size() {
348         return this.storage.size();
349     }
350 
defaultValue(ThreadContext context)351     private IRubyObject defaultValue(ThreadContext context) {
352         SentinelOuterClass.Sentinel sentinel = SentinelOuterClass.Sentinel.getDefaultInstance();
353         Object value;
354         switch (fieldType) {
355             case INT32:
356                 value = sentinel.getDefaultInt32();
357                 break;
358             case INT64:
359                 value = sentinel.getDefaultInt64();
360                 break;
361             case UINT32:
362                 value = sentinel.getDefaultUnit32();
363                 break;
364             case UINT64:
365                 value = sentinel.getDefaultUint64();
366                 break;
367             case FLOAT:
368                 value = sentinel.getDefaultFloat();
369                 break;
370             case DOUBLE:
371                 value = sentinel.getDefaultDouble();
372                 break;
373             case BOOL:
374                 value = sentinel.getDefaultBool();
375                 break;
376             case BYTES:
377                 value = sentinel.getDefaultBytes();
378                 break;
379             case STRING:
380                 value = sentinel.getDefaultString();
381                 break;
382             case ENUM:
383                 IRubyObject defaultEnumLoc = context.runtime.newFixnum(0);
384                 return RubyEnum.lookup(context, typeClass, defaultEnumLoc);
385             default:
386                 return context.runtime.getNil();
387         }
388         return Utils.wrapPrimaryValue(context, fieldType, value);
389     }
390 
checkArrayElementType(ThreadContext context, RubyArray arr)391     private void checkArrayElementType(ThreadContext context, RubyArray arr) {
392         for (int i = 0; i < arr.getLength(); i++) {
393             Utils.checkType(context, fieldType, arr.eltInternal(i), (RubyModule) typeClass);
394         }
395     }
396 
normalizeArrayIndex(IRubyObject index)397     private int normalizeArrayIndex(IRubyObject index) {
398         int arrIndex = RubyNumeric.num2int(index);
399         int arrSize = this.storage.size();
400         if (arrIndex < 0 && arrSize > 0) {
401             arrIndex = arrSize + arrIndex;
402         }
403         return arrIndex;
404     }
405 
406     private RubyArray storage;
407     private Descriptors.FieldDescriptor.Type fieldType;
408     private IRubyObject typeClass;
409 }
410