Tiramisu Compiler
expr.h
Go to the documentation of this file.
1 #ifndef _H_TIRAMISU_EXPR_
2 #define _H_TIRAMISU_EXPR_
3 
4 #include <isl/set.h>
5 #include <isl/map.h>
6 #include <isl/union_map.h>
7 #include <isl/union_set.h>
8 #include <isl/ast_build.h>
9 #include <isl/schedule.h>
10 #include <isl/schedule_node.h>
11 #include <isl/space.h>
12 
13 #include <map>
14 #include <unordered_map>
15 #include <vector>
16 #include <string.h>
17 #include <stdint.h>
18 #include <type_traits>
19 
20 #include <Halide.h>
21 #include <tiramisu/debug.h>
22 #include <tiramisu/type.h>
23 
24 namespace tiramisu
25 {
26 class function;
27 class computation;
28 
29 std::string generate_new_variable_name();
31 std::string str_tiramisu_type_op(tiramisu::op_t type);
33 
34 class buffer;
35 class expr;
36 class var;
37 class sync;
38 class global;
39 
40 template <typename T>
41 using only_integral = typename std::enable_if<std::is_integral<T>::value, expr>::type;
42 
43 /**
44  * A class that holds all the global variables necessary for Tiramisu.
45  * It also holds Tiramisu options.
46  */
47 class global
48 {
49 private:
50  /**
51  * Perform automatic data mapping ?
52  */
53  static bool auto_data_mapping;
54 
55  /**
56  * Type of the loop iterators to generate.
57  */
58  static primitive_t loop_iterator_type;
59 
60  /**
61  * When Tiramisu is initialized, an implicit Tiramisu
62  * function is created. All the computations and buffers
63  * created later are added by deafult to this function unless
64  * the user indicates otherwise using the Tiramisu API (by providing
65  * a different function as input to the API).
66  */
67  static function *implicit_fct;
68 
69 public:
70 
71  /**
72  * Return the implicit function created during Tiramisu initialization.
73  *
74  * When Tiramisu is initialized, an implicit Tiramisu
75  * function is created. All the computations and buffers
76  * created later are added by deafult to this function unless
77  * the user indicates otherwise using the Tiramisu API (by using the low
78  * level Tiramisu API and by providing a different function as input to the API).
79  */
80  static function *get_implicit_function()
81  {
82  return global::implicit_fct;
83  }
84 
85  /**
86  * Return the implicit function created during Tiramisu initialization.
87  *
88  * When Tiramisu is initialized, an implicit Tiramisu
89  * function is created. All the computations and buffers
90  * created later are added by deafult to this function unless
91  * the user indicates otherwise using the Tiramisu API (by using the low
92  * level Tiramisu API and by providing a different function as input to the API).
93  */
94  static void set_implicit_function(function *fct)
95  {
96  global::implicit_fct = fct;
97  }
98 
99  /**
100  * If this option is set to true, Tiramisu automatically
101  * modifies the computation data mapping whenever a new
102  * schedule is applied to a computation.
103  * If it is set to false, it is up to the user to set
104  * the right data mapping before code generation.
105  */
106  static void set_auto_data_mapping(bool v)
107  {
108  global::auto_data_mapping = v;
109  }
110 
111  /**
112  * Return whether auto data mapping is set.
113  * If auto data mapping is set, Tiramisu automatically
114  * modifies the computation data mapping whenever a new
115  * schedule is applied to a computation.
116  * If it is set to false, it is up to the user to set
117  * the right data mapping before code generation.
118  */
120  {
121  return global::auto_data_mapping;
122  }
123 
125  {
126  global::loop_iterator_type = p_int32;
127  set_auto_data_mapping(true);
128  }
129 
131  global::loop_iterator_type = t;
132  }
133 
135  {
136  return global::loop_iterator_type;
137  }
138 
140  {
142  }
143 };
144 
145 
146 
147 /**
148  * A class to represent tiramisu expressions.
149  */
150 class expr
151 {
152  friend class input;
153  friend class var;
154  friend class sync;
155  friend class computation;
156  friend class generator;
157 
158  /**
159  * The type of the operator.
160  */
161  tiramisu::op_t _operator;
162 
163  /**
164  * The value of the 1st, 2nd and 3rd operands of the expression.
165  * op[0] is the 1st operand, op[1] is the 2nd, ...
166  */
167  std::vector<tiramisu::expr> op;
168 
169  /**
170  * The value of the expression.
171  */
172  union
173  {
174  uint8_t uint8_value;
175  int8_t int8_value;
176  uint16_t uint16_value;
177  int16_t int16_value;
178  uint32_t uint32_value;
179  int32_t int32_value;
180  uint64_t uint64_value;
181  int64_t int64_value;
184  };
185 
186  /**
187  * A vector of expressions representing buffer accesses,
188  * or computation accesses.
189  * For example for the computation C0(i,j), the access is
190  * the vector {i, j}.
191  */
192  std::vector<tiramisu::expr> access_vector;
193 
194  /**
195  * A vector of expressions representing arguments of an
196  * external function.
197  * For example, to call the function foo() with the following
198  * three arguments as input
199  * the integer 1, the result of the computation C1(0,0), and
200  * the computation C0 (i.e., its buffer).
201  * \p vector should be {tiramisu::expr(1), C1(0,0), tiramisu::expr(o_address, tiramisu::var("C0"))}.
202  */
203  std::vector<tiramisu::expr> argument_vector;
204 
205  /**
206  * Is this expression defined?
207  */
208  bool defined;
209 
210 protected:
211  /**
212  * Identifier name.
213  */
214  std::string name;
215 
216  /**
217  * Data type.
218  */
220 
221  /**
222  * The type of the expression.
223  */
225 
226 public:
227 
228  /**
229  * Create an undefined expression.
230  */
232  {
233  this->defined = false;
234 
235  this->_operator = tiramisu::o_none;
236  this->etype = tiramisu::e_none;
237  this->dtype = tiramisu::p_none;
238  }
239 
240  /**
241  * Create a cast expression to type \p t (a unary operator).
242  */
244  {
245  assert((o == tiramisu::o_cast) && "Only support cast operator.");
246 
247  this->_operator = o;
248  this->etype = tiramisu::e_op;
249  this->dtype = dtype;
250  this->defined = true;
251 
252  this->op.push_back(expr0);
253  }
254 
255  /**
256  * Create an expression for a unary operator.
257  */
259  {
260  if ((o == tiramisu::o_floor) &&
261  (expr0.get_data_type() != tiramisu::p_float32) &&
262  (expr0.get_data_type() != tiramisu::p_float64))
263  expr0 = tiramisu::expr(tiramisu::o_cast, p_float32, expr0);
264 
265  this->_operator = o;
266  this->etype = tiramisu::e_op;
267  this->dtype = expr0.get_data_type();
268  this->defined = true;
269 
270  this->op.push_back(expr0);
271  }
272 
273  /**
274  * Create an expression for a unary operator that applies
275  * on a variable. For example: allocate(A) or free(B).
276  */
277  expr(tiramisu::op_t o, std::string name)
278  {
279  this->_operator = o;
280  this->etype = tiramisu::e_op;
281  this->dtype = tiramisu::p_none;
282  this->defined = true;
283 
284  this->name = name;
285  }
286 
287  /**
288  * Construct an expression for a binary operator.
289  */
291  {
292  if (expr0.get_data_type() != expr1.get_data_type())
293  {
294  tiramisu::str_dump("Binary operation between two expressions of different types:\n");
295  expr0.dump(false);
296  tiramisu::str_dump(" and ");
297  expr1.dump(false);
298  tiramisu::str_dump("\n");
299  ERROR("\nThe two expressions should be of the same type. Use casting to elevate the type of one expression to the other.\n", true);
300  }
301 
302  this->_operator = o;
303  this->etype = tiramisu::e_op;
304  this->dtype = expr0.get_data_type();
305  this->defined = true;
306 
307  this->op.push_back(expr0);
308  this->op.push_back(expr1);
309  }
310 
311  /**
312  * Construct an expression for a ternary operator.
313  */
315  {
316  assert(expr1.get_data_type() == expr2.get_data_type() &&
317  "expr1 and expr2 should be of the same type.");
318 
319  this->_operator = o;
320  this->etype = tiramisu::e_op;
321  this->dtype = expr1.get_data_type();
322  this->defined = true;
323 
324  this->op.push_back(expr0);
325  this->op.push_back(expr1);
326  this->op.push_back(expr2);
327  }
328 
329  /**
330  * Construct an access or a call.
331  */
332  expr(tiramisu::op_t o, std::string name,
333  std::vector<tiramisu::expr> vec,
335  {
336  assert(((o == tiramisu::o_access) || (o == tiramisu::o_call) || (o == tiramisu::o_address_of) ||
337  (o == tiramisu::o_lin_index) || (o == tiramisu::o_buffer)) &&
338  "The operator is not an access or a call operator.");
339 
340  assert(vec.size() > 0);
341  assert(name.size() > 0);
342 
343  this->_operator = o;
344  this->etype = tiramisu::e_op;
345  this->dtype = type;
346  this->defined = true;
347 
349  o == tiramisu::o_buffer)
350  {
351  this->set_access(vec);
352  }
353  else if (o == tiramisu::o_call)
354  {
355  this->set_arguments(vec);
356  }
357  else
358  {
359  ERROR("Type of operator is not o_access, o_call, o_address_of, o_buffer, or o_lin_index.", true);
360  }
361 
362  this->name = name;
363  }
364 
365  /**
366  * Construct an unsigned 8-bit integer expression.
367  */
368  expr(uint8_t val)
369  {
370  this->etype = tiramisu::e_val;
371  this->_operator = tiramisu::o_none;
372  this->defined = true;
373 
374  this->dtype = tiramisu::p_uint8;
375  this->uint8_value = val;
376  }
377 
378  /**
379  * Construct a signed 8-bit integer expression.
380  */
381  expr(int8_t val)
382  {
383  this->etype = tiramisu::e_val;
384  this->_operator = tiramisu::o_none;
385  this->defined = true;
386 
387  this->dtype = tiramisu::p_int8;
388  this->int8_value = val;
389  }
390 
391  /**
392  * Construct an unsigned 16-bit integer expression.
393  */
394  expr(uint16_t val)
395  {
396  this->defined = true;
397  this->etype = tiramisu::e_val;
398  this->_operator = tiramisu::o_none;
399 
400  this->dtype = tiramisu::p_uint16;
401  this->uint16_value = val;
402  }
403 
404  /**
405  * Construct a signed 16-bit integer expression.
406  */
407  expr(int16_t val)
408  {
409  this->defined = true;
410  this->etype = tiramisu::e_val;
411  this->_operator = tiramisu::o_none;
412 
413  this->dtype = tiramisu::p_int16;
414  this->int16_value = val;
415  }
416 
417  /**
418  * Construct an unsigned 32-bit integer expression.
419  */
420  expr(uint32_t val)
421  {
422  this->etype = tiramisu::e_val;
423  this->_operator = tiramisu::o_none;
424  this->defined = true;
425 
426  this->dtype = tiramisu::p_uint32;
427  this->uint32_value = val;
428  }
429 
430  /**
431  * Construct a signed 32-bit integer expression.
432  */
433  expr(int32_t val)
434  {
435  this->etype = tiramisu::e_val;
436  this->_operator = tiramisu::o_none;
437  this->defined = true;
438 
439  this->dtype = tiramisu::p_int32;
440  this->int32_value = val;
441  }
442 
443  /**
444  * Construct an unsigned 64-bit integer expression.
445  */
446  expr(uint64_t val)
447  {
448  this->etype = tiramisu::e_val;
449  this->_operator = tiramisu::o_none;
450  this->defined = true;
451 
452  this->dtype = tiramisu::p_uint64;
453  this->uint64_value = val;
454  }
455 
456  /**
457  * Construct a signed 64-bit integer expression.
458  */
459  expr(int64_t val)
460  {
461  this->etype = tiramisu::e_val;
462  this->_operator = tiramisu::o_none;
463  this->defined = true;
464 
465  this->dtype = tiramisu::p_int64;
466  this->int64_value = val;
467  }
468 
469  /**
470  * Construct a 32-bit float expression.
471  */
472  expr(float val)
473  {
474  this->etype = tiramisu::e_val;
475  this->_operator = tiramisu::o_none;
476  this->defined = true;
477 
478  this->dtype = tiramisu::p_float32;
479  this->float32_value = val;
480  }
481 
482  /**
483  * Copy an expression.
484  */
485  tiramisu::expr copy() const;
486 
487  /**
488  * Construct a 64-bit float expression.
489  */
490  expr(double val)
491  {
492  this->etype = tiramisu::e_val;
493  this->_operator = tiramisu::o_none;
494  this->defined = true;
495 
496  this->dtype = tiramisu::p_float64;
497  this->float64_value = val;
498  }
499 
500  /**
501  * Return the actual value of the expression.
502  */
503  // @{
504  uint8_t get_uint8_value() const
505  {
506  assert(this->get_expr_type() == tiramisu::e_val);
507  assert(this->get_data_type() == tiramisu::p_uint8);
508 
509  return uint8_value;
510  }
511 
512  int8_t get_int8_value() const
513  {
514  assert(this->get_expr_type() == tiramisu::e_val);
515  assert(this->get_data_type() == tiramisu::p_int8);
516 
517  return int8_value;
518  }
519 
520  uint16_t get_uint16_value() const
521  {
522  assert(this->get_expr_type() == tiramisu::e_val);
523  assert(this->get_data_type() == tiramisu::p_uint16);
524 
525  return uint16_value;
526  }
527 
528  int16_t get_int16_value() const
529  {
530  assert(this->get_expr_type() == tiramisu::e_val);
531  assert(this->get_data_type() == tiramisu::p_int16);
532 
533  return int16_value;
534  }
535 
536  uint32_t get_uint32_value() const
537  {
538  assert(this->get_expr_type() == tiramisu::e_val);
539  assert(this->get_data_type() == tiramisu::p_uint32);
540 
541  return uint32_value;
542  }
543 
544  int32_t get_int32_value() const
545  {
546  assert(this->get_expr_type() == tiramisu::e_val);
547  assert(this->get_data_type() == tiramisu::p_int32);
548 
549  return int32_value;
550  }
551 
552  uint64_t get_uint64_value() const
553  {
554  assert(this->get_expr_type() == tiramisu::e_val);
555  assert(this->get_data_type() == tiramisu::p_uint64);
556 
557  return uint64_value;
558  }
559 
560  int64_t get_int64_value() const
561  {
562  assert(this->get_expr_type() == tiramisu::e_val);
563  assert(this->get_data_type() == tiramisu::p_int64);
564 
565  return int64_value;
566  }
567 
568  float get_float32_value() const
569  {
570  assert(this->get_expr_type() == tiramisu::e_val);
571  assert(this->get_data_type() == tiramisu::p_float32);
572 
573  return float32_value;
574  }
575 
576  double get_float64_value() const
577  {
578  assert(this->get_expr_type() == tiramisu::e_val);
579  assert(this->get_data_type() == tiramisu::p_float64);
580 
581  return float64_value;
582  }
583  // @}
584 
585  int64_t get_int_val() const
586  {
587  assert(this->get_expr_type() == tiramisu::e_val);
588 
589  int64_t result = 0;
590 
591  if (this->get_data_type() == tiramisu::p_uint8)
592  {
593  result = this->get_uint8_value();
594  }
595  else if (this->get_data_type() == tiramisu::p_int8)
596  {
597  result = this->get_int8_value();
598  }
599  else if (this->get_data_type() == tiramisu::p_uint16)
600  {
601  result = this->get_uint16_value();
602  }
603  else if (this->get_data_type() == tiramisu::p_int16)
604  {
605  result = this->get_int16_value();
606  }
607  else if (this->get_data_type() == tiramisu::p_uint32)
608  {
609  result = this->get_uint32_value();
610  }
611  else if (this->get_data_type() == tiramisu::p_int32)
612  {
613  result = this->get_int32_value();
614  }
615  else if (this->get_data_type() == tiramisu::p_uint64)
616  {
617  result = this->get_uint64_value();
618  }
619  else if (this->get_data_type() == tiramisu::p_int64)
620  {
621  result = this->get_int64_value();
622  }
623  else if (this->get_data_type() == tiramisu::p_float32)
624  {
625  result = this->get_float32_value();
626  }
627  else if (this->get_data_type() == tiramisu::p_float64)
628  {
629  result = this->get_float64_value();
630  }
631  else
632  {
633  ERROR("Calling get_int_val() on a non integer expression.", true);
634  }
635 
636  return result;
637  }
638 
639  double get_double_val() const
640  {
641  assert(this->get_expr_type() == tiramisu::e_val);
642 
643  int64_t result = 0;
644 
645  if (this->get_data_type() == tiramisu::p_float32)
646  {
647  result = this->get_float32_value();
648  }
649  else if (this->get_data_type() == tiramisu::p_float64)
650  {
651  result = this->get_float64_value();
652  }
653  else
654  {
655  ERROR("Calling get_double_val() on a non double expression.", true);
656  }
657 
658  return result;
659  }
660 
661  /**
662  * Return the value of the \p i 'th operand of the expression.
663  * \p i can be 0, 1 or 2.
664  */
665  const tiramisu::expr &get_operand(int i) const
666  {
667  assert(this->get_expr_type() == tiramisu::e_op);
668  assert((i < (int)this->op.size()) && "Operand index is out of bounds.");
669 
670  return this->op[i];
671  }
672 
673  /**
674  * Return the number of arguments of the operator.
675  */
676  int get_n_arg() const
677  {
678  assert(this->get_expr_type() == tiramisu::e_op);
679 
680  return this->op.size();
681  }
682 
683  /**
684  * Return the type of the expression (tiramisu::expr_type).
685  */
687  {
688  return etype;
689  }
690 
691  /**
692  * Get the data type of the expression.
693  */
695  {
696  return dtype;
697  }
698 
699  /**
700  * Get the name of the ID or the variable represented by this expressions.
701  */
702  const std::string &get_name() const
703  {
704  assert((this->get_expr_type() == tiramisu::e_var) ||
705  (this->get_op_type() == tiramisu::o_access) ||
706  (this->get_op_type() == tiramisu::o_address) ||
707  (this->get_op_type() == tiramisu::o_call) ||
708  (this->get_op_type() == tiramisu::o_allocate) ||
709  (this->get_op_type() == tiramisu::o_free) ||
710  (this->get_op_type() == tiramisu::o_address_of) ||
711  (this->get_op_type() == tiramisu::o_lin_index) ||
712  (this->get_op_type() == tiramisu::o_buffer) ||
713  (this->get_op_type() == tiramisu::o_dummy));
714 
715  return name;
716  }
717 
718  void set_name(std::string &name)
719  {
720  assert((this->get_expr_type() == tiramisu::e_var) ||
721  (this->get_op_type() == tiramisu::o_access) ||
722  (this->get_op_type() == tiramisu::o_call) ||
723  (this->get_op_type() == tiramisu::o_allocate) ||
724  (this->get_op_type() == tiramisu::o_free) ||
725  (this->get_op_type() == tiramisu::o_address_of) ||
726  (this->get_op_type() == tiramisu::o_lin_index) ||
727  (this->get_op_type() == tiramisu::o_dummy));
728 
729  this->name = name;
730  }
731 
732  tiramisu::expr replace_op_in_expr(const std::string &to_replace,
733  const std::string &replace_with)
734  {
735  if (this->name == to_replace) {
736  this->name = replace_with;
737  return *this;
738  }
739  for (int i = 0; i < this->op.size(); i++) {
740  tiramisu::expr operand = this->get_operand(i);
741  this->op[i] = operand.replace_op_in_expr(to_replace, replace_with);
742  }
743  return *this;
744  }
745 
746  /**
747  * Get the type of the operator (tiramisu::op_t).
748  */
750  {
751  return _operator;
752  }
753 
754  /**
755  * Return a vector of the access of the computation
756  * or array.
757  * For example, for the computation C0(i,j), this
758  * function will return the vector {i, j} where i and j
759  * are both tiramisu expressions.
760  * For a buffer access A[i+1,j], it will return also {i+1, j}.
761  */
762  const std::vector<tiramisu::expr> &get_access() const
763  {
764  assert(this->get_expr_type() == tiramisu::e_op);
765  assert(this->get_op_type() == tiramisu::o_access || this->get_op_type() == tiramisu::o_lin_index ||
766  this->get_op_type() == tiramisu::o_address_of || this->get_op_type() == tiramisu::o_dummy ||
767  this->get_op_type() == tiramisu::o_buffer);
768 
769  return access_vector;
770  }
771 
772  /**
773  * Return the arguments of an external function call.
774  */
775  const std::vector<tiramisu::expr> &get_arguments() const
776  {
777  assert(this->get_expr_type() == tiramisu::e_op);
778  assert(this->get_op_type() == tiramisu::o_call);
779 
780  return argument_vector;
781  }
782 
783  /**
784  * Get the number of dimensions in the access vector.
785  */
786  int get_n_dim_access() const
787  {
788  assert(this->get_expr_type() == tiramisu::e_op);
789  assert(this->get_op_type() == tiramisu::o_access);
790 
791  return access_vector.size();
792  }
793 
794  /**
795  * Return true if the expression is defined.
796  */
797  bool is_defined() const
798  {
799  return defined;
800  }
801 
802  /**
803  * Return true if \p e is identical to this expression.
804  */
805  bool is_equal(tiramisu::expr e) const
806  {
807  bool equal = true;
808 
809  /**
810  * The value of the expression.
811  */
812  union
813  {
814  uint8_t uint8_value;
815  int8_t int8_value;
816  uint16_t uint16_value;
817  int16_t int16_value;
818  uint32_t uint32_value;
819  int32_t int32_value;
820  uint64_t uint64_value;
821  int64_t int64_value;
822  float float32_value;
823  double float64_value;
824  };
825 
826 
827  std::vector<tiramisu::expr> access_vector;
828 
829  std::vector<tiramisu::expr> argument_vector;
830 
831  if ((this->_operator != e._operator) ||
832  (this->op.size() != e.op.size()) ||
833  (this->access_vector.size() != e.access_vector.size()) ||
834  (this->argument_vector.size() != e.argument_vector.size()) ||
835  (this->defined != e.defined) ||
836  (this->name != e.name) ||
837  (this->dtype != e.dtype) ||
838  (this->etype != e.etype))
839  {
840  equal = false;
841  return equal;
842  }
843 
844  for (int i = 0; i < this->access_vector.size(); i++)
845  equal = equal && this->access_vector[i].is_equal(e.access_vector[i]);
846 
847  for (int i = 0; i < this->op.size(); i++)
848  equal = equal && this->op[i].is_equal(e.op[i]);
849 
850  for (int i = 0; i < this->argument_vector.size(); i++)
851  equal = equal && this->argument_vector[i].is_equal(e.argument_vector[i]);
852 
853  if ((this->etype == e_val) && (e.etype == e_val))
854  {
855  if (this->get_int_val() != e.get_int_val())
856  equal = false;
857  if ((this->get_data_type() == tiramisu::p_float32) ||
858  (this->get_data_type() == tiramisu::p_float64))
859  if (this->get_double_val() != e.get_double_val())
860  equal = false;
861  }
862 
863  return equal;
864  }
865 
866  /**
867  * Addition.
868  */
869 
870  expr operator+(tiramisu::expr other) const;
871 
872 
873  /**
874  * Subtraction.
875  */
876  expr operator-(tiramisu::expr other) const;
877 
878  /**
879  * Division.
880  */
881  expr operator/(tiramisu::expr other) const;
882 
883  /**
884  * Multiplication.
885  */
886  expr operator*(tiramisu::expr other) const;
887 
888  /**
889  * Modulo.
890  */
891  expr operator%(tiramisu::expr other) const;
892 
893  /**
894  * Right shift operator.
895  */
896  expr operator>>(tiramisu::expr other) const;
897 
898  /**
899  * Left shift operator.
900  */
901  expr operator<<(tiramisu::expr other) const;
902 
903  /**
904  * Logical and of two expressions.
905  */
907  {
908  return tiramisu::expr(tiramisu::o_logical_and, *this, e1);
909  }
910 
911  /**
912  * Logical and of two expressions.
913  */
915  {
916  return tiramisu::expr(tiramisu::o_logical_or, *this, e1);
917  }
918 
919  /**
920  * Expression multiplied by (-1).
921  */
923  {
924  return tiramisu::expr(tiramisu::o_minus, *this);
925  }
926 
927  /**
928  * Logical NOT of an expression.
929  */
931  {
933  }
934 
935  tiramisu::expr& operator=(tiramisu::expr const &);
936 
937  /**
938  * Comparison operator.
939  */
940  // @{
942  {
943  return tiramisu::expr(tiramisu::o_eq, *this, e1);
944  }
946  {
947  return tiramisu::expr(tiramisu::o_ne, *this, e1);
948  }
949  // @}
950 
951  /**
952  * Less than operator.
953  */
955  {
956  return tiramisu::expr(tiramisu::o_lt, *this, e1);
957  }
958 
959  /**
960  * Less than or equal operator.
961  */
963  {
964  return tiramisu::expr(tiramisu::o_le, *this, e1);
965  }
966 
967  /**
968  * Greater than operator.
969  */
971  {
972  return tiramisu::expr(tiramisu::o_gt, *this, e1);
973  }
974 
975  /**
976  * Greater than or equal operator.
977  */
979  {
980  return tiramisu::expr(tiramisu::o_ge, *this, e1);
981  }
982 
983  /**
984  * Set the access of a computation or an array.
985  * For example, for the computation C0, this
986  * function can set the vector {i, j} as an access vector.
987  * The result is that the computation C0 is accessed
988  * with C0(i,j).
989  */
990  void set_access(std::vector<tiramisu::expr> vector)
991  {
992  access_vector = vector;
993  }
994 
995  /**
996  * Set an element of the vector of accesses of a computation.
997  * This changes only one dimension of the access vector.
998  */
1000  {
1001  assert((i < (int)this->access_vector.size()) && "index is out of bounds.");
1002  access_vector[i] = acc;
1003  }
1004 
1005  /**
1006  * Set the arguments of an external function call.
1007  * For example, for the call my_external(C0, 1, C1(i,j)),
1008  * \p vector should be {C0, 1, C1(i,j)}.
1009  */
1010  void set_arguments(std::vector<tiramisu::expr> vector)
1011  {
1012  argument_vector = vector;
1013  }
1014 
1015  /**
1016  * Dump the object on standard output (dump most of the fields of
1017  * the expression class). This is mainly useful for debugging.
1018  * If \p exhaustive is set to true, all the fields of the class are
1019  * printed. This is useful to find potential initialization problems.
1020  */
1021  void dump(bool exhaustive) const
1022  {
1023  if (this->get_expr_type() != e_none)
1024  {
1025  if (exhaustive == true)
1026  {
1027  if (ENABLE_DEBUG && (this->is_defined()))
1028  {
1029  std::cout << "Expression:" << std::endl;
1030  std::cout << "Expression type:" << str_from_tiramisu_type_expr(this->etype) << std::endl;
1031  switch (this->etype)
1032  {
1033  case tiramisu::e_op:
1034  {
1035  std::cout << "Expression operator type:" << str_tiramisu_type_op(this->_operator) << std::endl;
1036  if (this->get_n_arg() > 0)
1037  {
1038  std::cout << "Number of operands:" << this->get_n_arg() << std::endl;
1039  std::cout << "Dumping the operands:" << std::endl;
1040  for (int i = 0; i < this->get_n_arg(); i++)
1041  {
1042  std::cout << "Operand " << std::to_string(i) << "." << std::endl;
1043  this->op[i].dump(exhaustive);
1044  }
1045  }
1046  if ((this->get_op_type() == tiramisu::o_access))
1047  {
1048  std::cout << "Access to " + this->get_name() + ". Access expressions:" << std::endl;
1049  for (const auto &e : this->get_access())
1050  {
1051  e.dump(exhaustive);
1052  }
1053  }
1054  if ((this->get_op_type() == tiramisu::o_address_of)) {
1055  std::cout << "Address to " + this->get_name() + ". Access expressions:" << std::endl;
1056  for (const auto &e : this->get_access()) {
1057  e.dump(exhaustive);
1058  }
1059  }
1060  if ((this->get_op_type() == tiramisu::o_lin_index)) {
1061  std::cout << "Linear address to " + this->get_name() + ". Access expressions:"
1062  << std::endl;
1063  for (const auto &e : this->get_access()) {
1064  e.dump(exhaustive);
1065  }
1066  }
1067  if ((this->get_op_type() == tiramisu::o_call))
1068  {
1069  std::cout << "call to " + this->get_name() + ". Argument expressions:" << std::endl;
1070  for (const auto &e : this->get_arguments())
1071  {
1072  e.dump(exhaustive);
1073  }
1074  }
1075  if ((this->get_op_type() == tiramisu::o_address))
1076  {
1077  std::cout << "Address of the following access : " << std::endl;
1078  this->get_operand(0).dump(true);
1079  }
1080  if ((this->get_op_type() == tiramisu::o_allocate))
1081  {
1082  std::cout << "allocate(" << this->get_name() << ")" << std::endl;
1083  }
1084  if ((this->get_op_type() == tiramisu::o_free))
1085  {
1086  std::cout << "free(" << this->get_name() << ")" << std::endl;
1087  }
1088  break;
1089  }
1090  case (tiramisu::e_val):
1091  {
1092  std::cout << "Expression value type:" << str_from_tiramisu_type_primitive(this->dtype) << std::endl;
1093 
1094  if (this->get_data_type() == tiramisu::p_uint8)
1095  {
1096  std::cout << "Value:" << this->get_uint8_value() << std::endl;
1097  }
1098  else if (this->get_data_type() == tiramisu::p_int8)
1099  {
1100  std::cout << "Value:" << this->get_int8_value() << std::endl;
1101  }
1102  else if (this->get_data_type() == tiramisu::p_uint16)
1103  {
1104  std::cout << "Value:" << this->get_uint16_value() << std::endl;
1105  }
1106  else if (this->get_data_type() == tiramisu::p_int16)
1107  {
1108  std::cout << "Value:" << this->get_int16_value() << std::endl;
1109  }
1110  else if (this->get_data_type() == tiramisu::p_uint32)
1111  {
1112  std::cout << "Value:" << this->get_uint32_value() << std::endl;
1113  }
1114  else if (this->get_data_type() == tiramisu::p_int32)
1115  {
1116  std::cout << "Value:" << this->get_int32_value() << std::endl;
1117  }
1118  else if (this->get_data_type() == tiramisu::p_uint64)
1119  {
1120  std::cout << "Value:" << this->get_uint64_value() << std::endl;
1121  }
1122  else if (this->get_data_type() == tiramisu::p_int64)
1123  {
1124  std::cout << "Value:" << this->get_int64_value() << std::endl;
1125  }
1126  else if (this->get_data_type() == tiramisu::p_float32)
1127  {
1128  std::cout << "Value:" << this->get_float32_value() << std::endl;
1129  }
1130  else if (this->get_data_type() == tiramisu::p_float64)
1131  {
1132  std::cout << "Value:" << this->get_float64_value() << std::endl;
1133  }
1134  break;
1135  }
1136  case (tiramisu::e_var):
1137  {
1138  std::cout << "Var name:" << this->get_name() << std::endl;
1139  std::cout << "Expression value type:" << str_from_tiramisu_type_primitive(this->dtype) << std::endl;
1140  break;
1141  }
1142  case (tiramisu::e_sync):
1143  std::cout << "Sync object" << std::endl;
1144  break;
1145  default:
1146  ERROR("Expression type not supported.", true);
1147  }
1148  }
1149  }
1150  else
1151  {
1152  std::cout << this->to_str();
1153  }
1154  }
1155  }
1156 
1157  /**
1158  * Return true if this expression is a literal constant (i.e., 0, 1, 2, ...).
1159  **/
1160  bool is_constant() const
1161  {
1162  if (this->get_expr_type() == tiramisu::e_val)
1163  return true;
1164  else
1165  return false;
1166  }
1167 
1168  bool is_unbounded() const
1169  {
1170  if (this->get_name() == "_unbounded")
1171  return true;
1172  else
1173  return false;
1174  }
1175 
1176  /**
1177  * Simplify the expression.
1178  */
1180  {
1181  if (this->get_expr_type() != e_none)
1182  {
1183  switch (this->etype)
1184  {
1185  case tiramisu::e_op:
1186  {
1187  switch (this->get_op_type())
1188  {
1190  return *this;
1192  return *this;
1193  case tiramisu::o_max:
1194  return *this;
1195  case tiramisu::o_min:
1196  return *this;
1197  case tiramisu::o_minus:
1198  return *this;
1199  case tiramisu::o_add:
1200  this->get_operand(0).simplify();
1201  this->get_operand(1).simplify();
1202  if ((this->get_operand(0).get_expr_type() == tiramisu::e_val) && (this->get_operand(1).get_expr_type() == tiramisu::e_val))
1203  if ((this->get_operand(0).get_data_type() == tiramisu::p_int32))
1204  return expr(this->get_operand(0).get_int_val() + this->get_operand(1).get_int_val());
1205  case tiramisu::o_sub:
1206  this->get_operand(0).simplify();
1207  this->get_operand(1).simplify();
1208  if ((this->get_operand(0).get_expr_type() == tiramisu::e_val) && (this->get_operand(1).get_expr_type() == tiramisu::e_val))
1209  if ((this->get_operand(0).get_data_type() == tiramisu::p_int32))
1210  return expr(this->get_operand(0).get_int_val() - this->get_operand(1).get_int_val());
1211  case tiramisu::o_mul:
1212  this->get_operand(0).simplify();
1213  this->get_operand(1).simplify();
1214  if ((this->get_operand(0).get_expr_type() == tiramisu::e_val) && (this->get_operand(1).get_expr_type() == tiramisu::e_val))
1215  if ((this->get_operand(0).get_data_type() == tiramisu::p_int32))
1216  return expr(this->get_operand(0).get_int_val() * this->get_operand(1).get_int_val());
1217  case tiramisu::o_div:
1218  return *this;
1219  case tiramisu::o_mod:
1220  return *this;
1221  case tiramisu::o_select:
1222  return *this;
1223  case tiramisu::o_cond:
1224  return *this;
1225  case tiramisu::o_lerp:
1226  return *this;
1227  case tiramisu::o_le:
1228  return *this;
1229  case tiramisu::o_lt:
1230  return *this;
1231  case tiramisu::o_ge:
1232  return *this;
1233  case tiramisu::o_gt:
1234  return *this;
1236  return *this;
1237  case tiramisu::o_eq:
1238  return *this;
1239  case tiramisu::o_ne:
1240  return *this;
1242  return *this;
1244  return *this;
1245  case tiramisu::o_floor:
1246  return *this;
1247  case tiramisu::o_sin:
1248  return *this;
1249  case tiramisu::o_cos:
1250  return *this;
1251  case tiramisu::o_tan:
1252  return *this;
1253  case tiramisu::o_atan:
1254  return *this;
1255  case tiramisu::o_acos:
1256  return *this;
1257  case tiramisu::o_asin:
1258  return *this;
1259  case tiramisu::o_sinh:
1260  return *this;
1261  case tiramisu::o_cosh:
1262  return *this;
1263  case tiramisu::o_tanh:
1264  return *this;
1265  case tiramisu::o_asinh:
1266  return *this;
1267  case tiramisu::o_acosh:
1268  return *this;
1269  case tiramisu::o_atanh:
1270  return *this;
1271  case tiramisu::o_abs:
1272  return *this;
1273  case tiramisu::o_sqrt:
1274  return *this;
1275  case tiramisu::o_expo:
1276  return *this;
1277  case tiramisu::o_log:
1278  return *this;
1279  case tiramisu::o_ceil:
1280  return *this;
1281  case tiramisu::o_round:
1282  return *this;
1283  case tiramisu::o_trunc:
1284  return *this;
1285  case tiramisu::o_cast:
1286  return *this;
1287  case tiramisu::o_access:
1288  return *this;
1289  case tiramisu::o_call:
1290  return *this;
1291  case tiramisu::o_address:
1292  return *this;
1293  case tiramisu::o_allocate:
1294  return *this;
1295  case tiramisu::o_free:
1296  return *this;
1297  default:
1298  ERROR("Simplifying an unsupported tiramisu expression.", 1);
1299  }
1300  break;
1301  }
1302  case (tiramisu::e_val):
1303  {
1304  return *this;
1305  }
1306  case (tiramisu::e_var):
1307  {
1308  return *this;
1309  }
1310  default:
1311  ERROR("Expression type not supported.", true);
1312  }
1313  }
1314 
1315  return *this;
1316  }
1317 
1318  std::string to_str() const
1319  {
1320  std::string str = std::string("");
1321 
1322  if (this->get_expr_type() != e_none)
1323  {
1324  switch (this->etype)
1325  {
1326  case tiramisu::e_op:
1327  {
1328  switch (this->get_op_type())
1329  {
1331  str += "(";
1332  this->get_operand(0).dump(false);
1333  str += " && ";
1334  str += this->get_operand(1).to_str();
1335  str += ")";
1336  break;
1338  str += "(" + this->get_operand(0).to_str();
1339  str += " || " + this->get_operand(1).to_str();
1340  str += ")";
1341  break;
1342  case tiramisu::o_max:
1343  str += "max(" + this->get_operand(0).to_str();
1344  str += ", " + this->get_operand(1).to_str();
1345  str += ")";
1346  break;
1347  case tiramisu::o_min:
1348  str += "min(" + this->get_operand(0).to_str();
1349  str += ", " + this->get_operand(1).to_str();
1350  str += ")";
1351  break;
1352  case tiramisu::o_minus:
1353  str += "(-" + this->get_operand(0).to_str();
1354  str += ")";
1355  break;
1356  case tiramisu::o_add:
1357  str += "(" + this->get_operand(0).to_str();
1358  str += " + " + this->get_operand(1).to_str();
1359  str += ")";
1360  break;
1361  case tiramisu::o_sub:
1362  str += "(" + this->get_operand(0).to_str();
1363  str += " - " + this->get_operand(1).to_str();
1364  str += ")";
1365  break;
1366  case tiramisu::o_mul:
1367  str += "(" + this->get_operand(0).to_str();
1368  str += " * " + this->get_operand(1).to_str();
1369  str += ")";
1370  break;
1371  case tiramisu::o_div:
1372  str += "(" + this->get_operand(0).to_str();
1373  str += " / " + this->get_operand(1).to_str();
1374  str += ")";
1375  break;
1376  case tiramisu::o_mod:
1377  str += "(" + this->get_operand(0).to_str();
1378  str += " % " + this->get_operand(1).to_str();
1379  str += ")";
1380  break;
1381  case tiramisu::o_memcpy:
1382  str += "memcpy(" + this->get_operand(0).to_str();
1383  str += ", " + this->get_operand(1).to_str();
1384  str += ")";
1385  break;
1386  case tiramisu::o_select:
1387  str += "select(" + this->get_operand(0).to_str();
1388  str += ", " + this->get_operand(1).to_str();
1389  str += ", " + this->get_operand(2).to_str();
1390  str += ")";
1391  break;
1392  case tiramisu::o_cond:
1393  str += "if(" + this->get_operand(0).to_str();
1394  str += "):(" + this->get_operand(1).to_str();
1395  str += ")";
1396  break;
1397  case tiramisu::o_lerp:
1398  str += "lerp(" + this->get_operand(0).to_str();
1399  str += ", " + this->get_operand(1).to_str();
1400  str += ", " + this->get_operand(2).to_str();
1401  str += ")";
1402  break;
1403  case tiramisu::o_le:
1404  str += "(" + this->get_operand(0).to_str();
1405  str += " <= " + this->get_operand(1).to_str();
1406  str += ")";
1407  break;
1408  case tiramisu::o_lt:
1409  str += "(" + this->get_operand(0).to_str();
1410  str += " < " + this->get_operand(1).to_str();
1411  str += ")";
1412  break;
1413  case tiramisu::o_ge:
1414  str += "(" + this->get_operand(0).to_str();
1415  str += " >= " + this->get_operand(1).to_str();
1416  str += ")";
1417  break;
1418  case tiramisu::o_gt:
1419  str += "(" + this->get_operand(0).to_str();
1420  str += " > " + this->get_operand(1).to_str();
1421  str += ")";
1422  break;
1424  str += "(!" + this->get_operand(0).to_str();
1425  str += ")";
1426  break;
1427  case tiramisu::o_eq:
1428  str += "(" + this->get_operand(0).to_str();
1429  str += " == " + this->get_operand(1).to_str();
1430  str += ")";
1431  break;
1432  case tiramisu::o_ne:
1433  str += "(" + this->get_operand(0).to_str();
1434  str += " != " + this->get_operand(1).to_str();
1435  str += ")";
1436  break;
1438  str += "(" + this->get_operand(0).to_str();
1439  str += " >> " + this->get_operand(1).to_str();
1440  str += ")";
1441  break;
1443  str += "(" + this->get_operand(0).to_str();
1444  str += " << " + this->get_operand(1).to_str();
1445  str += ")";
1446  break;
1447  case tiramisu::o_floor:
1448  str += "floor(" + this->get_operand(0).to_str();
1449  str += ") ";
1450  break;
1451  case tiramisu::o_sin:
1452  str += "sin(" + this->get_operand(0).to_str();
1453  str += ") ";
1454  break;
1455  case tiramisu::o_cos:
1456  str += "cos(" + this->get_operand(0).to_str();
1457  str += ") ";
1458  break;
1459  case tiramisu::o_tan:
1460  str += "tan(" + this->get_operand(0).to_str();
1461  str += ") ";
1462  break;
1463  case tiramisu::o_atan:
1464  str += "atan(" + this->get_operand(0).to_str();
1465  str += ") ";
1466  break;
1467  case tiramisu::o_acos:
1468  str += "acos(" + this->get_operand(0).to_str();
1469  str += ") ";
1470  break;
1471  case tiramisu::o_asin:
1472  str += "asin(" + this->get_operand(0).to_str();
1473  str += ") ";
1474  break;
1475  case tiramisu::o_sinh:
1476  str += "sinh(" + this->get_operand(0).to_str();
1477  str += ") ";
1478  break;
1479  case tiramisu::o_cosh:
1480  str += "cosh(" + this->get_operand(0).to_str();
1481  str += ") ";
1482  break;
1483  case tiramisu::o_tanh:
1484  str += "tanh(" + this->get_operand(0).to_str();
1485  str += ") ";
1486  break;
1487  case tiramisu::o_asinh:
1488  str += "asinh(" + this->get_operand(0).to_str();
1489  str += ") ";
1490  break;
1491  case tiramisu::o_acosh:
1492  str += "acosh(" + this->get_operand(0).to_str();
1493  str += ") ";
1494  break;
1495  case tiramisu::o_atanh:
1496  str += "atanh(" + this->get_operand(0).to_str();
1497  str += ") ";
1498  break;
1499  case tiramisu::o_abs:
1500  str += "abs(" + this->get_operand(0).to_str();
1501  str += ") ";
1502  break;
1503  case tiramisu::o_sqrt:
1504  str += "sqrt(" + this->get_operand(0).to_str();
1505  str += ") ";
1506  break;
1507  case tiramisu::o_expo:
1508  str += "exp(" + this->get_operand(0).to_str();
1509  str += ") ";
1510  break;
1511  case tiramisu::o_log:
1512  str += "log(" + this->get_operand(0).to_str();
1513  str += ") ";
1514  break;
1515  case tiramisu::o_ceil:
1516  str += "ceil(" + this->get_operand(0).to_str();
1517  str += ") ";
1518  break;
1519  case tiramisu::o_round:
1520  str += "round(" + this->get_operand(0).to_str();
1521  str += ") ";
1522  break;
1523  case tiramisu::o_trunc:
1524  str += "trunc(" + this->get_operand(0).to_str();
1525  str += ") ";
1526  break;
1527  case tiramisu::o_cast:
1528  str += "cast(" + this->get_operand(0).to_str();
1529  str += ") ";
1530  break;
1531  case tiramisu::o_access:
1533  case tiramisu::o_lin_index:
1534  case tiramisu::o_buffer:
1535  str += this->get_name() + "(";
1536  for (int k = 0; k < this->get_access().size(); k++)
1537  {
1538  if (k != 0)
1539  {
1540  str += ", ";
1541  }
1542  str += this->get_access()[k].to_str();
1543  }
1544  str += ")";
1545  break;
1546  case tiramisu::o_call:
1547  str += this->get_name() + "(";
1548  for (int k = 0; k < this->get_arguments().size(); k++)
1549  {
1550  if (k != 0)
1551  {
1552  str += ", ";
1553  }
1554  str += this->get_arguments()[k].to_str();
1555  }
1556  str += ")";
1557  break;
1558  case tiramisu::o_address:
1559  str += "&" + this->get_operand(0).get_name();
1560  break;
1561  case tiramisu::o_allocate:
1562  str += "allocate(" + this->get_name() + ")";
1563  break;
1564  case tiramisu::o_free:
1565  str += "free(" + this->get_name() + ")";
1566  break;
1567  default:
1568  ERROR("Dumping an unsupported tiramisu expression.", 1);
1569  }
1570  break;
1571  }
1572  case (tiramisu::e_val):
1573  {
1574  if (this->get_data_type() == tiramisu::p_uint8)
1575  {
1576  str += std::to_string((int)this->get_uint8_value());
1577  }
1578  else if (this->get_data_type() == tiramisu::p_int8)
1579  {
1580  str += std::to_string((int)this->get_int8_value());
1581  }
1582  else if (this->get_data_type() == tiramisu::p_uint16)
1583  {
1584  str += std::to_string(this->get_uint16_value());
1585  }
1586  else if (this->get_data_type() == tiramisu::p_int16)
1587  {
1588  str += std::to_string(this->get_int16_value());
1589  }
1590  else if (this->get_data_type() == tiramisu::p_uint32)
1591  {
1592  str += std::to_string(this->get_uint32_value());
1593  }
1594  else if (this->get_data_type() == tiramisu::p_int32)
1595  {
1596  str += std::to_string(this->get_int32_value());
1597  }
1598  else if (this->get_data_type() == tiramisu::p_uint64)
1599  {
1600  str += std::to_string(this->get_uint64_value());
1601  }
1602  else if (this->get_data_type() == tiramisu::p_int64)
1603  {
1604  str += std::to_string(this->get_int64_value());
1605  }
1606  else if (this->get_data_type() == tiramisu::p_float32)
1607  {
1608  str += std::to_string(this->get_float32_value());
1609  }
1610  else if (this->get_data_type() == tiramisu::p_float64)
1611  {
1612  str += std::to_string(this->get_float64_value());
1613  }
1614  break;
1615  }
1616  case (tiramisu::e_var):
1617  {
1618  str += this->get_name();
1619  break;
1620  }
1621  case (tiramisu::e_sync):
1622  {
1623  str += "sync object";
1624  break;
1625  }
1626  default:
1627  ERROR("Expression type not supported.", true);
1628  }
1629  }
1630 
1631  return str;
1632  }
1633 
1634  /**
1635  * Returns a new expression where for every (var, sub) pair in \p substitutions,
1636  * var in the original expression is replaced by sub.
1637  * For example: if \p substitutions is {(i, 5), (j, i)}, and the original expression is
1638  * i + j * 2, then this method returns 5 + i * 2.
1639  */
1640  expr substitute(std::vector<std::pair<var, expr>> substitutions) const;
1641 
1642  /**
1643  * Returns an expression where every access to a computation named
1644  * \p original is replaced with an access to a computation named
1645  * \p substitute, with the same access indices.
1646  * An example where this is useful is when modifying a computation
1647  * that was designed to work with a host buffer to work with a GPU
1648  * buffer.
1649  */
1650  expr substitute_access(std::string original, std::string substitute) const;
1651 
1652  expr apply_to_operands(std::function<expr(const expr &)> f) const
1653  {
1654  tiramisu::expr e{*this};
1655  for (int i = 0; i < access_vector.size(); i++)
1656  e.access_vector[i] = f(e.access_vector[i]);
1657  for (int i = 0; i < op.size(); i++)
1658  e.op[i] = f(e.op[i]);
1659  for (int i = 0; i < argument_vector.size(); i++)
1660  e.argument_vector[i] = f(e.argument_vector[i]);
1661 
1662  return e;
1663  }
1664 
1665  /** Create a variable that can be used that a dimension is unbounded.
1666  * i < tiramisu::expr::unbounded()
1667  * means that i does not have an upper bound.
1668  * i > tiramisu::expr::unbounded()
1669  * means that i does not have a lower bound.
1670  */
1671  static expr unbounded()
1672  {
1673  tiramisu::expr e;
1674  e.name = "_unbounded";
1675  e.etype = tiramisu::e_val;
1676  e._operator = tiramisu::o_none;
1677  e.defined = true;
1678  e.dtype = tiramisu::p_none;
1679  return e;
1680  }
1681 };
1682 
1683 /**
1684  * A class that represents a synchronization object.
1685  * e.g. in the context of GPUs this will get transformed to
1686  * __syncthreads();
1687  */
1688 class sync : public tiramisu::expr
1689 {
1690 public:
1691  sync() : expr()
1692  {
1693  etype = e_sync;
1694  _operator = o_none;
1695  dtype = p_none;
1696  defined = true;
1697 
1698  }
1699 };
1700 
1701 /**
1702  * A class that represents constant variable references
1703  */
1704 class var: public tiramisu::expr
1705 {
1706  friend computation;
1707 private:
1708  // TODO if more than one scope, variables are to be declared per scope
1709  /**
1710  * If a variable gets declared and saved, (either through calling a public constructor,
1711  * or through calling a private constructor with save set to true), then a mapping from
1712  * the name of the variable to the variable object is added.
1713  * The point of this is to make sure that all variables with the same name have the same
1714  * type, and thus are equal.
1715  */
1716  static std::unordered_map<std::string, var> declared_vars;
1717 
1718  /**
1719  * This has the same as the var(name), except that if \p save is false, then whatever
1720  * variable is created, it is not stored in declared_vars, and therefore calling this
1721  * constructor has no effect on the creation of future var objects.
1722  */
1723  var(std::string name, bool save);
1724 
1725  /**
1726  * This has the same as the var(type, name), except that if \p save is false, then whatever
1727  * variable is created, it is not stored in declared_vars, and therefore calling this
1728  * constructor has no effect on the creation of future var objects.
1729  */
1730  var(tiramisu::primitive_t type, std::string name, bool save);
1731 
1732  /**
1733  * lower loop bound when the variable is used as an iterator.
1734  */
1735  expr lower;
1736 
1737  /**
1738  * upper loop bound when the variable is used as an iterator.
1739  */
1740  expr upper;
1741 
1742 public:
1743  /**
1744  * Construct an expression that represents a variable.
1745  *
1746  * \p type is the type of the variable and \p name is its name.
1747  * If a variable with the same name has previously been declared,
1748  * but with a different type, this constructor will fail.
1749  * That way two variables with the same name are necessarily equal.
1750  */
1751  var(tiramisu::primitive_t type, std::string name) : var(type, name, true) {}
1752 
1753  /**
1754  * Construct an expression that represents an untyped variable.
1755  * For example to declare the variable "t", use
1756  * tiramisu::var("t");
1757  * If a variable with the same name has previously been declared, this
1758  * object will have the same type (i.e. it will be equal to the other variable object).
1759  *
1760  */
1761  var(std::string name) : var(name, true) {}
1762 
1763  /**
1764  * Construct a loop iterator that has \p name as a name.
1765  *
1766  * \p lower and \p upper are expressions that represent the lower and upper
1767  * bounds of this iterator. For example, the iterator i in the following
1768  * for loop
1769  *
1770  * \code
1771  * for (i = 0; i < 10; i++)
1772  * \endcode
1773  *
1774  * can be declared as
1775  *
1776  * \code
1777  * var i("i", expr(0), expr(10));
1778  * \endcode
1779  *
1780  */
1781  var(std::string name, expr lower_bound, expr upper_bound) : var(name, true)
1782  {
1783  lower = lower_bound;
1784  upper = upper_bound;
1785  }
1786 
1787  /* Construct an expression that represents an untyped variable.
1788  * The name of the variable is generated automatically.
1789  * For example to declare a variable, use
1790  * tiramisu::var t;
1791  */
1793 };
1794 
1795 /**
1796  * Convert a Tiramisu expression into a Halide expression.
1797  */
1798 Halide::Expr halide_expr_from_tiramisu_expr(
1799  const tiramisu::computation *comp,
1800  std::vector<isl_ast_expr *> &index_expr,
1801  const tiramisu::expr &tiramisu_expr);
1802 
1803 
1804 /**
1805  * Takes in a primitive value \p val, and returns an expression
1806  * of tiramisu type \p tT that represents \p val.
1807  */
1808 template <typename cT>
1810 
1811 // static_assert(std::is_fundamental<cT>::value, "Type must be fundamental");
1812 
1813  switch (tT) {
1814 
1815  case p_int8:
1816  return expr{static_cast<int8_t>(val)};
1817  case p_uint8:
1818  return expr{static_cast<uint8_t>(val)};
1819  case p_int16:
1820  return expr{static_cast<int16_t>(val)};
1821  case p_uint16:
1822  return expr{static_cast<uint16_t>(val)};
1823  case p_int32:
1824  return expr{static_cast<int32_t>(val)};
1825  case p_uint32:
1826  return expr{static_cast<uint32_t>(val)};
1827  case p_int64:
1828  return expr{static_cast<int64_t>(val)};
1829  case p_uint64:
1830  return expr{static_cast<uint64_t>(val)};
1831  case p_float32:
1832  return expr{static_cast<float>(val)};
1833  case p_float64:
1834  return expr{static_cast<double>(val)};
1835  default:
1836  throw std::invalid_argument{"Type not supported"};
1837  }
1838 }
1839 
1840 /**
1841  * Returns an expression that casts \p e to \p tT.
1842  */
1843 expr cast(primitive_t tT, const expr & e);
1844 
1845 
1846 template <typename T>
1848 {
1849  return e + value_cast(e.get_data_type(), val);
1850 }
1851 
1852 template <typename T>
1854 {
1855  return value_cast(e.get_data_type(), val) + e;
1856 }
1857 
1858 template <typename T>
1860 {
1861  return e - value_cast(e.get_data_type(), val);
1862 }
1863 
1864 template <typename T>
1866 {
1867  return value_cast(e.get_data_type(), val) - e;
1868 }
1869 
1870 template <typename T>
1872 {
1873  return e / expr{val};
1874 }
1875 
1876 template <typename T>
1878 {
1879  return expr{val} / e;
1880 }
1881 
1882 template <typename T>
1884 {
1885  return e * value_cast(e.get_data_type(), val);
1886 }
1887 
1888 template <typename T>
1890 {
1891  return value_cast(e.get_data_type(), val) * e;
1892 }
1893 
1894 template <typename T>
1896 {
1897  return e % expr{val};
1898 }
1899 
1900 template <typename T>
1902 {
1903  return expr{val} % e;
1904 }
1905 
1906 template <typename T>
1908 {
1909  return e >> expr{val};
1910 }
1911 
1912 template <typename T>
1914 {
1915  return expr{val} >> e;
1916 }
1917 
1918 template <typename T>
1920 {
1921  return e << expr{val};
1922 }
1923 
1924 template <typename T>
1926 {
1927  return expr{val} << e;
1928 }
1929 
1930 expr memcpy(const buffer& from, const buffer& to);
1931 expr allocate(const buffer& b);
1932 
1933 }
1934 #endif
uint8_t get_uint8_value() const
Return the actual value of the expression.
Definition: expr.h:504
int get_n_arg() const
Return the number of arguments of the operator.
Definition: expr.h:676
bool is_equal(tiramisu::expr e) const
Return true if e is identical to this expression.
Definition: expr.h:805
int16_t int16_value
Definition: expr.h:177
bool is_unbounded() const
Definition: expr.h:1168
static function * get_implicit_function()
Return the implicit function created during Tiramisu initialization.
Definition: expr.h:80
double get_float64_value() const
Return the actual value of the expression.
Definition: expr.h:576
A class that represents computations.
Definition: core.h:1320
only_integral< T > operator/(const tiramisu::expr &e, T val)
Definition: expr.h:1871
expr_t
The possible types of an expression.
Definition: type.h:14
expr cast(primitive_t tT, const expr &e)
Returns an expression that casts e to tT.
expr value_cast(primitive_t tT, cT val)
Takes in a primitive value val, and returns an expression of tiramisu type tT that represents val...
Definition: expr.h:1809
expr(tiramisu::op_t o, tiramisu::primitive_t dtype, tiramisu::expr expr0)
Create a cast expression to type t (a unary operator).
Definition: expr.h:243
tiramisu::expr operator!=(tiramisu::expr e1) const
Comparison operator.
Definition: expr.h:945
std::string str_tiramisu_type_op(tiramisu::op_t type)
float float32_value
Definition: expr.h:182
tiramisu::primitive_t get_data_type() const
Get the data type of the expression.
Definition: expr.h:694
int8_t get_int8_value() const
Return the actual value of the expression.
Definition: expr.h:512
int8_t int8_value
Definition: expr.h:175
primitive_t
tiramisu data types.
Definition: type.h:27
float get_float32_value() const
Return the actual value of the expression.
Definition: expr.h:568
static expr unbounded()
Create a variable that can be used that a dimension is unbounded.
Definition: expr.h:1671
tiramisu::primitive_t dtype
Data type.
Definition: expr.h:219
static primitive_t get_loop_iterator_data_type()
Definition: expr.h:134
void set_arguments(std::vector< tiramisu::expr > vector)
Set the arguments of an external function call.
Definition: expr.h:1010
Halide::Expr halide_expr_from_tiramisu_expr(const tiramisu::computation *comp, std::vector< isl_ast_expr * > &index_expr, const tiramisu::expr &tiramisu_expr)
Convert a Tiramisu expression into a Halide expression.
const std::vector< tiramisu::expr > & get_arguments() const
Return the arguments of an external function call.
Definition: expr.h:775
expr memcpy(const buffer &from, const buffer &to)
var(tiramisu::primitive_t type, std::string name)
Construct an expression that represents a variable.
Definition: expr.h:1751
expr()
Create an undefined expression.
Definition: expr.h:231
std::string generate_new_variable_name()
tiramisu::expr operator>(tiramisu::expr e1) const
Greater than operator.
Definition: expr.h:970
uint32_t uint32_value
Definition: expr.h:178
void dump(bool exhaustive) const
Dump the object on standard output (dump most of the fields of the expression class).
Definition: expr.h:1021
void set_access(std::vector< tiramisu::expr > vector)
Set the access of a computation or an array.
Definition: expr.h:990
int result
Definition: cuda_ast.h:705
only_integral< T > operator%(const tiramisu::expr &e, T val)
Definition: expr.h:1895
static void set_default_tiramisu_options()
Definition: expr.h:124
A class that represents buffers.
Definition: core.h:1017
bool is_defined() const
Return true if the expression is defined.
Definition: expr.h:797
expr(double val)
Construct a 64-bit float expression.
Definition: expr.h:490
expr(int8_t val)
Construct a signed 8-bit integer expression.
Definition: expr.h:381
A class that represents a synchronization object.
Definition: expr.h:1688
expr(uint8_t val)
Construct an unsigned 8-bit integer expression.
Definition: expr.h:368
only_integral< T > operator*(const tiramisu::expr &e, T val)
Definition: expr.h:1883
only_integral< T > operator>>(const tiramisu::expr &e, T val)
Definition: expr.h:1907
expr(tiramisu::op_t o, std::string name, std::vector< tiramisu::expr > vec, tiramisu::primitive_t type)
Construct an access or a call.
Definition: expr.h:332
const std::vector< tiramisu::expr > & get_access() const
Return a vector of the access of the computation or array.
Definition: expr.h:762
int64_t get_int_val() const
Definition: expr.h:585
uint64_t uint64_value
Definition: expr.h:180
tiramisu::expr simplify() const
Simplify the expression.
Definition: expr.h:1179
expr(tiramisu::op_t o, tiramisu::expr expr0, tiramisu::expr expr1)
Construct an expression for a binary operator.
Definition: expr.h:290
tiramisu::expr operator>=(tiramisu::expr e1) const
Greater than or equal operator.
Definition: expr.h:978
A class to represent tiramisu expressions.
Definition: expr.h:150
tiramisu::expr_t get_expr_type() const
Return the type of the expression (tiramisu::expr_type).
Definition: expr.h:686
tiramisu::expr_t etype
The type of the expression.
Definition: expr.h:224
expr(float val)
Construct a 32-bit float expression.
Definition: expr.h:472
static void set_auto_data_mapping(bool v)
If this option is set to true, Tiramisu automatically modifies the computation data mapping whenever ...
Definition: expr.h:106
only_integral< T > operator-(const tiramisu::expr &e, T val)
Definition: expr.h:1859
static bool is_auto_data_mapping_set()
Return whether auto data mapping is set.
Definition: expr.h:119
expr(uint16_t val)
Construct an unsigned 16-bit integer expression.
Definition: expr.h:394
int get_n_dim_access() const
Get the number of dimensions in the access vector.
Definition: expr.h:786
uint64_t get_uint64_value() const
Return the actual value of the expression.
Definition: expr.h:552
double get_double_val() const
Definition: expr.h:639
tiramisu::expr operator<(tiramisu::expr e1) const
Less than operator.
Definition: expr.h:954
static void set_implicit_function(function *fct)
Return the implicit function created during Tiramisu initialization.
Definition: expr.h:94
expr apply_to_operands(std::function< expr(const expr &)> f) const
Definition: expr.h:1652
expr(tiramisu::op_t o, tiramisu::expr expr0)
Create an expression for a unary operator.
Definition: expr.h:258
A class that holds all the global variables necessary for Tiramisu.
Definition: expr.h:47
int16_t get_int16_value() const
Return the actual value of the expression.
Definition: expr.h:528
expr(uint64_t val)
Construct an unsigned 64-bit integer expression.
Definition: expr.h:446
std::string str_from_tiramisu_type_expr(tiramisu::expr_t type)
int32_t int32_value
Definition: expr.h:179
tiramisu::expr operator||(tiramisu::expr e1) const
Logical and of two expressions.
Definition: expr.h:914
const tiramisu::expr & get_operand(int i) const
Return the value of the i &#39;th operand of the expression.
Definition: expr.h:665
uint32_t get_uint32_value() const
Return the actual value of the expression.
Definition: expr.h:536
int32_t get_int32_value() const
Return the actual value of the expression.
Definition: expr.h:544
uint16_t get_uint16_value() const
Return the actual value of the expression.
Definition: expr.h:520
expr(tiramisu::op_t o, std::string name)
Create an expression for a unary operator that applies on a variable.
Definition: expr.h:277
op_t
Types of tiramisu operators.
Definition: type.h:53
std::string str_from_tiramisu_type_primitive(tiramisu::primitive_t type)
tiramisu::expr operator&&(tiramisu::expr e1) const
Logical and of two expressions.
Definition: expr.h:906
tiramisu::op_t get_op_type() const
Get the type of the operator (tiramisu::op_t).
Definition: expr.h:749
expr allocate(const buffer &b)
tiramisu::expr operator<=(tiramisu::expr e1) const
Less than or equal operator.
Definition: expr.h:962
var(std::string name)
Construct an expression that represents an untyped variable.
Definition: expr.h:1761
static void set_loop_iterator_type(primitive_t t)
Definition: expr.h:130
double float64_value
Definition: expr.h:183
expr(int64_t val)
Construct a signed 64-bit integer expression.
Definition: expr.h:459
tiramisu::expr replace_op_in_expr(const std::string &to_replace, const std::string &replace_with)
Definition: expr.h:732
expr(tiramisu::op_t o, tiramisu::expr expr0, tiramisu::expr expr1, tiramisu::expr expr2)
Construct an expression for a ternary operator.
Definition: expr.h:314
std::string name
Identifier name.
Definition: expr.h:214
const std::string & get_name() const
Get the name of the ID or the variable represented by this expressions.
Definition: expr.h:702
only_integral< T > operator+(const tiramisu::expr &e, T val)
Definition: expr.h:1847
int64_t get_int64_value() const
Return the actual value of the expression.
Definition: expr.h:560
only_integral< T > operator<<(const tiramisu::expr &e, T val)
Definition: expr.h:1919
void set_name(std::string &name)
Definition: expr.h:718
tiramisu::expr operator==(tiramisu::expr e1) const
Comparison operator.
Definition: expr.h:941
Definition: core.h:27
int64_t int64_value
Definition: expr.h:181
var(std::string name, expr lower_bound, expr upper_bound)
Construct a loop iterator that has name as a name.
Definition: expr.h:1781
typename std::enable_if< std::is_integral< T >::value, expr >::type only_integral
Definition: expr.h:41
tiramisu::expr operator-() const
Expression multiplied by (-1).
Definition: expr.h:922
tiramisu::expr operator!() const
Logical NOT of an expression.
Definition: expr.h:930
A class that represents constant variable references.
Definition: expr.h:1704
expr(int16_t val)
Construct a signed 16-bit integer expression.
Definition: expr.h:407
expr(uint32_t val)
Construct an unsigned 32-bit integer expression.
Definition: expr.h:420
std::string to_str() const
Definition: expr.h:1318
bool is_constant() const
Return true if this expression is a literal constant (i.e., 0, 1, 2, ...).
Definition: expr.h:1160
expr(int32_t val)
Construct a signed 32-bit integer expression.
Definition: expr.h:433
uint8_t uint8_value
Definition: expr.h:174
void set_access_dimension(int i, tiramisu::expr acc)
Set an element of the vector of accesses of a computation.
Definition: expr.h:999
uint16_t uint16_value
Definition: expr.h:176
A class for code generation.
Definition: core.h:4288