diff --git a/.clang-format b/.clang-format
index bf8ff45ff1f436bb1bb358bded3ebbf5da219bbd..0b9a40282f8a247085847be735807513844c0301 100644
--- a/.clang-format
+++ b/.clang-format
@@ -1,30 +1,134 @@
----
+BasedOnStyle: Mozilla
+Language:        Cpp
+AccessModifierOffset: -2
 AlignAfterOpenBracket: Align
-AlignConsecutiveAssignments: true
+AlignConsecutiveAssignments: false
 AlignConsecutiveDeclarations: true
-AlignEscapedNewlinesLeft: true
+AlignEscapedNewlines: Left
 AlignOperands:   true
 AlignTrailingComments: true
+# clang 9.0 AllowAllArgumentsOnNextLine: true
+# clang 9.0 AllowAllConstructorInitializersOnNextLine: true
+AllowAllParametersOfDeclarationOnNextLine: false
 AllowShortBlocksOnASingleLine: false
 AllowShortCaseLabelsOnASingleLine: false
-AllowShortFunctionsOnASingleLine: false
+AllowShortFunctionsOnASingleLine: Inline
+# clang 9.0 AllowShortLambdasOnASingleLine: All
+# clang 9.0 features AllowShortIfStatementsOnASingleLine: Never
 AllowShortIfStatementsOnASingleLine: false
 AllowShortLoopsOnASingleLine: false
-AlwaysBreakBeforeMultilineStrings: true
-AlwaysBreakTemplateDeclarations: true
-BreakBeforeBraces: Allman
-ColumnLimit:     160
-ConstructorInitializerAllOnOneLineOrOnePerLine: true
+AlwaysBreakAfterDefinitionReturnType: None
+AlwaysBreakAfterReturnType: All
+AlwaysBreakBeforeMultilineStrings: false
+AlwaysBreakTemplateDeclarations: Yes
+BinPackArguments: false
+BinPackParameters: false
+BreakBeforeBraces: Custom
+BraceWrapping:
+  # clang 9.0 feature AfterCaseLabel:  false
+  AfterClass:      true
+  AfterControlStatement: true
+  AfterEnum:       true
+  AfterFunction:   true
+  AfterNamespace:  true
+  AfterObjCDeclaration: true
+  AfterStruct:     true
+  AfterUnion:      true
+  AfterExternBlock: true
+  BeforeCatch:     true
+  BeforeElse:      true
+## This is the big change from historical ITK formatting!
+# Historically ITK used a style similar to https://en.wikipedia.org/wiki/Indentation_style#Whitesmiths_style
+# with indented braces, and not indented code.  This style is very difficult to automatically
+# maintain with code beautification tools.  Not indenting braces is more common among
+# formatting tools.
+  IndentBraces:    false
+  SplitEmptyFunction: false
+  SplitEmptyRecord: false
+  SplitEmptyNamespace: false
+BreakBeforeBinaryOperators: None
+#clang 6.0 BreakBeforeInheritanceComma: true
+BreakInheritanceList: BeforeComma
+BreakBeforeTernaryOperators: true
+#clang 6.0 BreakConstructorInitializersBeforeComma: true
+BreakConstructorInitializers: BeforeComma
+BreakAfterJavaFieldAnnotations: false
+BreakStringLiterals: true
+## The following line allows larger lines in non-documentation code
+ColumnLimit:     120
+CommentPragmas:  '^ IWYU pragma:'
+CompactNamespaces: false
+ConstructorInitializerAllOnOneLineOrOnePerLine: false
 ConstructorInitializerIndentWidth: 2
-Cpp11BracedListStyle: true
+ContinuationIndentWidth: 2
+Cpp11BracedListStyle: false
 DerivePointerAlignment: false
+DisableFormat:   false
+ExperimentalAutoDetectBinPacking: false
+FixNamespaceComments: true
+ForEachMacros:
+  - foreach
+  - Q_FOREACH
+  - BOOST_FOREACH
+IncludeBlocks:   Preserve
+IncludeCategories:
+  - Regex:           '^"(llvm|llvm-c|clang|clang-c)/'
+    Priority:        2
+  - Regex:           '^(<|"(gtest|gmock|isl|json)/)'
+    Priority:        3
+  - Regex:           '.*'
+    Priority:        1
+IncludeIsMainRegex: '(Test)?$'
+IndentCaseLabels: true
+IndentPPDirectives: AfterHash
 IndentWidth:     2
-Language:        Cpp
+IndentWrappedFunctionNames: false
+JavaScriptQuotes: Leave
+JavaScriptWrapImports: true
+KeepEmptyLinesAtTheStartOfBlocks: true
+MacroBlockBegin: ''
+MacroBlockEnd:   ''
 MaxEmptyLinesToKeep: 2
 NamespaceIndentation: None
-PointerAlignment: Left
-SortIncludes: false
+ObjCBinPackProtocolList: Auto
+ObjCBlockIndentWidth: 2
+ObjCSpaceAfterProperty: true
+ObjCSpaceBeforeProtocolList: false
+PenaltyBreakAssignment: 2
+PenaltyBreakBeforeFirstCallParameter: 19
+PenaltyBreakComment: 300
+## The following line allows larger lines in non-documentation code
+PenaltyBreakFirstLessLess: 120
+PenaltyBreakString: 1000
+PenaltyBreakTemplateDeclaration: 10
+PenaltyExcessCharacter: 1000000
+PenaltyReturnTypeOnItsOwnLine: 200
+PointerAlignment: Middle
+ReflowComments:  true
+# We may want to sort the includes as a separate pass
+SortIncludes:    false
+# We may want to revisit this later
+SortUsingDeclarations: false
+SpaceAfterCStyleCast: false
+# SpaceAfterLogicalNot: false
+SpaceAfterTemplateKeyword: true
+SpaceBeforeAssignmentOperators: true
+SpaceBeforeCpp11BracedList: false
+SpaceBeforeCtorInitializerColon: true
+SpaceBeforeInheritanceColon: true
+SpaceBeforeParens: ControlStatements
+SpaceBeforeRangeBasedForLoopColon: true
+SpaceInEmptyParentheses: false
+SpacesBeforeTrailingComments: 1
+SpacesInAngles:  false
+SpacesInContainerLiterals: false
+SpacesInCStyleCastParentheses: false
+SpacesInParentheses: false
+SpacesInSquareBrackets: false
 Standard:        Cpp11
+StatementMacros:
+  - Q_UNUSED
+  - QT_REQUIRE_VERSION
 TabWidth:        2
 UseTab:          Never
 ...
diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml
index a55d9469def828967e232a131d872c2886122808..4b9fe883d917e245f9f738a8a09b7d595dad5138 100644
--- a/.gitlab-ci.yml
+++ b/.gitlab-ci.yml
@@ -22,17 +22,12 @@ stages:
 
 before_script:
   - *update_otbtf_src
