diff --git a/.clang-format b/.clang-format index bf8ff45ff1f436bb1bb358bded3ebbf5da219bbd..0b9a40282f8a247085847be735807513844c0301 100644 --- a/.clang-format +++ b/.clang-format @@ -1,30 +1,134 @@ ---- +BasedOnStyle: Mozilla +Language: Cpp +AccessModifierOffset: -2 AlignAfterOpenBracket: Align -AlignConsecutiveAssignments: true +AlignConsecutiveAssignments: false AlignConsecutiveDeclarations: true -AlignEscapedNewlinesLeft: true +AlignEscapedNewlines: Left AlignOperands: true AlignTrailingComments: true +# clang 9.0 AllowAllArgumentsOnNextLine: true +# clang 9.0 AllowAllConstructorInitializersOnNextLine: true +AllowAllParametersOfDeclarationOnNextLine: false AllowShortBlocksOnASingleLine: false AllowShortCaseLabelsOnASingleLine: false -AllowShortFunctionsOnASingleLine: false +AllowShortFunctionsOnASingleLine: Inline +# clang 9.0 AllowShortLambdasOnASingleLine: All +# clang 9.0 features AllowShortIfStatementsOnASingleLine: Never AllowShortIfStatementsOnASingleLine: false AllowShortLoopsOnASingleLine: false -AlwaysBreakBeforeMultilineStrings: true -AlwaysBreakTemplateDeclarations: true -BreakBeforeBraces: Allman -ColumnLimit: 160 -ConstructorInitializerAllOnOneLineOrOnePerLine: true +AlwaysBreakAfterDefinitionReturnType: None +AlwaysBreakAfterReturnType: All +AlwaysBreakBeforeMultilineStrings: false +AlwaysBreakTemplateDeclarations: Yes +BinPackArguments: false +BinPackParameters: false +BreakBeforeBraces: Custom +BraceWrapping: + # clang 9.0 feature AfterCaseLabel: false + AfterClass: true + AfterControlStatement: true + AfterEnum: true + AfterFunction: true + AfterNamespace: true + AfterObjCDeclaration: true + AfterStruct: true + AfterUnion: true + AfterExternBlock: true + BeforeCatch: true + BeforeElse: true +## This is the big change from historical ITK formatting! +# Historically ITK used a style similar to https://en.wikipedia.org/wiki/Indentation_style#Whitesmiths_style +# with indented braces, and not indented code. This style is very difficult to automatically +# maintain with code beautification tools. Not indenting braces is more common among +# formatting tools. + IndentBraces: false + SplitEmptyFunction: false + SplitEmptyRecord: false + SplitEmptyNamespace: false +BreakBeforeBinaryOperators: None +#clang 6.0 BreakBeforeInheritanceComma: true +BreakInheritanceList: BeforeComma +BreakBeforeTernaryOperators: true +#clang 6.0 BreakConstructorInitializersBeforeComma: true +BreakConstructorInitializers: BeforeComma +BreakAfterJavaFieldAnnotations: false +BreakStringLiterals: true +## The following line allows larger lines in non-documentation code +ColumnLimit: 120 +CommentPragmas: '^ IWYU pragma:' +CompactNamespaces: false +ConstructorInitializerAllOnOneLineOrOnePerLine: false ConstructorInitializerIndentWidth: 2 -Cpp11BracedListStyle: true +ContinuationIndentWidth: 2 +Cpp11BracedListStyle: false DerivePointerAlignment: false +DisableFormat: false +ExperimentalAutoDetectBinPacking: false +FixNamespaceComments: true +ForEachMacros: + - foreach + - Q_FOREACH + - BOOST_FOREACH +IncludeBlocks: Preserve +IncludeCategories: + - Regex: '^"(llvm|llvm-c|clang|clang-c)/' + Priority: 2 + - Regex: '^(<|"(gtest|gmock|isl|json)/)' + Priority: 3 + - Regex: '.*' + Priority: 1 +IncludeIsMainRegex: '(Test)?$' +IndentCaseLabels: true +IndentPPDirectives: AfterHash IndentWidth: 2 -Language: Cpp +IndentWrappedFunctionNames: false +JavaScriptQuotes: Leave +JavaScriptWrapImports: true +KeepEmptyLinesAtTheStartOfBlocks: true +MacroBlockBegin: '' +MacroBlockEnd: '' MaxEmptyLinesToKeep: 2 NamespaceIndentation: None -PointerAlignment: Left -SortIncludes: false +ObjCBinPackProtocolList: Auto +ObjCBlockIndentWidth: 2 +ObjCSpaceAfterProperty: true +ObjCSpaceBeforeProtocolList: false +PenaltyBreakAssignment: 2 +PenaltyBreakBeforeFirstCallParameter: 19 +PenaltyBreakComment: 300 +## The following line allows larger lines in non-documentation code +PenaltyBreakFirstLessLess: 120 +PenaltyBreakString: 1000 +PenaltyBreakTemplateDeclaration: 10 +PenaltyExcessCharacter: 1000000 +PenaltyReturnTypeOnItsOwnLine: 200 +PointerAlignment: Middle +ReflowComments: true +# We may want to sort the includes as a separate pass +SortIncludes: false +# We may want to revisit this later +SortUsingDeclarations: false +SpaceAfterCStyleCast: false +# SpaceAfterLogicalNot: false +SpaceAfterTemplateKeyword: true +SpaceBeforeAssignmentOperators: true +SpaceBeforeCpp11BracedList: false +SpaceBeforeCtorInitializerColon: true +SpaceBeforeInheritanceColon: true +SpaceBeforeParens: ControlStatements +SpaceBeforeRangeBasedForLoopColon: true +SpaceInEmptyParentheses: false +SpacesBeforeTrailingComments: 1 +SpacesInAngles: false +SpacesInContainerLiterals: false +SpacesInCStyleCastParentheses: false +SpacesInParentheses: false +SpacesInSquareBrackets: false Standard: Cpp11 +StatementMacros: + - Q_UNUSED + - QT_REQUIRE_VERSION TabWidth: 2 UseTab: Never ... diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index a55d9469def828967e232a131d872c2886122808..4b9fe883d917e245f9f738a8a09b7d595dad5138 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -22,17 +22,12 @@ stages: before_script: - *update_otbtf_src - - *compile_otbtf - - sudo rm -rf $OTB_BUILD/Testing/Temporary/* # Empty testing temporary folder (old files here) - - pip3 install pytest pytest-cov build: stage: Build allow_failure: false script: - *compile_otbtf - - sudo rm -rf $OTB_BUILD/Testing/Temporary/* # Empty testing temporary folder (old files here) - flake8: stage: Static Analysis @@ -58,6 +53,8 @@ cppcheck: ctest: stage: Test script: + - *compile_otbtf + - sudo rm -rf $OTB_BUILD/Testing/Temporary/* # Empty testing temporary folder (old files here) - cd $OTB_BUILD/ && sudo ctest -L OTBTensorflow # Run ctest after_script: - cp -r $OTB_BUILD/Testing/Temporary $CI_PROJECT_DIR/testing # Copy artifacts (they must be in $CI_PROJECT_DIR) @@ -70,6 +67,8 @@ ctest: sr4rs: stage: Test script: + - *compile_otbtf + - pip3 install pytest pytest-cov - cd $CI_PROJECT_DIR - wget -O sr4rs_sentinel2_bands4328_france2020_savedmodel.zip https://nextcloud.inrae.fr/s/boabW9yCjdpLPGX/download/sr4rs_sentinel2_bands4328_france2020_savedmodel.zip diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index dd69f197f23a980d9fc4759acb59f40eecd412a1..cd39540e7c15ad843a6f52502ea717faf1c0e38a 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -1,8 +1,8 @@ -Remi Cresson -Nicolas Narcon -Benjamin Commandre -Vincent Delbar -Loic Lozac'h -Pratyush Das -Doctor Who -Jordi Inglada +- Remi Cresson +- Nicolas Narcon +- Benjamin Commandre +- Vincent Delbar +- Loic Lozac'h +- Pratyush Das +- Doctor Who +- Jordi Inglada diff --git a/include/otbTensorflowCopyUtils.cxx b/include/otbTensorflowCopyUtils.cxx index 3449e11b976d4c8db86b4d4206247a2da489eccf..7b3cb969cd241a98f7d1411faf072ba06a8b0577 100644 --- a/include/otbTensorflowCopyUtils.cxx +++ b/include/otbTensorflowCopyUtils.cxx @@ -11,27 +11,31 @@ =========================================================================*/ #include "otbTensorflowCopyUtils.h" -namespace otb { -namespace tf { +namespace otb +{ +namespace tf +{ // // Display a TensorShape // -std::string PrintTensorShape(const tensorflow::TensorShape & shp) +std::string +PrintTensorShape(const tensorflow::TensorShape & shp) { std::stringstream s; - unsigned int nDims = shp.dims(); + unsigned int nDims = shp.dims(); s << "{" << shp.dim_size(0); - for (unsigned int d = 1 ; d < nDims ; d++) + for (unsigned int d = 1; d < nDims; d++) s << ", " << shp.dim_size(d); - s << "}" ; + s << "}"; return s.str(); } // // Display infos about a tensor // -std::string PrintTensorInfos(const tensorflow::Tensor & tensor) +std::string +PrintTensorInfos(const tensorflow::Tensor & tensor) { std::stringstream s; s << "Tensor "; @@ -45,11 +49,12 @@ std::string PrintTensorInfos(const tensorflow::Tensor & tensor) // // Create a tensor with the good datatype // -template<class TImage> -tensorflow::Tensor CreateTensor(tensorflow::TensorShape & shape) +template <class TImage> +tensorflow::Tensor +CreateTensor(tensorflow::TensorShape & shape) { tensorflow::DataType ts_dt = GetTensorflowDataType<typename TImage::InternalPixelType>(); - tensorflow::Tensor out_tensor(ts_dt, shape); + tensorflow::Tensor out_tensor(ts_dt, shape); return out_tensor; } @@ -58,32 +63,35 @@ tensorflow::Tensor CreateTensor(tensorflow::TensorShape & shape) // Populate a tensor with the buffered region of a vector image using std::copy // Warning: tensor datatype must be consistent with the image value type // -template<class TImage> -void PopulateTensorFromBufferedVectorImage(const typename TImage::Pointer bufferedimagePtr, tensorflow::Tensor & out_tensor) +template <class TImage> +void +PopulateTensorFromBufferedVectorImage(const typename TImage::Pointer bufferedimagePtr, tensorflow::Tensor & out_tensor) { - size_t n_elem = bufferedimagePtr->GetNumberOfComponentsPerPixel() * - bufferedimagePtr->GetBufferedRegion().GetNumberOfPixels(); - std::copy_n(bufferedimagePtr->GetBufferPointer(), - n_elem, - out_tensor.flat<typename TImage::InternalPixelType>().data()); + size_t n_elem = + bufferedimagePtr->GetNumberOfComponentsPerPixel() * bufferedimagePtr->GetBufferedRegion().GetNumberOfPixels(); + std::copy_n( + bufferedimagePtr->GetBufferPointer(), n_elem, out_tensor.flat<typename TImage::InternalPixelType>().data()); } // // Recopy an VectorImage region into a 4D-shaped tensorflow::Tensor ({-1, sz_y, sz_x, sz_bands}) // -template<class TImage, class TValueType=typename TImage::InternalPixelType> -void RecopyImageRegionToTensor(const typename TImage::Pointer inputPtr, const typename TImage::RegionType & region, - tensorflow::Tensor & tensor, unsigned int elemIdx) // element position along the 1st dimension +template <class TImage, class TValueType = typename TImage::InternalPixelType> +void +RecopyImageRegionToTensor(const typename TImage::Pointer inputPtr, + const typename TImage::RegionType & region, + tensorflow::Tensor & tensor, + unsigned int elemIdx) // element position along the 1st dimension { typename itk::ImageRegionConstIterator<TImage> inIt(inputPtr, region); - unsigned int nBands = inputPtr->GetNumberOfComponentsPerPixel(); - auto tMap = tensor.tensor<TValueType, 4>(); + unsigned int nBands = inputPtr->GetNumberOfComponentsPerPixel(); + auto tMap = tensor.tensor<TValueType, 4>(); for (inIt.GoToBegin(); !inIt.IsAtEnd(); ++inIt) { const int y = inIt.GetIndex()[1] - region.GetIndex()[1]; const int x = inIt.GetIndex()[0] - region.GetIndex()[0]; - for (unsigned int band = 0 ; band < nBands ; band++) + for (unsigned int band = 0; band < nBands; band++) tMap(elemIdx, y, x, band) = inIt.Get()[band]; } } @@ -92,9 +100,12 @@ void RecopyImageRegionToTensor(const typename TImage::Pointer inputPtr, const ty // Type-agnostic version of the 'RecopyImageRegionToTensor' function // TODO: add some numeric types // -template<class TImage> -void RecopyImageRegionToTensorWithCast(const typename TImage::Pointer inputPtr, const typename TImage::RegionType & region, - tensorflow::Tensor & tensor, unsigned int elemIdx) // element position along the 1st dimension +template <class TImage> +void +RecopyImageRegionToTensorWithCast(const typename TImage::Pointer inputPtr, + const typename TImage::RegionType & region, + tensorflow::Tensor & tensor, + unsigned int elemIdx) // element position along the 1st dimension { tensorflow::DataType dt = tensor.dtype(); if (dt == tensorflow::DT_FLOAT) @@ -110,21 +121,25 @@ void RecopyImageRegionToTensorWithCast(const typename TImage::Pointer inputPtr, else if (dt == tensorflow::DT_INT32) RecopyImageRegionToTensor<TImage, int>(inputPtr, region, tensor, elemIdx); else if (dt == tensorflow::DT_UINT16) - RecopyImageRegionToTensor<TImage, unsigned short int> (inputPtr, region, tensor, elemIdx); + RecopyImageRegionToTensor<TImage, unsigned short int>(inputPtr, region, tensor, elemIdx); else if (dt == tensorflow::DT_INT16) RecopyImageRegionToTensor<TImage, short int>(inputPtr, region, tensor, elemIdx); else if (dt == tensorflow::DT_UINT8) - RecopyImageRegionToTensor<TImage, unsigned char> (inputPtr, region, tensor, elemIdx); + RecopyImageRegionToTensor<TImage, unsigned char>(inputPtr, region, tensor, elemIdx); else - itkGenericExceptionMacro("TF DataType "<< dt << " not currently implemented !"); + itkGenericExceptionMacro("TF DataType " << dt << " not currently implemented !"); } // // Sample a centered patch (from index) // -template<class TImage> -void SampleCenteredPatch(const typename TImage::Pointer inputPtr, const typename TImage::IndexType & centerIndex, const typename TImage::SizeType & patchSize, - tensorflow::Tensor & tensor, unsigned int elemIdx) +template <class TImage> +void +SampleCenteredPatch(const typename TImage::Pointer inputPtr, + const typename TImage::IndexType & centerIndex, + const typename TImage::SizeType & patchSize, + tensorflow::Tensor & tensor, + unsigned int elemIdx) { typename TImage::IndexType regionStart; regionStart[0] = centerIndex[0] - patchSize[0] / 2; @@ -136,9 +151,13 @@ void SampleCenteredPatch(const typename TImage::Pointer inputPtr, const typename // // Sample a centered patch (from coordinates) // -template<class TImage> -void SampleCenteredPatch(const typename TImage::Pointer inputPtr, const typename TImage::PointType & centerCoord, const typename TImage::SizeType & patchSize, - tensorflow::Tensor & tensor, unsigned int elemIdx) +template <class TImage> +void +SampleCenteredPatch(const typename TImage::Pointer inputPtr, + const typename TImage::PointType & centerCoord, + const typename TImage::SizeType & patchSize, + tensorflow::Tensor & tensor, + unsigned int elemIdx) { // Assuming tensor is of shape {-1, sz_y, sz_x, sz_bands} // Get the index of the center @@ -154,10 +173,11 @@ void SampleCenteredPatch(const typename TImage::Pointer inputPtr, const typename // shape {x, y, c} --> c (e.g. a patch) // shape {n, x, y, c} --> c (e.g. some patches) // -tensorflow::int64 GetNumberOfChannelsForOutputTensor(const tensorflow::Tensor & tensor) +tensorflow::int64 +GetNumberOfChannelsForOutputTensor(const tensorflow::Tensor & tensor) { const tensorflow::TensorShape shape = tensor.shape(); - const int nDims = shape.dims(); + const int nDims = shape.dims(); if (nDims == 1) return 1; return shape.dim_size(nDims - 1); @@ -172,9 +192,13 @@ tensorflow::int64 GetNumberOfChannelsForOutputTensor(const tensorflow::Tensor & // shape {n, c} --> c (e.g. a vector) // shape {x, y, c} --> c (e.g. a multichannel image) // -template<class TImage, class TValueType> -void CopyTensorToImageRegion(const tensorflow::Tensor & tensor, const typename TImage::RegionType & bufferRegion, - typename TImage::Pointer outputPtr, const typename TImage::RegionType & outputRegion, int & channelOffset) +template <class TImage, class TValueType> +void +CopyTensorToImageRegion(const tensorflow::Tensor & tensor, + const typename TImage::RegionType & bufferRegion, + typename TImage::Pointer outputPtr, + const typename TImage::RegionType & outputRegion, + int & channelOffset) { // Flatten the tensor @@ -191,15 +215,13 @@ void CopyTensorToImageRegion(const tensorflow::Tensor & tensor, const typename T const tensorflow::int64 nElmI = bufferRegion.GetNumberOfPixels() * outputDimSize_C; if (nElmI != nElmT) { - itkGenericExceptionMacro("Number of elements in the tensor is " << nElmT << - " but image outputRegion has " << nElmI << - " values to fill.\nBuffer region:\n" << bufferRegion << - "\nNumber of components: " << outputDimSize_C << - "\nTensor shape:\n " << PrintTensorShape(tensor.shape()) << - "\nPlease check the input(s) field of view (FOV), " << - "the output field of expression (FOE), and the " << - "output spacing scale if you run the model in fully " << - "convolutional mode (how many strides in your model?)"); + itkGenericExceptionMacro("Number of elements in the tensor is " + << nElmT << " but image outputRegion has " << nElmI << " values to fill.\nBuffer region:\n" + << bufferRegion << "\nNumber of components: " << outputDimSize_C << "\nTensor shape:\n " + << PrintTensorShape(tensor.shape()) << "\nPlease check the input(s) field of view (FOV), " + << "the output field of expression (FOE), and the " + << "output spacing scale if you run the model in fully " + << "convolutional mode (how many strides in your model?)"); } // Iterate over the image @@ -212,145 +234,216 @@ void CopyTensorToImageRegion(const tensorflow::Tensor & tensor, const typename T // TODO: it could be useful to change the tensor-->image mapping here. // e.g use a lambda for "pos" calculation const int pos = outputDimSize_C * (y * nCols + x); - for (unsigned int c = 0 ; c < outputDimSize_C ; c++) - outIt.Get()[channelOffset + c] = tFlat( pos + c); + for (unsigned int c = 0; c < outputDimSize_C; c++) + outIt.Get()[channelOffset + c] = tFlat(pos + c); } // Update the offset channelOffset += outputDimSize_C; - } // // Type-agnostic version of the 'CopyTensorToImageRegion' function -// TODO: add some numeric types // -template<class TImage> -void CopyTensorToImageRegion(const tensorflow::Tensor & tensor, const typename TImage::RegionType & bufferRegion, - typename TImage::Pointer outputPtr, const typename TImage::RegionType & region, int & channelOffset) +template <class TImage> +void +CopyTensorToImageRegion(const tensorflow::Tensor & tensor, + const typename TImage::RegionType & bufferRegion, + typename TImage::Pointer outputPtr, + const typename TImage::RegionType & region, + int & channelOffset) { tensorflow::DataType dt = tensor.dtype(); if (dt == tensorflow::DT_FLOAT) - CopyTensorToImageRegion<TImage, float> (tensor, bufferRegion, outputPtr, region, channelOffset); + CopyTensorToImageRegion<TImage, float>(tensor, bufferRegion, outputPtr, region, channelOffset); else if (dt == tensorflow::DT_DOUBLE) - CopyTensorToImageRegion<TImage, double> (tensor, bufferRegion, outputPtr, region, channelOffset); + CopyTensorToImageRegion<TImage, double>(tensor, bufferRegion, outputPtr, region, channelOffset); + else if (dt == tensorflow::DT_UINT64) + CopyTensorToImageRegion<TImage, unsigned long long int>(tensor, bufferRegion, outputPtr, region, channelOffset); else if (dt == tensorflow::DT_INT64) CopyTensorToImageRegion<TImage, long long int>(tensor, bufferRegion, outputPtr, region, channelOffset); + else if (dt == tensorflow::DT_UINT32) + CopyTensorToImageRegion<TImage, unsigned int>(tensor, bufferRegion, outputPtr, region, channelOffset); else if (dt == tensorflow::DT_INT32) - CopyTensorToImageRegion<TImage, int> (tensor, bufferRegion, outputPtr, region, channelOffset); + CopyTensorToImageRegion<TImage, int>(tensor, bufferRegion, outputPtr, region, channelOffset); + else if (dt == tensorflow::DT_UINT16) + CopyTensorToImageRegion<TImage, unsigned short int>(tensor, bufferRegion, outputPtr, region, channelOffset); + else if (dt == tensorflow::DT_INT16) + CopyTensorToImageRegion<TImage, short int>(tensor, bufferRegion, outputPtr, region, channelOffset); + else if (dt == tensorflow::DT_UINT8) + CopyTensorToImageRegion<TImage, unsigned char>(tensor, bufferRegion, outputPtr, region, channelOffset); else - itkGenericExceptionMacro("TF DataType "<< dt << " not currently implemented !"); - + itkGenericExceptionMacro("TF DataType " << dt << " not currently implemented !"); } // // Compare two string lowercase // -bool iequals(const std::string& a, const std::string& b) +bool +iequals(const std::string & a, const std::string & b) { - return std::equal(a.begin(), a.end(), - b.begin(), b.end(), - [](char cha, char chb) { - return tolower(cha) == tolower(chb); - }); + return std::equal( + a.begin(), a.end(), b.begin(), b.end(), [](char cha, char chb) { return tolower(cha) == tolower(chb); }); } -// Convert an expression into a dict -// +// Convert a value into a tensor // Following types are supported: // -bool // -int // -float +// -vector of float +// +// e.g. "true", "0.2", "14", "(1.2, 4.2, 4)" // -// e.g. is_training=true, droptout=0.2, nfeat=14 -std::pair<std::string, tensorflow::Tensor> ExpressionToTensor(std::string expression) +// TODO: we could add some other types (e.g. string) +tensorflow::Tensor +ValueToTensor(std::string value) { - std::pair<std::string, tensorflow::Tensor> dict; + std::vector<std::string> values; - std::size_t found = expression.find("="); - if (found != std::string::npos) - { - // Find name and value - std::string name = expression.substr(0, found); - std::string value = expression.substr(found+1); + // Check if value is a vector or a scalar + const bool has_left = (value[0] == '('); + const bool has_right = value[value.size() - 1] == ')'; - dict.first = name; + // Check consistency + bool is_vec = false; + if (has_left || has_right) + { + is_vec = true; + if (!has_left || !has_right) + itkGenericExceptionMacro("Error parsing vector expression (missing parenthese ?)" << value); + } - // Find type - std::size_t found_dot = value.find(".") != std::string::npos; - std::size_t is_digit = value.find_first_not_of("0123456789.") == std::string::npos; - if (is_digit) + // Scalar --> Vector for generic processing + if (!is_vec) + { + values.push_back(value); + } + else + { + // Remove "(" and ")" chars + std::string trimmed_value = value.substr(1, value.size() - 2); + + // Split string into vector using "," delimiter + std::regex rgx("\\s*,\\s*"); + std::sregex_token_iterator iter{ trimmed_value.begin(), trimmed_value.end(), rgx, -1 }; + std::sregex_token_iterator end; + values = std::vector<std::string>({ iter, end }); + } + + // Find type + bool has_dot = false; + bool is_digit = true; + for (auto & val : values) + { + has_dot = has_dot || val.find(".") != std::string::npos; + is_digit = is_digit && val.find_first_not_of("-0123456789.") == std::string::npos; + } + + // Create tensor + tensorflow::TensorShape shape({}); + tensorflow::Tensor out(tensorflow::DT_BOOL, shape); + if (is_digit) + { + if (has_dot) + out = tensorflow::Tensor(tensorflow::DT_FLOAT, shape); + else + out = tensorflow::Tensor(tensorflow::DT_INT32, shape); + } + + // Fill tensor + unsigned int idx = 0; + for (auto & val : values) + { + + if (is_digit) + { + if (has_dot) { - if (found_dot) + // FLOAT + try { - // FLOAT - try - { - float val = std::stof(value); - tensorflow::Tensor out(tensorflow::DT_FLOAT, tensorflow::TensorShape()); - out.scalar<float>()() = val; - dict.second = out; - - } - catch(...) - { - itkGenericExceptionMacro("Error parsing name=" - << name << " with value=" << value << " as float"); - } - + out.scalar<float>()(idx) = std::stof(val); } - else + catch (...) { - // INT - try - { - int val = std::stoi(value); - tensorflow::Tensor out(tensorflow::DT_INT32, tensorflow::TensorShape()); - out.scalar<int>()() = val; - dict.second = out; - - } - catch(...) - { - itkGenericExceptionMacro("Error parsing name=" - << name << " with value=" << value << " as int"); - } - + itkGenericExceptionMacro("Error parsing value \"" << val << "\" as float"); } } else { - // BOOL - bool val = true; - if (iequals(value, "true")) + // INT + try { - val = true; + out.scalar<int>()(idx) = std::stoi(val); } - else if (iequals(value, "false")) + catch (...) { - val = false; + itkGenericExceptionMacro("Error parsing value \"" << val << "\" as int"); } - else - { - itkGenericExceptionMacro("Error parsing name=" - << name << " with value=" << value << " as bool"); - } - tensorflow::Tensor out(tensorflow::DT_BOOL, tensorflow::TensorShape()); - out.scalar<bool>()() = val; - dict.second = out; } - } else { - itkGenericExceptionMacro("The following expression is not valid: " - << "\n\t" << expression - << ".\nExpression must be in the form int_value=1 or float_value=1.0 or bool_value=true."); + // BOOL + bool ret = true; + if (iequals(val, "true")) + { + ret = true; + } + else if (iequals(val, "false")) + { + ret = false; + } + else + { + itkGenericExceptionMacro("Error parsing value \"" << val << "\" as bool"); + } + out.scalar<bool>()(idx) = ret; } + idx++; + } - return dict; + return out; +} + +// Convert an expression into a dict +// +// Following types are supported: +// -bool +// -int +// -float +// -vector of float +// +// e.g. is_training=true, droptout=0.2, nfeat=14, x=(1.2, 4.2, 4) +std::pair<std::string, tensorflow::Tensor> +ExpressionToTensor(std::string expression) +{ + std::pair<std::string, tensorflow::Tensor> dict; + + + std::size_t found = expression.find("="); + if (found != std::string::npos) + { + // Find name and value + std::string name = expression.substr(0, found); + std::string value = expression.substr(found + 1); + + dict.first = name; + + // Transform value into tensorflow::Tensor + dict.second = ValueToTensor(value); + } + else + { + itkGenericExceptionMacro("The following expression is not valid: " + << "\n\t" << expression << ".\nExpression must be in one of the following form:" + << "\n- int32_value=1 \n- float_value=1.0 \n- bool_value=true." + << "\n- float_vec=(1.0, 5.253, 2)"); + } + return dict; } } // end namespace tf diff --git a/include/otbTensorflowCopyUtils.h b/include/otbTensorflowCopyUtils.h index f21097229d645b5a910b2e3de35dbf24222c7ac7..78a15c8cea5ecf02e718cf11c1a4a9f86ca25cc4 100644 --- a/include/otbTensorflowCopyUtils.h +++ b/include/otbTensorflowCopyUtils.h @@ -28,6 +28,7 @@ // STD #include <string> +#include <regex> namespace otb { namespace tf { @@ -75,6 +76,9 @@ void CopyTensorToImageRegion(const tensorflow::Tensor & tensor, typename TImage: template<class TImage> void CopyTensorToImageRegion(const tensorflow::Tensor & tensor, const typename TImage::RegionType & bufferRegion, typename TImage::Pointer outputPtr, const typename TImage::RegionType & outputRegion, int & channelOffset); +// Convert a value into a tensor +tensorflow::Tensor ValueToTensor(std::string value); + // Convert an expression into a dict std::pair<std::string, tensorflow::Tensor> ExpressionToTensor(std::string expression); diff --git a/python/ckpt2savedmodel.py b/python/ckpt2savedmodel.py index ec757dc77d721b2b5bd9e0cb1680fb3450672bc7..4235be8fabe280b78ea25da0fd7c6ec9a8a2e495 100755 --- a/python/ckpt2savedmodel.py +++ b/python/ckpt2savedmodel.py @@ -18,21 +18,27 @@ # limitations under the License. # # ==========================================================================*/ +""" +This application converts a checkpoint into a SavedModel, that can be used in +TensorflowModelTrain or TensorflowModelServe OTB applications. +This is intended to work mostly with tf.v1 models, since the models in tf.v2 +can be more conveniently exported as SavedModel (see how to build a model with +keras in Tensorflow 2). +""" import argparse from tricks import ckpt_to_savedmodel -# Parser -parser = argparse.ArgumentParser() -parser.add_argument("--ckpt", help="Checkpoint file (without the \".meta\" extension)", required=True) -parser.add_argument("--inputs", help="Inputs names (e.g. [\"x_cnn_1:0\", \"x_cnn_2:0\"])", required=True, nargs='+') -parser.add_argument("--outputs", help="Outputs names (e.g. [\"prediction:0\", \"features:0\"])", required=True, - nargs='+') -parser.add_argument("--model", help="Output directory for SavedModel", required=True) -parser.add_argument('--clear_devices', dest='clear_devices', action='store_true') -parser.set_defaults(clear_devices=False) -params = parser.parse_args() - if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--ckpt", help="Checkpoint file (without the \".meta\" extension)", required=True) + parser.add_argument("--inputs", help="Inputs names (e.g. [\"x_cnn_1:0\", \"x_cnn_2:0\"])", required=True, nargs='+') + parser.add_argument("--outputs", help="Outputs names (e.g. [\"prediction:0\", \"features:0\"])", required=True, + nargs='+') + parser.add_argument("--model", help="Output directory for SavedModel", required=True) + parser.add_argument('--clear_devices', dest='clear_devices', action='store_true') + parser.set_defaults(clear_devices=False) + params = parser.parse_args() + ckpt_to_savedmodel(ckpt_path=params.ckpt, inputs=params.inputs, outputs=params.outputs, diff --git a/python/otbtf.py b/python/otbtf.py index e8c4e5c9ed54e2c6bbc24f494d6bace6bc88e0a6..988a68208658232699c850ca132bb78599c52f07 100644 --- a/python/otbtf.py +++ b/python/otbtf.py @@ -24,11 +24,11 @@ OTBTF framework. import threading import multiprocessing import time +import logging +from abc import ABC, abstractmethod import numpy as np import tensorflow as tf import gdal -import logging -from abc import ABC, abstractmethod """ @@ -40,7 +40,7 @@ def gdal_open(filename): """ Open a GDAL raster :param filename: raster file - :return: a GDAL ds instance + :return: a GDAL gdal_ds instance """ gdal_ds = gdal.Open(filename) if gdal_ds is None: @@ -84,13 +84,23 @@ class Buffer: self.container = [] def size(self): + """ + Returns the buffer size + """ return len(self.container) - def add(self, x): - self.container.append(x) + def add(self, new_element): + """ + Add an element in the buffer + :param new_element: new element to add + """ + self.container.append(new_element) assert self.size() <= self.max_length def is_complete(self): + """ + Return True if the buffer is at full capacity + """ return self.size() == self.max_length @@ -180,64 +190,61 @@ class PatchesImagesReader(PatchesReaderBase): assert len(filenames_dict.values()) > 0 - # ds dict - self.ds = dict() - for src_key, src_filenames in filenames_dict.items(): - self.ds[src_key] = [] - for src_filename in src_filenames: - self.ds[src_key].append(gdal_open(src_filename)) + # gdal_ds dict + self.gdal_ds = {key: [gdal_open(src_fn) for src_fn in src_fns] for key, src_fns in filenames_dict.items()} - if len(set([len(ds_list) for ds_list in self.ds.values()])) != 1: + # check number of patches in each sources + if len({len(ds_list) for ds_list in self.gdal_ds.values()}) != 1: raise Exception("Each source must have the same number of patches images") # streaming on/off self.use_streaming = use_streaming - # ds check - nb_of_patches = {key: 0 for key in self.ds} + # gdal_ds check + nb_of_patches = {key: 0 for key in self.gdal_ds} self.nb_of_channels = dict() - for src_key, ds_list in self.ds.items(): - for ds in ds_list: - nb_of_patches[src_key] += self._get_nb_of_patches(ds) + for src_key, ds_list in self.gdal_ds.items(): + for gdal_ds in ds_list: + nb_of_patches[src_key] += self._get_nb_of_patches(gdal_ds) if src_key not in self.nb_of_channels: - self.nb_of_channels[src_key] = ds.RasterCount + self.nb_of_channels[src_key] = gdal_ds.RasterCount else: - if self.nb_of_channels[src_key] != ds.RasterCount: + if self.nb_of_channels[src_key] != gdal_ds.RasterCount: raise Exception("All patches images from one source must have the same number of channels!" "Error happened for source: {}".format(src_key)) if len(set(nb_of_patches.values())) != 1: raise Exception("Sources must have the same number of patches! Number of patches: {}".format(nb_of_patches)) - # ds sizes - src_key_0 = list(self.ds)[0] # first key - self.ds_sizes = [self._get_nb_of_patches(ds) for ds in self.ds[src_key_0]] + # gdal_ds sizes + src_key_0 = list(self.gdal_ds)[0] # first key + self.ds_sizes = [self._get_nb_of_patches(ds) for ds in self.gdal_ds[src_key_0]] self.size = sum(self.ds_sizes) # if use_streaming is False, we store in memory all patches images if not self.use_streaming: - patches_list = {src_key: [read_as_np_arr(ds) for ds in self.ds[src_key]] for src_key in self.ds} - self.patches_buffer = {src_key: np.concatenate(patches_list[src_key], axis=-1) for src_key in self.ds} + patches_list = {src_key: [read_as_np_arr(ds) for ds in self.gdal_ds[src_key]] for src_key in self.gdal_ds} + self.patches_buffer = {src_key: np.concatenate(patches_list[src_key], axis=0) for src_key in self.gdal_ds} def _get_ds_and_offset_from_index(self, index): offset = index - for index, ds_size in enumerate(self.ds_sizes): + for idx, ds_size in enumerate(self.ds_sizes): if offset < ds_size: break offset -= ds_size - return index, offset + return idx, offset @staticmethod - def _get_nb_of_patches(ds): - return int(ds.RasterYSize / ds.RasterXSize) + def _get_nb_of_patches(gdal_ds): + return int(gdal_ds.RasterYSize / gdal_ds.RasterXSize) @staticmethod - def _read_extract_as_np_arr(ds, offset): - assert ds is not None - psz = ds.RasterXSize + def _read_extract_as_np_arr(gdal_ds, offset): + assert gdal_ds is not None + psz = gdal_ds.RasterXSize yoff = int(offset * psz) - assert yoff + psz <= ds.RasterYSize - buffer = ds.ReadAsArray(0, yoff, psz, psz) + assert yoff + psz <= gdal_ds.RasterYSize + buffer = gdal_ds.ReadAsArray(0, yoff, psz, psz) if len(buffer.shape) == 3: buffer = np.transpose(buffer, axes=(1, 2, 0)) return np.float32(buffer) @@ -252,14 +259,14 @@ class PatchesImagesReader(PatchesReaderBase): ... "src_key_M": np.array((psz_y_M, psz_x_M, nb_ch_M))} """ - assert 0 <= index + assert index >= 0 assert index < self.size if not self.use_streaming: - res = {src_key: self.patches_buffer[src_key][index, :, :, :] for src_key in self.ds} + res = {src_key: self.patches_buffer[src_key][index, :, :, :] for src_key in self.gdal_ds} else: i, offset = self._get_ds_and_offset_from_index(index) - res = {src_key: self._read_extract_as_np_arr(self.ds[src_key][i], offset) for src_key in self.ds} + res = {src_key: self._read_extract_as_np_arr(self.gdal_ds[src_key][i], offset) for src_key in self.gdal_ds} return res @@ -282,7 +289,7 @@ class PatchesImagesReader(PatchesReaderBase): axis = (0, 1) # (row, col) def _filled(value): - return {src_key: value * np.ones((self.nb_of_channels[src_key])) for src_key in self.ds} + return {src_key: value * np.ones((self.nb_of_channels[src_key])) for src_key in self.gdal_ds} _maxs = _filled(0.0) _mins = _filled(float("inf")) @@ -302,7 +309,7 @@ class PatchesImagesReader(PatchesReaderBase): "max": _maxs[src_key], "mean": rsize * _sums[src_key], "std": np.sqrt(rsize * _sqsums[src_key] - np.square(rsize * _sums[src_key])) - } for src_key in self.ds} + } for src_key in self.gdal_ds} logging.info("Stats: {}".format(stats)) return stats @@ -397,6 +404,8 @@ class Dataset: logging.info("output_shapes: {}".format(self.output_shapes)) # buffers + if self.size <= buffer_length: + buffer_length = self.size self.miner_buffer = Buffer(buffer_length) self.mining_lock = multiprocessing.Lock() self.consumer_buffer = Buffer(buffer_length) @@ -438,12 +447,12 @@ class Dataset: This function dumps the miner_buffer into the consumer_buffer, and restart the miner_thread """ # Wait for miner to finish his job - t = time.time() + date_t = time.time() self.miner_thread.join() - self.tot_wait += time.time() - t + self.tot_wait += time.time() - date_t # Copy miner_buffer.container --> consumer_buffer.container - self.consumer_buffer.container = [elem for elem in self.miner_buffer.container] + self.consumer_buffer.container = self.miner_buffer.container.copy() # Clear miner_buffer.container self.miner_buffer.container.clear() @@ -470,15 +479,15 @@ class Dataset: """ Create and starts the thread for the data collect """ - t = threading.Thread(target=self._collect) - t.start() - return t + new_thread = threading.Thread(target=self._collect) + new_thread.start() + return new_thread def _generator(self): """ Generator function, used for the tf dataset """ - for elem in range(self.size): + for _ in range(self.size): yield self.read_one_sample() def get_tf_dataset(self, batch_size, drop_remainder=True): diff --git a/python/tricks.py b/python/tricks.py index 33e98b4fe6faa69d1b42b8feca1d8133be4afc76..b8f85406afe570a33a86beb073a80ac0125b8f2f 100644 --- a/python/tricks.py +++ b/python/tricks.py @@ -23,9 +23,9 @@ and TensorFlow models. Starting from OTBTF >= 3.0.0, tricks is only used as a backward compatible stub for TF 1.X versions. """ -from otbtf import gdal_open, read_as_np_arr as read_as_np_arr_from_gdal_ds import tensorflow.compat.v1 as tf from deprecated import deprecated +from otbtf import gdal_open, read_as_np_arr as read_as_np_arr_from_gdal_ds tf.disable_v2_behavior() @@ -90,26 +90,6 @@ def read_samples(filename): return read_image_as_np(filename, as_patches=True) -@deprecated(version="3.0.0", reason="Please consider using TensorFlow >= 2 to build and save your nets") -def CreateSavedModel(sess, inputs, outputs, directory): - """ - Create a SavedModel from TF 1.X graphs - :param sess: The Tensorflow V1 session - :param inputs: List of inputs names (e.g. ["x_cnn_1:0", "x_cnn_2:0"]) - :param outputs: List of outputs names (e.g. ["prediction:0", "features:0"]) - :param directory: Path for the generated SavedModel - """ - create_savedmodel(sess, inputs, outputs, directory) - - -@deprecated(version="3.0.0", reason="Please consider using TensorFlow >= 2 to build and save your nets") -def CheckpointToSavedModel(ckpt_path, inputs, outputs, savedmodel_path, clear_devices=False): - """ - Read a Checkpoint and build a SavedModel for TF 1.X graphs - :param ckpt_path: Path to the checkpoint file (without the ".meta" extension) - :param inputs: List of inputs names (e.g. ["x_cnn_1:0", "x_cnn_2:0"]) - :param outputs: List of outputs names (e.g. ["prediction:0", "features:0"]) - :param savedmodel_path: Path for the generated SavedModel - :param clear_devices: Clear TensorFlow devices positioning (True/False) - """ - ckpt_to_savedmodel(ckpt_path, inputs, outputs, savedmodel_path, clear_devices) +# Aliases for backward compatibility +CreateSavedModel = create_savedmodel +CheckpointToSavedModel = ckpt_to_savedmodel diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index f255dcb9d26a7a5b5988c62a936701e13bb720d8..f7716a138fc79aa34c9734e8b5d3a44cfce35b39 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -1,5 +1,24 @@ otb_module_test() +# Unit tests +set(${otb-module}Tests + otbTensorflowTests.cxx + otbTensorflowCopyUtilsTests.cxx) + +add_executable(otbTensorflowTests ${${otb-module}Tests}) + +target_include_directories(otbTensorflowTests PRIVATE ${tensorflow_include_dir}) +target_link_libraries(otbTensorflowTests ${${otb-module}-Test_LIBRARIES} ${TENSORFLOW_CC_LIB} ${TENSORFLOW_FRAMEWORK_LIB}) +otb_module_target_label(otbTensorflowTests) + +# CopyUtilsTests +otb_add_test(NAME floatValueToTensorTest COMMAND otbTensorflowTests floatValueToTensorTest) +otb_add_test(NAME intValueToTensorTest COMMAND otbTensorflowTests intValueToTensorTest) +otb_add_test(NAME boolValueToTensorTest COMMAND otbTensorflowTests boolValueToTensorTest) +otb_add_test(NAME floatVecValueToTensorTest COMMAND otbTensorflowTests floatVecValueToTensorTest) +otb_add_test(NAME intVecValueToTensorTest COMMAND otbTensorflowTests intVecValueToTensorTest) +otb_add_test(NAME boolVecValueToTensorTest COMMAND otbTensorflowTests boolVecValueToTensorTest) + # Directories set(DATADIR ${CMAKE_CURRENT_SOURCE_DIR}/data) set(MODELSDIR ${CMAKE_CURRENT_SOURCE_DIR}/models) diff --git a/test/otbTensorflowCopyUtilsTests.cxx b/test/otbTensorflowCopyUtilsTests.cxx new file mode 100644 index 0000000000000000000000000000000000000000..b4a72973d7d9f6e41b66a1675440bb82632eeb81 --- /dev/null +++ b/test/otbTensorflowCopyUtilsTests.cxx @@ -0,0 +1,108 @@ +/*========================================================================= + + Copyright (c) 2018-2019 IRSTEA + Copyright (c) 2020-2020 INRAE + + + This software is distributed WITHOUT ANY WARRANTY; without even + the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR + PURPOSE. See the above copyright notices for more information. + +=========================================================================*/ + +#include "otbTensorflowCopyUtils.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "itkMacro.h" + +template<typename T> +int compare(tensorflow::Tensor & t1, tensorflow::Tensor & t2) +{ + std::cout << "Compare " << t1.DebugString() << " and " << t2.DebugString() << std::endl; + if (t1.dims() != t2.dims()) + { + std::cout << "dims() differ!" << std::endl; + return EXIT_FAILURE; + } + if (t1.dtype() != t2.dtype()) + { + std::cout << "dtype() differ!" << std::endl; + return EXIT_FAILURE; + } + if (t1.NumElements() != t2.NumElements()) + { + std::cout << "NumElements() differ!" << std::endl; + return EXIT_FAILURE; + } + for (unsigned int i = 0; i < t1.NumElements(); i++) + if (t1.scalar<T>()(i) != t2.scalar<T>()(i)) + { + std::cout << "scalar " << i << " differ!" << std::endl; + return EXIT_FAILURE; + } + // Else + std::cout << "Tensors are equals :)" << std::endl; + return EXIT_SUCCESS; +} + +template<typename T> +int genericValueToTensorTest(tensorflow::DataType dt, std::string expr, T value) +{ + tensorflow::Tensor t = otb::tf::ValueToTensor(expr); + tensorflow::Tensor t_ref(dt, tensorflow::TensorShape({})); + t_ref.scalar<T>()() = value; + + return compare<T>(t, t_ref); +} + +int floatValueToTensorTest(int itkNotUsed(argc), char * itkNotUsed(argv)[]) +{ + return genericValueToTensorTest<float>(tensorflow::DT_FLOAT, "0.1234", 0.1234) + && genericValueToTensorTest<float>(tensorflow::DT_FLOAT, "-0.1234", -0.1234) ; +} + +int intValueToTensorTest(int itkNotUsed(argc), char * itkNotUsed(argv)[]) +{ + return genericValueToTensorTest<int>(tensorflow::DT_INT32, "1234", 1234) + && genericValueToTensorTest<int>(tensorflow::DT_INT32, "-1234", -1234); +} + +int boolValueToTensorTest(int itkNotUsed(argc), char * itkNotUsed(argv)[]) +{ + return genericValueToTensorTest<bool>(tensorflow::DT_BOOL, "true", true) + && genericValueToTensorTest<bool>(tensorflow::DT_BOOL, "True", true) + && genericValueToTensorTest<bool>(tensorflow::DT_BOOL, "False", false) + && genericValueToTensorTest<bool>(tensorflow::DT_BOOL, "false", false); +} + +template<typename T> +int genericVecValueToTensorTest(tensorflow::DataType dt, std::string expr, std::vector<T> values) +{ + tensorflow::Tensor t = otb::tf::ValueToTensor(expr); + tensorflow::Tensor t_ref(dt, tensorflow::TensorShape({})); + unsigned int i = 0; + for (auto value: values) + { + t_ref.scalar<T>()(i) = value; + i++; + } + + return compare<T>(t, t_ref); +} + +int floatVecValueToTensorTest(int itkNotUsed(argc), char * itkNotUsed(argv)[]) +{ + return genericVecValueToTensorTest<float>(tensorflow::DT_FLOAT, "(0.1234, -1,-20,2.56 ,3.5)", std::vector<float>({0.1234, -1, -20, 2.56 ,3.5})); +} + +int intVecValueToTensorTest(int itkNotUsed(argc), char * itkNotUsed(argv)[]) +{ + return genericVecValueToTensorTest<int>(tensorflow::DT_INT32, "(1234, -1,-20,256 ,35)", std::vector<int>({1234, -1, -20, 256 ,35})); +} + +int boolVecValueToTensorTest(int itkNotUsed(argc), char * itkNotUsed(argv)[]) +{ + return genericVecValueToTensorTest<bool>(tensorflow::DT_BOOL, "(true, false,True, False)", std::vector<bool>({true, false, true, false})); +} + + diff --git a/test/otbTensorflowTests.cxx b/test/otbTensorflowTests.cxx new file mode 100644 index 0000000000000000000000000000000000000000..135e047434951982ab31768121e23df406e2ed66 --- /dev/null +++ b/test/otbTensorflowTests.cxx @@ -0,0 +1,25 @@ +/*========================================================================= + + Copyright (c) 2018-2019 IRSTEA + Copyright (c) 2020-2020 INRAE + + + This software is distributed WITHOUT ANY WARRANTY; without even + the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR + PURPOSE. See the above copyright notices for more information. + +=========================================================================*/ + +#include "otbTestMain.h" + +void RegisterTests() +{ + REGISTER_TEST(floatValueToTensorTest); + REGISTER_TEST(intValueToTensorTest); + REGISTER_TEST(boolValueToTensorTest); + REGISTER_TEST(floatVecValueToTensorTest); + REGISTER_TEST(intVecValueToTensorTest); + REGISTER_TEST(boolVecValueToTensorTest); +} + + diff --git a/test/sr4rs_unittest.py b/test/sr4rs_unittest.py index 311e7a7094c0003af90c51fc5da2716a716bcc3d..fbb921f8451cc83b3fd7b9e9e90bf61755511eca 100644 --- a/test/sr4rs_unittest.py +++ b/test/sr4rs_unittest.py @@ -3,12 +3,44 @@ import unittest import os - +from pathlib import Path import gdal import otbApplication as otb -class SR4RSTest(unittest.TestCase): +def command_train_succeed(extra_opts=""): + root_dir = os.environ["CI_PROJECT_DIR"] + ckpt_dir = "/tmp/" + + def _input(file_name): + return "{}/sr4rs_data/input/{}".format(root_dir, file_name) + + command = "python {}/sr4rs/code/train.py ".format(root_dir) + command += "--lr_patches " + command += _input("DIM_SPOT6_MS_202007290959110_ORT_ORTHO-MS-193_posA_s2.jp2 ") + command += _input("DIM_SPOT7_MS_202004111036186_ORT_ORTHO-MS-081_posA_s2.jp2 ") + command += _input("DIM_SPOT7_MS_202006201000507_ORT_ORTHO-MS-054_posA_s2.jp2 ") + command += "--hr_patches " + command += _input("DIM_SPOT6_MS_202007290959110_ORT_ORTHO-MS-193_posA_s6_cal.jp2 ") + command += _input("DIM_SPOT7_MS_202004111036186_ORT_ORTHO-MS-081_posA_s6_cal.jp2 ") + command += _input("DIM_SPOT7_MS_202006201000507_ORT_ORTHO-MS-054_posA_s6_cal.jp2 ") + command += "--save_ckpt {} ".format(ckpt_dir) + command += "--depth 4 " + command += "--nresblocks 1 " + command += "--epochs 1 " + command += extra_opts + os.system(command) + file = Path("{}/checkpoint".format(ckpt_dir)) + return file.is_file() + + +class SR4RSv1Test(unittest.TestCase): + + def test_train_nostream(self): + self.assertTrue(command_train_succeed()) + + def test_train_stream(self): + self.assertTrue(command_train_succeed(extra_opts="--streaming")) def test_inference(self): root_dir = os.environ["CI_PROJECT_DIR"] @@ -27,7 +59,7 @@ class SR4RSTest(unittest.TestCase): self.assertTrue(nbchannels_reconstruct == nbchannels_baseline) - for i in range(1, 1+nbchannels_baseline): + for i in range(1, 1 + nbchannels_baseline): comp = otb.Registry.CreateApplication('CompareImages') comp.SetParameterString('ref.in', baseline) comp.SetParameterInt('ref.channel', i) @@ -41,4 +73,3 @@ class SR4RSTest(unittest.TestCase): if __name__ == '__main__': unittest.main() -