-  - *compile_otbtf
-  - sudo rm -rf $OTB_BUILD/Testing/Temporary/*  # Empty testing temporary folder (old files here)
-  - pip3 install pytest pytest-cov
 
 build:
   stage: Build
   allow_failure: false
   script:
     - *compile_otbtf
-    - sudo rm -rf $OTB_BUILD/Testing/Temporary/*  # Empty testing temporary folder (old files here)
-
 
 flake8:
   stage: Static Analysis
@@ -58,6 +53,8 @@ cppcheck:
 ctest:
   stage: Test
   script:
+    - *compile_otbtf
+    - sudo rm -rf $OTB_BUILD/Testing/Temporary/*  # Empty testing temporary folder (old files here)
     - cd $OTB_BUILD/ && sudo ctest -L OTBTensorflow  # Run ctest
   after_script:
     - cp -r $OTB_BUILD/Testing/Temporary $CI_PROJECT_DIR/testing  # Copy artifacts (they must be in $CI_PROJECT_DIR)
@@ -70,6 +67,8 @@ ctest:
 sr4rs:
   stage: Test
   script:
+    - *compile_otbtf
+    - pip3 install pytest pytest-cov
     - cd $CI_PROJECT_DIR
     - wget -O sr4rs_sentinel2_bands4328_france2020_savedmodel.zip
       https://nextcloud.inrae.fr/s/boabW9yCjdpLPGX/download/sr4rs_sentinel2_bands4328_france2020_savedmodel.zip
diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md
index dd69f197f23a980d9fc4759acb59f40eecd412a1..cd39540e7c15ad843a6f52502ea717faf1c0e38a 100644
--- a/CONTRIBUTORS.md
+++ b/CONTRIBUTORS.md
@@ -1,8 +1,8 @@
-Remi Cresson
-Nicolas Narcon
-Benjamin Commandre
-Vincent Delbar
-Loic Lozac'h
-Pratyush Das
-Doctor Who
-Jordi Inglada
+- Remi Cresson
+- Nicolas Narcon
+- Benjamin Commandre
+- Vincent Delbar
+- Loic Lozac'h
+- Pratyush Das
+- Doctor Who
+- Jordi Inglada
diff --git a/include/otbTensorflowCopyUtils.cxx b/include/otbTensorflowCopyUtils.cxx
index 3449e11b976d4c8db86b4d4206247a2da489eccf..7b3cb969cd241a98f7d1411faf072ba06a8b0577 100644
--- a/include/otbTensorflowCopyUtils.cxx
+++ b/include/otbTensorflowCopyUtils.cxx
@@ -11,27 +11,31 @@
 =========================================================================*/
 #include "otbTensorflowCopyUtils.h"
 
-namespace otb {
-namespace tf {
+namespace otb
+{
+namespace tf
+{
 
 //
 // Display a TensorShape
 //
-std::string PrintTensorShape(const tensorflow::TensorShape & shp)
+std::string
+PrintTensorShape(const tensorflow::TensorShape & shp)
 {
   std::stringstream s;
-  unsigned int nDims = shp.dims();
+  unsigned int      nDims = shp.dims();
   s << "{" << shp.dim_size(0);
-  for (unsigned int d = 1 ; d < nDims ; d++)
+  for (unsigned int d = 1; d < nDims; d++)
     s << ", " << shp.dim_size(d);
-  s << "}" ;
+  s << "}";
   return s.str();
 }
 
 //
 // Display infos about a tensor
 //
-std::string PrintTensorInfos(const tensorflow::Tensor & tensor)
+std::string
+PrintTensorInfos(const tensorflow::Tensor & tensor)
 {
   std::stringstream s;
   s << "Tensor ";
@@ -45,11 +49,12 @@ std::string PrintTensorInfos(const tensorflow::Tensor & tensor)
 //
 // Create a tensor with the good datatype
 //
-template<class TImage>
-tensorflow::Tensor CreateTensor(tensorflow::TensorShape & shape)
+template <class TImage>
+tensorflow::Tensor
+CreateTensor(tensorflow::TensorShape & shape)
 {
   tensorflow::DataType ts_dt = GetTensorflowDataType<typename TImage::InternalPixelType>();
-  tensorflow::Tensor out_tensor(ts_dt, shape);
+  tensorflow::Tensor   out_tensor(ts_dt, shape);
 
   return out_tensor;
 }
@@ -58,32 +63,35 @@ tensorflow::Tensor CreateTensor(tensorflow::TensorShape & shape)
 // Populate a tensor with the buffered region of a vector image using std::copy
 // Warning: tensor datatype must be consistent with the image value type
 //
-template<class TImage>
-void PopulateTensorFromBufferedVectorImage(const typename TImage::Pointer bufferedimagePtr, tensorflow::Tensor & out_tensor)
+template <class TImage>
+void
+PopulateTensorFromBufferedVectorImage(const typename TImage::Pointer bufferedimagePtr, tensorflow::Tensor & out_tensor)
 {
-  size_t n_elem = bufferedimagePtr->GetNumberOfComponentsPerPixel() *
-      bufferedimagePtr->GetBufferedRegion().GetNumberOfPixels();
-  std::copy_n(bufferedimagePtr->GetBufferPointer(),
-      n_elem,
-      out_tensor.flat<typename TImage::InternalPixelType>().data());
+  size_t n_elem =
+    bufferedimagePtr->GetNumberOfComponentsPerPixel() * bufferedimagePtr->GetBufferedRegion().GetNumberOfPixels();
+  std::copy_n(
+    bufferedimagePtr->GetBufferPointer(), n_elem, out_tensor.flat<typename TImage::InternalPixelType>().data());
 }
 
 //
 // Recopy an VectorImage region into a 4D-shaped tensorflow::Tensor ({-1, sz_y, sz_x, sz_bands})
 //
-template<class TImage, class TValueType=typename TImage::InternalPixelType>
-void RecopyImageRegionToTensor(const typename TImage::Pointer inputPtr, const typename TImage::RegionType & region,
-    tensorflow::Tensor & tensor, unsigned int elemIdx) // element position along the 1st dimension
+template <class TImage, class TValueType = typename TImage::InternalPixelType>
+void
+RecopyImageRegionToTensor(const typename TImage::Pointer      inputPtr,
+                          const typename TImage::RegionType & region,
+                          tensorflow::Tensor &                tensor,
+                          unsigned int                        elemIdx) // element position along the 1st dimension
 {
   typename itk::ImageRegionConstIterator<TImage> inIt(inputPtr, region);
-  unsigned int nBands = inputPtr->GetNumberOfComponentsPerPixel();
-  auto tMap = tensor.tensor<TValueType, 4>();
+  unsigned int                                   nBands = inputPtr->GetNumberOfComponentsPerPixel();
+  auto                                           tMap = tensor.tensor<TValueType, 4>();
   for (inIt.GoToBegin(); !inIt.IsAtEnd(); ++inIt)
   {
     const int y = inIt.GetIndex()[1] - region.GetIndex()[1];
     const int x = inIt.GetIndex()[0] - region.GetIndex()[0];
 
-    for (unsigned int band = 0 ; band < nBands ; band++)
+    for (unsigned int band = 0; band < nBands; band++)
       tMap(elemIdx, y, x, band) = inIt.Get()[band];
   }
 }
@@ -92,9 +100,12 @@ void RecopyImageRegionToTensor(const typename TImage::Pointer inputPtr, const ty
 // Type-agnostic version of the 'RecopyImageRegionToTensor' function
 // TODO: add some numeric types
 //
-template<class TImage>
-void RecopyImageRegionToTensorWithCast(const typename TImage::Pointer inputPtr, const typename TImage::RegionType & region,
-    tensorflow::Tensor & tensor, unsigned int elemIdx) // element position along the 1st dimension
+template <class TImage>
+void
+RecopyImageRegionToTensorWithCast(const typename TImage::Pointer      inputPtr,
+                                  const typename TImage::RegionType & region,
+                                  tensorflow::Tensor &                tensor,
+                                  unsigned int elemIdx) // element position along the 1st dimension
 {
   tensorflow::DataType dt = tensor.dtype();
   if (dt == tensorflow::DT_FLOAT)
@@ -110,21 +121,25 @@ void RecopyImageRegionToTensorWithCast(const typename TImage::Pointer inputPtr,
   else if (dt == tensorflow::DT_INT32)
     RecopyImageRegionToTensor<TImage, int>(inputPtr, region, tensor, elemIdx);
   else if (dt == tensorflow::DT_UINT16)
-    RecopyImageRegionToTensor<TImage, unsigned short int> (inputPtr, region, tensor, elemIdx);
+    RecopyImageRegionToTensor<TImage, unsigned short int>(inputPtr, region, tensor, elemIdx);
   else if (dt == tensorflow::DT_INT16)
     RecopyImageRegionToTensor<TImage, short int>(inputPtr, region, tensor, elemIdx);
   else if (dt == tensorflow::DT_UINT8)
-    RecopyImageRegionToTensor<TImage, unsigned char> (inputPtr, region, tensor, elemIdx);
+    RecopyImageRegionToTensor<TImage, unsigned char>(inputPtr, region, tensor, elemIdx);
   else
-    itkGenericExceptionMacro("TF DataType "<< dt << " not currently implemented !");
+    itkGenericExceptionMacro("TF DataType " << dt << " not currently implemented !");
 }
 
 //
 // Sample a centered patch (from index)
 //
-template<class TImage>
-void SampleCenteredPatch(const typename TImage::Pointer inputPtr, const typename TImage::IndexType & centerIndex, const typename TImage::SizeType & patchSize,
-    tensorflow::Tensor & tensor, unsigned int elemIdx)
+template <class TImage>
+void
+SampleCenteredPatch(const typename TImage::Pointer     inputPtr,
+                    const typename TImage::IndexType & centerIndex,
+                    const typename TImage::SizeType &  patchSize,
+                    tensorflow::Tensor &               tensor,
+                    unsigned int                       elemIdx)
 {
   typename TImage::IndexType regionStart;
   regionStart[0] = centerIndex[0] - patchSize[0] / 2;
@@ -136,9 +151,13 @@ void SampleCenteredPatch(const typename TImage::Pointer inputPtr, const typename
 //
 // Sample a centered patch (from coordinates)
 //
-template<class TImage>
-void SampleCenteredPatch(const typename TImage::Pointer inputPtr, const typename TImage::PointType & centerCoord, const typename TImage::SizeType & patchSize,
-    tensorflow::Tensor & tensor, unsigned int elemIdx)
+template <class TImage>
+void
+SampleCenteredPatch(const typename TImage::Pointer     inputPtr,
+                    const typename TImage::PointType & centerCoord,
+                    const typename TImage::SizeType &  patchSize,
+                    tensorflow::Tensor &               tensor,
+                    unsigned int                       elemIdx)
 {
   // Assuming tensor is of shape {-1, sz_y, sz_x, sz_bands}
   // Get the index of the center
@@ -154,10 +173,11 @@ void SampleCenteredPatch(const typename TImage::Pointer inputPtr, const typename
 // shape {x, y, c}    --> c (e.g. a patch)
 // shape {n, x, y, c} --> c (e.g. some patches)
 //
-tensorflow::int64 GetNumberOfChannelsForOutputTensor(const tensorflow::Tensor & tensor)
+tensorflow::int64
+GetNumberOfChannelsForOutputTensor(const tensorflow::Tensor & tensor)
 {
   const tensorflow::TensorShape shape = tensor.shape();
-  const int nDims = shape.dims();
+  const int                     nDims = shape.dims();
   if (nDims == 1)
     return 1;
   return shape.dim_size(nDims - 1);
@@ -172,9 +192,13 @@ tensorflow::int64 GetNumberOfChannelsForOutputTensor(const tensorflow::Tensor &
 // shape {n, c}       --> c (e.g. a vector)
 // shape {x, y, c}    --> c (e.g. a multichannel image)
 //
-template<class TImage, class TValueType>
-void CopyTensorToImageRegion(const tensorflow::Tensor & tensor, const typename TImage::RegionType & bufferRegion,
-                             typename TImage::Pointer outputPtr, const typename TImage::RegionType & outputRegion, int & channelOffset)
+template <class TImage, class TValueType>
+void
+CopyTensorToImageRegion(const tensorflow::Tensor &          tensor,
+                        const typename TImage::RegionType & bufferRegion,
+                        typename TImage::Pointer            outputPtr,
+                        const typename TImage::RegionType & outputRegion,
+                        int &                               channelOffset)
 {
 
   // Flatten the tensor
@@ -191,15 +215,13 @@ void CopyTensorToImageRegion(const tensorflow::Tensor & tensor, const typename T
   const tensorflow::int64 nElmI = bufferRegion.GetNumberOfPixels() * outputDimSize_C;
   if (nElmI != nElmT)
   {
-    itkGenericExceptionMacro("Number of elements in the tensor is " << nElmT <<
-        " but image outputRegion has " << nElmI <<
-        " values to fill.\nBuffer region:\n" << bufferRegion <<
-        "\nNumber of components: " << outputDimSize_C <<
-        "\nTensor shape:\n " << PrintTensorShape(tensor.shape()) <<
-        "\nPlease check the input(s) field of view (FOV), " <<
-        "the output field of expression (FOE), and the  " <<
-        "output spacing scale if you run the model in fully " <<
-        "convolutional mode (how many strides in your model?)");
+    itkGenericExceptionMacro("Number of elements in the tensor is "
+                             << nElmT << " but image outputRegion has " << nElmI << " values to fill.\nBuffer region:\n"
+                             << bufferRegion << "\nNumber of components: " << outputDimSize_C << "\nTensor shape:\n "
+                             << PrintTensorShape(tensor.shape()) << "\nPlease check the input(s) field of view (FOV), "
+                             << "the output field of expression (FOE), and the  "
+                             << "output spacing scale if you run the model in fully "
+                             << "convolutional mode (how many strides in your model?)");
   }
 
   // Iterate over the image
@@ -212,145 +234,216 @@ void CopyTensorToImageRegion(const tensorflow::Tensor & tensor, const typename T
     // TODO: it could be useful to change the tensor-->image mapping here.
     // e.g use a lambda for "pos" calculation
     const int pos = outputDimSize_C * (y * nCols + x);
-    for (unsigned int c = 0 ; c < outputDimSize_C ; c++)
-      outIt.Get()[channelOffset + c] = tFlat( pos + c);
+    for (unsigned int c = 0; c < outputDimSize_C; c++)
+      outIt.Get()[channelOffset + c] = tFlat(pos + c);
   }
 
   // Update the offset
   channelOffset += outputDimSize_C;
-
 }
 
 //
 // Type-agnostic version of the 'CopyTensorToImageRegion' function
-// TODO: add some numeric types
 //
-template<class TImage>
-void CopyTensorToImageRegion(const tensorflow::Tensor & tensor, const typename TImage::RegionType & bufferRegion,
-                             typename TImage::Pointer outputPtr, const typename TImage::RegionType & region, int & channelOffset)
+template <class TImage>
+void
+CopyTensorToImageRegion(const tensorflow::Tensor &          tensor,
+                        const typename TImage::RegionType & bufferRegion,
+                        typename TImage::Pointer            outputPtr,
+                        const typename TImage::RegionType & region,
+                        int &                               channelOffset)
 {
   tensorflow::DataType dt = tensor.dtype();
   if (dt == tensorflow::DT_FLOAT)
-    CopyTensorToImageRegion<TImage, float>        (tensor, bufferRegion, outputPtr, region, channelOffset);
+    CopyTensorToImageRegion<TImage, float>(tensor, bufferRegion, outputPtr, region, channelOffset);
   else if (dt == tensorflow::DT_DOUBLE)
-    CopyTensorToImageRegion<TImage, double>       (tensor, bufferRegion, outputPtr, region, channelOffset);
+    CopyTensorToImageRegion<TImage, double>(tensor, bufferRegion, outputPtr, region, channelOffset);
+  else if (dt == tensorflow::DT_UINT64)
+    CopyTensorToImageRegion<TImage, unsigned long long int>(tensor, bufferRegion, outputPtr, region, channelOffset);
   else if (dt == tensorflow::DT_INT64)
     CopyTensorToImageRegion<TImage, long long int>(tensor, bufferRegion, outputPtr, region, channelOffset);
+  else if (dt == tensorflow::DT_UINT32)
+    CopyTensorToImageRegion<TImage, unsigned int>(tensor, bufferRegion, outputPtr, region, channelOffset);
   else if (dt == tensorflow::DT_INT32)
-    CopyTensorToImageRegion<TImage, int>          (tensor, bufferRegion, outputPtr, region, channelOffset);
+    CopyTensorToImageRegion<TImage, int>(tensor, bufferRegion, outputPtr, region, channelOffset);
+  else if (dt == tensorflow::DT_UINT16)
+    CopyTensorToImageRegion<TImage, unsigned short int>(tensor, bufferRegion, outputPtr, region, channelOffset);
+  else if (dt == tensorflow::DT_INT16)
+    CopyTensorToImageRegion<TImage, short int>(tensor, bufferRegion, outputPtr, region, channelOffset);
+  else if (dt == tensorflow::DT_UINT8)
+    CopyTensorToImageRegion<TImage, unsigned char>(tensor, bufferRegion, outputPtr, region, channelOffset);
   else
-    itkGenericExceptionMacro("TF DataType "<< dt << " not currently implemented !");
-
+    itkGenericExceptionMacro("TF DataType " << dt << " not currently implemented !");
 }
 
 //
 // Compare two string lowercase
 //
-bool iequals(const std::string& a, const std::string& b)
+bool
+iequals(const std::string & a, const std::string & b)
 {
-  return std::equal(a.begin(), a.end(),
-      b.begin(), b.end(),
-      [](char cha, char chb) {
-    return tolower(cha) == tolower(chb);
-  });
+  return std::equal(
+    a.begin(), a.end(), b.begin(), b.end(), [](char cha, char chb) { return tolower(cha) == tolower(chb); });
 }
 
-// Convert an expression into a dict
-//
+// Convert a value into a tensor
 // Following types are supported:
 // -bool
 // -int
 // -float
+// -vector of float
+//
+// e.g. "true", "0.2", "14", "(1.2, 4.2, 4)"
 //
-// e.g. is_training=true, droptout=0.2, nfeat=14
-std::pair<std::string, tensorflow::Tensor> ExpressionToTensor(std::string expression)
+// TODO: we could add some other types (e.g. string)
+tensorflow::Tensor
+ValueToTensor(std::string value)
 {
-  std::pair<std::string, tensorflow::Tensor> dict;
 
+  std::vector<std::string> values;
 
-    std::size_t found = expression.find("=");
-    if (found != std::string::npos)
-    {
-      // Find name and value
-      std::string name = expression.substr(0, found);
-      std::string value = expression.substr(found+1);
+  // Check if value is a vector or a scalar
+  const bool has_left = (value[0] == '(');
+  const bool has_right = value[value.size() - 1] == ')';
 
-      dict.first = name;
+  // Check consistency
+  bool is_vec = false;
+  if (has_left || has_right)
+  {
+    is_vec = true;
+    if (!has_left || !has_right)
+      itkGenericExceptionMacro("Error parsing vector expression (missing parenthese ?)" << value);
+  }
 
-      // Find type
-      std::size_t found_dot = value.find(".") != std::string::npos;
-      std::size_t is_digit = value.find_first_not_of("0123456789.") == std::string::npos;
-      if (is_digit)
+  // Scalar --> Vector for generic processing
+  if (!is_vec)
+  {
+    values.push_back(value);
+  }
+  else
+  {
+    // Remove "(" and ")" chars
+    std::string trimmed_value = value.substr(1, value.size() - 2);
+
+    // Split string into vector using "," delimiter
+    std::regex                 rgx("\\s*,\\s*");
+    std::sregex_token_iterator iter{ trimmed_value.begin(), trimmed_value.end(), rgx, -1 };
+    std::sregex_token_iterator end;
+    values = std::vector<std::string>({ iter, end });
+  }
+
+  // Find type
+  bool has_dot = false;
+  bool is_digit = true;
+  for (auto & val : values)
+  {
+    has_dot = has_dot || val.find(".") != std::string::npos;
+    is_digit = is_digit && val.find_first_not_of("-0123456789.") == std::string::npos;
+  }
+
+  // Create tensor
+  tensorflow::TensorShape shape({});
+  tensorflow::Tensor      out(tensorflow::DT_BOOL, shape);
+  if (is_digit)
+  {
+    if (has_dot)
+      out = tensorflow::Tensor(tensorflow::DT_FLOAT, shape);
+    else
+      out = tensorflow::Tensor(tensorflow::DT_INT32, shape);
+  }
+
+  // Fill tensor
+  unsigned int idx = 0;
+  for (auto & val : values)
+  {
+
+    if (is_digit)
+    {
+      if (has_dot)
       {
-        if (found_dot)
+        // FLOAT
+        try
         {
-          // FLOAT
-          try
-          {
-            float val = std::stof(value);
-            tensorflow::Tensor out(tensorflow::DT_FLOAT, tensorflow::TensorShape());
-            out.scalar<float>()() = val;
-            dict.second = out;
-
-          }
-          catch(...)
-          {
-            itkGenericExceptionMacro("Error parsing name="
-                << name << " with value=" << value << " as float");
-          }
-
+          out.scalar<float>()(idx) = std::stof(val);
         }
-        else
+        catch (...)
         {
-          // INT
-          try
-          {
-            int val = std::stoi(value);
-            tensorflow::Tensor out(tensorflow::DT_INT32, tensorflow::TensorShape());
-            out.scalar<int>()() = val;
-            dict.second = out;
-
-          }
-          catch(...)
-          {
-            itkGenericExceptionMacro("Error parsing name="
-                << name << " with value=" << value << " as int");
-          }
-
+          itkGenericExceptionMacro("Error parsing value \"" << val << "\" as float");
         }
       }
       else
       {
-        // BOOL
-        bool val = true;
-        if (iequals(value, "true"))
+        // INT
+        try
         {
-          val = true;
+          out.scalar<int>()(idx) = std::stoi(val);
         }
-        else if (iequals(value, "false"))
+        catch (...)
         {
-          val = false;
+          itkGenericExceptionMacro("Error parsing value \"" << val << "\" as int");
         }
-        else
-        {
-          itkGenericExceptionMacro("Error parsing name="
-                          << name << " with value=" << value << " as bool");
-        }
-        tensorflow::Tensor out(tensorflow::DT_BOOL, tensorflow::TensorShape());
-        out.scalar<bool>()() = val;
-        dict.second = out;
       }
-
     }
     else
     {
-      itkGenericExceptionMacro("The following expression is not valid: "
-          << "\n\t" << expression
-          << ".\nExpression must be in the form int_value=1 or float_value=1.0 or bool_value=true.");
+      // BOOL
+      bool ret = true;
+      if (iequals(val, "true"))
+      {
+        ret = true;
+      }
+      else if (iequals(val, "false"))
+      {
+        ret = false;
+      }
+      else
+      {
+        itkGenericExceptionMacro("Error parsing value \"" << val << "\" as bool");
+      }
+      out.scalar<bool>()(idx) = ret;
     }
+    idx++;
+  }
 
-    return dict;
+  return out;
+}
+
+// Convert an expression into a dict
+//
+// Following types are supported:
+// -bool
+// -int
+// -float
+// -vector of float
+//
+// e.g. is_training=true, droptout=0.2, nfeat=14, x=(1.2, 4.2, 4)
+std::pair<std::string, tensorflow::Tensor>
+ExpressionToTensor(std::string expression)
+{
+  std::pair<std::string, tensorflow::Tensor> dict;
+
+
+  std::size_t found = expression.find("=");
+  if (found != std::string::npos)
+  {
+    // Find name and value
+    std::string name = expression.substr(0, found);
+    std::string value = expression.substr(found + 1);
+
+    dict.first = name;
+
+    // Transform value into tensorflow::Tensor
+    dict.second = ValueToTensor(value);
+  }
+  else
+  {
+    itkGenericExceptionMacro("The following expression is not valid: "
+                             << "\n\t" << expression << ".\nExpression must be in one of the following form:"
+                             << "\n- int32_value=1 \n- float_value=1.0 \n- bool_value=true."
+                             << "\n- float_vec=(1.0, 5.253, 2)");
+  }
 
+  return dict;
 }
 
 } // end namespace tf
diff --git a/include/otbTensorflowCopyUtils.h b/include/otbTensorflowCopyUtils.h
index f21097229d645b5a910b2e3de35dbf24222c7ac7..78a15c8cea5ecf02e718cf11c1a4a9f86ca25cc4 100644
--- a/include/otbTensorflowCopyUtils.h
+++ b/include/otbTensorflowCopyUtils.h
@@ -28,6 +28,7 @@
 
 // STD
 #include <string>
+#include <regex>
 
 namespace otb {
 namespace tf {
@@ -75,6 +76,9 @@ void CopyTensorToImageRegion(const tensorflow::Tensor & tensor, typename TImage:
 template<class TImage>
 void CopyTensorToImageRegion(const tensorflow::Tensor & tensor, const typename TImage::RegionType & bufferRegion, typename TImage::Pointer outputPtr, const typename TImage::RegionType & outputRegion, int & channelOffset);
 
+// Convert a value into a tensor
+tensorflow::Tensor ValueToTensor(std::string value);
+
 // Convert an expression into a dict
 std::pair<std::string, tensorflow::Tensor> ExpressionToTensor(std::string expression);
 
diff --git a/python/ckpt2savedmodel.py b/python/ckpt2savedmodel.py
index ec757dc77d721b2b5bd9e0cb1680fb3450672bc7..4235be8fabe280b78ea25da0fd7c6ec9a8a2e495 100755
--- a/python/ckpt2savedmodel.py
+++ b/python/ckpt2savedmodel.py
@@ -18,21 +18,27 @@
 #   limitations under the License.
 #
 # ==========================================================================*/
+"""
+This application converts a checkpoint into a SavedModel, that can be used in
+TensorflowModelTrain or TensorflowModelServe OTB applications.
+This is intended to work mostly with tf.v1 models, since the models in tf.v2
+can be more conveniently exported as SavedModel (see how to build a model with
+keras in Tensorflow 2).
+"""
 import argparse
 from tricks import ckpt_to_savedmodel
 
-# Parser
-parser = argparse.ArgumentParser()
-parser.add_argument("--ckpt", help="Checkpoint file (without the \".meta\" extension)", required=True)
-parser.add_argument("--inputs", help="Inputs names (e.g. [\"x_cnn_1:0\", \"x_cnn_2:0\"])", required=True, nargs='+')
-parser.add_argument("--outputs", help="Outputs names (e.g. [\"prediction:0\", \"features:0\"])", required=True,
-                    nargs='+')
-parser.add_argument("--model", help="Output directory for SavedModel", required=True)
-parser.add_argument('--clear_devices', dest='clear_devices', action='store_true')
-parser.set_defaults(clear_devices=False)
-params = parser.parse_args()
-
 if __name__ == "__main__":
+    parser = argparse.ArgumentParser()
+    parser.add_argument("--ckpt", help="Checkpoint file (without the \".meta\" extension)", required=True)
+    parser.add_argument("--inputs", help="Inputs names (e.g. [\"x_cnn_1:0\", \"x_cnn_2:0\"])", required=True, nargs='+')
+    parser.add_argument("--outputs", help="Outputs names (e.g. [\"prediction:0\", \"features:0\"])", required=True,
+                        nargs='+')
+    parser.add_argument("--model", help="Output directory for SavedModel", required=True)
+    parser.add_argument('--clear_devices', dest='clear_devices', action='store_true')
+    parser.set_defaults(clear_devices=False)
+    params = parser.parse_args()
+
     ckpt_to_savedmodel(ckpt_path=params.ckpt,
                        inputs=params.inputs,
                        outputs=params.outputs,
diff --git a/python/otbtf.py b/python/otbtf.py
index e8c4e5c9ed54e2c6bbc24f494d6bace6bc88e0a6..988a68208658232699c850ca132bb78599c52f07 100644
--- a/python/otbtf.py
+++ b/python/otbtf.py
@@ -24,11 +24,11 @@ OTBTF framework.
 import threading
 import multiprocessing
 import time
+import logging
+from abc import ABC, abstractmethod
 import numpy as np
 import tensorflow as tf
 import gdal
-import logging
-from abc import ABC, abstractmethod
 
 
 """
@@ -40,7 +40,7 @@ def gdal_open(filename):
     """
     Open a GDAL raster
     :param filename: raster file
-    :return: a GDAL ds instance
+    :return: a GDAL gdal_ds instance
     """
     gdal_ds = gdal.Open(filename)
     if gdal_ds is None:
@@ -84,13 +84,23 @@ class Buffer:
         self.container = []
 
     def size(self):
+        """
+        Returns the buffer size
+        """
         return len(self.container)
 
-    def add(self, x):
-        self.container.append(x)
+    def add(self, new_element):
+        """
+        Add an element in the buffer
+        :param new_element: new element to add
+        """
+        self.container.append(new_element)
         assert self.size() <= self.max_length
 
     def is_complete(self):
+        """
+        Return True if the buffer is at full capacity
+        """
         return self.size() == self.max_length
 
 
@@ -180,64 +190,61 @@ class PatchesImagesReader(PatchesReaderBase):
 
         assert len(filenames_dict.values()) > 0
 
-        # ds dict
-        self.ds = dict()
-        for src_key, src_filenames in filenames_dict.items():
-            self.ds[src_key] = []
-            for src_filename in src_filenames:
-                self.ds[src_key].append(gdal_open(src_filename))
+        # gdal_ds dict
+        self.gdal_ds = {key: [gdal_open(src_fn) for src_fn in src_fns] for key, src_fns in filenames_dict.items()}
 
-        if len(set([len(ds_list) for ds_list in self.ds.values()])) != 1:
+        # check number of patches in each sources
+        if len({len(ds_list) for ds_list in self.gdal_ds.values()}) != 1:
             raise Exception("Each source must have the same number of patches images")
 
         # streaming on/off
         self.use_streaming = use_streaming
 
-        # ds check
-        nb_of_patches = {key: 0 for key in self.ds}
+        # gdal_ds check
+        nb_of_patches = {key: 0 for key in self.gdal_ds}
         self.nb_of_channels = dict()
-        for src_key, ds_list in self.ds.items():
-            for ds in ds_list:
-                nb_of_patches[src_key] += self._get_nb_of_patches(ds)
+        for src_key, ds_list in self.gdal_ds.items():
+            for gdal_ds in ds_list:
+                nb_of_patches[src_key] += self._get_nb_of_patches(gdal_ds)
                 if src_key not in self.nb_of_channels:
-                    self.nb_of_channels[src_key] = ds.RasterCount
+                    self.nb_of_channels[src_key] = gdal_ds.RasterCount
                 else:
-                    if self.nb_of_channels[src_key] != ds.RasterCount:
+                    if self.nb_of_channels[src_key] != gdal_ds.RasterCount:
                         raise Exception("All patches images from one source must have the same number of channels!"
                                         "Error happened for source: {}".format(src_key))
         if len(set(nb_of_patches.values())) != 1:
             raise Exception("Sources must have the same number of patches! Number of patches: {}".format(nb_of_patches))
 
-        # ds sizes
-        src_key_0 = list(self.ds)[0]  # first key
-        self.ds_sizes = [self._get_nb_of_patches(ds) for ds in self.ds[src_key_0]]
+        # gdal_ds sizes
+        src_key_0 = list(self.gdal_ds)[0]  # first key
+        self.ds_sizes = [self._get_nb_of_patches(ds) for ds in self.gdal_ds[src_key_0]]
         self.size = sum(self.ds_sizes)
 
         # if use_streaming is False, we store in memory all patches images
         if not self.use_streaming:
-            patches_list = {src_key: [read_as_np_arr(ds) for ds in self.ds[src_key]] for src_key in self.ds}
-            self.patches_buffer = {src_key: np.concatenate(patches_list[src_key], axis=-1) for src_key in self.ds}
+            patches_list = {src_key: [read_as_np_arr(ds) for ds in self.gdal_ds[src_key]] for src_key in self.gdal_ds}
+            self.patches_buffer = {src_key: np.concatenate(patches_list[src_key], axis=0) for src_key in self.gdal_ds}
 
     def _get_ds_and_offset_from_index(self, index):
         offset = index
-        for index, ds_size in enumerate(self.ds_sizes):
+        for idx, ds_size in enumerate(self.ds_sizes):
             if offset < ds_size:
                 break
             offset -= ds_size
 
-        return index, offset
+        return idx, offset
 
     @staticmethod
-    def _get_nb_of_patches(ds):
-        return int(ds.RasterYSize / ds.RasterXSize)
+    def _get_nb_of_patches(gdal_ds):
+        return int(gdal_ds.RasterYSize / gdal_ds.RasterXSize)
 
     @staticmethod
-    def _read_extract_as_np_arr(ds, offset):
-        assert ds is not None
-        psz = ds.RasterXSize
+    def _read_extract_as_np_arr(gdal_ds, offset):
+        assert gdal_ds is not None
+        psz = gdal_ds.RasterXSize
         yoff = int(offset * psz)
-        assert yoff + psz <= ds.RasterYSize
-        buffer = ds.ReadAsArray(0, yoff, psz, psz)
+        assert yoff + psz <= gdal_ds.RasterYSize
+        buffer = gdal_ds.ReadAsArray(0, yoff, psz, psz)
         if len(buffer.shape) == 3:
             buffer = np.transpose(buffer, axes=(1, 2, 0))
         return np.float32(buffer)
@@ -252,14 +259,14 @@ class PatchesImagesReader(PatchesReaderBase):
              ...
              "src_key_M": np.array((psz_y_M, psz_x_M, nb_ch_M))}
         """
-        assert 0 <= index
+        assert index >= 0
         assert index < self.size
 
         if not self.use_streaming:
-            res = {src_key: self.patches_buffer[src_key][index, :, :, :] for src_key in self.ds}
+            res = {src_key: self.patches_buffer[src_key][index, :, :, :] for src_key in self.gdal_ds}
         else:
             i, offset = self._get_ds_and_offset_from_index(index)
-            res = {src_key: self._read_extract_as_np_arr(self.ds[src_key][i], offset) for src_key in self.ds}
+            res = {src_key: self._read_extract_as_np_arr(self.gdal_ds[src_key][i], offset) for src_key in self.gdal_ds}
 
         return res
 
@@ -282,7 +289,7 @@ class PatchesImagesReader(PatchesReaderBase):
             axis = (0, 1)  # (row, col)
 
             def _filled(value):
-                return {src_key: value * np.ones((self.nb_of_channels[src_key])) for src_key in self.ds}
+                return {src_key: value * np.ones((self.nb_of_channels[src_key])) for src_key in self.gdal_ds}
 
             _maxs = _filled(0.0)
             _mins = _filled(float("inf"))
@@ -302,7 +309,7 @@ class PatchesImagesReader(PatchesReaderBase):
                                "max": _maxs[src_key],
                                "mean": rsize * _sums[src_key],
                                "std": np.sqrt(rsize * _sqsums[src_key] - np.square(rsize * _sums[src_key]))
-                               } for src_key in self.ds}
+                               } for src_key in self.gdal_ds}
         logging.info("Stats: {}".format(stats))
         return stats
 
@@ -397,6 +404,8 @@ class Dataset:
         logging.info("output_shapes: {}".format(self.output_shapes))
 
         # buffers
+        if self.size <= buffer_length:
+            buffer_length = self.size
         self.miner_buffer = Buffer(buffer_length)
         self.mining_lock = multiprocessing.Lock()
         self.consumer_buffer = Buffer(buffer_length)
@@ -438,12 +447,12 @@ class Dataset:
         This function dumps the miner_buffer into the consumer_buffer, and restart the miner_thread
         """
         # Wait for miner to finish his job
-        t = time.time()
+        date_t = time.time()
         self.miner_thread.join()
-        self.tot_wait += time.time() - t
+        self.tot_wait += time.time() - date_t
 
         # Copy miner_buffer.container --> consumer_buffer.container
-        self.consumer_buffer.container = [elem for elem in self.miner_buffer.container]
+        self.consumer_buffer.container = self.miner_buffer.container.copy()
 
         # Clear miner_buffer.container
         self.miner_buffer.container.clear()
@@ -470,15 +479,15 @@ class Dataset:
         """
         Create and starts the thread for the data collect
         """
-        t = threading.Thread(target=self._collect)
-        t.start()
-        return t
+        new_thread = threading.Thread(target=self._collect)
+        new_thread.start()
+        return new_thread
 
     def _generator(self):
         """
         Generator function, used for the tf dataset
         """
-        for elem in range(self.size):
+        for _ in range(self.size):
             yield self.read_one_sample()
 
     def get_tf_dataset(self, batch_size, drop_remainder=True):
diff --git a/python/tricks.py b/python/tricks.py
index 33e98b4fe6faa69d1b42b8feca1d8133be4afc76..b8f85406afe570a33a86beb073a80ac0125b8f2f 100644
--- a/python/tricks.py
+++ b/python/tricks.py
@@ -23,9 +23,9 @@ and TensorFlow models.
 Starting from OTBTF >= 3.0.0, tricks is only used as a backward compatible stub
 for TF 1.X versions.
 """
-from otbtf import gdal_open, read_as_np_arr as read_as_np_arr_from_gdal_ds
 import tensorflow.compat.v1 as tf
 from deprecated import deprecated
+from otbtf import gdal_open, read_as_np_arr as read_as_np_arr_from_gdal_ds
 tf.disable_v2_behavior()
 
 
@@ -90,26 +90,6 @@ def read_samples(filename):
     return read_image_as_np(filename, as_patches=True)
 
 
-@deprecated(version="3.0.0", reason="Please consider using TensorFlow >= 2 to build and save your nets")
-def CreateSavedModel(sess, inputs, outputs, directory):
-    """
-    Create a SavedModel from TF 1.X graphs
-    :param sess: The Tensorflow V1 session
-    :param inputs: List of inputs names (e.g. ["x_cnn_1:0", "x_cnn_2:0"])
-    :param outputs: List of outputs names (e.g. ["prediction:0", "features:0"])
-    :param directory: Path for the generated SavedModel
-    """
-    create_savedmodel(sess, inputs, outputs, directory)
-
-
-@deprecated(version="3.0.0", reason="Please consider using TensorFlow >= 2 to build and save your nets")
-def CheckpointToSavedModel(ckpt_path, inputs, outputs, savedmodel_path, clear_devices=False):
-    """
-    Read a Checkpoint and build a SavedModel for TF 1.X graphs
-    :param ckpt_path: Path to the checkpoint file (without the ".meta" extension)
-    :param inputs: List of inputs names (e.g. ["x_cnn_1:0", "x_cnn_2:0"])
-    :param outputs: List of outputs names (e.g. ["prediction:0", "features:0"])
-    :param savedmodel_path: Path for the generated SavedModel
-    :param clear_devices: Clear TensorFlow devices positioning (True/False)
-    """
-    ckpt_to_savedmodel(ckpt_path, inputs, outputs, savedmodel_path, clear_devices)
+# Aliases for backward compatibility
+CreateSavedModel = create_savedmodel
+CheckpointToSavedModel = ckpt_to_savedmodel
diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt
index f255dcb9d26a7a5b5988c62a936701e13bb720d8..f7716a138fc79aa34c9734e8b5d3a44cfce35b39 100644
--- a/test/CMakeLists.txt
+++ b/test/CMakeLists.txt
@@ -1,5 +1,24 @@
 otb_module_test()
 
+# Unit tests
+set(${otb-module}Tests 
+  otbTensorflowTests.cxx
+  otbTensorflowCopyUtilsTests.cxx)
+
+add_executable(otbTensorflowTests ${${otb-module}Tests})
+
+target_include_directories(otbTensorflowTests PRIVATE ${tensorflow_include_dir})
+target_link_libraries(otbTensorflowTests ${${otb-module}-Test_LIBRARIES} ${TENSORFLOW_CC_LIB} ${TENSORFLOW_FRAMEWORK_LIB})
+otb_module_target_label(otbTensorflowTests)
+
+# CopyUtilsTests
+otb_add_test(NAME floatValueToTensorTest COMMAND otbTensorflowTests floatValueToTensorTest)
+otb_add_test(NAME intValueToTensorTest COMMAND otbTensorflowTests intValueToTensorTest)
+otb_add_test(NAME boolValueToTensorTest COMMAND otbTensorflowTests boolValueToTensorTest)
+otb_add_test(NAME floatVecValueToTensorTest COMMAND otbTensorflowTests floatVecValueToTensorTest)
+otb_add_test(NAME intVecValueToTensorTest COMMAND otbTensorflowTests intVecValueToTensorTest)
+otb_add_test(NAME boolVecValueToTensorTest COMMAND otbTensorflowTests boolVecValueToTensorTest)
+
 # Directories
 set(DATADIR ${CMAKE_CURRENT_SOURCE_DIR}/data)
 set(MODELSDIR ${CMAKE_CURRENT_SOURCE_DIR}/models)
diff --git a/test/otbTensorflowCopyUtilsTests.cxx b/test/otbTensorflowCopyUtilsTests.cxx
new file mode 100644
index 0000000000000000000000000000000000000000..b4a72973d7d9f6e41b66a1675440bb82632eeb81
--- /dev/null
+++ b/test/otbTensorflowCopyUtilsTests.cxx
@@ -0,0 +1,108 @@
+/*=========================================================================
+
+     Copyright (c) 2018-2019 IRSTEA
+     Copyright (c) 2020-2020 INRAE
+
+
+     This software is distributed WITHOUT ANY WARRANTY; without even
+     the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
+     PURPOSE.  See the above copyright notices for more information.
+
+=========================================================================*/
+
+#include "otbTensorflowCopyUtils.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "itkMacro.h"
+
+template<typename T>
+int compare(tensorflow::Tensor & t1, tensorflow::Tensor & t2)
+{
+  std::cout << "Compare " << t1.DebugString() << " and " << t2.DebugString() << std::endl;
+  if (t1.dims() != t2.dims())
+    {
+    std::cout << "dims() differ!" << std::endl;
+    return EXIT_FAILURE;
+    }
+  if (t1.dtype() != t2.dtype())
+    {
+    std::cout << "dtype() differ!" << std::endl;
+    return EXIT_FAILURE;
+    }
+  if (t1.NumElements() != t2.NumElements())
+    {
+    std::cout << "NumElements() differ!" << std::endl;
+    return EXIT_FAILURE;
+    }
+  for (unsigned int i = 0; i < t1.NumElements(); i++)
+    if (t1.scalar<T>()(i) != t2.scalar<T>()(i))
+      {
+      std::cout << "scalar " << i << " differ!" << std::endl;
+      return EXIT_FAILURE;
+      }
+  // Else
+  std::cout << "Tensors are equals :)" << std::endl;
+  return EXIT_SUCCESS;
+}
+
+template<typename T>
+int genericValueToTensorTest(tensorflow::DataType dt, std::string expr, T value)
+{
+  tensorflow::Tensor t = otb::tf::ValueToTensor(expr);
+  tensorflow::Tensor t_ref(dt, tensorflow::TensorShape({}));
+  t_ref.scalar<T>()() = value;
+
+  return compare<T>(t, t_ref);
+}
+
+int floatValueToTensorTest(int itkNotUsed(argc), char * itkNotUsed(argv)[])
+{
+  return genericValueToTensorTest<float>(tensorflow::DT_FLOAT, "0.1234", 0.1234)
+      && genericValueToTensorTest<float>(tensorflow::DT_FLOAT, "-0.1234", -0.1234) ;
+}
+
+int intValueToTensorTest(int itkNotUsed(argc), char * itkNotUsed(argv)[])
+{
+  return genericValueToTensorTest<int>(tensorflow::DT_INT32, "1234", 1234)
+      && genericValueToTensorTest<int>(tensorflow::DT_INT32, "-1234", -1234);
+}
+
+int boolValueToTensorTest(int itkNotUsed(argc), char * itkNotUsed(argv)[])
+{
+  return genericValueToTensorTest<bool>(tensorflow::DT_BOOL, "true", true)
+      && genericValueToTensorTest<bool>(tensorflow::DT_BOOL, "True", true)
+      && genericValueToTensorTest<bool>(tensorflow::DT_BOOL, "False", false)
+      && genericValueToTensorTest<bool>(tensorflow::DT_BOOL, "false", false);
+}
+
+template<typename T>
+int genericVecValueToTensorTest(tensorflow::DataType dt, std::string expr, std::vector<T> values)
+{
+  tensorflow::Tensor t = otb::tf::ValueToTensor(expr);
+  tensorflow::Tensor t_ref(dt, tensorflow::TensorShape({}));
+  unsigned int i = 0;
+  for (auto value: values)
+    {
+    t_ref.scalar<T>()(i) = value;
+    i++;
+    }
+
+  return compare<T>(t, t_ref);
+}
+
+int floatVecValueToTensorTest(int itkNotUsed(argc), char * itkNotUsed(argv)[])
+{
+  return genericVecValueToTensorTest<float>(tensorflow::DT_FLOAT, "(0.1234, -1,-20,2.56 ,3.5)", std::vector<float>({0.1234, -1, -20, 2.56 ,3.5}));
+}
+
+int intVecValueToTensorTest(int itkNotUsed(argc), char * itkNotUsed(argv)[])
+{
+  return genericVecValueToTensorTest<int>(tensorflow::DT_INT32, "(1234, -1,-20,256 ,35)", std::vector<int>({1234, -1, -20, 256 ,35}));
+}
+
+int boolVecValueToTensorTest(int itkNotUsed(argc), char * itkNotUsed(argv)[])
+{
+  return genericVecValueToTensorTest<bool>(tensorflow::DT_BOOL, "(true, false,True, False)", std::vector<bool>({true, false, true, false}));
+}
+
+
diff --git a/test/otbTensorflowTests.cxx b/test/otbTensorflowTests.cxx
new file mode 100644
index 0000000000000000000000000000000000000000..135e047434951982ab31768121e23df406e2ed66
--- /dev/null
+++ b/test/otbTensorflowTests.cxx
@@ -0,0 +1,25 @@
+/*=========================================================================
+
+     Copyright (c) 2018-2019 IRSTEA
+     Copyright (c) 2020-2020 INRAE
+
+
+     This software is distributed WITHOUT ANY WARRANTY; without even
+     the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
+     PURPOSE.  See the above copyright notices for more information.
+
+=========================================================================*/
+
+#include "otbTestMain.h"
+
+void RegisterTests()
+{
+  REGISTER_TEST(floatValueToTensorTest);
+  REGISTER_TEST(intValueToTensorTest);
+  REGISTER_TEST(boolValueToTensorTest);
+  REGISTER_TEST(floatVecValueToTensorTest);
+  REGISTER_TEST(intVecValueToTensorTest);
+  REGISTER_TEST(boolVecValueToTensorTest);
+}
+
+
diff --git a/test/sr4rs_unittest.py b/test/sr4rs_unittest.py
index 311e7a7094c0003af90c51fc5da2716a716bcc3d..fbb921f8451cc83b3fd7b9e9e90bf61755511eca 100644
--- a/test/sr4rs_unittest.py
+++ b/test/sr4rs_unittest.py
@@ -3,12 +3,44 @@
 
 import unittest
 import os
-
+from pathlib import Path
 import gdal
 import otbApplication as otb
 
 
-class SR4RSTest(unittest.TestCase):
+def command_train_succeed(extra_opts=""):
+    root_dir = os.environ["CI_PROJECT_DIR"]
+    ckpt_dir = "/tmp/"
+
+    def _input(file_name):
+        return "{}/sr4rs_data/input/{}".format(root_dir, file_name)
+
+    command = "python {}/sr4rs/code/train.py ".format(root_dir)
+    command += "--lr_patches "
+    command += _input("DIM_SPOT6_MS_202007290959110_ORT_ORTHO-MS-193_posA_s2.jp2 ")
+    command += _input("DIM_SPOT7_MS_202004111036186_ORT_ORTHO-MS-081_posA_s2.jp2 ")
+    command += _input("DIM_SPOT7_MS_202006201000507_ORT_ORTHO-MS-054_posA_s2.jp2 ")
+    command += "--hr_patches "
+    command += _input("DIM_SPOT6_MS_202007290959110_ORT_ORTHO-MS-193_posA_s6_cal.jp2 ")
+    command += _input("DIM_SPOT7_MS_202004111036186_ORT_ORTHO-MS-081_posA_s6_cal.jp2 ")
+    command += _input("DIM_SPOT7_MS_202006201000507_ORT_ORTHO-MS-054_posA_s6_cal.jp2 ")
+    command += "--save_ckpt {} ".format(ckpt_dir)
+    command += "--depth 4 "
+    command += "--nresblocks 1 "
+    command += "--epochs 1 "
+    command += extra_opts
+    os.system(command)
+    file = Path("{}/checkpoint".format(ckpt_dir))
+    return file.is_file()
+
+
+class SR4RSv1Test(unittest.TestCase):
+
+    def test_train_nostream(self):
+        self.assertTrue(command_train_succeed())
+
+    def test_train_stream(self):
+        self.assertTrue(command_train_succeed(extra_opts="--streaming"))
 
     def test_inference(self):
         root_dir = os.environ["CI_PROJECT_DIR"]
@@ -27,7 +59,7 @@ class SR4RSTest(unittest.TestCase):
 
         self.assertTrue(nbchannels_reconstruct == nbchannels_baseline)
 
-        for i in range(1, 1+nbchannels_baseline):
+        for i in range(1, 1 + nbchannels_baseline):
             comp = otb.Registry.CreateApplication('CompareImages')
             comp.SetParameterString('ref.in', baseline)
             comp.SetParameterInt('ref.channel', i)
@@ -41,4 +73,3 @@ class SR4RSTest(unittest.TestCase):
 
 if __name__ == '__main__':
     unittest.main()
-